from pylab import *             # library for making plots
from decimal import *           # library for exact arithmetic with large numbers

# make a plot of the probability mass function y(x)
def pmfplot(x, y):	bar(x, y, width=0.5, align='center')	xticks(x)	ylim(0, 1.1 * max(y))	show()
	
# calculate the binomial coefficients (n choose k) for k from 0 to n
def binom(n):
    B = [0]*(n + 1)
    B[0] = 1
    for k in range(1, n + 1):
        B[k] = B[k - 1] * (n - k + 1) / k
    return B

# the probability that you collect all of n coupons in d days
def coupons_prob(n, d):
    B = binom(n)
    sum = 0
    for i in range(1, n):
        sum = sum + (-1)**(i + 1) * B[i] * (n - i)**d
    return 1.0 - float(Decimal(sum) / Decimal(n**d))

# plot the probability of collecting all coupons within 1, ..., d days
def plot_coupons_prob(n, d):
    days = range(0, d + 1) 
    plot(days, map(lambda day: coupons_prob(n, day), days), 'x')
    ylim([0.0, 1.0])           # sets the boundary values of the y axis
    show()                     # displays the plot

# find the first day on which all coupons are collected with probability at least p
def day_of_coupon_collection(n, p):
    # do a binary search on d starting at d = n
    d_l = n-1
    d_r = n
    p_r = coupons_prob(n, d_r)
    while(p_r < p):
        d_l = d_r
        p_l = p_r
        d_r = 2 * d_l
        p_r = coupons_prob(n, d_r)
    while(d_r - d_l > 1):
        d_m = (d_l + d_r)/2
        if(coupons_prob(n, d_m) < p):
            d_l = d_m
        else:
            d_r = d_m
    return d_r

# plot the day on which the probability of coupon collection becomes p for up to n coupons
def plot_coupons_day(n, p):
    coupons = range(1, n + 1)
    plot(coupons, map(lambda nc: day_of_coupon_collection(nc, p), coupons), 'x')
    show()                     # displays the plot

# plot the approximating function n log(n / log(1 / (1 - p))) 
def plot_coupons_day_appx(n, p):
    coupons = range(1, n + 1)
    log1p = math.log(1.0 / (1.0 - p))
    plot(coupons, map(lambda c: c * math.log(c / log1p), coupons))
    show()                     # displays the plot

