###################################################
#   Solution code for EE367 HW6, task 2
#
#   Gordon Wetzstein, 10/2021
###################################################

# import packages
import numpy as np
from tqdm import tqdm
from itertools import product
from numpy.fft import fft2, ifft2
from pypher.pypher import psf2otf

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# DnCNN model from Kai Zhag
from network_dncnn import DnCNN as net

def deconv_hqs_dncnn(b, c, lam, rho, num_iters, denoiser_model_filename='../release/dncnn_25.pth'):

    # Blur kernel
    cFT = psf2otf(c, b.shape)
    cTFT = np.conj(cFT)

    # Fourier transform of b 
    bFT = fft2(b)

    # initialize x,z,u with all zeros
    x = np.zeros_like(b)
    z = np.zeros_like(b)    

    # load pre-trained DnCNN model
    model = net(in_nc=1, out_nc=1, nc=64, nb=17, act_mode='R')        
    model.load_state_dict(torch.load(denoiser_model_filename), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False
    model = model.to(device)

    ################# begin task 2 ###################################

    # HQS with DnCNN doesn't require a gradient function, so we don't 
    # need it here. Just pre-compute the denominator for the x-update 
    # here, because that doesn't change unless rho changes, which is 
    # not the case here 

    # pre-compute denominator of x update
    denom = 1   # you need to edit this, it's just a placeholder

    ################# end task 2 ####################################

    for it in tqdm(range(num_iters)):

        ################# begin task 2 ###################################

        # Complete this part by implementing the x-update discussed in 
        # class and in the problem session. If you implemented the 
        # denominator term above, you only need to compute the nominator
        # here as well as the rest of the x-update
        
        # x update - inverse filtering: Fourier multiplications and divisions        
        x = 0 # you need to edit this, it's just a placeholder

        ################# end task 2 ####################################

        # run DnCNN denoiser
        x_tensor = torch.reshape(torch.from_numpy(x).float().to(device), (1,1,x.shape[0],x.shape[1]))
        x_tensor_denoised = model(x_tensor)

        z = torch.squeeze(x_tensor_denoised).cpu().numpy()

    return x


