"""
Authors: Michael Xie, Neal Jean, Anthony Perez
"""

import tensorflow as tf
import math
import numpy as np

def lrn(input_layer, local_size, alpha, beta, bias=1.0, name='lrn'):
    """Local Response Normalization

    :param input_layer: Tensor
    :param local_size: int
    :param alpha: alpha * local_size, same definition of alpha as in Caffe
    :param beta: float
    :param bias: float [default=1.0]
    :param name: string [default='lrn']
    :return: Tensor
    """
    return tf.nn.local_response_normalization(
        input_layer, depth_radius=(local_size // 2),
        alpha=(float(alpha) / local_size), beta=beta, bias=bias, name=name)


def max_pool(input_layer, size, stride, pad='SAME', name='pool'):
    '''
    Max pooling layer

    :param input_layer: Tensor
    :param size: int
    :param stride: int
    :param pad: string [default='SAME']
    :param name: string
    '''
    return tf.nn.max_pool(
        input_layer, ksize=[1, size, size, 1], strides=[1, stride, stride, 1],
        padding=pad, name=name)


def average_pool(input_layer, size, stride, pad='SAME', name='pool'):
    '''
    Average pooling layer

    :param input_layer: Tensor
    :param size: int
    :param stride: int
    :param pad: string [default='SAME']
    :param name: string
    '''
    return tf.nn.avg_pool(
        input_layer, ksize=[1, size, size, 1], strides=[1, stride, stride, 1],
        padding=pad, name=name)


def conv2d(input_layer, size, depth, stride, pad='SAME',
           activation_fn=tf.nn.relu,
           bias_init=tf.constant_initializer(0.1),
           l2loss=None, reuse=False,
           trainable=True, name='conv'):
    '''
    2D convolution layer

    :param input_layer: Tensor
    :param size: int
    :param depth: int
    :param stride: int
    :param pad: string [default='SAME']
    :param activation_fn: TF function [default=tf.nn.relu]
    :param bias_init: initializer function [default=tf.nn.relu]
    :param l2loss: float
    :param reuse: bool
    :param trainable:bool
    :param name: string
    '''
    full_size = [size, size, int(input_layer.get_shape()[3]), depth]
    patch_size = full_size[0] * full_size[1]
    weight_initializer = xavier_init(
        full_size[2] * patch_size, full_size[3] * patch_size)
    with tf.variable_scope(name, reuse=reuse):
        filt = tf.get_variable(
            name+'_weights', shape=full_size, initializer=weight_initializer,
            trainable=trainable)
        if l2loss is not None:
            weight_decay = tf.multiply(
                tf.nn.l2_loss(filt), l2loss, name='weight_decay')
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, weight_decay)
        conv = tf.nn.conv2d(
            input_layer, filt, [1, stride, stride, 1], padding=pad)
        conv_biases = tf.get_variable(
            name+'_bias', [full_size[-1]], initializer=bias_init,
            trainable=trainable)
        bias = tf.nn.bias_add(conv, conv_biases)
        return activation_fn(bias)


def dropout(input_layer, prob, name='dropout'):
    '''Dropout layer, zero-out units with probability prob.

    :param input_layer: Tensor
    :param prob: float Tensor
    :param name: string [default=\'dropout\']
    :return: Tensor
    '''
    return tf.nn.dropout(input_layer, prob, name=name)


def fully_connected(input_layer, out_size,
                    activation_fn=tf.nn.relu,
                    bias_init=tf.constant_initializer(0.1),
                    l2loss=None, reuse=False,
                    trainable=True, name='fc'):
    '''
    Fully connected layer

    :param input_layer: Tensor
    :param out_size: int
    :param activation_fn: TF function [default=tf.nn.relu]
    :param bias_init: bias initializer [default=tf.constant_initializer]
    :param l2loss: float [default=None]
    :param reuse: bool [default=False]
    :param trainable: bool [default=True]
    :param name: string [default='fc']
    '''
    full_size = [int(input_layer.get_shape()[1]), out_size]
    weight_initializer = xavier_init(full_size[0], full_size[1])
    with tf.variable_scope(name, reuse=reuse):
            W = tf.get_variable(
                name+'_W', shape=full_size, initializer=weight_initializer,
                trainable=trainable)
            b = tf.get_variable(
                name+'_b', shape=full_size[-1], initializer=bias_init,
                trainable=trainable)
            if l2loss is not None:
                    wd = tf.mul(tf.nn.l2_loss(W), l2loss, name='weight_decay')
                    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, wd)
            mult = tf.matmul(input_layer, W)
            plus_b = tf.nn.bias_add(mult, b)
            return activation_fn(plus_b)


def xavier_init(n_inputs, n_outputs, uniform=True):
    """
    Set the parameter initialization using the method described.
    This method is designed to keep the scale of the gradients roughly the same
    in all layers.
    Xavier Glorot and Yoshua Bengio (2010):
        Understanding the difficulty of training deep feedforward neural
        networks. International conference on artificial intelligence and
        statistics.

    :param n_inputs: The number of input nodes into each output.
    :param n_outputs: The number of output nodes for each input.
    :param uniform: If true use a uniform distribution, otherwise use a normal.
    :return: An initializer
    """
    if uniform:
        # 6 was used in the paper.
        init_range = math.sqrt(6.0 / (n_inputs + n_outputs))
        return tf.random_uniform_initializer(-init_range, init_range)
    else:
        # 3 gives us approximately the same limits as above since this repicks
        # values greater than 2 standard deviations from the mean.
        stddev = math.sqrt(3.0 / (n_inputs + n_outputs))
        return tf.truncated_normal_initializer(stddev=stddev)

def dilated_conv_helper(input_layer, input_filter, strides, stride_pad, dilation_pad="SAME"):
    # sum several convolutions across layers, dilate them so that each conv is looking at the image at its proper resolution.
    # i.e. if each pixel is 15 meters, and an image has a resolution of 30 meters, its convolution
    #     should be dilated with d=2 so that it looks at its image correctly
    # assumes the bands are in the following order
    """
    In order and by index, the bands are:
    0: Blue (Band 1)
    1: Green (Band 2)
    2: Red (Band 3)
    3: Near Infared (NIR) (Band 4)
    4: Short-wave Infrared 1 (SWIR1) (Band 5)
    5: Short-wave Infrared 2 (SWIR2) (Band 7)
    6: Panchromatic (Band 8)
    7: Thermal 1 (Band 6 VCID 1)
    8: Thermal 2 (Band 6 VCID 2)
    """
    # Define info for dilation conv
    _15_meter = [0,1,2,6]
    _30_meter = [3,4,5]
    _60_meter = [7,8]

    input_shape = input_filter.get_shape().as_list() # THIS IS THE FILTER SHAPE, NOT THE SHAPE OF THE INPUT
    if not isinstance(strides, list):
        raise Exception("strides must be a regualr python list")
    if input_shape[2] != 9:
        raise Exception("Attempting to use dilated convolution on image that does not have 9 bands. Is rgb_only True?")
    if strides[1] > input_shape[0] or strides[2] > input_shape[1]:
        raise Exception("Convolutions do not support stride > filter size.")
    if input_shape[0] != input_shape[1]:
        raise Exception("The dilated convolution method has not been test for convolutions with different height and width kernel dimensions.  You can probably remove this check without error, the code is written for generic input filter sizes.")

    # Perform dilation convolution
    split_weights = tf.split(axis=2, num_or_size_splits=9, value=input_filter)
    split_x = tf.split(axis=3, num_or_size_splits=9, value=input_layer)
    def do_dilated_cov(indicies, rate, name):
        return tf.nn.atrous_conv2d(
               value = tf.concat(axis=3, values=[band for i, band in enumerate(split_x) if i in indicies]),
               filters = tf.concat(axis=2, values=[_filter for i, _filter in enumerate(split_weights) if i in indicies]),
               rate = rate,
               padding = dilation_pad,
               name = name)
    result = do_dilated_cov(_15_meter, 1, "15_meter_dilated_conv") \
       + do_dilated_cov(_30_meter, 2, "30_meter_dilated_conv") \
       + do_dilated_cov(_60_meter, 4, "60_meter_dilated_conv")

    if np.any(np.array(strides) > 1):
        # USe a convolution with only the middle pixel as 1 (the rest are zero) to perform a stride
        if input_shape[0] % 2 == 0 or input_shape[1] % 2 == 0:
            raise NotImplementedError("Stride > 1 in dilated convolution not implemented for even filter size")
        conv_stride_filter = np.zeros((input_shape[0], input_shape[1], input_shape[3], input_shape[3]), dtype=np.float32)
        for i in range(input_shape[3]):
            conv_stride_filter[input_shape[0] // 2 + 1, input_shape[1] // 2 + 1, i, i] = 1
        conv_stride_filter = tf.constant(conv_stride_filter)

        result = tf.nn.conv2d(result, conv_stride_filter, strides,
                              padding=stride_pad, name="dilation_stride_conv")

    return result


def dilated_conv(input_layer, size, depth, stride, stride_pad='SAME',
           activation_fn=tf.nn.relu,
           bias_init=tf.constant_initializer(0.1),
           l2loss=None, reuse=False,
           trainable=True, name='dilated_conv'):
    # Define shape, patch_size, and initializer
    full_size = [size, size, int(input_layer.get_shape()[3]), depth]
    patch_size = full_size[0] * full_size[1]
    weight_initializer = xavier_init(
        full_size[2] * patch_size, full_size[3] * patch_size)

    with tf.variable_scope(name, reuse=reuse):
        # Get weights and add weight decay
        filt = tf.get_variable(
            name+'_weights', shape=full_size, initializer=weight_initializer,
            trainable=trainable)
        if l2loss is not None:
            weight_decay = tf.multiply(
                tf.nn.l2_loss(filt), l2loss, name='weight_decay')
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, weight_decay)

        result = dilated_conv_helper(input_layer, filt, [1, stride, stride, 1], stride_pad)
        if bias_init is not None:
            conv_biases = tf.get_variable(
                name+'_bias', [full_size[-1]], initializer=bias_init,
                trainable=trainable)
            result = tf.nn.bias_add(result, conv_biases)
        if activation_fn is not None:
            result = activation_fn(result)

        return result
