import csv
import operator
import math

EPSILON = 0.000001

def main():
	# Calculate all the ps and qs
	# Eg hamiltonWordProb['congress'] = 0.005
	# hamiltonWordProb['piech'] = 0.0
	# hamiltonWordProb['the'] = 0.001

	hamilton_word_prob = make_word_prob_map('hamilton.txt')
	madison_word_prob = make_word_prob_map('madison.txt')

















	# Get the word count of the unknown document
	# Eg unknown_doc_count['congress'] = 5
	unknown_doc_count, nWords = make_word_count_map('unknown.txt')

	term_hamilton = calc_prob_doc(hamilton_word_prob, unknown_doc_count)
	term_madison = calc_prob_doc(madison_word_prob, unknown_doc_count)
	print('hamilton: \t\t', term_hamilton)
	print('madison: \t\t',term_madison)
	print(term_hamilton/ term_madison)
	# print('diff:\t\t', term_hamilton - term_madison)

def calc_prob_doc(word_prob_map, count_map):
	prob = 1
	for word in count_map:
		count_i = count_map[word]
		prob_i = get_word_prob(word_prob_map, word)
		prob *= prob_i ** count_i
		print(prob)
	return prob

def get_word_prob(word_prob_map, word):
	if word in word_prob_map:
		return word_prob_map[word]
	return EPSILON

# From a file name, approximate the probability of a word
# being generated from the same distribution as the file.
# Assume that each word is produced independently, regardless
# of order.
def make_word_prob_map(fileName):
	wordMap, nWords = make_word_count_map(fileName)
	probabilityMap = {}
	for word in wordMap:
		count = wordMap[word]
		p = float(count) / nWords	
		probabilityMap[word] = p
	return probabilityMap

# From a file name, count the number of times each word exists
# in that file. Return the result as a map (aka a dictionary)
def make_word_count_map(fileName):
	wordMap = {}
	nWords = 0
	with open(fileName) as f:
		for line in f:
			words = line.split(' ')
			for word in words:
				word = standardize(word)
				add_word_to_count_map(wordMap, word)
				nWords+= 1
	return wordMap, nWords

# Add a word to a count map. Makes sure not to crash if the
# word has not been seen before.
def add_word_to_count_map(wordMap, word):
	if not word in wordMap:
		wordMap[word] = 0
	wordMap[word] += 1

# Standardizes a word. For now, we are just going to make it
# lower case.
def standardize(word):
	standard = word.lower().strip()
	# remove punctuation
	standard = ''.join([i for i in standard if i.isalpha()])
	return standard

if __name__ == '__main__':
	main()