from util import *

"""
Program: Joint Sample
---------------------
we can answer any probability question
with multivariate samples from the joint,
where conditioned variables match
"""
N_SAMPLES = 100000


def main():	
	obs = get_observation()
	samples = sample_a_ton()
	prob = prob_flu_given_obs(samples, obs)
	print('Observation = ', obs)
	print('Pr(Flu | Obs) = ', prob)

def get_observation():
	"""
	Change this observation to calculate a different 
	conditional probability.
	"""
	# None  means no observation
	# 1  means you observe and the state is true
	# 0 means you observe and the state is false
	return {
		'flu':None, 
		'undergrad':1, 
		'fever':101, 
		'tired':1
	}

def prob_flu_given_obs(samples, obs):
	"""
	Calculate the probability of flu given many
	samples from the joint distribution and a set
	of ovservations to condition on.
	"""
	# reject all samples which don't align 
	# with condition
	keep_samples = []
	for sample in samples:
		if check_obs_match(sample, obs):
			keep_samples.append(sample)
	
	# from remaining, simply count...
	flu_count = 0
	for sample in keep_samples:
		if sample['flu'] == 1:
			flu_count += 1

	# counting can be so sweet...
	return float(flu_count) / len(keep_samples)

def check_obs_match(sample, obs):
	"""
	Check Observation Match
	-------------------------------
	returns true if and only if the random vars in
	the sample matches with the observed random vars
	for example:
	sample = [1, 0, 1, 1] 
	obs = [None, 0, None, None]
	check_obs_match(sample, obs) will return True
	since the only observed var (the second one) matches
	"""

	# loop over all random variables
	for key in obs:
		var_obs = obs[key]
		var_sam = sample[key]

		# if this random is observed, make sure it matches 
		if var_obs != None and var_obs != var_sam:
			return False
	return True


def sample_a_ton():
	"""
	Sample A Ton
	--------------------
	chose N_SAMPLES with likelhood proportional
	to the joint distribution
	"""
	samples = []
	for i in range(N_SAMPLES):
		sample = make_sample()
		print(sample)
		samples.append(sample)
	return samples

def make_sample():
	"""
	Make Sample
	-------------------
	chose a single sample from the joint distribution
	based on the medical "Probabilistic Graphical Model"
	"""
	# prior on causal factors
	flu = bern(0.1)
	undergrad = bern(0.8)

	# choose fever based on flue
	# choose fever
	if flu == 1:
		fever = norm(100.0, 1.81)
	else:
		fever = norm(98.25, 0.73)

	# choose tired based on (undergrade and flu)
	if undergrad == 1 and flu == 1:   tired = bern(1.0)
	elif undergrad == 1 and flu == 0: tired = bern(0.8)
	elif undergrad == 0 and flu == 1: tired = bern(0.9)
	else:                             tired = bern(0.1)

	# a sample from the joint has an
	# assignment to *all* random variables
	return {'flu':flu, 'undergrad':undergrad, 'fever':fever, 'tired':tired}

if __name__ == '__main__':
	main()