# visualize_denoising.py
# ----------------------
# Visualizes the reverse diffusion (denoising) process using your trained model.
# Result: mnist_diffusion/vis/denoising_process_class_8.png

import os
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm.auto import tqdm


# ============================================================
# CONFIG
# ============================================================

# Device selection
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"[denoising] Using device: {device}")

IMAGE_SIZE = 28
CHANNELS = 1

T          = 300
BETA_START = 1e-4
BETA_END   = 0.02

TARGET_CLASS      = 8
ONLY_TARGET_CLASS = True
NUM_LABELS        = 10

# --- Checkpoint config ---
# Adjust START_STEP and CHECKPOINT_PATH if your checkpoint is different
START_STEP      = 200000
CHECKPOINT_PATH = f"mnist_diffusion/checkpoint_{TARGET_CLASS}/mnist_cond_step_{START_STEP}.pt"

# Output dir
vis_dir = "mnist_diffusion/vis"
os.makedirs(vis_dir, exist_ok=True)


# ============================================================
# DIFFUSION OBJECT: schedule + q(x_t | x_0) + p_θ(x_{t-1} | x_t)
# ============================================================

class Diffusion:
    def __init__(self, T: int, beta_start: float, beta_end: float, device: torch.device):
        self.T = T
        self.device = device

        self.betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32, device=device)
        self.alphas = (1.0 - self.betas).float()
        self.alpha_bars = torch.cumprod(self.alphas, dim=0).float()

    def sample_timesteps(self, batch_size: int) -> torch.Tensor:
        return torch.randint(low=0, high=self.T, size=(batch_size,), device=self.device)

    def sample_gaussian_noise(self, x: torch.Tensor) -> torch.Tensor:
        return torch.randn_like(x)

    def get_alpha_bar(self, t: torch.Tensor) -> torch.Tensor:
        return self.alpha_bars[t].view(-1, 1, 1, 1)

    def get_diffusion_coefficients(self, t: torch.Tensor):
        beta_t      = self.betas[t].view(-1, 1, 1, 1)
        alpha_t     = self.alphas[t].view(-1, 1, 1, 1)
        alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
        return alpha_t, beta_t, alpha_bar_t

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        alpha_bar_t = self.get_alpha_bar(t)
        sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
        x_t = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
        return x_t

    def ddpm_p_mean_variance(
        self,
        model: nn.Module,
        x_t: torch.Tensor,
        t: torch.Tensor,
        y: torch.Tensor
    ):
        noise_pred = model(x_t, t, y)
        alpha_t, beta_t, alpha_bar_t = self.get_diffusion_coefficients(t)

        one_minus_alpha_bar_t = 1.0 - alpha_bar_t
        sqrt_one_minus_alpha_bar_t = torch.sqrt(one_minus_alpha_bar_t)

        coef1 = 1.0 / torch.sqrt(alpha_t)
        coef2 = beta_t / sqrt_one_minus_alpha_bar_t

        mu_theta = coef1 * (x_t - coef2 * noise_pred)
        return mu_theta, beta_t

    @torch.no_grad()
    def p_sample_step(
        self,
        model: nn.Module,
        x_t: torch.Tensor,
        t_scalar: int,
        y: torch.Tensor
    ) -> torch.Tensor:
        batch_size = x_t.shape[0]
        t = torch.full(
            (batch_size,),
            t_scalar,
            device=x_t.device,
            dtype=torch.long
        )

        mu_theta, beta_t = self.ddpm_p_mean_variance(model, x_t, t, y)

        if t_scalar > 0:
            noise = torch.randn_like(x_t)
            x_prev = mu_theta + torch.sqrt(beta_t) * noise
        else:
            x_prev = mu_theta

        return x_prev


# ============================================================
# TIME EMBEDDING (SINUSOIDAL)
# ============================================================

def sinusoidal_embedding(t, dim):
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(0, half, device=t.device).float() / half
    )

    args = t.float()[:, None] * freqs[None, :]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

    if dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)

    return emb


# ============================================================
# Neural Net MODEL: PREDICTS THE NOISE
# ============================================================

class NoisePredictor(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_ch=64, num_labels=10):
        super().__init__()

        self.base_ch    = base_ch
        self.num_labels = num_labels

        self.down_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels + num_labels, base_ch, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch, base_ch * 2, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch * 2),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch * 2, base_ch * 2, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch * 2),
                nn.SiLU()
            ),
        ])

        self.up_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(base_ch * 2, base_ch * 2, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch * 2),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch * 2, base_ch, kernel_size=5, padding=2),
                nn.GroupNorm(1, base_ch),
                nn.SiLU()
            ),
            nn.Sequential(
                nn.Conv2d(base_ch, out_channels, kernel_size=5, padding=2),
            ),
        ])

        self.time_embedding_layer = nn.Sequential(
            nn.Linear(base_ch, base_ch),
            nn.LayerNorm(base_ch),
            nn.SiLU(),
            nn.Linear(base_ch, base_ch),
            nn.LayerNorm(base_ch)
        )

        self.downscale = nn.MaxPool2d(2)
        self.upscale   = nn.Upsample(scale_factor=2, mode="nearest")

    def embed_time(self, t, batch_size, device):
        t = t.view(batch_size).to(device)
        t_sin = sinusoidal_embedding(t, dim=self.base_ch)
        return self.time_embedding_layer(t_sin)

    def make_class_maps(self, cls, H, W, device):
        cls = cls.view(-1).to(device)
        one_hot = F.one_hot(cls, num_classes=self.num_labels).float()
        class_maps = one_hot[:, :, None, None].expand(-1, -1, H, W)
        return class_maps

    def forward(self, x, t, cls):
        B, _, H, W = x.shape
        device = x.device

        t_embd = self.embed_time(t, B, device)  # [B, base_ch]

        class_maps = self.make_class_maps(cls, H, W, device)
        x = torch.cat([x, class_maps], dim=1)

        h = []

        for i, layer in enumerate(self.down_layers):
            x = layer(x)
            if i == 0:
                x = x + t_embd[:, :, None, None]
            if i < 2:
                h.append(x)
                x = self.downscale(x)

        for i, layer in enumerate(self.up_layers):
            if i > 0:
                x = self.upscale(x)
                x = x + h.pop()
            x = layer(x)

        return x


# ============================================================
# CHECKPOINT HELPERS
# ============================================================

def load_checkpoint_if_available(model, path, device):
    if path is not None and os.path.isfile(path):
        print(f"[denoising] Loading checkpoint from {path}")
        data = torch.load(path, map_location=device)
        model.load_state_dict(data["model"])
        step = data.get("step", 0)
        print(f"[denoising] Resuming from step {step}")
        return step
    else:
        print("[denoising] No checkpoint found, please check CHECKPOINT_PATH.")
        return 0


# ============================================================
# VISUALIZATION: REVERSE DENOISING
# ============================================================

@torch.no_grad()
def visualize_reverse_denoising(
    model: nn.Module,
    diffusion: Diffusion,
    labels: torch.Tensor,
    num_snapshots: int,
    save_path: str,
    device: torch.device,
):
    """
    Shows snapshots of x_t as we go from t = T-1 -> 0.
    """
    model.eval()
    labels = labels.view(1).to(device)

    # Start from pure Gaussian noise
    x_t = torch.randn(1, 1, 28, 28, device=device)

    # Choose which timesteps to snapshot
    snapshot_ts = torch.linspace(diffusion.T - 1, 0, num_snapshots).long().tolist()
    snapshot_ts = set(snapshot_ts)

    xt_list = []
    print("[denoising] Running reverse diffusion...")
    for t_scalar in tqdm(
        range(diffusion.T - 1, -1, -1),
        desc="reverse diffusion",
        leave=False
    ):
        x_t = diffusion.p_sample_step(model, x_t, t_scalar, labels)

        if t_scalar in snapshot_ts or t_scalar == 0:
            xt_list.append((t_scalar, x_t.clone()))

    # Sort snapshots from high t (very noisy) to low t (clean)
    xt_list = sorted(xt_list, key=lambda p: p[0], reverse=True)
    imgs = torch.cat([x for (_, x) in xt_list], dim=0)  # [K,1,28,28]

    # [-1,1] -> [0,1]
    imgs = (imgs.clamp(-1, 1) + 1) / 2.0

    nrow = len(xt_list)
    grid = torchvision.utils.make_grid(imgs, nrow=nrow, padding=2)
    np_img = grid.permute(1, 2, 0).cpu().numpy()

    plt.figure(figsize=(1.5 * nrow, 2))
    plt.title(f"Reverse denoising (class {TARGET_CLASS})")
    plt.imshow(np_img.squeeze(), cmap="gray")
    plt.axis("off")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()
    print(f"[denoising] Saved visualization to: {save_path}")


# ============================================================
# MAIN
# ============================================================

def main():
    # 1. Build diffusion + model and load checkpoint
    diffusion = Diffusion(
        T=T,
        beta_start=BETA_START,
        beta_end=BETA_END,
        device=device,
    )

    unet = NoisePredictor(
        in_channels=1,
        out_channels=1,
        base_ch=64,
        num_labels=NUM_LABELS
    ).to(device)

    load_checkpoint_if_available(unet, CHECKPOINT_PATH, device)

    # 2. Choose label (TARGET_CLASS)
    label = torch.tensor([TARGET_CLASS], dtype=torch.long, device=device)

    for i in range(20):

        # 3. Visualize reverse process
        NUM_SNAPSHOTS = 8
        save_path = os.path.join(vis_dir, f"denoising_process_class_{TARGET_CLASS}_{i}.png")
        visualize_reverse_denoising(unet, diffusion, label, NUM_SNAPSHOTS, save_path, device)
    


if __name__ == "__main__":
    main()