"""
Author: Anthony Perez


Note:  This file is a collection of code snippets that might be educational, but not necessarily diretly useful.
"""

def create_optimizer(current_epoch_var, learning_rate, beta2=0.999):
    '''
    Args:
        current_epoch_var: tf.Variable
        learning_rate: float
        beta2: float

    Returns: tuple (optimizer, l_r)
        optimizer: tf.train.AdamOptimizer
        l_r: tf.Tensor, dtype=tf.float32, the current learning rate
    '''
    with tf.variable_scope("optimizer"):
        # Here we show how one can maniupluate the learning rate as a function of the epoch
        l_r_halving_point = tf.constant(30)
        l_r = tf.constant(learning_rate, dtype=tf.float32)
        l_r = tf.cond(tf.less(l_r_halving_point, current_epoch_var), lambda: l_r / 2.5, lambda: l_r)
        decay = 0.98
        decay_at_epoch = tf.pow(decay, tf.cast(current_epoch_var, tf.float32))
        l_r = l_r * decay_at_epoch
        eps = 1.0 # default is 1e-08

        # These 2 lines will help you visualizes your learning rate
        tf.summary.scalar("Learning Rate After Decay", l_r)
        tf.summary.scalar("Decay At Epoch (lr multiplier)", decay_at_epoch)

        optimizer = tf.train.AdamOptimizer(learning_rate=l_r, beta1=0.9, beta2=beta2, epsilon=eps)
        return optimizer, l_r


def get_train_op(optimizer, all_losses, global_step):
    '''
    Args:
        optimizer:  This should be the result of something like the function above.
        all_losses:  The loss value you wish to minimize -- see base_model.py
        global_step: a tf.Variable that will be incremented everytime the train_op is called.
    '''
    train_op = optimizer.minimize(all_losses, global_step=global_step)
    return train_op

######################
####### SAVING #######
######################

def save(sess, saver, checkpoint_dir, dataset_name, batch_size, step, tag=None):
    model_name = "your_model_name.model"
    model_dir = "%s_%s" % (dataset_name, batch_size)
    if tag is not None:
        model_dir = "%s_%s" % (model_dir, tag)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    saver.save(sess, os.path.join(checkpoint_dir, model_name),
                    global_step=step)
                    
class LoadNoFileError(Exception):
    pass
                    
def load(sess, saver, checkpoint_dir, dataset_name, batch_size, tag=None):
    print(" [*] Reading checkpoints...")

    if checkpoint_dir is None:
        raise Exception("No checkpoint path, given, cannot load checkpoint")

    model_dir = "%s_%s" % (dataset_name, batch_size)
    if tag is not None:
        model_dir = "%s_%s" % (model_dir, tag)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        ckpt_full_path = os.path.join(checkpoint_dir, ckpt_name)
        if not checkpoint_path_exists(ckpt_full_path):
            raise LoadNoFileError("Checkpoint could not be loaded because it does not exist," + \
                                  " but its information is in the checkpoint meta-data file.")
        saver.restore(sess, ckpt_full_path)
        return True
    else:
        return False

MAX_MODELS_TO_KEEP = 10
# None defaults to everything
saver = tf.train.Saver(var_list=None, max_to_keep=MAX_MODELS_TO_KEEP)
validation_saver = tf.train.Saver(var_list=None, max_to_keep=MAX_MODELS_TO_KEEP)
if not load(sess, saver, checkpoint_dir, dataset_name, batch_size, tag=tag):
    # If no checkpoint, initialize from pre-trained weights in numpy file.
    print('No checkpoint file found.')
    
####################################
#             TRAINING             #
####################################

# Create a variable to keep track of the current step
global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32)
steps_per_epoch = float(NUM_DATA_EXAMPLES) / batch_size
current_epoch_var = tf.cast(tf.cast(global_step, tf.float32) / steps_per_epoch, tf.int32)
tf.summary.scalar("Epoch", current_epoch_var)

# Create a tensorflow node that will train the model for one step when run
all_losses = None # Assume all_losses exists from some previous computation you've done
optimizer, _ = create_optimizer(current_epoch_var, 0.01)
train_op = get_train_op(optimizer, all_losses, global_step)

# Batchnorm's moving mean and variance don't update automatically, so we need to update them
batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
batchnorm_updates_op = tf.group(*batchnorm_updates)

# function that will run the optimization node
# In our code, this function was defined inside another 
# function that created the train_op and batchnorm_updates_op.
# Additionally: The code we use sets up a node that loads the next data example from disk
# This means that we don't use any placeholders and thus don't need to supply a feeddict
# You may need to modify this code to supply a feeddict
def run_train_op(other_ops=[], runtime_trace=False):
    '''Runs train_op, batchnorm_updates_op, and any ops in other_ops. Increments global step.

    Args:
        other_ops: list of ops that will be run at the same time the optimizer op is run
        runtime_trace: bool, if true then return a second output that contains the runtime metadata

    Returns: tuple of (outputs, run_metadata)
        outputs: result of sess.run(other_ops)
        run_metadata: if runtime_trace=True, contains the metadata. Otherwise is None.
    '''
    if runtime_trace:
        run_options = tf.RunOptions(trace_level=tf.RunOptions.SOFTWARE_TRACE)
        run_metadata = tf.RunMetadata()
        out = sess.run(other_ops + [batchnorm_updates_op, train_op], options=run_options, run_metadata=run_metadata)
        return out[:-2], run_metadata
    else:
        out = sess.run(other_ops + [batchnorm_updates_op, train_op])
        return out[:-2], None
        
# The actual training loop would look something like

try: 
    while ('''not done training'''):
        # Get the current step
        start_time = time.time()
        step = sess.run(global_step)
        print("Step: %d" % step)

        # run train op and calculate loss
        # Here we get back the loss, learning rate, and current epoch so that we 
        # can print out the current progress.  summary_op deals with tensorboard summaries.
        output_ops = [all_losses, l_r, current_epoch_var, summary_op]
        outputs, runtime_metadata = run_train_op(output_ops, runtime_trace=collect_runtime_data)
        loss_value, step_learning_rate, current_epoch, summary_str = outputs
        if collect_runtime_data:
            summary_writer.add_run_metadata(runtime_metadata, 'step{}'.format(step))

        # It may be useful to know how long running one step takes
        duration = time.time() - start_time

        # Print an overview fairly often.
        if step % PRINT_EVERY == 0:
            print('%s: Step %d: lr %.8f, loss = %.3f, epoch: %d (%.3f sec)' %
                (dataset_name, step, step_learning_rate, loss_value, current_epoch, duration))
        if step % SAVE_EVERY == 0:
            pass # You will likely want to save every so often.
finally:
    pass
    # You might want to save here so that you keep your weights if an error occurs.