/* Maze generator in C++
 * Maze creation and backtracking solution by 
 * Joe Wingbermuehle, 19990805
 */

#include<iostream>
#include<string>
#include<vector>
#include<set>
#include<unistd.h>
#include<sys/wait.h>
#include<errno.h>
#include <time.h>

using namespace std;

enum Location {PATH, WALL, FINISH, SOLUTION};

struct coordinate {
    int x,y;
    bool operator<(const coordinate &c) const { return ((x < c.x) ||  
            ((x == c.x) && (y < c.y))) ? true : false; }
    bool operator==(const coordinate &c) const { return (x == c.x && y == c.y); }
    bool operator!=(const coordinate &c) const { return !(x == c.x && y == c.y); }
    friend ostream& operator<<(ostream& os, const coordinate& c) {
        os << c.x << "," << c.y;
        return os;
    }
};

class Maze {
    public:
        Maze(int width, int height) : width(width * 2 + 3), height(height * 2 + 3) {
            maze = (Location *)malloc(this->width * this->height * sizeof(Location));
        }
        Location getVal(coordinate c) const { return maze[c.y * width + c.x]; }
        void setVal(coordinate c, Location val) { maze[c.y * width + c.x] = val; }
        int getWidth() const { return width; }
        int getHeight() const { return height; }

    private:
        int width, height;
        Location *maze;
};

/* Display the maze. */
void showMaze(const Maze &maze) {
    int x, y;
    int width = maze.getWidth();
    int height = maze.getHeight();
    for(y = 0; y < height; y++) {
        for(x = 0; x < width; x++) {
            switch(maze.getVal({x,y})) {
                case WALL:  cout << "[]";  break;
                case SOLUTION:  cout << "<>";  break;
                case FINISH: cout << "FF"; break;
                default: cout << "  ";  break;
            }
        }
        cout << endl;
    }
    cout << endl;
}

/*  Carve the maze starting at x, y. */
void carveMaze(Maze &maze, int x, int y) {
    int width = maze.getWidth();
    int height = maze.getHeight();
    int x1, y1;
    int x2, y2;
    int dx, dy;
    int dir, count;

    dir = rand() % 4;
    count = 0;
    while(count < 4) {
        dx = 0; dy = 0;
        switch(dir) {
            case 0:  dx = 1;  break;
            case 1:  dy = 1;  break;
            case 2:  dx = -1; break;
            default: dy = -1; break;
        }
        x1 = x + dx;
        y1 = y + dy;
        x2 = x1 + dx;
        y2 = y1 + dy;
        if(   x2 > 0 && x2 < width && y2 > 0 && y2 < height
                && maze.getVal({x1,y1}) == WALL && maze.getVal({x2,y2}) == WALL) {
            maze.setVal({x1,y1},PATH);
            maze.setVal({x2,y2},PATH);
            x = x2; y = y2;
            dir = rand() % 4;
            count = 0;
        } else {
            dir = (dir + 1) % 4;
            count += 1;
        }
    }
}

/* Generate maze in matrix maze with size width, height. */
void generateMaze(Maze &maze) {

    int x, y;
    int width = maze.getWidth();
    int height = maze.getHeight();

    /* Initialize the maze. */
    for(x = 0; x < width; x++) {
        for (y = 0; y < height; y++) {
            maze.setVal({x,y},WALL);
        }
    }

    maze.setVal({1,1},PATH);

    /* Seed the random number generator. */
    //srand(time(0));
    srand(12345);

    /* Carve the maze. */
    for(y = 1; y < height; y += 2) {
        for(x = 1; x < width; x += 2) {
            carveMaze(maze, x, y);
        }
    }

    /* Set up the entry and exit. */
    //maze[0 * width + 1] = PATH;
    maze.setVal({width-2,height-1},FINISH);
}

/* Solve the maze. */
void solveMazeBacktracking(Maze &maze) {

    int dir, count;
    int x, y;
    int dx, dy;
    int forward;
    int width = maze.getWidth();
    int height = maze.getHeight();

    /* Remove the entry and exit. */
    maze.setVal({1,0}, WALL);
    maze.setVal({width-2,height-1},WALL);

    forward = 1;
    dir = 0;
    count = 0;
    x = 1;
    y = 1;
    while(x != width - 2 || y != height - 2) {
        dx = 0; dy = 0;
        switch(dir) {
            case 0:  dx = 1;  break;
            case 1:  dy = 1;  break;
            case 2:  dx = -1; break;
            default: dy = -1; break;
        }
        if(   (forward  && maze.getVal({x+dx,y+dy}) == PATH)
                || (!forward && maze.getVal({x+dx,y+dy}) == SOLUTION)) {
            maze.setVal({x,y}, forward ? SOLUTION : FINISH);
            x += dx;
            y += dy;
            forward = 1;
            count = 0;
            dir = 0;
        } else {
            dir = (dir + 1) % 4;
            count += 1;
            if(count > 3) {
                forward = 0;
                count = 0;
            }
        }
    }

    /* Replace the entry and exit. */
    maze.setVal({width-2,height-2},SOLUTION);
    maze.setVal({width-2,height-1},SOLUTION);

}

void createMaze(Maze &maze, bool solve) {
    /* Generate the maze. */
    generateMaze(maze);

    /* Solve the maze if requested. */
    if(solve) {
        showMaze(maze);
        cout << "solving" << endl;
        solveMazeBacktracking(maze);
        showMaze(maze);
    }
}

bool solveMazeFork(Maze &maze, coordinate start, vector<coordinate> &path) {
    // basic idea:
    // check each direction. For all directions other than east, 
    // fork if there is a path in that direction
    set<coordinate> visited;
    pid_t pid = 0;
    while (pid == 0) {
        pid = -1;
        path.push_back(start);
        visited.insert(start);
        coordinate test_coord;
        Location test_val;
        for (int i=0; i < 4; i++) {
            // check north, south, west
            switch(i) {
                case 0: test_coord = {start.x,start.y-1}; break; // north
                case 1: test_coord = {start.x,start.y+1}; break; // south
                case 2: test_coord = {start.x-1,start.y}; break; // west
                case 3: test_coord = {start.x+1,start.y}; break; // east
            }
            test_val = maze.getVal(test_coord); 
            if (test_val == FINISH) {
                cout << "Found finish!" << endl;
                path.push_back(test_coord);
                return true;
            }
            if (test_val != WALL && visited.find(test_coord) == visited.end()) {
                pid = fork();
                if (pid == 0) {
                    start = test_coord;
                    break;
                } else {
                    visited.insert(test_coord);
                }
            }
        }
    }
    while (true) {
        pid_t pid = waitpid(-1,NULL,0);
        if (pid == -1) break;
    }
    return false; // no solution from this path
}

void populateMazeWithPath(Maze &maze, vector<coordinate> &path) {
    for (coordinate c : path) {
        maze.setVal(c,SOLUTION);
    }
}

int main(int argc, char *argv[]) {
    if (argc < 3) {
        cout << "Usage:\n\t" << argv[0] << " width height --use-backtracking" << endl;
        return 0;
    }
    const int width = atoi(argv[1]);
    const int height = atoi(argv[2]);
    Maze maze(width,height);
    createMaze(maze, false);
    showMaze(maze); 
    if (argc < 4) {
        coordinate start = {1,1};
        //cout << "end: " << end << endl;
        vector<coordinate> finishedPath;
        bool solved = solveMazeFork(maze,start,finishedPath);
        if (solved) {
            populateMazeWithPath(maze,finishedPath);
            showMaze(maze);
        }
    } else {
        solveMazeBacktracking(maze);
        showMaze(maze);
    }
    return 0;
}
