/* Anthony Giardullo
 * Danish Lakhani
 * CS259 Project
 * 
 * This file contains the code we used to transform our specially annotated
 * PRISM code into runnable PRISM code.  See the comments in
 * prismHelper.java to see how we 'annotated' our PRISM code.
 *
 * By giving the appropriate command line parameters and using the
 * keywords "*GENERATECONSTANTS*" and "*GENERATEGRAPH*", this code will
 * create PRISM code specific to our project.
 *
 * This program will create a runnable .pm and .pctl file.
 *
 *
 * Useage:
 * 
 * java preparse in-file out-file seed badProb numPaths maxLength minLength 
 *               userCount mixCount [minSearchLength maxSearchLength]
 *
 * in-file: annotated prism-code (use the file entropy.pm)
 * out-file: the name of the created .pm and .pctl files
 * seed: any number to seed the rng
 * badProb: probability a node is bad from 0.0 to 1.0
 * numPaths: number of messages in the network
 * maxLength: maximum length of a path of a message
 * minLength: minimum length of a path of a message
 * userCount: number of users
 * mixCount: number of mix nodes
 * minSearchLength(optional): Assumes attacker doesn't consider paths
 *   shorter than this number
 * maxSearchLength(optional): Assumes attacker doesn't consider paths
 *   longer than this number
 *
 *
 */

import java.io.*;
import java.lang.*;
import java.util.*;


public class preparse{


    private static Hashtable constants = new Hashtable();
    private static Random r;

    private static void GenerateGraph(PrintStream pStream) throws Exception{
	int numPaths  =	Integer.parseInt((String)constants.get("numPaths"));
	int maxLength = Integer.parseInt((String)constants.get("maxLength"));
	int minLength = Integer.parseInt((String)constants.get("minLength"));
	int userCount = Integer.parseInt((String)constants.get("userCount"));
	int mixCount  = Integer.parseInt((String)constants.get("mixCount"));
	
	if(maxLength > mixCount + 1 || minLength > maxLength ||
	   maxLength < 1 || minLength < 1 || userCount < 2)
	    throw new Exception("bad consts!");

	//boolean[][][] graph = new boolean[userCount+mixCount+1][userCount+mixCount+1][numPaths+1];
	//for(int i = 0; i < graph.length; i++)
	//    for(int j = 0; j < graph[0].length; j++)
	//	for(int k = 0; k < graph[0][0].length; k++)
	//	    graph[i][j][k] = false;
	

	for(int n = 1; n <= numPaths; n++){

	    boolean[] beenThere = new boolean[userCount+mixCount+1];
	    for(int i = 0; i < beenThere.length; i++)
		beenThere[i] = false;

	    int src;
	    int dest;
	    if(n == 1)
		src = 1;
	    else
		src = pickUser(userCount, beenThere);	    
	    
	    beenThere[src] = true;

	    for(int l = 1; ;l++){
		if((l == maxLength) || 
		   (l >= minLength && (r.nextInt(maxLength - l + 1) == 0))) {

		    dest = pickUser(userCount, beenThere);
		    //graph[src][dest][n] = true;
		    printEdge(dest, n, src, pStream);

		    if(n == 1)
			pStream.println("const initNode = " + dest + ";");
		    break;
		}
		else{
		    dest = pickMix(userCount, mixCount, beenThere);
		    //graph[src][dest][n] = true;
		    printEdge(dest, n, src, pStream);
		    beenThere[dest] = true;
		    
		    src = dest;

		}

	    }
	    for(int i = 1; i <= userCount; i++){
		if(i != dest)
		    printEdge(i, n, 0, pStream);
	    }
	    for(int i = userCount+1; i < beenThere.length; i++){
		if(!beenThere[i])
		    printEdge(i, n, 0, pStream);
	    }
	}

	
	
// 	for(int k = 1; k < graph[0][0].length; k++)
// 	    for(int j = 1; j < graph[0].length; j++){
// 		int src = 0;
// 		for(int i = 1; i < graph.length; i++)
// 		    if(graph[i][j][k])
// 			src = i;
	
// 		pStream.println("const node"+j+"path"+k+" = "+src+";");
		    
// 	    }
    }

    private static void printEdge(int dest, int path, int src, PrintStream
	pStream){
	pStream.println("const node"+dest+"path"+path+" = "+src+";");
    }
    
    private static int pickUser(int userCount, boolean[] beenThere)
	throws Exception{

	boolean check = true;
	for(int i = 1; i <= userCount; i++)
	    check &= beenThere[i];

	if(check)
	    throw new Exception("No users to pick");

	while(true){
	    int user = r.nextInt(userCount) + 1;
	    if(!beenThere[user])
		return user;
	}

    }


    private static int pickMix(int userCount, int mixCount, boolean[]
    beenThere) throws Exception{
	
	boolean check = true;
        for(int i = userCount+1; i < beenThere.length; i++)
            check &= beenThere[i];

        if(check)
            throw new Exception("No mixes to pick");

        while(true){
            int mix = r.nextInt(mixCount) + 1 + userCount;
            if(!beenThere[mix])
                return mix;
        }


    }

    public static void printUseage(){
	System.out.println("java preparse in-file out-file [seed badProb numPaths"
			   + " maxLength minLength userCount mixCount" 
			   + " [minSearchLength maxSearchLength] ]");
			   
    }

    public static void addConstant(String name, int value, PrintStream pStream){
	String line = "const " + name + " = " + value +";";
	pStream.println(line);
	constants.put(name, String.valueOf(value));
    }

    public static void addRate(String name, double value, PrintStream
			       pStream){
	String line = "rate " + name + " = " + value + ";";
	pStream.println(line);
    }

    public static void addArgs(String[] args, PrintStream pStream){
	double badP = Double.parseDouble(args[3]);
	addRate("badP", badP, pStream);
	addRate("goodP", 1.0 - badP, pStream);
	
	int numPaths = Integer.parseInt(args[4]);
	addConstant("numPaths", numPaths, pStream);
	
	addConstant("maxLength", Integer.parseInt(args[5]), pStream);
	addConstant("minLength", Integer.parseInt(args[6]), pStream);
	
	int userCount = Integer.parseInt(args[7]);
	int mixCount = Integer.parseInt(args[8]);
	addConstant("userCount", userCount, pStream);
	addConstant("mixCount", mixCount, pStream);
	addConstant("nodeCount", userCount+mixCount, pStream);
	
	addRate("numPathsInv", 1.0 / numPaths, pStream);

	int minSearchLength = 1;
	int maxSearchLength = 1;
	int useSearchLength = 0;
	if(args.length == 11){

	    minSearchLength = Integer.parseInt(args[9]);
	    maxSearchLength = Integer.parseInt(args[10]);
	    useSearchLength = 1;

	}
	
	addConstant("useSearchLength", useSearchLength, pStream);
	addConstant("minSearchLength", maxSearchLength, pStream);
	addConstant("maxSearchLength", minSearchLength, pStream);

    }

    public static void main(String[] args) throws Exception{
	if(!(args.length ==11 || args.length == 9 || args.length == 3 ||
	args.length == 2)){

	    printUseage();
	    return;
	}
	
	String in = args[0];
	String out = args[1];

	if(args.length > 2)
	    r = new Random(Integer.parseInt(args[2]));
	else
	    r = new Random();


	File inFile = new File(in);
	FileInputStream inStream = new FileInputStream(inFile);
	BufferedReader reader = 
	    new BufferedReader(new InputStreamReader(inStream));


	File outFile = new File(out + ".pm");
	FileOutputStream outStream = new FileOutputStream(outFile);
	PrintStream pStream = new PrintStream(outStream);
	
	String line;
	while((line = reader.readLine()) != null){

	    if(line.trim().startsWith("const ")){
		int index1 = line.indexOf("const ");
		int index2 = line.indexOf("=");
		int index3 = line.indexOf(";");
		String c = line.substring(index1 + 6, index2).trim(); 
		String cval = line.substring(index2 + 1, index3).trim();
		int val = Integer.parseInt(cval);

		constants.put(c, cval);
	    }

	    if(line.trim().equals("*GENERATEGRAPH*")){
		GenerateGraph(pStream);
	    }
	    else if(line.trim().equals("*GENERATECONSTANTS*")){
		addArgs(args, pStream);
	    }
	    else if(line.trim().startsWith("{")){
		line = line.trim();
		doScript(line, pStream, null);
	    }
	    else if(line.trim().startsWith("2{")){
		line = line.trim().substring(1);

		doScript(line, pStream, reader.readLine().trim());
	    }
	    else{
		pStream.println(line);
	    }
	    

	}
	
	createPCTL(out + ".pctl");
    }

    private static void doScript(String line, PrintStream pStream, String line2) {

	String dec = line.substring(1, line.indexOf("}"));
	line = line.substring(line.indexOf("}")+1);
	
	StringTokenizer tok = new StringTokenizer(dec, " ");
	
	int numVars = tok.countTokens() / 3;
	String[] vars = new String[numVars];
	int[] varStarts = new int[numVars];
	int[] varEnds = new int[numVars];
	
	for(int i = 0; tok.hasMoreTokens(); i++){
	    vars[i] = tok.nextToken();
	    
	    String token = tok.nextToken();
	    Object cval = constants.get(token);
	    if(cval != null)
		varStarts[i] = Integer.parseInt((String)cval);
	    else
		varStarts[i] = Integer.parseInt(token);
	    
	    token = tok.nextToken();
	    cval = constants.get(token);
	    if(cval != null)
		varEnds[i] = Integer.parseInt((String)cval);
	    else
		varEnds[i] = Integer.parseInt(token);
	}
	
	
	//only for 2 vars for now
	for(int i = varStarts[0]; i <= varEnds[0]; i++){
	    
	    String temp = 
		replaceSubstring(line, "{"+vars[0]+"}",String.valueOf(i));
	    String line2temp = 
		replaceSubstring(line2,"{"+vars[0]+"}", String.valueOf(i));
	    
	    if(numVars > 1){
		for(int j = varStarts[1]; j <= varEnds[1]; j++){
		    
		    String temp2 = 
			replaceSubstring(temp, "{"+vars[1]+"}", String.valueOf(j));
		    String line2temp2 = 
			replaceSubstring(line2temp, "{"+vars[1]+"}", String.valueOf(j));
    
		    if(i == varEnds[0] && j == varEnds[1])
			temp2 = replacePlusWithSemi(temp2);
		    
		    pStream.println(temp2);
		    if(line2 != null)
			doScript(line2temp2, pStream, null);
		    
		}
	    }else{
		
		if(i == varEnds[0])
		    temp = replacePlusWithSemi(temp);
		
		pStream.println(temp);
		if(line2 != null){
		   
		    doScript(line2temp, pStream, null);
		}
	    }
	    
	}
		    
    }

    private static void createPCTL(String fileName) throws Exception{
	int userCount = Integer.parseInt((String)constants.get("userCount"));

	File outFile = new File(fileName);
	FileOutputStream outStream = new FileOutputStream(outFile);
	PrintStream pStream = new PrintStream(outStream);

	for(int i = 1; i <= userCount; i++){
	    pStream.println("P>0.00001 [true U (done & currentNode=" +
			    i + ") {start} ]");
	}
    }

    private static String replacePlusWithSemi(String line){
	if(line.trim().endsWith("+"))	{
	    int index = line.lastIndexOf("+");

	    return line.substring(0, index) + ";";
	}
	else{
	    return line;
	}
    }


    private static String replaceSubstring(String s, String old, String replace)
    {
	String result = "";
	
	if(s == null)
	    return null;

	int position;
	while((position = s.indexOf(old)) >= 0){
	    result += s.substring(0, position);
	    result += replace;

	    s = s.substring(position + old.length());

	}

	result += s;

	return result;
    }
}
