Skip to article frontmatterSkip to article content
from __init__ import install_dependencies

await install_dependencies()
import random
from ipywidgets import interact
import matplotlib.pyplot as plt
import numpy as np
%matplotlib widget
%reload_ext divewidgets

The Monty-Hall Game

The Monty Hall Problem

Is it better to change the initial pick? What is the chance of winning if we change?

Let’s use the following program to play the game a couple of times.

import random


def play_monty_hall(num_of_doors=3):
    """Main function to run the Monty Hall game."""
    doors = {str(i) for i in range(num_of_doors)}
    door_with_treasure = random.sample(sorted(doors), 1)[0]

    # Input the initial pick of the door.
    while True:
        initial_pick = input(f'Pick a door from {", ".join(sorted(doors))}: ')
        if initial_pick in doors:
            break

    # Open all but one other door. The opened door must have nothing.
    doors_to_open = doors - {initial_pick, door_with_treasure}
    other_door = (
        door_with_treasure
        if initial_pick != door_with_treasure
        else doors_to_open.pop()
    )
    print("Door(s) with nothing behind:", *doors_to_open)

    # Allow the player to change the initial pick to the other (unopened) door.
    change_pick = (
        input(f"Would you like to change your choice to {other_door}? [y/N] ").lower()
        == "y"
    )

    # Check and record winning.
    if not change_pick:
        mh_stats["# no change"] += 1
        if door_with_treasure == initial_pick:
            mh_stats["# win without changing"] += 1
            return print("You won!")
    else:
        mh_stats["# change"] += 1
        if door_with_treasure == other_door:
            mh_stats["# win by changing"] += 1
            return print("You won!")
    print(f"You lost. The door with treasure is {door_with_treasure}.")


mh_stats = dict.fromkeys(
    ("# win by changing", "# win without changing", "# change", "# no change"), 0
)


def monty_hall_statistics():
    """Print the statistics of the Monty Hall games."""
    print("-" * 30, "Statistics", "-" * 30)
    if mh_stats["# change"]:
        print(
            f"% win by changing: \
        {mh_stats['# win by changing'] / mh_stats['# change']:.0%}"
        )
    if mh_stats["# no change"]:
        print(
            f"% win without changing: \
        {mh_stats['# win without changing']/mh_stats['# no change']:.0%}"
        )
play_monty_hall()
monty_hall_statistics()

To get a good estimate of the chance of winning, we need to play the game many times.
We can write a Monty-Carlo simulation instead.

# Do not change any variables defined here, or some tests may fail.
import numpy as np

np.random.seed(0)  # for reproducible result
num_of_games = int(10e7)
door_with_treasure = np.random.randint(1, 4, num_of_games, dtype=np.uint8)
initial_pick = np.random.randint(1, 4, num_of_games, dtype=np.uint8)

print(f"{'Door with treasure:':>19}", *door_with_treasure[:10], "...")
print(f"{'Initial pick:':>19}", *initial_pick[:10], "...")
  • door_with_treasure stores as 8-bit unsigned integers uint8 the door numbers randomly chosen from {1,2,3}\{1, 2, 3\} as the doors with treasure behind for a number num_of_games of Monty-Hall games.
  • initial_pick stores the initial choices for the different games.

If players do not change their initial pick, the chance of winning can be estimated as follows:

def estimate_chance_of_winning_without_change(door_with_treasure, initial_pick):
    """Estimate the chance of winning the Monty Hall game without changing
    the initial pick using the Monte Carlo simulation of door_with_treasure
    and initial_pick."""
    count_of_win = 0
    for x, y in zip(door_with_treasure, initial_pick):
        if x == y:
            count_of_win += 1
    return count_of_win / n


n = num_of_games // 100
estimate_chance_of_winning_without_change(door_with_treasure[:n], initial_pick[:n])

However, the above code is inefficient and takes a long time to run. You may try running it on the entire sequences of door_with_treasure and initial_pick but DO NOT put the code in your notebook, as the server refuses to auto-grade notebooks that take too much time or memory to run.

A simpler and also more efficient solution with well over 100 times speed up is as follows:

def estimate_chance_of_winning_without_change(door_with_treasure, initial_pick):
    """Estimate the chance of winning the Monty Hall game without changing
    the initial pick using the Monte Carlo simulation of door_with_treasure
    and initial_pick."""
    return (door_with_treasure == initial_pick).mean()


estimate_chance_of_winning_without_change(door_with_treasure, initial_pick)

The code uses the method mean of ndarray that computes the mean of the numpy array.
In computing the mean, True and False are regarded as 1 and 0, respectively, as illustrated below.

for i in True, False:
    for j in True, False:
        print(f"{i} + {j} == {i + j}")
def estimate_chance_of_winning_by_change(door_with_treasure, initial_pick):
    """Estimate the chance of winning the Monty Hall game by changing
    the initial pick using the Monte Carlo simulation of door_with_treasure
    and initial_pick."""
    ### BEGIN SOLUTION
    return (door_with_treasure != initial_pick).mean()
    ### END SOLUTION
Source
# tests
assert np.isclose(
    estimate_chance_of_winning_by_change(door_with_treasure[:10], initial_pick[:10]),
    0.7)

Solving non-linear equations

Suppose we want to solve:

f(x)=0f(x) = 0

for some possibly non-linear real-valued function f(x)f(x) in one real-valued variable xx. A quadratic equation with an x2x^2 term is an example. The following is another example.

f = lambda x: x * (x - 1) * (x - 2)
x = np.linspace(-0.5, 2.5)
plt.figure(1, clear=True)
plt.plot(x, f(x))
plt.axhline(color="gray", linestyle=":")
plt.xlabel(r"$x$")
plt.title(r"Plot of $f(x)\coloneq x(x-1)(x-2)$")
plt.show()

While it is clear that the above function has three roots, namely, x=0,1,2x=0, 1, 2, can we write a program to compute a root of any given continuous function ff?

Bisection Method

The following function bisection implements the bisection method as a recursion.

def bisection(f, a, b, *, n=10):
    """
    Find the root of the function f(x) within the interval [a, b] using the bisection method.

    Parameters:
        f: function
            The function whose root is to be found.
        a, b: float
            The endpoints of the initial interval [a, b].
        n: int, optional
            The maximum number of bisections to perform. Defaults to 10.

    Returns:
        list
            The interval [xstart, xstop] containing the root of f(x), or 
            an empty list if f(a) and f(b) have the same sign.
    """    
    if f(a) * f(b) > 0:
        return []  # because f(x) may not have a root between x=a and x=b
    elif n <= 0:  # base case when recursion cannot go any deeper
        return [a, b] if a <= b else [b, a]
    else:
        c = (a + b) / 2  # bisect the interval between a and b
        return bisection(f, a, c, n=n - 1) or bisection(f, c, b, n=n - 1)  # recursion


f = lambda x: x * (x - 1) * (x - 2)
bisection(f, -0.5, 0.5)

In the following widget, try changing the values of aa and bb as follows and change nn to see the change of the interval step-by-step:

  • [a,b][0.5,0.5][a, b] \approx [-0.5, 0.5]
  • [a,b][1.5,0.5][a, b] \approx [1.5, 0.5]
  • [a,b][0.1,2.5][a, b] \approx [-0.1, 2.5]
# bisection solver
def bisection_solver(f, a, b, n=10, *args, **kwargs):
    f = lambda x: x * (x - 1) * (x - 2)
    x = np.linspace(a, b)
    fig = plt.figure(*args, **kwargs)
    plt.plot(x, f(x), "b-")
    plt.xlabel(r"$x$")
    plt.title(r"Bisection on $f(x)$")
    stem_plot = interval_plot = ()
    s = (b-a)/100

    @interact(a=(a, b, s), b=(a, b, s), n=(0, n, 1))
    def plot(a=a, b=b, n=0):
        nonlocal interval_plot, stem_plot
        plt.figure(fig.number)
        stem_plot and stem_plot.remove()
        stem_plot = plt.stem([a,b],[f(a),f(b)], linefmt="g:",basefmt="g:")
        interval_plot and interval_plot[0].remove()  # clear only the previously plotted interval, if any
        interval_plot = (interval := bisection(f, a, b, n=n)) and plt.plot(interval, [0, 0], "r|-") or ()
        plt.show()
        print("Interval: ", interval)


bisection_solver(f, -0.5, 2.5, num=1, clear=True)
def bisection(f, a, b, *, tol=1e-9, n=None):
    """
    Find the root of the function f(x) within the interval [a, b] using the bisection method.

    Parameters:
        f: function
            The function whose root is to be found.
        a, b: float
            The endpoints of the interval [a, b].
        tol: float, optional
            The absolute tolerance level for the root approximation. Defaults to 1e-9.
        n: int, optional
            The maximum number of bisection to perform. Defaults to None (unlimited).

    Returns:
        list
            The interval [xstart, xstop] containing the root of f(x), or 
            an empty list if f(a) and f(b) have the same sign.
    """    
    ### BEGIN SOLUTION
    if f(a) * f(b) > 0:
        return []
    elif abs(b - a) <= tol or n is not None and n<=0:
        return [a, b] if a <= b else [b, a]
    else:
        c = (a + b) / 2
        return bisection(f, c, b, tol=tol, n=n and n-1) or bisection(f, a, c, tol=tol, n=n and n-1)
    ### END SOLUTION
Source
# tests
import numpy as np

f = lambda x: x * (x - 1) * (x - 2)
assert np.isclose(bisection(f, -0.5, 0.5), [-9.313225746154785e-10, 0.0]).all()
_ = bisection(f, 1.5, 0.5, tol=1e-2)
assert np.isclose(_, [1.0, 1.0078125]).all() or np.isclose(_, [0.9921875, 1.0]).all()
assert np.isclose(
    bisection(f, -0.1, 2.5, tol=1e-3), [1.9998046875000002, 2.0004394531250003]).all()
assert np.isclose(bisection(f, -0.1, 2.5, tol=1e-3, n=2), [1.85, 2.5]).all()