import json

from cpt import CPT
from const import Const
from inference import Inference
from node import Node



class Basic():
    def __init__(self):
        self.enjoyCpt = CPT(Const.ENJOY_DOMAIN, [Const.GRADE_DOMAIN])
        self.timeCpt = CPT(Const.TIME_DOMAIN, [Const.GRADE_DOMAIN])
        self.gradeCpt = CPT(Const.GRADE_DOMAIN, [])
        self.countCheckPlust = 0
        self.totalGradez = 0

    def learnParams(self, data):
        for sunetId in data:
            self.addFeatures(sunetId, data)
        self.enjoyCpt.normalize()
        self.gradeCpt.normalize()
        self.timeCpt.normalize()
    
    def predictGrade(self, data, sunetId, assnId):
        enjoy = self.getEnjoy(data, sunetId, assnId)
        time = self.getTime(data, sunetId, assnId)
        
        observed = {}
        observed['enjoy'] = enjoy

        n1 = Node('grade', self.gradeCpt, [])
        n2 = Node('enjoy', self.enjoyCpt, ['grade'])
        n3 = Node('time', self.timeCpt, ['grade'])
        network = [n1, n2, n3]
        assignments = Inference.infer(network, observed)
        return self.getMle(assignments)

    def getMle(self, assignments):
        probDist = assignments['grade']
        argMax = None
        maxValue = None
        for key in probDist:
            prob = probDist[key]
            if argMax == None or prob > maxValue:
                argMax = key
                maxValue = prob
        return argMax

    def addFeatures(self, sunetId, data):
        
        for assnId in Const.ASSN_IDS:
            grade = self.getGrade(data, sunetId, assnId)
            if grade == 'checkPlus': self.countCheckPlust += 1
            enjoy = self.getEnjoy(data, sunetId, assnId)
            time = self.getTime(data, sunetId, assnId)
            self.enjoyCpt.addObservation(enjoy, [grade])
            self.timeCpt.addObservation(time, [grade])
            self.gradeCpt.addObservation(grade, [])
            self.totalGradez += 1

    def getVar(self, data, sunetId, assnId, varId):
        features = data[sunetId]
        var = varId + assnId
        return features[var]
   
    def getEnjoy(self, data, sunetId, assnId):
        return self.getVar(data, sunetId, assnId, 'e')

    def getTime(self, data, sunetId, assnId):
        return self.getVar(data, sunetId, assnId, 't')
 
    def getGrade(self, data, sunetId, assnId):
        return self.getVar(data, sunetId, assnId, 'g')

