'''
Created on Jun 27, 2013

@author: Chris
'''

from Puzzle import Puzzle
from Path import Path
import Queue
from sets import Set

class StarSolver(object):
    
    # Function: Constructor
    # ---------------------
    # Set the counter of nodes expanded
    # to be zero.
    def __init__(self):
        self.nodesExpanded = 0
    
    # Function: Run
    # ---------------------
    # Create an instance of a N-Puzzle problem,
    # call the solver, and report the solution
    def run(self):
        puzzle = Puzzle()
        puzzle.printPuzzle()
        print 'Find solution...'
        path = self.getPath(puzzle)
        print ''
        print 'Solution found!'
        print 'Solution length: ' + str(len(path.actions))
        print 'Nodes expanded: ' + str(self.nodesExpanded)
        print 'Solution:'
        for action in path.actions:
            print action
           
    # Function: Report Nodes Expanded
    # ---------------------
    # Prints to the console if the number of
    # nodes expanded so far is a multiple of 1000 
    def reportNodesExpanded(self):
        if self.nodesExpanded % 1000 == 0:
            print str(self.nodesExpanded)
       
    # Function: Get Heuristic
    # ---------------------
    # Return a heuristic value for the cost to 
    # get from the passed in state to the solution.
    # The heuristic must be consistent.
    def getHeuristic(self, state):
        return self.getHeuristic2(state)
      
    # Function: Get Heuristic 2
    # -------------------------
    # Manhatten distance heuristic
    def getHeuristic2(self, state):
        statePieces = state.getPieces()
        different = 0
        for rowIndex in range(Puzzle.ROWS):
            for colIndex in range(Puzzle.COLS):
                stateValue = statePieces[rowIndex][colIndex]
                if stateValue != ' ':
                    solnRow = (stateValue - 1) / 3
                    solnCol = (stateValue - 1) % 3 
                    
                    different += abs(solnRow - rowIndex)
                    different += abs(solnCol - colIndex)
        return different
        
    # Function: Get Heuristic 1
    # -------------------------
    # Misplaced Tile Heuristic
    def getHeuristic1(self, state):
        solutionPieces = Puzzle.getSolutionPieces()
        statePieces = state.getPieces()
        different = 0
        for rowIndex in range(Puzzle.ROWS):
            for colIndex in range(Puzzle.COLS):
                solnValue = solutionPieces[rowIndex][colIndex]
                stateValue = statePieces[rowIndex][colIndex]
                if stateValue != solnValue:
                    different += 1
        return different
       
    # Function: Get Priority
    # -------------------------
    # Return the f score for the current path. 
    def getPriority(self, path):
        gScore = path.getLength()
        hScore = self.getHeuristic(path.getLastState())
        return gScore + hScore
        
    # Function: Get Path
    # ---------------------
    # Finds the shortest path from the start state
    # to the solution 
    def getPath(self, startState):
        fringe = Queue.PriorityQueue()
        startPath = Path()
        startPath.setStartState(startState)
        fScore = self.getPriority(startPath)
        fringe.put((fScore, startPath))
        visited = Set([])
        while not fringe.empty(): 
            (fScore, currPath) = fringe.get()
            self.nodesExpanded += 1
            self.reportNodesExpanded()
            currState = currPath.getLastState()
            if currState.isSolution(): return currPath
            if currState in visited: continue
            visited.add(currState)
            for action in currState.getLegalActions():
                successor = currState.getSuccessorState(action)
                newPath = Path(currPath)
                newPath.addAction(action, successor)
                fScore = self.getPriority(newPath)
                fringe.put((fScore, newPath))
        return None



if __name__ == "__main__":
    StarSolver().run()