###################################################
#   Solution code for EE367 HW6, task 1
#
#   Instructions: 
#       You don't need to change anything here, please 
#       edit the files deconv_hqs_tv.py and deconv_hqs_dncnn.py
#
#   Gordon Wetzstein, 10/2021
###################################################

# import packages
import numpy as np
from numpy.fft import fft2, ifft2

import skimage.io as io
from skimage.metrics import peak_signal_noise_ratio 
from skimage.color import rgb2gray
from skimage.filters import gaussian

from pypher.pypher import psf2otf
from pathlib import Path
from tqdm import tqdm
from itertools import product
import matplotlib.pyplot as plt

# import our deconvolution using Adam, HQS+TV, and HQS+DnCNN
from deconv_adam_tv import *
from deconv_hqs_tv import *
from deconv_hqs_dncnn import *

# helper function for computing a 2D Gaussian convolution kernel
def fspecial_gaussian_2d(size, sigma):
    kernel = np.zeros(tuple(size))
    kernel[size[0]//2, size[1]//2] = 1
    kernel = gaussian(kernel, sigma)
    return kernel/np.sum(kernel)


# select target image and load it
name = 'birds'
img = io.imread(f'{name}.png').astype(float)/255

# create blur kernel
c = fspecial_gaussian_2d((30, 30), 2.5)

# compute otf of blur kernel
cFT = psf2otf(c, (img.shape[0],img.shape[1]))

# this is our forward image formation model as a function
Afun = lambda x: np.real(ifft2(fft2(x) * cFT))

# standard deviation of sensor noise 
sigma = 0.1

# simulated measurements for all 3 color channels
b = np.zeros(np.shape(img))
for it in range(3):
    b[:,:,it] = Afun(img[:,:,it]) + sigma * np.random.randn(img.shape[0],img.shape[1])
    
# anisotropic or isotropic
b_anisotropic = True

# number of iterations for all 3 solvers (Adam, HQS+TV, HQS+DnCNN)
num_iters = 75

# Adam solver parameters
lam = 0.05              # relative weight of TV term
learning_rate = 5e-2    # learning rate

# run PyTorch-based Adam solver for each color channel with anisotropic TV regularizer
x_adam_tv = np.zeros(np.shape(b))
for it in range(3):    
    x_adam_tv[:,:,it] = deconv_adam_tv(b[:,:,it],c,lam,num_iters,learning_rate,b_anisotropic)
# clip results to make sure it's within the range [0,1]
x_adam_tv = np.clip(x_adam_tv,0.0,1.0)
# compute PSNR using skimage library and round it to 2 digits
PSNR_ADAM_TV = round(peak_signal_noise_ratio(img, x_adam_tv),1)



# HQS+TV solver parameters
rho         = 5.0       # rho parameter of HQS
lam         = 0.025     # relative weight of TV term

# run HQS+TV solver
x_hqs_tv = np.zeros(np.shape(b))
for it in range(3):
    x_hqs_tv[:,:,it] = deconv_hqs_tv(b[:,:,it],c,lam,rho,num_iters,b_anisotropic)
# clip results to make sure it's within the range [0,1]
x_hqs_tv = np.clip(x_hqs_tv,0.0,1.0)
# compute PSNR using skimage library and round it to 2 digits
PSNR_HQS_TV = round(peak_signal_noise_ratio(img, x_hqs_tv),1)


# HQS+DnCNN solver parameters (these are different from the TV solver)
lam         = 0.01 * 0.5    # relative weight of TV term
rho         = 1 * 0.5       # rho parameter of HQS

# run HQS+DnCNN solver
x_hqs_dncnn = np.zeros(np.shape(b))
for it in range(3):
    x_hqs_dncnn[:,:,it] = deconv_hqs_dncnn(b[:,:,it],c,lam,rho,num_iters)
# clip results to make sure it's within the range [0,1]
x_hqs_dncnn = np.clip(x_hqs_dncnn,0.0,1.0)
# compute PSNR using skimage library and round it to 2 digits
PSNR_HQS_DNCNN = round(peak_signal_noise_ratio(img, x_hqs_dncnn),1)



# show results
fig = plt.figure()

ax = fig.add_subplot(2, 3, 1)
ax.imshow(img)
ax.set_title("Target Image", fontsize=10)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

ax = fig.add_subplot(2, 3, 2)
ax.imshow(b)
ax.set_title("Blurry and Noisy Image", fontsize=10)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

ax = fig.add_subplot(2, 3, 4)
ax.imshow(x_adam_tv)
ax.set_title("Adam+TV, PSNR: " + str(PSNR_ADAM_TV), fontsize=10)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

ax = fig.add_subplot(2, 3, 5)
ax.imshow(x_hqs_tv)
ax.set_title("HQS+TV, PSNR: " + str(PSNR_HQS_TV), fontsize=10)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)

ax = fig.add_subplot(2, 3, 6)
ax.imshow(x_hqs_dncnn)
ax.set_title("HQS+DnCNN, PSNR: " + str(PSNR_HQS_DNCNN), fontsize=10)
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)


plt.show()

