import numpy

class CPT():

    # Function Constructor
    # --------------------
    # Create a cpt over discrete variables that can take
    # on the given domains. Important: the order of parents
    # is expected to stay the same throughout the use of this
    # table, but we do not check or enforce it.
    def __init__(self, nodeDomain, parentDomains):
         self.nodeDomain = nodeDomain
        self.parentDomains = parentDomains
        self.tableShape = []
        for domain in parentDomains:
            dim = len(domain)
            self.tableShape.append(dim)
        self.tableShape.append(len(nodeDomain))
        self.table = numpy.ones(self.tableShape)
        
    # Function: Add Observation
    # -------------------------
    # Note that a given combination of parent values and node
    # values have been seen. I sure hope parent values were 
    # given in the same order as parentDomains!
    def addObservation(self, nodeValue, parentValues):
        subTable = self.table
        for i in range(len(parentValues)):
            domain = self.parentDomains[i]
            value = parentValues[i]
            subIndex = domain.index(value)
            subTable = subTable[subIndex]
        nodeIndex = self.nodeDomain.index(nodeValue)
        subTable[nodeIndex] += 1

    # Function: Normalize
    # -------------------------
    # Make it so that the things that should sum to 1, actually
    # sum to 1.
    def normalize(self):
        self._recursiveNormalize(self.table, [])

    # Function: Get Domain
    # -------------------------
    # Simple getter
    def getDomain(self):
        return self.nodeDomain

    # Function: Get Table
    # -------------------------
    # Simple getter
    def getTable(self):
        return self.table

    def _recursiveNormalize(self, table, parentValues):
        parentIndex = len(parentValues)
        if parentIndex == len(self.parentDomains):
            self._normalizeRow(table, parentValues)
            return
        domain = self.parentDomains[parentIndex]
        for value in domain:
            index = domain.index(value)
            subTable = table[index]
            self._recursiveNormalize(subTable, parentValues + [value])

    def _normalizeRow(self, row, parentValues):
        parentCount = sum(row)
        for i in range(len(row)):
            count = row[i]
            row[i] = count / parentCount
