import os
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm.auto import tqdm

# ============================================================
# 0. CONFIG: HIGH-LEVEL KNOBS (constants only)
# ============================================================

device = torch.device("mps")  # or "cuda" / "cpu"
print(f"Using device: {device}")

IMAGE_SIZE = 28
CHANNELS = 1

BATCH_SIZE   = 64
TRAIN_STEPS  = 400_000          # total gradient steps
SAVE_EVERY   = 2000           # save checkpoint + samples every N steps

T            = 300             # number of diffusion timesteps (how many rounds of noising/denoising)
LR           = 2e-4            # learning rate

# Single-class training (e.g. only digit "2") vs full MNIST:
TARGET_CLASS      = 'all'
ONLY_TARGET_CLASS = False
NUM_LABELS        = 10

# Diffusion schedule hyperparameters (how fast we destroy images)
BETA_START = 1e-4
BETA_END   = 0.02

# ---- Checkpoint config ----
START_STEP      = 0
CHECKPOINT_PATH = None
# Example to resume:
START_STEP      = 8000
CHECKPOINT_PATH = f"mnist_diffusion/checkpoint_{TARGET_CLASS}/mnist_cond_step_{START_STEP}.pt"

# Paths
ckpt_dir = f"mnist_diffusion/checkpoint_{TARGET_CLASS}"
samp_dir = f"mnist_diffusion/samples_{TARGET_CLASS}"
os.makedirs(ckpt_dir, exist_ok=True)
os.makedirs(samp_dir, exist_ok=True)


# ============================================================
# 1. DATA: MNIST
# ============================================================

def create_mnist_dataloader(
    batch_size: int,
    only_target_class: bool,
    target_class: int,
    device: torch.device,
) -> DataLoader:
    """
    Create a DataLoader for MNIST.
    If only_target_class=True, keep only images with label == target_class.
    """

    transform = transforms.Compose([
        transforms.ToTensor(),
        # Normalize so that pixel values lie roughly in [-1, 1]
        transforms.Normalize((0.5,), (0.5,)),
    ])

    full_dataset = datasets.MNIST(
        root="./mnist_diffusion/data",
        train=True,
        download=True,
        transform=transform,
    )

    if only_target_class:
        # Keep only the chosen class (e.g. only 2s)
        indices = [i for i, (_, y) in enumerate(full_dataset) if y == target_class]
        train_dataset = Subset(full_dataset, indices)
    else:
        train_dataset = full_dataset

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )

    print("Number of training images:", len(train_dataset))
    return train_loader


# ============================================================
# 2. DIFFUSION OBJECT: schedule + q(x_t | x_0) + p_θ(x_{t-1} | x_t)
# ============================================================

class Diffusion:
    """
    Encapsulates the entire diffusion process:
      - the schedule (β_t, α_t, ᾱ_t)
      - the forward noising process q(x_t | x_0)
      - the reverse denoising process p_θ(x_{t-1} | x_t)
    """

    def __init__(self, T: int, beta_start: float, beta_end: float, device: torch.device):
        self.T = T
        self.device = device

        # β_t linearly increases from beta_start to beta_end
        self.betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32, device=device)
        # α_t = 1 - β_t
        self.alphas = (1.0 - self.betas).float()
        # ᾱ_t = ∏_{s=0}^t α_s (cumulative product)
        self.alpha_bars = torch.cumprod(self.alphas, dim=0).float()

    # ---------- Simple helpers ----------

    def sample_timesteps(self, batch_size: int) -> torch.Tensor:
        """
        Sample a random timestep t for each image in the batch, uniformly from {0, ..., T-1}.
        """
        return torch.randint(low=0, high=self.T, size=(batch_size,), device=self.device)

    def sample_gaussian_noise(self, x: torch.Tensor) -> torch.Tensor:
        """
        Sample Gaussian noise ε ~ N(0, I) with the same shape as x.
        """
        return torch.randn_like(x)

    # ---------- Access α_t, β_t, ᾱ_t for a batch of timesteps ----------

    def get_alpha_bar(self, t: torch.Tensor) -> torch.Tensor:
        """
        Given timesteps t (shape [B]), return ᾱ_t reshaped to [B,1,1,1].
        """
        return self.alpha_bars[t].view(-1, 1, 1, 1)

    def get_diffusion_coefficients(self, t: torch.Tensor):
        """
        For a batch of timesteps t (shape [B]), return α_t, β_t, ᾱ_t all reshaped to [B,1,1,1].
        """
        beta_t      = self.betas[t].view(-1, 1, 1, 1)
        alpha_t     = self.alphas[t].view(-1, 1, 1, 1)
        alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
        return alpha_t, beta_t, alpha_bar_t

    # ---------- Forward process q(x_t | x_0) ----------

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        """
        Forward (noising) process:
            x_t = sqrt(ᾱ_t) * x_0 + sqrt(1 - ᾱ_t) * ε
            where ε ~ N(0, I).

        x0:    original clean images [B, C, H, W]
        t:     timesteps for each image [B]
        noise: Gaussian noise [B, C, H, W]
        """
        alpha_bar_t = self.get_alpha_bar(t)                         # [B,1,1,1]
        sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)                  # [B,1,1,1]
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)  # [B,1,1,1]

        # Blend clean image and pure noise with appropriate weights
        x_t = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
        return x_t

    # ---------- Reverse process p_θ(x_{t-1} | x_t) ----------

    def ddpm_p_mean_variance(
        self,
        model: nn.Module,
        x_t: torch.Tensor,
        t: torch.Tensor,
        y: torch.Tensor
    ):
        """
        Compute the mean and variance of p_θ(x_{t-1} | x_t, t, y) using the DDPM formula.

        1. The model predicts the noise: ε_θ(x_t, t, y).
        2. We plug that into the closed-form expression for the mean μ_θ.
        3. The variance is just β_t (simplest choice).

        Returns:
            mu_theta:  mean of p_θ(x_{t-1} | x_t) [B, C, H, W]
            beta_t:    variance term β_t         [B, 1, 1, 1]
        """
        # 1. Predict noise with the neural net
        noise_pred = model(x_t, t, y)

        # 2. Get α_t, β_t, ᾱ_t for these timesteps
        alpha_t, beta_t, alpha_bar_t = self.get_diffusion_coefficients(t)

        # 3. Compute the DDPM mean:
        #    μ_θ(x_t, t) = 1/sqrt(α_t) * (x_t - (β_t / sqrt(1 - ᾱ_t)) * ε_θ(x_t, t))
        one_minus_alpha_bar_t = 1.0 - alpha_bar_t
        sqrt_one_minus_alpha_bar_t = torch.sqrt(one_minus_alpha_bar_t)

        coef1 = 1.0 / torch.sqrt(alpha_t)
        coef2 = beta_t / sqrt_one_minus_alpha_bar_t

        mu_theta = coef1 * (x_t - coef2 * noise_pred)

        return mu_theta, beta_t

    @torch.no_grad()
    def p_sample_step(
        self,
        model: nn.Module,
        x_t: torch.Tensor,
        t_scalar: int,
        y: torch.Tensor
    ) -> torch.Tensor:
        """
        One step of reverse diffusion:
            x_t -> x_{t-1}

        If t > 0, we sample:
            x_{t-1} ~ N(μ_θ(x_t, t), β_t I)
        If t == 0, we just return the mean μ_θ (no extra noise).
        """
        batch_size = x_t.shape[0]
        t = torch.full(
            (batch_size,),
            t_scalar,
            device=x_t.device,
            dtype=torch.long
        )

        mu_theta, beta_t = self.ddpm_p_mean_variance(model, x_t, t, y)

        if t_scalar > 0:
            # Add Gaussian noise with variance β_t
            noise = torch.randn_like(x_t)
            x_prev = mu_theta + torch.sqrt(beta_t) * noise
        else:
            # At t=0, we typically don't add more noise
            x_prev = mu_theta

        return x_prev


# ============================================================
# 3. TIME EMBEDDING (SINUSOIDAL)
# ============================================================

def sinusoidal_embedding(t, dim):
    """
    Create a sinusoidal embedding for timesteps t.
    This is similar to positional encodings in Transformers.

    t:   [B] (timesteps as ints or floats)
    dim: embedding dimension
    returns: [B, dim] sinusoidal timestep embedding
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(0, half, device=t.device).float() / half
    )  # [half]

    # Outer product: each timestep times each frequency
    args = t.float()[:, None] * freqs[None, :]      # [B, half]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)  # [B, 2*half]

    # If dim is odd, pad one dimension with zeros
    if dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)

    return emb


# ============================================================
# 4. Neural Net MODEL: PREDICTS THE NOISE
# ============================================================

class NoisePredictor(nn.Module):
    """
    UNet-like network that:
      - takes a noisy image x_t
      - is conditioned on timestep t (via a vector embedding)
      - is conditioned on class label y (via extra one-hot channels)

    Inputs:
        x_t: [B, 1, 28, 28]
        t:   [B]          timestep indices
        cls: [B]          class labels (0-9)

    Output:
        predicted noise with same shape as x_t
    """
    def __init__(self, in_channels=1, out_channels=1, base_ch=64, num_labels=10):
        super().__init__()

        self.base_ch   = base_ch
        self.num_labels = num_labels

        # ---- Down layers ----
        # First conv now sees image + one-hot label maps -> (1 + num_labels) channels
        self.down_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels + num_labels, base_ch, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch, base_ch * 2, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch * 2),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch * 2, base_ch * 2, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch * 2),
                nn.SiLU()
            ),
        ])

        # ---- Up layers ----
        self.up_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(base_ch * 2, base_ch * 2, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch * 2),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch * 2, base_ch, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch, out_channels, kernel_size=5, padding=2)
                # Output layer - no normalization, no activation
            ),
        ])

        # ---- Time embedding ----
        self.time_embedding_layer = nn.Sequential(
            nn.Linear(base_ch, base_ch),
            nn.LayerNorm(base_ch),
            nn.SiLU(),
            nn.Linear(base_ch, base_ch),
            nn.LayerNorm(base_ch)
        )

        # Scaling operators
        self.downscale = nn.MaxPool2d(2)
        self.upscale   = nn.Upsample(scale_factor=2, mode='nearest')

    # ---------- helpers ----------

    def embed_time(self, t, batch_size, device):
        """
        Turn timesteps t into a [B, base_ch] embedding.
        """
        t = t.view(batch_size).to(device)
        t_sin = sinusoidal_embedding(t, dim=self.base_ch)
        return self.time_embedding_layer(t_sin)  # [B, base_ch]

    def make_class_maps(self, cls, H, W, device):
        """
        Turn labels [B] into one-hot maps [B, num_labels, H, W].
        Example: label 3 -> channel 3 is all 1s, others 0.
        """
        cls = cls.view(-1).to(device)                            # [B]
        one_hot = F.one_hot(cls, num_classes=self.num_labels)    # [B, num_labels]
        one_hot = one_hot.float()
        class_maps = one_hot[:, :, None, None].expand(-1, -1, H, W)  # [B, num_labels, H, W]
        return class_maps

    # ---------- forward ----------

    def forward(self, x, t, cls):
        """
        x:   [B,1,28,28]  noised image x_t
        t:   [B]          timesteps
        cls: [B]          digit labels
        """
        B, _, H, W = x.shape
        device = x.device

        # Time embedding (vector) – we'll add it as a bias at the first down block
        t_embd = self.embed_time(t, B, device)  # [B, base_ch]

        # Class conditioning as extra channels
        class_maps = self.make_class_maps(cls, H, W, device)  # [B, num_labels, H, W]
        x = torch.cat([x, class_maps], dim=1)                 # [B, 1+num_labels, H, W]

        h = []

        # --- Downsampling path ---
        for i, layer in enumerate(self.down_layers):
            x = layer(x)

            if i == 0:
                # inject time only where channels = base_ch (to avoid size mismatch)
                x = x + t_embd[:, :, None, None]

            if i < 2:
                h.append(x)
                x = self.downscale(x)

        # --- Upsampling path ---
        for i, layer in enumerate(self.up_layers):
            if i > 0:
                x = self.upscale(x)
                x = x + h.pop()
            x = layer(x)

        return x

# ============================================================
# 5. SAMPLING: REVERSE DIFFUSION LOOP
# ============================================================

@torch.no_grad()
def sample_grid(model, diffusion: Diffusion, step: int, save_path: str, device: torch.device):
    """
    Sample images from the model using the reverse diffusion process
    and save them in a grid.

    If ONLY_TARGET_CLASS:
        sample a grid of the TARGET_CLASS only.
    Else:
        sample one digit per class 0..9.
    """
    model.eval()

    if ONLY_TARGET_CLASS:
        num_samples = 16  # 4x4 grid, all TARGET_CLASS
        labels = torch.full(
            (num_samples,),
            TARGET_CLASS,
            device=device,
            dtype=torch.long,
        )
        nrow = 4
    else:
        labels = torch.arange(0, NUM_LABELS, device=device, dtype=torch.long)
        num_samples = len(labels)           # 10 samples, one per class
        nrow = 5                            # 5x2 grid

    # Start from pure Gaussian noise at time T-1
    x_t = torch.randn(num_samples, 1, 28, 28, device=device)

    # Go backwards in time: T-1 → ... → 0
    for t_scalar in tqdm(
        range(diffusion.T - 1, -1, -1),
        leave=False,
        desc=f"sampling@{step}"
    ):
        x_t = diffusion.p_sample_step(model, x_t, t_scalar, labels)

    # Map from [-1, 1] back to [0, 1] for visualization
    imgs = (x_t.clamp(-1, 1) + 1) / 2.0

    grid = torchvision.utils.make_grid(imgs, nrow=nrow)
    np_img = grid.permute(1, 2, 0).cpu().numpy()

    plt.figure(figsize=(6, 3))
    title = f"samples at step {step}"
    if ONLY_TARGET_CLASS:
        title += f" (class {TARGET_CLASS})"
    plt.title(title)
    plt.imshow(np_img, cmap="gray")
    plt.axis("off")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()


# ============================================================
# 6. CHECKPOINT HELPERS
# ============================================================

def load_checkpoint_if_available(model, path, device):
    """
    If CHECKPOINT_PATH is provided and exists, load model weights and step.
    Returns the starting step (0 if no checkpoint).
    """
    if path is not None and os.path.isfile(path):
        print(f"Loading checkpoint from {path}")
        data = torch.load(path, map_location=device)
        model.load_state_dict(data["model"])
        step = data.get("step", 0)
        print(f"Resuming from step {step}")
        return step
    else:
        print("No checkpoint loaded; starting from scratch.")
        return 0


def save_checkpoint(model, step, ckpt_dir):
    """
    Save model weights and current training step.
    """
    ckpt_path = os.path.join(ckpt_dir, f"mnist_cond_step_{step}.pt")
    torch.save(
        {
            "step": step,
            "model": model.state_dict(),
        },
        ckpt_path,
    )
    print(f"Saved checkpoint to {ckpt_path}")


# ============================================================
# 7. MAIN: "ADD NOISE, THEN PREDICT THE NOISE"
# ============================================================

def main():
    # ----- Setup data, diffusion, model, optimizer -----
    train_loader = create_mnist_dataloader(
        batch_size=BATCH_SIZE,
        only_target_class=ONLY_TARGET_CLASS,
        target_class=TARGET_CLASS,
        device=device,
    )

    diffusion = Diffusion(
        T=T,
        beta_start=BETA_START,
        beta_end=BETA_END,
        device=device,
    )

    unet = NoisePredictor(in_channels=1, out_channels=1, base_ch=64).to(device)

    # quick shape sanity check
    with torch.no_grad():
        x_test  = torch.randn(5, 1, 28, 28).to(device)
        t_test  = torch.randint(0, T, (5,), device=device)
        cls_test = torch.randint(0, 10, (5,), device=device)
        out = unet(x_test, t_test, cls_test)
        print("NoisePredictor output shape:", out.shape)

    loss_fn   = nn.MSELoss()
    optimizer = torch.optim.AdamW(unet.parameters(), lr=LR)

    # ----- Load checkpoint if any -----
    step_start = load_checkpoint_if_available(unet, CHECKPOINT_PATH, device)
    step = max(START_STEP, step_start)

    loader_iter = iter(train_loader)

    print("Starting training: learn to predict the noise we add to images...")

    # ----- Training loop -----
    while step < TRAIN_STEPS:
        try:
            x0, y = next(loader_iter)
        except StopIteration:
            # Restart the DataLoader if we've reached the end of the dataset
            loader_iter = iter(train_loader)
            x0, y = next(loader_iter)

        unet.train()

        x0 = x0.to(device).float()   # [B,1,28,28]
        y  = y.to(device).long()     # [B]

        # === 1. Choose a random timestep for each image ===
        b = x0.shape[0]
        t = diffusion.sample_timesteps(batch_size=b)  # [B]

        # === 2. Add Gaussian noise to create x_t ===
        noise   = diffusion.sample_gaussian_noise(x0)  # ε ~ N(0, I)
        x_noisy = diffusion.q_sample(x0, t, noise)     # q(x_t | x_0)

        # === 3. Let the model predict the noise we just added ===
        pred_noise = unet(x_noisy, t, y)

        # === 4. Train the model so that predicted noise ≈ true noise ===
        loss_val = loss_fn(pred_noise, noise)

        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        step += 1

        if step % 100 == 0:
            print(f"step {step}/{TRAIN_STEPS}, loss {loss_val.item():.4f}")

        if step % SAVE_EVERY == 0:
            # Save checkpoint
            save_checkpoint(unet, step, ckpt_dir)

            # Save samples from the current model
            samp_path = os.path.join(samp_dir, f"samples_step_{step}.png")
            sample_grid(unet, diffusion, step, samp_path, device)

    print("Training finished.")


if __name__ == "__main__":
    main()