import numpy
import mmh3
from sortedcontainers import SortedList

# Stores total allocated memory
mem_curr = 0
mem_used = 0

def reset_mem():
    global mem_curr, mem_used
    mem_curr = mem_used = 0
    
def report_mem():
    global mem_curr, mem_used
    return mem_used
    
# Basic functions for allocating memory so that we can accurately count the 
# amount of memory used

# A sorted list container that supports log time insertion and deletion
class tracked_sortedlist(SortedList):
    def __init__(self, init=[]):
        global mem_curr, mem_used
        values = list(init)
        mem_curr += len(values)
        mem_used = max(mem_used, mem_curr)
        SortedList.__init__(self, values)
        
    def clear(self):
        global mem_curr, mem_used
        mem_curr -= len(self)
        SortedList.clear(self)
    
    def discard(self, value):
        global mem_curr, mem_used
        mem_curr -= 1
        SortedList.discard(self, value)
        
    def add(self, item):
        global mem_curr, mem_used
        mem_curr += 1
        mem_used = max(mem_used, mem_curr)
        SortedList.add(self, item)
        
# A basic list that tracks the amount of memory it uses as we add and remove 
# elements
class tracked_list(list):
    def __init__(self, init=[]):
        global mem_curr, mem_used
        values = list(init)
        mem_curr += len(values)
        mem_used = max(mem_used, mem_curr)
        list.__init__(self, values)
        
    def clear(self):
        global mem_curr, mem_used
        mem_curr -= len(self)
        del self[:]
    
    def pop(self):
        global mem_curr, mem_used
        mem_curr -= 1
        list.pop(self)
        
    def append(self, item):
        global mem_curr, mem_used
        mem_curr += 1
        mem_used = max(mem_used, mem_curr)
        list.append(self, item)

# To "allocate" an int, use this function
def tracked_int(x = 0):
    global mem_curr, mem_used
    mem_curr += 1
    mem_used = max(mem_used, mem_curr)
    return int(x)
    
# To "allocate" a double, use this function
def tracked_double(x = 0):
    global mem_curr, mem_used
    mem_curr += 1
    mem_used = max(mem_used, mem_curr)
    return numpy.double(x)
 
# Each different seed gives a different hash function below
def hash(elem, seed=42):
    return mmh3.hash(str(elem), seed)

# Uses the previous hash function to generate a random float
def hash_double(elem, seed=42):
    mod = 2**128
    x, y = mmh3.hash64(str(elem), seed=seed, signed=False)
    value = (x + 2**64 * y) % mod
    return numpy.double(value) / numpy.double(mod)
    