"""
Authors: Neal Jean, Michael Xie, Chris Yeh, Anthony Perez
"""

from .base_model import BaseModel
from . import layers
from .init_from_pretrained_utils import init_vggf_from_numpy

import tensorflow as tf


class VGGF(BaseModel):
    '''VGGF Model

    Adapted from what Neal and Michael used in the Science paper
    http://science.sciencemag.org/content/353/6301/790.long
    - add regularization losses to the regularization losses collection, not "losses"
    - add option to use dilated convolution in 1st layer
    - take (224, 224) input images, instead of (400, 400), which requires removing
        the pool7 layer

    Based on original VGG-F model
    https://arxiv.org/abs/1405.3531
    '''
    def __init__(self, inputs, is_training, fc_reg, conv_reg=0.0005, device='/gpu:0', reuse=False,
                 use_dilated_conv_in_first_layer=False):
        '''
        Args:
            inputs: tf.Tensor with shape [batch_size, img_height, img_width, img_depth], dtype tf.float32
            is_training: bool
            fc_reg: float, regularization for weights in the fully-connected layer
            conv_reg: float, regularization for weights in the conv layers
            device: str, '/cpu:i' or '/gpu:i' where i is nonnegative int
            reuse: bool (reuse variable weights or create new weights)
            use_dilated_conv_in_first_layer: bool
        '''
        super(VGGF, self).__init__(
            inputs=inputs,
            is_training=is_training,
            fc_reg=fc_reg,
            conv_reg=conv_reg)

        self.device = device
        self.update_ops_collection_key = tf.GraphKeys.UPDATE_OPS

        with tf.variable_scope('vggf'):
            self.set_up_graph(images=inputs, is_training=is_training, reuse=reuse, reg=conv_reg,
                              use_dilation=use_dilated_conv_in_first_layer)

    def init_from_numpy(self, path, sess, hs_weight_init='random'):
        '''
        Args:
            path: str, path to saved weights
            sess: tf.Session
            hs_weight_init: str, one of ['random', 'same']
        '''
        init_vggf_from_numpy(path, sess, hs_weight_init=hs_weight_init)

    def set_up_graph(self, images, is_training, reuse, reg, use_dilation):
        """
        Builds computation graph.

        :param images: Tensor
        :param is_training: bool, whether weights are trainable and to use dropout
        :param reuse: bool (reuse variable weights or create new weights)
        :param reg: float, regularization for weights in the conv layers
        :param use_dilation: bool, whether to use dilated convolution in first layer
        """
        with tf.device(self.device):
            self.images = images
            if use_dilation:
                # In shape: [B, 224, 224, 9].  Out shape: [B, 224, 224, 64], [B, 54, 54, 64]
                self.conv1 = layers.dilated_conv(
                    self.images, 11, 64, 4, stride_pad="VALID",
                    bias_init=tf.constant_initializer(value=0.1), l2loss=reg,
                    reuse=reuse, trainable=is_training, name='conv1')
            else:
                # In shape: [B, 224, 224, C].  Out shape: [B, 54, 54, 64]
                self.conv1 = layers.conv2d(
                    self.images, 11, 64, 4, pad='VALID',
                    bias_init=tf.constant_initializer(value=0.1), l2loss=reg,
                    reuse=reuse, trainable=is_training, name='conv1')
            self.norm1 = layers.lrn(
                self.conv1, local_size=5, alpha=0.0005, beta=0.75, bias=2.0,
                name='norm1')
            self.pool1 = layers.max_pool(
                self.norm1, 3, 2, pad='SAME', name='pool1')
            self.conv2 = layers.conv2d(
                self.pool1, 5, 256, 1, pad='SAME',
                bias_init=tf.constant_initializer(value=0.1),
                l2loss=reg, reuse=reuse, trainable=is_training,
                name='conv2')
            self.norm2 = layers.lrn(
                self.conv2, local_size=5, alpha=0.0005, beta=0.75, bias=2.0,
                name='norm2')
            self.pool2 = layers.max_pool(
                self.norm2, 3, 2, pad='VALID', name='pool2')
            self.conv3 = layers.conv2d(
                self.pool2, 3, 256, 1, pad='SAME',
                bias_init=tf.constant_initializer(value=0.1), l2loss=reg,
                reuse=reuse, trainable=is_training, name='conv3')
            self.conv4 = layers.conv2d(
                self.conv3, 3, 256, 1, pad='SAME',
                bias_init=tf.constant_initializer(value=0.1), l2loss=reg,
                reuse=reuse, trainable=is_training, name='conv4')
            self.conv5 = layers.conv2d(
                self.conv4, 3, 256, 1, pad='SAME',
                bias_init=tf.constant_initializer(value=0.1), l2loss=reg,
                reuse=reuse, trainable=is_training, name='conv5')
            self.pool5 = layers.max_pool(
                self.conv5, 3, 2, pad='SAME', name='pool5')
            self.conv6 = layers.conv2d(
                self.pool5, 6, 4096, 6, pad='VALID',
                bias_init=tf.constant_initializer(value=0.1), l2loss=reg,
                reuse=reuse, trainable=is_training, name='conv6')

            # after conv6:
            # - if images are (400, 400), then self.conv6 is (2, 2, 4096)
            # - if images are (224, 224), then self.conv6 is (1, 1, 4096)

            if is_training:
                self.dropout6 = layers.dropout(
                    self.conv6, prob=0.5, name='dropout6')
                self.conv7 = layers.conv2d(
                    self.dropout6, 1, 4096, 1, pad='VALID',
                    bias_init=tf.constant_initializer(value=0.1),
                    l2loss=reg, reuse=reuse, trainable=is_training,
                    name='conv7')
            else:
                self.conv7 = layers.conv2d(
                    self.conv6, 1, 4096, 1, pad='VALID',
                    bias_init=tf.constant_initializer(value=0.1),
                    l2loss=reg, reuse=reuse, trainable=is_training,
                    name='conv7')

            # Extract features
            self.features_layer = tf.squeeze(
                self.conv7, squeeze_dims=[1, 2], name='features')

            # self.pool7 = layers.average_pool(
            #     self.conv7, 2, 1, pad='VALID', name='pool7')
            # self.features = tf.squeeze(
            #     self.pool7, squeeze_dims=[1, 2], name='features')
