#!/usr/bin/env julia
"""
Prepare offline embedding datasets for the ENGR 108 lectures on Embeddings by Stephen Boyd and Aqib Syed
Creates word clusters plot and nearest neighbors table for lecture slides.
"""
# Activate project environment
import Pkg
Pkg.activate(joinpath(@__DIR__, ".."))

using CSV
using DataFrames
using VMLS
using LinearAlgebra
using Statistics
using Printf
import PlotlyJS

# Load WordEmbeddings module
Base.include(Main, "WordEmbeddings.jl")
using .WordEmbeddings

# Load the word embeddings data
words, emb300, emb2 = word_embeddings_data()

# Build index map
idxmap = Dict(w => i for (i, w) in enumerate(words))

# Define groups for the plot. Note: we can change this to try different group plots.
groups = Dict(
    "professions" => ["lawyer", "doctor", "teacher"],
    "colors" => ["blue", "red", "green"], 
    "china" => ["china", "beijing", "shanghai"]
)

# Define colors to match LaTeX styling
colors = ["red", "blue", "green!60!black"]

# Text positions to rotate through to avoid overlap
textpositions = ["top center", "bottom center", "top left", "top right", 
                 "bottom left", "bottom right", "middle left", "middle right"]

traces = PlotlyJS.GenericTrace[]

for (i, (group_name, group_words)) in enumerate(groups)
    indices = [get(idxmap, word, 0) for word in group_words]
    indices = [i for i in indices if i > 0]
    found_words = [group_words[j] for (j, idx) in enumerate(indices) if idx > 0]
    
    # Use colors to match LaTeX minipage exactly
    if group_name == "professions"
        color = "red"  # red for professions 
    elseif group_name == "colors"
        color = "blue"  # Blue for colors 
    elseif group_name == "china"
        color = "green"  # Green for China 
    else
        color = colors[mod1(i, length(colors))]
    end
    
    # Calculate text positions with manual offsets for close words
    text_positions = String[]
    
    for (j, word) in enumerate(found_words)
        base_pos = textpositions[mod1(j, length(textpositions))]
        
        
        push!(text_positions, base_pos)
    end
    
    # Extract x and y coordinates from vector of vectors
    x_coords = [emb2[i][1] for i in indices]
    y_coords = [emb2[i][2] for i in indices]
    
    trace = PlotlyJS.scattergl(
        x = x_coords, 
        y = y_coords,
        mode = "markers+text",
        text = found_words,
        textposition = text_positions,
        marker = PlotlyJS.attr(size=14, color=color, opacity=0.9, line=PlotlyJS.attr(width=1, color="white")),
        name = group_name
    )
    push!(traces, trace)
end

layout = PlotlyJS.Layout(
    xaxis = PlotlyJS.attr(scaleanchor="y", scaleratio=1, showgrid=true, gridwidth=0.5, gridcolor="lightgray", showticklabels=false),
    yaxis = PlotlyJS.attr(constrain="domain", showgrid=true, gridwidth=0.5, gridcolor="lightgray", showticklabels=false),
    plot_bgcolor="white",
    paper_bgcolor="white",
    font = PlotlyJS.attr(size=12, family="Arial"),
    showlegend=false,
    width=600,
    height=600
)

plt = PlotlyJS.plot(traces, layout)

# save plot as PNG
png_path = "data/word_clusters.png"
PlotlyJS.savefig(plt, png_path, format="png", width=600, height=600)
println("Plot saved: $png_path")

# Words for the table
table_words = ["lawyer", "university", "car", "guitar", "cat", "happy"]

# get k nearest neighbors using VMLS
function get_k_nearest_neighbors(word, k)
    idx = get(idxmap, word, 0)
    if idx == 0
        return String[]
    end
    
    # Convert emb300 to matrix for VMLS (it expects columns as vectors)
    emb300_matrix = hcat(emb300...)'
    
    # Get the word's embedding
    word_embedding = emb300_matrix[idx, :]
    
    # Find nearest neighbors using VMLS
    nbrs = VMLS.k_nearest_neighbors(emb300_matrix', word_embedding, k)
    neighbor_indices = isa(nbrs, Tuple) ? nbrs[1] : nbrs
    
    # Convert indices to words
    return [words[i] for i in neighbor_indices]
end

# Generate LaTeX table
println("\nLaTeX Table:")
println("\\begin{table}")
println("\\begin{center}")
println("\\begin{tabular}{l|lllll}")
println("Word & NN  & 2-NN & 3-NN & 4-NN & 5-NN \\\\ \\midrule")

for word in table_words
    neighbors = get_k_nearest_neighbors(word, 5)
    if length(neighbors) >= 5
        println("$word & $(neighbors[1]) & $(neighbors[2]) & $(neighbors[3]) & $(neighbors[4]) & $(neighbors[5]) \\\\")
    end
end

println("\\end{tabular}")
println("\\end{center}")
println("\\end{table}")

plt #display plot