import json

from cpt import CPT
from const import Const
from inference import Inference
from node import Node

ASSN_IDS = ['1', '2', '3']

ENJOY_DOMAIN = ['one', 'two', 'three', 'four', 'five']

class Enjoy():
    def __init__(self):
        self.enjoyCpt = CPT(ENJOY_DOMAIN, [Const.GRADE_DOMAIN])
        self.gradeCpt = CPT(Const.GRADE_DOMAIN, [])

    def learnParams(self, data):
        for sunetId in data:
            self.addFeatures(sunetId, data)
        self.enjoyCpt.normalize()
        self.gradeCpt.normalize()
    
    def predictGrade(self, data, sunetId, assnId):
        enjoy = self.getEnjoy(data, sunetId, assnId)
        
        observed = {}
        observed['enjoy'] = enjoy

        n1 = Node('grade', self.gradeCpt, [])
        n2 = Node('enjoy', self.enjoyCpt, ['grade'])
        network = [n1, n2]
        assignments = Inference.infer(network, observed)
        return self.getMle(assignments)

    def getMle(self, assignments):
        probDist = assignments['grade']
        argMax = None
        maxValue = None
        for key in probDist:
            prob = probDist[key]
            if argMax == None or prob > maxValue:
                argMax = key
                maxValue = prob
        return argMax

    def addFeatures(self, sunetId, data):
        for assnId in ASSN_IDS:
            grade = self.getGrade(data, sunetId, assnId)
            enjoy = self.getEnjoy(data, sunetId, assnId)
            self.enjoyCpt.addObservation(enjoy, [grade])
            self.gradeCpt.addObservation(grade, [])

    def getVar(self, data, sunetId, assnId, varId):
        features = data[sunetId]
        var = varId + assnId
        return features[var]
   
    def getEnjoy(self, data, sunetId, assnId):
        return self.getVar(data, sunetId, assnId, 'e')

    def getTime(self, data, sunetId, assnId):
        return self.getVar(data, sunetId, assnId, 't')
 
    def getGrade(self, data, sunetId, assnId):
        return self.getVar(data, sunetId, assnId, 'g')

