#include <iostream>
#include "string.h"
#include "strlib.h"
#include "error.h"
#include "vector.h"
#include "math.h"
#include "tokenscanner.h"
#include "map.h"
#include "queue.h"
using namespace std;

struct node{
    node* left;
    node* right;
    node* parent;
    double time;
    bool isLeaf; //updated from input
    bool isInSubtree; // updated from markSubtree
    Vector<int> initial; // 2 elements: S, C initial vaues (for isLeaf == true only)
    Vector<int> bounds; // 2 elements: Used in marking subtree and setting bounds; the number of initial S, and C lineages below the node
};

//This function gets the next phrase (subtree) from a string (phrase) that represents a tree
string nextPhrase(string phrase, int& pos){
    int leftCount = 0;
    int currPos = pos;
    while(true){
        if(phrase[currPos] == '(') leftCount++;
        else if(phrase[currPos] == ')') leftCount--;
        if(leftCount == 0) break;
        else currPos++;
    }
    string next = phrase.substr(pos, currPos + 1 - pos);
    pos = currPos + 1;
//    cout << phrase << next << endl;
    return next;
}


//This function builds a tree from an input line.
node* buildTree(string phrase){
//    cout << phrase << endl;
    int pos = 1;
    node* leftNode = new node;
    leftNode->left = NULL;
    leftNode->right = NULL;
    leftNode->parent = NULL;
    leftNode->isLeaf = false;
    leftNode->isInSubtree = false;
    node* rightNode = new node;
    rightNode->left = NULL;
    rightNode->right = NULL;
    rightNode->parent = NULL;
    rightNode->isLeaf = false;
    rightNode->isInSubtree = false;
    string phrase1 = nextPhrase(phrase, pos); //nextPhrase updates pos
    string phrase2 = nextPhrase(phrase, pos);
    if(phrase1[0] == '('){
        leftNode = buildTree(phrase1);
    }
    else {
        leftNode->isLeaf = true;
    }
    if(phrase2[0] == '('){
        rightNode = buildTree(phrase2);
    }
    else{
        rightNode->isLeaf = true;
    }
    node* currNode = new node;
    if(leftNode != NULL) leftNode->parent = currNode;
    if(rightNode != NULL) rightNode->parent = currNode;
    currNode->left = leftNode;
    currNode->right = rightNode;
    if(currNode->left == NULL && currNode->right == NULL) currNode->isLeaf = true;
    return currNode;
}

void inorder(node* root){
    if(root == NULL) return;
    inorder(root->left);
    cout << root << " " << root->time << " " << root->isLeaf << " " << root->isInSubtree << root->initial << endl;
    inorder(root->right);
    return;
}

void getNodes(Map<node*, Map<string, long double> >& calcs, node* currNode){
    Map<string, long double> newMap;
    if(currNode == NULL) return;
    getNodes(calcs, currNode->left);
    calcs.put(currNode, newMap);
    getNodes(calcs, currNode->right);
    return;
}

// Calculates the upper bounds for a particular node (the root).
// You must use the node in question in both arguments when calling it first.
void getBounds(node* currNode, node* root){
    if(currNode == NULL) return;
    getBounds(currNode->left, root);
    getBounds(currNode->right, root);
    if(currNode->isLeaf){
    for(int i = 0; i < 3; i++){
        if(root->bounds.size() >= i + 1){
            int currBound;
            currBound = root->bounds.get(i);
            currBound += currNode->initial.get(i);
            root->bounds.set(i, currBound);
        }
        else root->bounds.add(0);
    }
    }
    return;
}

//Gets the total number of input S lineages for the purposes of marking the S subtree before the initialS vector is destroyed by the loading process.
int getTotalS(Vector<int> initialS){
    int totalS = 0;
    for(int i = 0; i < initialS.size(); i++){    //Counting the total number of S lineages in the tree.
    totalS += initialS.get(i);
    }
    return totalS;
}

//Calculates the number of lineages of type "type" (0=S, 1=C, 2=M) below the current node.  The 64 is in honor of current Tigers LHP Duane Below, who is frankly not very good, but still.
int numBelow64(node* currNode, int type){
    int output;
    if(currNode->isLeaf == true) output =  currNode->initial.get(type);
    else output = currNode->left->bounds.get(type) + currNode->right->bounds.get(type);
    return output;

}

//Goes down the tree once in a post-order traversal to mark at least a minimally-defining subset of the entire S-subtree.
//Also frames the getting of the bounds.  Those bounds will be used in the S subtree determination.
void nodeCheck(node* currNode, int totalS){
    if(currNode == NULL) return;
    nodeCheck(currNode->left, totalS);
    nodeCheck(currNode->right, totalS);
    for(int i = 0; i < 2; i++){
        if(currNode->bounds.size() < i + 1) currNode->bounds.add(numBelow64(currNode, i));
        else currNode->bounds.set(i, numBelow64(currNode, i));
    }
    if(currNode->bounds.get(0) > 0 && currNode->bounds.get(0) < totalS) currNode->isInSubtree = true;
    return;
}

//Marks every node not already marked below a given root.  Uses a pre-order traversal for funsies.
void finishMarking(node* currNode){
    if(currNode == NULL) return;
    if(!currNode->isInSubtree) currNode->isInSubtree = true;
    finishMarking(currNode->left);
    finishMarking(currNode->right);
    return;
}

//Finds the root of the S-subtree using a breadth-first search.
node* findSubRoot(node* root){
    Queue<node*> nodeQueue;
    node* currNode;
    nodeQueue.enqueue(root);
    while(!nodeQueue.isEmpty()){
        //cout << "test" << endl;
        currNode = nodeQueue.dequeue();
        if((currNode->left != NULL && currNode->left->isInSubtree) && (currNode->right != NULL && currNode->right->isInSubtree)) return currNode;
        else{
            if(currNode->left != NULL) nodeQueue.enqueue(currNode->left);
            if(currNode->right != NULL) nodeQueue.enqueue(currNode->right);
        }
    }
    return NULL;  //This happens if the subtree is just a leaf; nothing should be marked.

}

//When marking the subtree, the algorithm is as follows.  We count the total number of S lineages in the tree, and then we go and calculate the total number of S lineages that are in the leaves of a particular node (by summing up a value from the left and right).  If this number is strictly less than the total AND greater than 0, then the node is in the subtree.  If the two numbers are equal, then the node is not in the subtree (the first time they are equal it technically is, but then it is the root of that subtree and we don't actually want to mark it).  If this number is 0, then it is preliminarily not in the subtree.  We then go and find the node whose two daughters are in the subtree (the root of the subtree) and make sure that all nodes descended from it are marked in the subtree (this is to capture all the ones that were 0 but still in the subtree).  This should work.
void markSubtree(node* root, int totalS){
    nodeCheck(root, totalS); //doing the initial markings
    //inorder(root);
    node* subRoot = findSubRoot(root); //finding the root of the S-subtree
    //inorder(root);
    if(subRoot != NULL){
        finishMarking(subRoot->left); //finishing marking the entire S-subtree
        finishMarking(subRoot->right); //note that we do not want to mark the root of the subtree.
    }
    return;
}

Vector<int> extractInitialSamples(string initial){
    Vector<int> toReturn;
    for(int i = 0; i < initial.size(); i++){
        char currchar = initial[i];
        string currstring;
        int value;
        if(currchar != '/'){
            currstring += currchar;
        }
        else if(currchar == '/'){
            value = stringToInteger(currstring);
            toReturn.add(value);
            currstring.clear();
        }
        else error("Something went wrong reading initial samples");
    }
    return toReturn;
}

Vector<double> extractTimes(string times){
    Vector<double> toReturn;
    for(int i = 0; i < times.size(); i++){
        char currchar = times[i];
        string currstring;
        int value;
        if(currchar != '/'){
            currstring += currchar;
        }
        else if(currchar == '/'){
            value = stringToReal(currstring);
            toReturn.add(value);
            currstring.clear();
        }
        else error("Something went wrong reading initial samples");
    }
    return toReturn;
}

void loadTimes(node* root, Vector<double>& times){
    if(root == NULL) return;
    loadTimes(root->left, times);
    root->time = times.get(times.size() - 1);
    times.remove(times.size() - 1);
    loadTimes(root->right, times);
    return;
}

void loadInitialSamples(node* root, Vector<int>& initial){
    if(root == NULL) return;
    loadInitialSamples(root->left, initial);
    if(root->isLeaf){
        root->initial.add(initial.get(initial.size() - 1));
        initial.remove(initial.size() - 1);
    }
    loadInitialSamples(root->right, initial);
    return;
}

void pre(node* root){
    if(root == NULL) return;
    cout << root << " " << root->time << " " << root->isLeaf << " " << root->isInSubtree << root->initial << endl;
    pre(root->left);
    pre(root->right);
    return;
}

void post(node* root){
    if(root == NULL) return;
    post(root->left);
    post(root->right);
    cout << root << " " << root->time << " " << root->isLeaf << " " << root->isInSubtree << root->initial << endl;
    return;
}

void readinitials(TokenScanner scanner, Vector<int>& initial){
    string next;
    while(true){
        next = scanner.nextToken();
        cout << next << endl;
        if(next == ",") continue;
        if(next == ";") return;
        else if(isalnum(next[0])) initial.add(stringToInteger(next));
        cout << initial << endl;
    }
}

void getTimes(string timesstr, Vector<double>& times){
    for(int i = 0; i < timesstr.size(); i++){
        string tempstring = timesstr.substr(i, 1);
        int THEHEAT = stringToInteger(tempstring);
        times.add(THEHEAT);
    }
}

long long int fact(int n){
    long long int value = 1;
    for(int i = n; i > 0; i--){
        value = value * i;
    }
    return value;
}

long long int fallingFact(int n, int k){
    long long int value = fact(n) / fact(n - k);
    return value;
}

long long int risingFact(int n, int k){
    long long int value = fact(n + k - 1) / fact(n - 1);
    //if(value == 0) cout << n << " " << k << endl;
    return value;
}


double biCo(int a, int b){
    double result = 1;
    for(double i = 0; i < b; i++){
        result = result * (1.0 + (a - b) / (b - i));
    }
    return result;
}

long double gFunction(int n, int j, double T){
    long double value = 0;
    if(j > n) return 0;
    else if(n == 0 || j == 0) return 0;
    else if(T > 5000) {
        if(j == 1) return 1;
        else return 0;
    }
    else if(T >= 0.1 && n >= 90 && j >= 50) return 0;
    else if(T >= 1 && n >= 20 && j >= 10) return 0;
    else{
        for(int k = j; k <= n; k++){
            long double exponent, val, num, den;
            exponent = exp((-k) * (k - 1) * T / 2);
            int sign;
            sign = pow((-1), k - j);
            if(k == 1){
                val = exponent * sign;
            }
            else{
                num = (2 * k - 1) * biCo(j + k - 2, j) * biCo(n - 1, k - 1) * biCo(k - 1, j - 1);
                den = (n + k - 1) * biCo(n + k - 2, n);
                val = exponent * sign * num / den;
            }
            value += val;
        }
        if(value < 0.000001) value = 0;
        return value;
    }
}


long double wFunction(int m, int n, int x, int k, double T){
    long double num, den, summand, value;
    value = 0;
    den = 0;
    num = gFunction(m, x, T) * gFunction(n, k - x, T);
    for(int i = 1; i < k; i++){
        summand = gFunction(m, i, T) * gFunction(n, k - i, T);
        den += summand;
    }
    value = num / den;
    return value;
}

long double iFunction(int n, int k){
    long double num, den, value, a, b, c, d;
    if(k > n) return 0;
    else if(n == 0 && k != 0) return 0;
    else if(n == 0 && k == 0) return 1;
    else{
        a = fact(n);
        b = fact(n - 1);
        num = a * b;
        c = fact(k);
        d = fact(k - 1);
        den = pow(2, n - k)* d * c;
        value = num / den;
        return value;
    }
}

long double wSub2(int x, int y){
    long double value, sum, b, c, d;
    long long int a;
    sum = x + y;
    a = fact(sum);
    b = fact(x);
    c = fact(y);
    d = b * c;
    value = a / d;
    return value;
}

long double wSub3(int x, int y, int z){
    long double value, sum, b, c, d, e;
    long long int a;
    sum = x + y + z;
    a = fact(sum);
    b = fact(x);
    c = fact(y);
    d = fact(z);
    e = b * c * d;
    value = a / e;
    return value;
}

long double bigF(int s_L, int s_R, int c_L, int c_R, int m_L, int m_R, int s, int c, int m, double T, bool isInSubtree){
    int s_I = s_L + s_R;
    int c_I = c_L + c_R;
    int m_I = m_L + m_R;
    //if(s_I == 1 && c_I == 2 && m_I == 0 && s == 0 && c == 1 && m == 1) cout << "AAHHAHAHA" /*K << " " << gvalue << " " << output*/ << endl;
    if(s_I + c_I + m_I < s + c + m) return 0;
    else{
        long double K;
        if(s_I == 0 && c_I > 0 && m_I == 0 && c_I >= c && c > 0 && s == 0 && m == 0) {
            K = 1;
        }
        else if(s_I == 0 && c_I > 0 && m_I == 1 && c_I >= c && s == 0 && m == 1) {
            K = 1;
        }
        else if(s_I == 0 && c_I == 0 && m_I == 1 && c == 0 && s == 0 && m == 1){
            K = 1;
        }
        else if(s_I > 0 && isInSubtree == false && m_I == 0 && s == 0 && m == 1 && c_I >= c && c_I > 0){
            K = 0;
            for(int k = c + 1; k <= c_I; k++){
                K += 2 * biCo(c_I - 1, c_I - k) / biCo(s_I + c_I, s_I) / biCo(s_I + c_I - 1, k);
            }
        }
        else if(s_I > 0 && s > 0 && s_I >= s  && c_I == 0 && c == 0 && m_I == 0 && m == 0){
            K = 1;
        }
        else if(s_I == 0 && s == 0 && c_I >= c  && c_I > 0 && c > 0 && m_I == 0 && m == 0){
            K = 1;
        }
        else if(s_I > 0 && c_I > 0 && c > 0 && s > 0 && s_I >= s && c_I >= c && m_I == 0  && m == 0){
            K = biCo(s + c, s) * biCo(s_I - 1, s - 1) * biCo(c_I - 1, c - 1) / biCo(s_I + c_I - 1, s + c - 1) / biCo(s_I + c_I, s_I);
        }
        else K = 0;
        long double output;
        long double gvalue = gFunction(s_I + c_I + m_I, s + c + m, T);
        if(K == 0 || gvalue == 0) output = 0;
        else output = gvalue * K;
        //if(gvalue == 0) cout << s_I + c_I + m_I << " " << s + c + m << endl;
        //if(output != output) cout << gvalue << " " << K << " " << endl;
        //if(output > 0) cout << s_I << " " << c_I << " " << m_I << " | " << s << " " << c << " " << m << " " << output << endl;
        return output;
    }
}

long double bigFRM(int s_L, int s_R, int c_L, int c_R, int m_L, int m_R, int s, int c, int m, double T, bool isRoot){
    int s_I = s_L + s_R;
    int c_I = c_L + c_R;
    int m_I = m_L + m_R;
    if(s_I + c_I + m_I < s + c + m) {
        return 0;
    }
    else{
        long double K;
        if(s_I == 0 && c_I > 0 && m_I == 0 && c_I >= c && c > 0 && s == 0 && m == 0) {
            K = 1;
        }
        else if(s_I == 0 && c_I == 0 && m_I == 1 && c == 0 && s == 0 && m == 1){
            K = 1;
        }
        else if(s_I > 0 && s > 0 && s_I >= s  && c_I == 0 && c == 0 && m_I == 0 && m == 0){
            K = 1;
        }
        else if(s_I > 0 && c_I > 0 && m_I == 0 && s == 0 && c == 0 && m == 1 && isRoot == true){
            K = 2 / (biCo(s_I + c_I, s_I) * (s_I + c_I - 1));
        }
        else if(s_I > 0 && c_I > 0 && m_I == 0 && s > 0 && c > 0 && s_I >= s && c_I >= c && m == 0){
            K = biCo(s + c, s) * biCo(s_I - 1, s - 1) * biCo(c_I - 1, c - 1) / biCo(s_I + c_I - 1, s + c - 1) / biCo(s_I + c_I, s_I);
        }
        else K = 0;
        long double output;
        long double gvalue = gFunction(s_I + c_I + m_I, s + c + m, T);
        if(K == 0 || gvalue == 0) output = 0;
        else output = gvalue * K;
        //if(output != output) cout << gvalue << " " << K << " " << endl;
        return output;
    }
}

long double calculateProb(node* root, int s, int c, int m, int UB_s, int UB_c, Map<node*, Map<string, long double> >& calcs, bool rm, int level){
    level++;
    bool isRoot;
    if(level == 0){
        isRoot = true;
    }
    else(isRoot = false);
//    for(int  l = 0; l < level; l++){
//        cout << "  ";
//    }
//    cout << root << endl;
//    for(int  l = 0; l < level; l++){
//        cout << "  ";
//    }
//    cout << s << " " << c << " " << m << " " << endl;
    long double value = 0, prod;
    for(int s_L = 0; s_L <= UB_s; s_L++){
        for(int s_R = 0; s_R <= UB_s - s_L; s_R++){
            for(int c_L = 0; c_L <= UB_c; c_L++){
                for(int c_R = 0; c_R <= UB_c - c_L; c_R++){
                    for(int m_L = 0; m_L <= 1; m_L++){
                        for(int m_R = 0; m_R <= 1 - m_L; m_R++){
                            if(s_L + c_L + m_L > 0 || s_R + c_R + m_R > 0){
                                    if(s_L + c_L + m_L + s_R + c_R + m_R >= s + c + m){
                                    long double leftProb, rightProb;
                                    //cout << s_L << " " << c_L << " " << m_L << " | " << s_R << " " << c_R << " " << m_R << " | " << s << " " << c << " " << m << " " << endl;
                                    long double fResult;
                                    if(rm == true){
                                        fResult = bigFRM(s_L, s_R, c_L, c_R, m_L, m_R, s, c, m, root->time, isRoot);
                                    }
                                    else if(rm == false) {
                                        fResult = bigF(s_L, s_R, c_L, c_R, m_L, m_R, s, c, m, root->time, root->isInSubtree);
                                    }
                                    else error("rm not bool");
                                    if(fResult == 0) prod = 0;
                                    else{
                                        if(root->left == NULL && root->right == NULL) {

                                            if(s_L == root->initial[0] && c_R == 0 && m_L == 0 && s_R == 0 && c_L == root->initial[1] && m_R == 0) {
                                                leftProb = 1;
                                                rightProb = 1;
                                            }
                                            else {
                                                leftProb = 0;
                                                rightProb = 0;
                                            }
                                        }
                                        else {
                                        // cout << "call nonleaf" << endl;
                                            string config;
                                            config += integerToString(s_L);
                                            config += integerToString(c_L);
                                            config += integerToString(m_L);
                                            Map<string, long double> nodeCalc = calcs.get(root->left);
                                            if(nodeCalc.containsKey(config)){
                                                leftProb = nodeCalc.get(config);
                                            }
                                            else{
                                                leftProb = calculateProb(root->left, s_L, c_L, m_L, UB_s, UB_c, calcs, rm, level);
                                                nodeCalc.put(config, leftProb);
                                                calcs.put(root->left, nodeCalc);
                                            }
                                            config.clear();
                                            config += integerToString(s_R);
                                            config += integerToString(c_R);
                                            config += integerToString(m_R);
                                            nodeCalc = calcs.get(root->right);
                                            if(nodeCalc.containsKey(config)){
                                                rightProb = nodeCalc.get(config);
                                            }
                                            else{
                                                rightProb = calculateProb(root->right, s_R, c_R, m_R, UB_s, UB_c, calcs, rm, level);
                                                nodeCalc.put(config, rightProb);
                                                calcs.put(root->right, nodeCalc);
                                            }
                                            config.clear();
                                        }
                                        prod = leftProb * rightProb * fResult;
                                        //for(int  l = 0; l < level; l++){
                                        //    cout << "  ";
                                        //}
                                        //cout << prod << endl;

                                      }
                                    value += prod;
                                  }
                            }
                        }
                    }
                }
            }
        }
    }
//    for(int  l = 0; l < level; l++){
//        cout << "  ";
//    }
//    cout << value << endl;
    level--;
    return value;
}
