#include "StringSet.h"
#include "queue.h"

// The three different add functions
// Uncomment one and leave two commented out

#define ADD_ARMS_LENGTH
// #define ADD_REF_TO_POINTER
// #define ADD_POINTER_TO_POINTER

// constructor
StringSet::StringSet() {
    root = nullptr; // empty set
    count = 0;      // no nodes yet
}

// destructor
StringSet::~StringSet() {
    postOrderClear(root);
    root = nullptr; // no more tree!
}

void StringSet::postOrderClear(Node *node) {
    // perform a post-order delete to remove each node
    // base case
    if (node == nullptr) {
        return;
    }
    postOrderClear(node->left);
    postOrderClear(node->right);
    delete node;
}

// adds a string to the BST
void StringSet::add(string s) {
#ifdef ADD_ARMS_LENGTH
    // call the arms-length add function
    if (root == nullptr) {
        root = new Node(s);
        count++;
    } else {
        add(s, root);
    }
#elif defined ADD_REF_TO_POINTER
    // call the add function with a reference to a pointer
    add(s, root);
#elif defined ADD_POINTER_TO_POINTER
    // call the add function with a pointer to a pointer
    // we need to pass in the address of the root itself
    add(s, &root);
#endif
}

// overloaded add helper for recursion
void StringSet::addArmsLength(string s, Node *node) {
    if (node->str > s) {
        if (node->left == nullptr) {
            node->left = new Node(s);
            count++;
        } else {
            addArmsLength(s, node->left);
        }
    } else if (node->str < s) {
        if (node->right == nullptr) {
            node->right = new Node(s);
            count++;
        } else {
            addArmsLength(s, node->right);
        }
    }
}

// overloaded add helper for recursion
// Note: this uses a "reference to a pointer,"
//       which can be tricky to understand!
//       You would be wise to spend time walking through it
//       with drawings!
void StringSet::add(string s, Node *&node) {
    if (node == nullptr) {
        node = new Node(s);
        count++;
    } else if (node->str > s) {
        add(s, node->left);
    } else if (node->str < s) {
        add(s, node->right);
    }
}

// overloaded add helper for recursion
// Note: this version uses a pointer-to-pointer
// which can also be tricky to understand!
// You will need to understand it for the next assignment!
void StringSet::add(string s, Node **node) {
    if (*node == nullptr) {
        *node = new Node(s);
        count++;
    } else if ((*node)->str > s) {
        add(s, &((*node)->left));
    } else if ((*node)->str < s) {
        add(s, &((*node)->right));
    }
}

// returns the string earliest  in the alphabet (empty string if not found)
string StringSet::findMin() { return findMin(root); }

// overloaded findMin helper for recursion
string StringSet::findMin(Node *node) {
    // base cases
    if (node == nullptr) {
        return ""; // did not find
    }
    if (node->left == nullptr) {
        return node->str;
    }
    return findMin(node->left);
}

// returns the string farthest in the alphabet (throws error if empty)
string StringSet::findMax() { return findMax(root); }

// overloaded findMax helper for recursion
string StringSet::findMax(Node *node) {
    // base case
    if (node == nullptr) {
        return ""; // did not find
    }
    if (node->right == nullptr) {
        return node->str;
    }
    return findMax(node->right);
}

bool StringSet::contains(string s) {
    if (root == nullptr) {
        return false;
    }
    return contains(s, root);
}

// overloaded contains helper for recursion
bool StringSet::contains(string &s, Node *node) {
    // base cases
    if (node == nullptr) {
        return false;
    }
    if (node->str == s) {
        return true;
    }
    // recursive cases
    if (s < node->str) {
        return contains(s, node->left);
    }
    return contains(s, node->right);
}

// removes string from the set
void StringSet::remove(string s) {
    if (root != nullptr) {
        Node *removedNode;
        if (root->str == s) { // root holds value to remove
            // create dummy node, and set root to its left
            Node dummyNode("");
            dummyNode.left = root;
            removedNode = remove(s, root, &dummyNode);
            root = dummyNode.left;
        } else {
            removedNode = remove(s, root, nullptr);
        }
        if (removedNode != nullptr) {
            delete removedNode;
            count--;
        }
    }
}

// overloaded remove helper for recursion
Node *StringSet::remove(string &s, Node *node, Node *parent) {
    // traverse left if we aren't at the correct node to remove.
    if (s < node->str) {
        if (node->left != nullptr) {
            return remove(s, node->left, node);
        } else { // the node we want to remove doesn't exist
            return nullptr;
        }
    } else if (s > node->str) {
        // traverse right if we aren't at the correct node
        if (node->right != nullptr) {
            return remove(s, node->right, node);
        } else { // the node we want to remove doesn't exist
            return nullptr;
        }
    } else { // we found the node to remove
        if (node->left != nullptr && node->right != nullptr) {
            // two children
            // change data to min of right child
            node->str = findMin(node->right);

            // recursively delete min of right child
            return remove(node->str, node->right, node);
        } else if (parent->left == node) {
            // replace the parent's left with either the right or left child of
            // this node, depending on which one exists (and the right will be
            // nullptr if it has zero children)
            parent->left = (node->left != nullptr) ? node->left : node->right;
            return node;
        } else if (parent->right == node) {
            // replace the parent's right with either the right or left child of
            // this node, depending on which one exists (and the right will be
            // nullptr if it has zero children)
            parent->right = (node->left != nullptr) ? node->left : node->right;
            return node;
        }
    }
    return nullptr; // this line will never be reached, but the compiler will
                    // complain without it.
}

// returns the number of nodes in the tree
int StringSet::size() { return count; }

// returns true if the tree is empty
bool StringSet::isEmpty() {
    return root == nullptr;
    // also could return count == 0
}

// overloads << to print set in order
ostream &operator<<(ostream &out, StringSet &set) {
    string max = set.findMax();
    out << "[";
    set.inOrderTraversal(out, set.root, max);
    out << "]";
    return out;
}

// Traverses the tree in order and adds each node's value to the stream
// Note: needs "max" so we don't print a comma after the last node's value
//       (because we want it to be pretty!)
void StringSet::inOrderTraversal(ostream &out, Node *node, string &max) {
    if (node == nullptr) {
        return;
    }
    inOrderTraversal(out, node->left, max);
    out << node->str;
    if (node->str != max) {
        out << ", ";
    }
    inOrderTraversal(out, node->right, max);
}

int treeHeight(Node *tree) {
    if (tree == nullptr) {
        return 0;
    }
    int leftHeight = treeHeight(tree->left);
    int rightHeight = treeHeight(tree->right);
    if (leftHeight > rightHeight) {
        return leftHeight + 1;
    } else {
        return rightHeight + 1;
    }
}

int longestData(Node *tree) {
    if (tree == nullptr)
        return 0;
    int longest = max(longestData(tree->left), longestData(tree->right));
    return (max((int)tree->str.length(), longest));
}

void StringSet::prettyPrint() {
    // calculate levels
    int numLevels = treeHeight(root);
    int maxLength = longestData(root);

    // go through each level and mark locations of numbers
    Queue<Node *> q;
    q.enqueue(root);
    int level = 0;
    int numberCount = 1 << level;
    int nextNumberOfCount = 0;
    int spacesSoFar = 0;
    while (!q.isEmpty()) {
        Node *curr = q.dequeue();
        //             xxx
        //     xxx             xxx
        // xxx     xxx     xxx     xxx
        int first = (1 << (numLevels - level - 1)) - 1;
        int elementMult = 1 << (numLevels - level);
        int nextPos = (first + (nextNumberOfCount * elementMult)) * maxLength;

        printChars(nextPos - spacesSoFar, ' ');
        spacesSoFar = nextPos;
        if (curr != nullptr) {
            cout << padString(curr->str, maxLength);
        } else {
            printChars(maxLength, '-');
        }
        spacesSoFar += maxLength;
        if (curr != nullptr) {
            q.enqueue(curr->left);
            q.enqueue(curr->right);
        } else {
            q.enqueue(nullptr);
            q.enqueue(nullptr);
        }
        nextNumberOfCount++;
        if (nextNumberOfCount == numberCount) {
            nextNumberOfCount = 0;
            level++;
            numberCount = 1 << level;
            spacesSoFar = 0;
            if (level == numLevels)
                break; // don't print last level, which will be empty
            cout << endl;
        }
    }
    cout << endl;
}

void StringSet::printChars(int n, char c) {
    for (int i = 0; i < n; i++) {
        cout << c;
    }
}

string StringSet::padString(string s, int len) {
    // make the string len long
    int sLen = s.length();
    if ((len - sLen) % 2 == 1) {
        s += " ";
        sLen++;
    }
    for (int i = 0; i < (len - sLen) / 2; i++) {
        s = string(" ") + s + " ";
    }
    return s;
}
