#!/usr/bin/env julia
"""
Prepare offline embedding datasets for the ENGR 108 lectures on Embeddings by Stephen Boyd and Aqib Syed
Creates vocab.txt, embeddings_hd_unit.csv, and embeddings_2d.csv.
"""
# Activate project environment
import Pkg
Pkg.activate(joinpath(@__DIR__, ".."))
Pkg.instantiate()

using DelimitedFiles
using LinearAlgebra
using Statistics

# Load GloVe embeddings—we load the 6B parameter 300D version
glove_file = joinpath(@__DIR__, "..", "data", "glove.6B.300d.txt")

if !isfile(glove_file)
    glove_url = "https://nlp.stanford.edu/data/glove.6B.zip"
    zip_file = joinpath(@__DIR__, "..", "data", "glove.6B.zip")
    
    # Create data directory if it doesn't exist
    data_dir = dirname(glove_file)
    isdir(data_dir) || mkpath(data_dir)
    
    # Download the zip file
    println("Downloading GloVe 6B embeddings (822 MB)...")
    run(`curl -L -o $zip_file $glove_url`)
    
    # Extract the zip file
    println("Extracting GloVe data...")
    run(`unzip -o $zip_file -d $data_dir`)
    
    # Clean up zip file
    rm(zip_file)
end
lines = readlines(glove_file)
vocab_full = String[]
vectors = Float32[]

for line in lines
    parts = split(line)
    if length(parts) > 1
        word = parts[1]
        vector = parse.(Float32, parts[2:end])
        push!(vocab_full, word)
        append!(vectors, vector)
    end
end

# Reshape into matrix
n_words = length(vocab_full)
n_dims = length(vectors) ÷ n_words
W_full = reshape(vectors, n_dims, n_words)'


# Filter to keep top 10,000 English words by frequency
function is_english_word(w::String)
    all(isletter, w) && length(w) >= 2 && any(c -> c in "aeiouy", lowercase(w))
end

# Keep first 10,000 valid words
keep_idx = Int[]
for (i, word) in enumerate(vocab_full)
    is_english_word(word) && push!(keep_idx, i)
    length(keep_idx) >= 10000 && break
end

vocab = vocab_full[keep_idx]
W = W_full[keep_idx, :]

# Normalize to unit length
norms = sqrt.(sum(abs2, W; dims=2))
Xnorm = W ./ max.(norms, eps(eltype(W)))

# 2D projections using (centered) SVD
mu = mean(Xnorm; dims=1)
Xc = Xnorm .- mu
U, S, V = svd(Xc; full=false)
Z2 = Xc * V[:, 1:2]

# Save data!
data_dir = joinpath(@__DIR__, "..", "data")
ispath(data_dir) || mkpath(data_dir)

writedlm(joinpath(data_dir, "vocab.txt"), vocab)
writedlm(joinpath(data_dir, "embeddings_hd_unit.csv"), Xnorm, ',')
writedlm(joinpath(data_dir, "embeddings_2d.csv"), Z2, ',')

println("Data saved successfully")