import numpy as np
import matplotlib.pyplot as plt
import time
import NMF as nmf

n = 1000 		# Number of Data points
d = 87			# Dimension of the Data
r = 4			# Rank of the model
n_f = 20		# Number of data points on the faces
deg_prob = np.array((0, 0.6, 0.4))		# Distribution of the support size of the weights for data points on the faces. 

alpha = 1			# Dirichlet Parameter
sigma = 0.001			# Noise Magnitude
W0 = generate_weights(n,r,alpha,n_f,deg_prob)		# Generating Weights

def read_spectral_data():

# Reading spectral data used as archetypes.

    c_data_file = open("CAFFEINE.txt","r")
    C = []
    for line in c_data_file:
        l = line.split()
        del l[0]
        l = map(float,l)
        l = np.asarray(l)
        C = np.append(C,l)

    C = C[184:271]
    C = C/np.sum(C)

    s_data_file = open("Sucrose.txt","r")
    S = []
    for line in s_data_file:
        l = line.split()
        del l[0]
        l = map(float,l)
        l = np.asarray(l)
        S = np.append(S,l)

    S = S[783:1131]
    S = np.reshape(S,(len(S)/4,4))
    S = 2 - np.log10(100*S[:,0])
    S = S/np.sum(S)

    l_data_file = open("Lactose.txt","r")
    L = []
    for line in l_data_file:
        l = line.split()
        del l[0]
        l = map(float,l)
        l = np.asarray(l)
        L = np.append(L,l)

    L = L[2463:2811]
    L = L[::-1]
    L = np.reshape(L,(len(L)/4,4))
    L = 2 - np.log10(100*L[:,0])
    L = L/np.sum(L)

    t_data_file = open("Trioctanoin.txt","r")
    T = []
    for line in t_data_file:
        l = line.split()
        del l[0]
        l = map(float,l)
        l = np.asarray(l)
        T = np.append(T,l)

    T = T[656:917]
    T = T[::-1]
    T = np.reshape(T,(len(T)/3,3))
    T = 2- np.log10(100*T[:,0])
    T = T/np.sum(T)
    
    return [C, S, L, T]


[C, S, L, T] = read_spectral_data()

H0 = C				# Setting archetypes
H0 = np.vstack([H0, S])
H0 = np.vstack([H0, L])
H0 = np.vstack([H0, T])

X0 = np.dot(W0,H0)		# Generating data points
X = X0 + sigma*np.random.normal(0,1,(n,d))


#W, H, L, Err = palm_nmf(X, r=4, l=0.001,  maxiter=5000, epsilon=1e-6, threshold=1e-8, c1 = 1, c2 = 1, verbose =False, proj_low_dim = False, oracle = True, H0=H0)

W_acc, H_acc, L_acc, Err_acc = acc_palm_nmf(X, r=4, maxiter=1000, delta = 1e-5, epsilon=1e-6, threshold=1e-8, c1 = 1, c2 = 1, verbose = False, oracle = True, H0=H0)

# Plotting Results and ground truth.

order_H_acc = range(0,r)
for i in range (0,r):
    d = np.linalg.norm(H0[i,:] - H_acc[0,:])
    for j in range (1,r):
        dj = np.linalg.norm(H0[i,:] - H_acc[j,:])
        if (dj <= d):
            order_H_acc[i] = j
            d = dj

H_acc = H_acc[order_H,:]

fig_h0_acc = plt.figure()
plt.plot(H0[0,:])
plt.plot(H_acc[0,:])
plt.show()
fig_h1 = plt.figure()
plt.plot(H0[1,:])
plt.plot(H_acc[1,:])
plt.show()
fig_h2 = plt.figure()
plt.plot(H0[2,:])
plt.plot(H_acc[2,:])
plt.show()
fig_h3 = plt.figure()
plt.plot(H0[3,:])
plt.plot(H_acc[3,:])
plt.show()
