import population
import numpy as np
import random
import util
import math
import matplotlib.pyplot as plt
from scipy import stats
from tqdm import tqdm

NBOOTSTRAPS = 10000
NSAMPLE = 20
TRUE_VARIANCE = 10

def main():
	np.random.seed(seed=1)
	print('Testing Sample Variance\n')

	# true distribution is a normal with mean = 5, var = 20

	# lets draw 200 samples t
	samples = []
	for i in range(NSAMPLE):
		nextSample = stats.norm.rvs(0, math.sqrt(TRUE_VARIANCE))
		samples.append(nextSample)
	sSquared = calcSampleVariance(samples)


	# now lets bootstrap!!!
	resampledVars = []
	for i in tqdm(range(NBOOTSTRAPS)):
		
		newSample = np.random.choice(samples, NSAMPLE, replace=True)
		newSSquared = calcSampleVariance(newSample)
		resampledVars.append(newSSquared)
	
	# what results did we get?
	print('True Var  ', TRUE_VARIANCE)
	print('What could we figure out from n={} samples?'.format(NSAMPLE))
	print('S^2       ', calcSampleVariance(samples))
	print('Std(S^2)  ', math.sqrt(calcSampleVariance(resampledVars)))
	util.plotHistogram(resampledVars)
	

# This is the definition of the variance of an
# entire population.
def calcSampleVariance(data):
	sampleMean = np.mean(data)
	n = len(data)
	total = 0
	for i in range(n):
		d = data[i]
		total += math.pow(d - sampleMean, 2)
	return float(total)/ (n-1)

# This is the definition of the variance of an
# entire population.
def calcVariance(data):
	sampleMean = np.mean(data)
	n = len(data)
	total = 0
	for i in range(n):
		d = data[i]
		total += math.pow(d - sampleMean, 2)
	return float(total)/ (n)

def printRow(row):
	util.printRow(row)

if __name__ == '__main__':
	main()
	plt.show()