from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


def main():
    # get the data
    train, test = download_data()
    print(f"Training examples: {len(train.dataset)}")
    print(f"Test examples: {len(test.dataset)}")

    # a very simple and fast nn
    model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, 512), # images are 28 x 28 pixels
        nn.Sigmoid(),
        nn.Linear(512, 512),
        nn.Sigmoid(),
        nn.Linear(512, 10),
        nn.LogSoftmax(dim=1)
    )
    model.to(device)
    
    
    # see how bad it is before training
    accuracy = run_test(test, model)
    print(f"Untrained, Test Accuracy: {accuracy}")

    # train the model
    print("training...")
    for epoch in range(100):
        # learn from examples
        likelihood = run_train(train, model)
        # measure how well we are doing
        accuracy = run_test(test, model)
        print(f"Epoch {epoch}, Test Accuracy: {accuracy}, LogLikelihood: {likelihood}")

def run_train(train, model):
    model.train()
    # this function can score how good a precition is
    loss_function = nn.NLLLoss(reduction="sum")
    # this function can adjust the model to make better predictions
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    total_log_likelihood = 0.0   # sum of NLL over all examples
    total_samples = 0

    # get labelled examples to train on
    optimizer.zero_grad()
    for image_batch, truth_batch in train:

        # predict the label
        pred = model(image_batch)

        # score how well we did
        loss = loss_function(pred, truth_batch)

        # learn from our mistakes
        loss.backward()
        optimizer.step()

        # keep track of the likelihood (loss is negative log likelihood)
        total_log_likelihood -= loss.item()
        total_samples += truth_batch.size(0)

    return total_log_likelihood / total_samples

def run_test(test, model):
    correct = 0
    # loop over the test examples
    for image, truth in test:
        # get the probability it is each digit
        predicted_probabilities = model(image)
        # chose the digit with the highest probability
        predicted_label = predicted_probabilities.argmax(1)
        # you are correct if the prediction matches the truth
        is_correct = predicted_label == truth
        # count how many we got correct
        correct += is_correct.type(torch.float).sum().item()

    size = len(test.dataset)
    return 100 * correct / size

def download_data():
    training_data = datasets.MNIST(root=".", train=True, download=True, transform=ToTensor())
    show_images(training_data)

    test_data = datasets.MNIST(root=".", train=False, download=True, transform=ToTensor())

    train = DataLoader(training_data, batch_size=64, shuffle=True)
    test = DataLoader(test_data, batch_size=64, shuffle=True)

    return train, test

def show_images(training_data):
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 5, 5

    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
        plt.title(label)
    plt.show()

if __name__ == "__main__":
    main()