# Inference
# ---------
# This file takes data of pre/post from different
# countries and tries to jointly infer the country
# ability and the task parameters. It uses Adam 
# optimzer to chose params which minimize MSE. The
# implementation is in pytorch, but so far we don't
# use any neural networks...

# this is the bayesian program model
from firstSimulator import sampleStudent

import torch.nn as nn
from torch.optim import Adam
from torch.optim import SGD
import torch
import torch.nn.functional as F
import math
import pickle
import numpy as np
import csv

N_COUNTRIES = 329

def main():
  # 1. get some data (either simulate or load)
  data = load_NWEA_data()
  # 2. build a pytorch model
  model = DeepLearningModel()
  # 3. run optimization
  optimize(model, data)

# sometimes its more fun to load data...
def loadData():
  return pickle.load(open('data.pkl', 'rb'))

# Define the model you are optimizing over...
# Written as a pytorch model so we can optimize 
class DeepLearningModel(nn.Module):
  # the initialize method
  def __init__(self):
    super().__init__() # necessary

    # learning rates for each of the countries
    self.theta = nn.Parameter(torch.ones(N_COUNTRIES))

    # the global parameters for the function. In the writeup
    # I call these phis (but in code I give them names so I
    # dont get confused).
    self.offset_param = nn.Parameter(torch.ones(1) * 150.0)
    self.scale_param = nn.Parameter(torch.ones(1) * 1500.0)
    self.amplitude_param = nn.Parameter(torch.ones(1) * 9.0)
    self.floor_param = nn.Parameter(torch.ones(1) * -1)

  # This is the parametric-family written in pytorch code. 
  # See the writeup for details. Note how we use one-hot 
  # vectors to select the theta from a student's country.
  # The input to this function is all your data as matrices
  def forward(self, alpha, country_one_hot):
    # n x 1 vector which has the country theta selected for
    # each student in the batch. (basically chose a theta from
    # the country thetas for each student)
    ability = torch.matmul(country_one_hot,self.theta).unsqueeze(1)

    # The normal distribution exponential
    numerator = (alpha - self.offset_param) ** 2
    exponent = - numerator / self.scale_param
    exponential = torch.exp(exponent)

    # beta = alpha + theta * phi_1 * exponential
    learn = ability * self.amplitude_param * exponential
    prediction = alpha + learn

    # the result of the forward function is a beta prediction
    return prediction

# Actually perform optimization
def optimize(model, data):
  # I heart adam... learning rate is arbitrary but seems to work
  optimizer = Adam(model.parameters(), lr=0.050)

  # turn the data into matrices (students are rows)
  alpha = make_tensor(data, 'alpha')
  beta = make_tensor(data, 'beta')
  country_one_hot = make_one_hot(data, 'countryId')

  # the argmin found so far
  bestParams = None
  minLoss = float("inf")

  # gradient descend.
  while True:
    # forward pass. Make beta predictions
    pred = model(alpha, country_one_hot)

    # backwards pass. Calculate gradients
    optimizer.zero_grad()
    loss = F.mse_loss(pred, beta)
    loss.backward()
    optimizer.step() 

    # see where you are and record params if
    # loss is low. Loss doesn't always go down,
    # but if it goes too long without going down
    # you either have a bug or you are done...
    curr_params = model.theta
    if(math.isnan(loss.item())): break
    if loss.item() < minLoss:
      bestParams = curr_params
      minLoss = loss.item()

    # a little console output. How come my students
    # get such good looking graphs? Jealous!
    output(loss.item(), bestParams, curr_params)


########### Helper methods #########

# Produces an n x m matrix where n is the
# number of students and m is the number of 
# countries. Each student row is zero in each
# column except for their country, which is 1.
# I did this so I can easily select country theta
# in a way that allows for a derivative...
def make_one_hot(data, key):
  n = len(data)
  one_hot = torch.zeros(n, N_COUNTRIES)
  for i in range(len(data)):
    datum = data[i]
    index = int(datum[key])
    one_hot[i][index] = 1
  return one_hot

# output to the console after each epoc.
def output(loss, bestParams, curr_params):
  s = 'loss = {:.4f} '.format(loss)
  
  for value in bestParams:
    print(value.item())
  print('---')
  print(s)

# Extract a property of all students and retrun a matrix. 
# Ex: select out all the 'beta' parameters for a student.
# Returned matrix is n x 1 where n is num students.
def make_tensor(data, key):
  values = []
  for datum in data:
    values.append(datum[key])
  # super important. Must give n x 1 matrix, not vector
  # of length n...
  return torch.tensor(values).unsqueeze(1).float()

# Lets make some fake countries!
COUNTRIES = [
  {
    'learningRate':2.0,
    'preMu':-1,
    'preStd':2
  },
  {
    'learningRate':2.0,
    'preMu':1,
    'preStd':2
  },
  {
    'learningRate':1.0,
    'preMu':-1,
    'preStd':2
  },
  {
    'learningRate':1.0,
    'preMu':2,
    'preStd':2
  },
  {
    'learningRate':1.5,
    'preMu':0,
    'preStd':1
  },
  {
    'learningRate':1.5,
    'preMu':0,
    'preStd':3
  },
  {
    'learningRate':1.75,
    'preMu':-1,
    'preStd':1
  },
  {
    'learningRate':1.75,
    'preMu':1,
    'preStd':2
  }
]

# task1: start easy with a begginer happy learning curve.
def simulateGoodForBegginer():
  task = {
    'startDifficulty':-5,
    'zoneSigma':1,
    'beginnerWeight':1
  }
  simulations = []
  for i in range(len(COUNTRIES)):
    addSimulations(simulations, i, COUNTRIES[i], task)
  pickle.dump(simulations, open('goodForBeg.pkl', 'wb'))
  return simulations

# task2: start medium with a slightly less beginner happy
# learning curve.
def simulateGoodForMedium():
  task = {
    'startDifficulty':0,
    'zoneSigma':1,
    'beginnerWeight':0.9
  }
  simulations = []
  for i in range(len(COUNTRIES)):
    addSimulations(simulations, i, COUNTRIES[i], task)
  pickle.dump(simulations, open('goodForMedium.pkl', 'wb'))
  return simulations

# Simulate 2K students for one country on one task. Append
# results to simulations. Each datapoint is labelled with
# countryId (which we will later associate with the index
# into country theta... so it better be in the range 0 ->
# numCountries -1)..
def addSimulations(simulations, countryId, country, task):
  learnings = []
  for i in range(2000):
    alpha, beta = sampleStudent(country, task)
    simulations.append({
      'countryId':countryId, 
      'alpha':alpha, 
      'beta':beta
    })
    learnings.append(beta-alpha)
  print(np.mean(learnings))


def load_NWEA_data():
  data = []
  reader = csv.reader(open('nwea_100.csv'))
  for row in reader:
    newItem = {
      'countryId':int(row[2]), 
      'alpha':float(row[0]), 
      'beta':float(row[1])
    }
    data.append(newItem)
  return data

# runnable
if __name__ == '__main__':
  main()