import csv
import random
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import math

MIN_CORRELATION = -1
DEFAULT = 3
FILE_NAME = 'music.csv'
SORT_MATRIX = True

# This program finds the correlation between all the 
# categories in the given csv file. It puts the correlatios
# into a correlation matrix, where similar items are sorted
# to be close to one another.
def main():
	# load data
	categories, category_map = load_data(FILE_NAME)
	n = len(categories)

	# find all correlations
	correlation_matrix = np.zeros((n, n))
	for i in range(n):
		for j in range(n):
			x = category_map[categories[i]]
			y = category_map[categories[j]]
			corr = calculate_correlation(x, y)
			correlation_matrix[i][j] = corr

	# sort and display
	if SORT_MATRIX:
		sorted_indices, categories = sort_categories(correlation_matrix, categories)
		correlation_matrix = sort_matrix(sorted_indices, correlation_matrix)
	normalize_diagonal(correlation_matrix)
	make_figure(correlation_matrix, categories)

# Return the Pearson correlation between two different random variables (based
# on equally weighted samples). 
def calculate_correlation(x, y):
	# compute E[X*Y] – E[X]E[Y]
	# xy = [a * b for a, b in zip(x, y)]
	# return np.mean(xy) - np.mean(x)*np.mean(y)
	return stats.pearsonr(x, y)[0] 

# takes in a matrix, and a desired ordering of rows 
# and resorts the matrix such that rows are in the desired order
def sort_matrix(sorted_indices, correlation_matrix):
	n = len(sorted_indices)
	newMatrix = np.zeros((n, n))
	for i in range(n):
		for j in range(n):
			lookupI = sorted_indices[i]
			lookupJ = sorted_indices[j]
			value = correlation_matrix[lookupI][lookupJ]
			newMatrix[i][j] = value
	return newMatrix

def normalize_diagonal(correlation_matrix):
	n = correlation_matrix.shape[0]
	
	# Make the diagonal equal to the min element so that
	# it doesn't dominate the max.
	for i in range(n):
		correlation_matrix[i][i] = MIN_CORRELATION
	
	# Make the diagonal equal to the max element so that
	# it doesn't dominate the vizualization.
	for i in range(n):
		correlation_matrix[i][i] = max(correlation_matrix.flatten())
		
def sort_categories(correlation_matrix, categories):
	n = len(categories)
	# Work on a copy of the matrix
	copyOfMatrix = np.copy(correlation_matrix)
	for i in range(len(categories)):
		# Ignore the diagonal
		copyOfMatrix[i][i] = MIN_CORRELATION

	# Chose the first two to have the highest correlation
	indexMax = copyOfMatrix.argmax()
	first = int(round(indexMax / len(categories)))
	second = indexMax % len(categories)
	sorted_indices = [first,second]

	# Greedily chose the next rows based to have the max
	# correlation to the previous row.
	pre = first
	curr = second
	for i in range(n - 2):
		# remove col from consideration
		for i in range(n):
			copyOfMatrix[i][pre] = MIN_CORRELATION
		# chose the next max
		nextIndex = copyOfMatrix[curr].argmax()
		sorted_indices.append(nextIndex)
		pre = curr
		curr = nextIndex

	# Also sort the textual names
	sortedNames = []
	for v in sorted_indices:
		sortedNames.append(categories[v])
	return sorted_indices, sortedNames

def load_data(fileName):
	category_map = {}
	reader = csv.reader(open(fileName))

	# get a list of the csv headers
	headers = next(reader)
	for i in range(len(headers)):
		headers[i] = headers[i].strip()
		category_map[headers[i]] = []

	# read the rest of the file
	for row in reader:
		for i in range(len(headers)):
			# there are a small number of missing data. 
			# assume missing at random	
			if row[i] == '':
				row[i] = DEFAULT
			value = int(row[i])
			category_map[headers[i]].append(value)
	return headers, category_map
	
# Oh matplot lib. You can be so hard to work with :'(
def make_figure(data, labels):
	fig = plt.figure()
	ax = fig.add_subplot(111)
	cax = ax.matshow(data)
	fig.colorbar(cax)
	ax.set_xticklabels(['']+labels)
	ax.set_yticklabels(['']+labels)
	for tick in ax.get_xticklabels():
		tick.set_rotation(90)
	ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
	ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
	plt.show()

if __name__ == '__main__':
	main()