"""
Authors: Anthony Perez, Chris Yeh
"""

import tensorflow as tf


class BaseModel(object):
    '''The base class of models'''

    def __init__(self, inputs, is_training, fc_reg, conv_reg):
        '''
        Args:
            inputs: tf.Tensor with shape [batch_size, img_height, img_width, img_depth], dtype tf.float32
            is_training: bool
            fc_reg: float, regularization for weights in the fully-connected layer
            conv_reg: float, regularization for weights in the conv layers
        '''
        self.inputs = inputs
        self.is_training = is_training
        self.fc_reg = fc_reg
        self.conv_reg = conv_reg
        self.features_layer = None

    def init_from_numpy(self, path, sess):
        '''
        Args:
            path: str, path to saved weights
            sess: tf.Session
        '''
        pass

    def loss(self, labels, add_summaries=True, reuse=None):
        '''
        Args:
            labels: tf.Tensor, shape [num_examples]
            add_sumaries: bool, whether or not to create summaries for the loss variables
            reuse: bool
        '''
        # FC layer for nightlights, with xavier init as default
        logits = tf.contrib.layers.fully_connected(self.features_layer,
                num_outputs=3, activation_fn=None,
                weights_regularizer=tf.contrib.layers.l2_regularizer(scale=self.fc_reg),
                biases_initializer=tf.constant_initializer(0.1),
                reuse=reuse, trainable=True, scope='nightlight_class_FC_layer')
        print("Added fully connected reg of {0}".format(self.fc_reg))

        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        cross_entropy_mean = tf.reduce_mean(cross_entropy)

        if (self.fc_reg > 0) or (self.conv_reg > 0):
            regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            # "add_n" sums a list of values, whereas "add" which sums togther two values x + y
            total_regularization_loss = tf.add_n(regularization_losses)
        else:
            total_regularization_loss = tf.constant(0.0)

        loss_ = cross_entropy_mean + total_regularization_loss
        if add_summaries:
            tf.summary.scalar('loss_total', loss_)
            tf.summary.scalar('loss_regularization_only', total_regularization_loss)
            tf.summary.scalar('loss_cross_entropy_nightlights', cross_entropy_mean)

            correct_predictions = tf.equal(tf.argmax(logits, axis=1), tf.cast(labels, tf.int64))
            tf.summary.scalar("nightlights_training_accuracy", tf.reduce_mean(tf.cast(correct_predictions, tf.float32)))

        return loss_, cross_entropy_mean, logits
