/*
 * File: kNearestNeighbors.java
 * ---------------------
 * This class implements the k-Nearest-Neighbors
 * classification algorithm on MNIST, a frequently
 * used dataset of handwritten integers from 0-9.
 * 
 */

import java.util.*;
import acm.graphics.*;
import acm.program.*;
import java.io.File;
import java.io.IOException;

public class kNearestNeighbors extends GraphicsProgram {

	// The number of "nearest neighbors" to check
	private static final int K = 5;

	// Padding to use around the GImages and labels
	private static final int PADDING = 20;

	public void run() {		
		setCanvasSize(1000, 600);

		// Step 1: read in datasets
		ArrayList<MLImage> trainImages = loadDataset("res/train", "res/train_labels.txt", 300);
		ArrayList<MLImage> testImages = loadDataset("res/test", "res/test_labels.txt", 50);
		System.out.println("Loaded training and testing data.");

		// Step 2: classify each image, tracking how many we get incorrect
		ArrayList<MLImage> incorrectImages = new ArrayList<>();
		ArrayList<MLImage> correctImages = new ArrayList<>();
		for (MLImage testImage : testImages) {
			boolean isCorrect = classifyImage(testImage, trainImages);
			if (isCorrect) {
				correctImages.add(testImage);
			}
			else{
				incorrectImages.add(testImage);
			}
		}

		// Step 3: display the results
		displayImages(incorrectImages, "Incorrect", 0);
		displayImages(correctImages, "Correct", getWidth() / 2);
		double accuracy = 100.0 * (testImages.size() - incorrectImages.size()) / testImages.size();
		System.out.println(accuracy + "% of the test data were classified correctly");
	}

	/*
	 * This method classifies the specified image given an ArrayList of training
	 * images.  It looks through and makes a list of the K closest training images
	 * to this image.  Then, it finds the most frequent label among those K training
	 * images, and uses that label as its prediction.  Finally, it compares this
	 * prediction against the actual label, and returns true if it was correct,
	 * and false if it was incorrect.
	 */
	private boolean classifyImage(MLImage image, ArrayList<MLImage> trainImages) {
		MLImage[] topKNeighbors = new MLImage[K];

		// Find the k "nearest" train images
		for (MLImage trainImage : trainImages) {
			double dist = computeDistance(trainImage, image);

			for(int i = 0; i < topKNeighbors.length; i++) {
				// If we don't have all K neighbors yet, just add this one
				if (topKNeighbors[i] == null) {
					topKNeighbors[i] = trainImage;
					break;
					// Otherwise, only add it if it's better than one of our existing neighbors
				} else if (dist < computeDistance(topKNeighbors[i], image)) {
					for(int j = topKNeighbors.length-1; j > i; j--){
						topKNeighbors[j] = topKNeighbors[j-1];
					}
					topKNeighbors[i] = trainImage;
					break;
				}
			}
		}

		image.setKNearestNeighbors(topKNeighbors);
		return mostFrequentLabel(topKNeighbors) == image.getLabel();
	}

	/*
	 * This method returns the Euclidean distance between the pixels of the
	 * two given images.
	 */
	private double computeDistance(MLImage image1, MLImage image2) {
		int[][] pixels1 = image1.getPixelArray();
		int[][] pixels2 = image2.getPixelArray();
		double dist = 0;
		for(int r = 0; r < pixels1.length; r++){
			for(int c = 0; c < pixels1[0].length; c++){
				dist += (Math.pow(pixels1[r][c] - pixels2[r][c], 2)); 
			}
		}
		return Math.sqrt(dist);
	}

	/*
	 * This method returns the most frequent label out of the neighbors
	 * specified by topKNeighbors.
	 */
	private int mostFrequentLabel(MLImage[] topKNeighbors) {
		// Build up a histogram of the labels of our K neighbors
		int[] hist = new int[10];
		for (int i = 0; i < topKNeighbors.length; i++) {
			int label = topKNeighbors[i].getLabel();
			hist[label]++;
		}

		// Find the label that occurred the most
		int mostFreqLabel = -1;
		int highestFreq = -1;
		for(int i = 0; i < hist.length; i++) {
			if (hist[i] > highestFreq) {
				highestFreq = hist[i];
				mostFreqLabel = i;
			}
		}
		return mostFreqLabel;
	}

	/*
	 * This method displays the given images, along with each image's K 
	 * nearest neighbors.  It draws a dividing line between them, and on 
	 * the left displays each incorrect image, and on the right its K nearest neighbors.
	 * Displays starting at startX.
	 */
	private void displayImages(ArrayList<MLImage> images, String label, int startX) {

		// Dividing line between the image and its neighbors
		add(new GLine(startX + getWidth() / 4, 0, startX + getWidth() / 4, getHeight()));

		GLabel imagesLabel = new GLabel(label + " Images (" + images.size() + ")");
		add(imagesLabel, startX + getWidth() / 4 - imagesLabel.getWidth() - PADDING, imagesLabel.getAscent());

		// e.g. "5 Nearest Neighbors"
		GLabel neighborsLabel = new GLabel(K + " Nearest Neighbors");
		add(neighborsLabel, startX + getWidth() / 4 + PADDING, neighborsLabel.getAscent());

		for (int i = 0; i < images.size(); i++) {
			MLImage currentImage = images.get(i);

			// Step 1: Display the incorrect image
			double x = startX + getWidth() / 4 - PADDING - currentImage.getImage().getWidth();
			double y = PADDING + i * (currentImage.getImage().getWidth() + PADDING);
			add(currentImage.getImage(), x, y);

			// Step 2: Display the K nearest neighbors of this image
			for (int j = 0; j < currentImage.getKNearestNeighbors().length; j++) {
				MLImage neighbor = currentImage.getKNearestNeighbors()[j];
				x = startX + getWidth() / 4 + PADDING + j * (neighbor.getImage().getWidth() + PADDING);

				// We need to make a *copy* of this GImage because it may appear onscreen multiple times!
				add(new GImage(neighbor.getPixelArray()), x, y);
			}
		}
	}

	/*
	 * This method returns an ArrayList of the given number of images read
	 * from the given directory, with the given labels text file.  
	 * This method assumes images are named using sequential numbering, 
	 * e.g. "1.png", "2.png", ... etc.  
	 */
	public ArrayList<MLImage> loadDataset(String imagesDirectory, String labelsFilename, int numImages) {
		ArrayList<MLImage> images = new ArrayList<MLImage>();
		try {
			Scanner scanner = new Scanner(new File(labelsFilename));
			while (scanner.hasNextInt() && images.size() < numImages) {
				int label = scanner.nextInt();
				int imageNumber = images.size();
				GImage datasetImage = new GImage(imagesDirectory + "/" + imageNumber + ".png");
				MLImage image = new MLImage(datasetImage, label);
				images.add(image);
			}
			scanner.close();
		} catch(Exception ex){
			System.out.println("Error!");
		}
		return images;
	}
}
