# visualize_noising.py
# --------------------
# Visualizes the forward diffusion process q(x_t | x_0) on MNIST.
# Result: mnist_diffusion/vis/noising_process_class_8.png

import os
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torchvision


# ============================================================
# CONFIG
# ============================================================

# Try to be robust: use CUDA if available, then MPS, else CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"[noising] Using device: {device}")

IMAGE_SIZE = 28
CHANNELS = 1

# Diffusion schedule hyperparameters
T          = 300          # number of diffusion steps
BETA_START = 1e-4
BETA_END   = 0.02

# Class selection (same as your training)
TARGET_CLASS      = 8
ONLY_TARGET_CLASS = True

# Output directory
vis_dir = "mnist_diffusion/vis"
os.makedirs(vis_dir, exist_ok=True)


# ============================================================
# DATA: MNIST (only used to grab one x0 example)
# ============================================================

def create_mnist_dataloader(
    batch_size: int,
    only_target_class: bool,
    target_class: int,
) -> DataLoader:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # => roughly [-1, 1]
    ])

    full_dataset = datasets.MNIST(
        root="./mnist_diffusion/data",
        train=True,
        download=True,
        transform=transform,
    )

    if only_target_class:
        indices = [i for i, (_, y) in enumerate(full_dataset) if y == target_class]
        train_dataset = Subset(full_dataset, indices)
    else:
        train_dataset = full_dataset

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )

    print("[noising] Number of training images:", len(train_dataset))
    return train_loader


# ============================================================
# DIFFUSION OBJECT: schedule + q(x_t | x_0)
# ============================================================

class Diffusion:
    def __init__(self, T: int, beta_start: float, beta_end: float, device: torch.device):
        self.T = T
        self.device = device

        self.betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32, device=device)
        self.alphas = (1.0 - self.betas).float()
        self.alpha_bars = torch.cumprod(self.alphas, dim=0).float()

    def sample_gaussian_noise(self, x: torch.Tensor) -> torch.Tensor:
        return torch.randn_like(x)

    def get_alpha_bar(self, t: torch.Tensor) -> torch.Tensor:
        return self.alpha_bars[t].view(-1, 1, 1, 1)

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        alpha_bar_t = self.get_alpha_bar(t)
        sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
        x_t = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
        return x_t


# ============================================================
# VISUALIZATION: FORWARD NOISING
# ============================================================

@torch.no_grad()
def visualize_forward_noising(
    diffusion: Diffusion,
    x0: torch.Tensor,
    num_steps_to_show: int,
    save_path: str,
):
    """
    Show x_t at a series of timesteps on a single row.

    x0: [1, 1, 28, 28] single image, normalized to [-1, 1].
    """
    device = x0.device

    # Select timesteps to visualize (e.g. 0, 50, 100, ..., T-1)
    timesteps = torch.linspace(0, diffusion.T - 1, num_steps_to_show).long().tolist()
    print(f"[noising] Visualizing timesteps: {timesteps}")

    # Use a single ε so that the trajectory is consistent as t increases
    noise = diffusion.sample_gaussian_noise(x0)

    xt_list = []
    for t_scalar in timesteps:
        t = torch.full((x0.shape[0],), t_scalar, device=device, dtype=torch.long)
        x_t = diffusion.q_sample(x0, t, noise)
        xt_list.append(x_t)

    # [K, 1, 28, 28]
    xt = torch.cat(xt_list, dim=0)

    # Map from [-1, 1] -> [0, 1]
    imgs = (xt.clamp(-1, 1) + 1) / 2.0

    grid = torchvision.utils.make_grid(imgs, nrow=num_steps_to_show, padding=2)
    np_img = grid.permute(1, 2, 0).cpu().numpy()

    plt.figure(figsize=(1.5 * num_steps_to_show, 2))
    plt.title(f"Forward noising (class {TARGET_CLASS})")
    plt.imshow(np_img.squeeze(), cmap="gray")
    plt.axis("off")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()
    print(f"[noising] Saved visualization to: {save_path}")


# ============================================================
# MAIN
# ============================================================

def main():
    # 1. Create a small dataloader and pick one example image x0
    loader = create_mnist_dataloader(
        batch_size=16,
        only_target_class=ONLY_TARGET_CLASS,
        target_class=TARGET_CLASS,
    )
    x0_batch, y_batch = next(iter(loader))
    # Just take the first image in the batch
    x0 = x0_batch[0:1].to(device)  # [1,1,28,28]
    print(f"[noising] Using example with label: {y_batch[0].item()}")

    # 2. Set up diffusion with the same schedule as training
    diffusion = Diffusion(
        T=T,
        beta_start=BETA_START,
        beta_end=BETA_END,
        device=device,
    )

    # 3. Visualize
    NUM_STEPS_TO_SHOW = 8
    save_path = os.path.join(vis_dir, f"noising_process_class_{TARGET_CLASS}.png")
    visualize_forward_noising(diffusion, x0, NUM_STEPS_TO_SHOW, save_path)


if __name__ == "__main__":
    main()