from scipy.stats import bernoulli, binom
from tqdm import tqdm
import random

"""
Program: Joint Sample
---------------------
we can answer any probability question
with multivariate samples from the joint,
where conditioned variables match
"""

# note the answer with 10M samples is 0.159
N_SAMPLES = 1000000

PRIOR = 0.0001

def main():	
    obs = {
        'cousin_1':1
    }

    event = {
        'cousin_2':1
    }
    samples = sample_a_ton()
    prob = prob_event(samples, obs, event)
    print('Observation = ', obs)
    print('Pr(Event | Obs) = ', prob)


def sample_a_ton():
    """
    Sample A Ton
    --------------------
    chose N_SAMPLES with likelhood proportional
    to the joint distribution
    """
    samples = []
    for i in tqdm(range(N_SAMPLES)):
        sample = make_sample()
        samples.append(sample)
    return samples

def make_sample():
    """
    Make Sample
    -------------------
    One of the grandparents has a recessive gene
    """
    def sim_prior():
        return bernoulli.rvs(PRIOR)
    
    def sim_child(parent_1, parent_2):
        def passes_on():
            return random.choice([True, False])

        if parent_1 == 1 and passes_on():
            return 1
        
        if parent_2 == 1 and passes_on():
            return 1
        
        return 0

    grand_parent_1 = sim_prior()
    grand_parent_2 = sim_prior()

    child_1 = sim_child(grand_parent_1, grand_parent_2)
    spouse_1 = sim_prior()

    child_2 = sim_child(grand_parent_1, grand_parent_2)
    spouse_2 = sim_prior()

    grand_child_1 = sim_child(child_1, spouse_1)
    grand_child_2 = sim_child(child_2, spouse_2)

    sample = {
        "grand_parent_1": grand_parent_1,
        "grand_parent_2": grand_parent_2,
        "child_1": child_1,
        "spouse_1": spouse_1,
        "child_2": child_2,
        "spouse_2": spouse_2,
        "cousin_1": grand_child_1,
        "cousin_2": grand_child_2
    }

    # print(sample)
    return sample



def prob_event(samples, obs, event):
    """
    Calculate the probability of flu given many
    samples from the joint distribution and a set
    of ovservations to condition on.
    """
    # reject all samples which don't align 
    # with condition
    keep_samples = []
    for sample in samples:
        if check_obs_match(sample, obs):
            keep_samples.append(sample)
    
    # from remaining, simply count...
    event_count = 0
    for sample in keep_samples:
        if check_obs_match(sample, event):
            event_count += 1

    # counting can be so sweet...
    return float(event_count) / len(keep_samples)

def check_obs_match(sample, obs):
    """
    Check Observation Match
    -------------------------------
    returns true if and only if the random vars in
    the sample matches with the observed random vars
    for example:
    sample = [1, 0, 1, 1] 
    obs = [None, 0, None, None]
    check_obs_match(sample, obs) will return True
    since the only observed var (the second one) matches
    """

    # loop over all random variables
    for key in obs:
        var_obs = obs[key]
        var_sam = sample[key]

        # if this random is observed, make sure it matches 
        if var_obs != None and var_obs != var_sam:
            return False
    return True




if __name__ == '__main__':
    main()