from pylab import *             # library for making plots

# cdf for the uniform(a, b) random variable
def uniformcdf(a, b):
    def cdf(x):
        if x < a:
            return 0
        elif x > b:
            return 1
        else:
            return 1.0*x/(b-a)
    return cdf

# pmf for the uniform(a, b) random variable
def uniformpmf(a, b):
    def pmf(x):
        if x < a or x > b:
            return 0
        else:
            return 1.0/(b-a)
    return pmf

# cdf for the exponential(l) random variable
def expcdf(l):
    def cdf(x):
        if x < 0:
            return 0
        else:
            return 1 - exp(-l*x)
    return cdf

# pmf for the exponential(l) random variable
def exppmf(l):
    def pmf(x):
        if x < 0:
            return 0
        else:
            return l*exp(-l*x)
    return pmf

# make a plot of the probability mass function y(x)
def pmfplot(x, y):
    bar(x, y, width=0.5, align='center')
    xticks(x, ticklabels(x, 10))
    ylim(0, 1.1 * max(y))
    show()

# make a plot of the distribution function df from a to b
def cont_plot(left, right, df):
    N = 100
    x = arange(left, right, 1.0*(right - left)/N)
    y = map(df, x)
    plot(x, y)
    xlim(min(x), max(x))
    ylim(0, 1.1*max(y))
    show()
    
# make a plot of the cumulative distribution function for a discrete random variable with pmf y(x)
def disc_cdfplot(x, y):
    eps = 0.01
    xcdf = []
    ycdf = []
    ylomarks = []
    yhimarks = []
    curcdf = 0.0
    for i in range(len(x)):
        xcdf.append(x[i] - eps)
        xcdf.append(x[i] + eps)
        ycdf.append(curcdf)
        ycdf.append(curcdf + y[i])
        ylomarks.append(curcdf)
        yhimarks.append(curcdf + y[i])
        curcdf = curcdf + y[i]
    plot(xcdf, ycdf)
    plot(x, yhimarks, 'o', mfc='blue', ms=6.0, mew=0)    
    plot(x, ylomarks, 'o', mfc='white', mec='blue', ms=5.0, mew=1.0)
    xticks(x, ticklabels(x, 10))
    show()
    
# make tick labels so there are at most l of them 
def ticklabels(seq, l):
    k = len(seq) / l
    if k == 0:
        return seq
    else:
        sp = []
        for i in range(len(seq)):
            if i % k == 0:
                sp.append(seq[i])
            else:
                sp.append('')
        return sp
