/* 
 * Stephen Becker, 11/10/08
 * Computes A(omega), where A = U*V' is never explicitly computed 
 * srbecker@caltech.edu
 * */

#include "mex.h"

/* for Windows, BLAS functions have no underscore, but for linux..., they do 
 * So, for Windows, uncomment the following: 
 * to compile, use mex -L(location) -lmwblas -DWINDOWS  for Windows 
 * and mex -UWINDOWS for linux
 */
/* #define WINDOWS */

/* to use "mwSize", need matrix.h, but doesn't work for R2006a */
/* #include "matrix.h" */

/* So, use the following definitions instead: */
#ifndef mwSize
    #define mwSize int
#endif
#ifndef mwIndex
    #define mwIndex int
#endif
#ifndef true
    #define true 1
#endif
#ifndef false
    #define false 0
#endif

#ifdef WINDOWS
/* make a level-1 BLAS call */
double ddot( int *K, double *x, int *incx, double *y, int *incy );
/* if omega = 1:M*N, then use a level-3 BLAS call */
double dgemm( char *transA, char *transB, int *M, int *N, int *K, double *alpha, double *, int *LDA, double *R, int *LDB, double *beta, double *C, int *LDC );
#else
double ddot_( int *K, double *x, int *incx, double *y, int *incy );
double dgemm_( char *transA, char *transB, int *M, int *N, int *K, double *alpha, double *, int *LDA, double *R, int *LDB, double *beta, double *C, int *LDC );
#endif

void printUsage() {
    mexPrintf("XonOmega.c: usage is\n\t b = XonOmega(U,V,omega)\n");
    mexPrintf("where A = U*V' and b = A(omega)\n");
    mexPrintf("\nAlternative usage is\n\t b = XonOmega(U,V,omegaI,omegaJ),\n where [omegaI,omegaJ] = ind2sub(...,omega)\n");
    mexPrintf("\nAnother alternative usage is\n\t b = XonOmega(U,V, OMEGA),\n where OMEGA is a sparse matrix with nonzeros on omega.\nThis will agree with the other forms of the command if omega is sorted\n\n");
}

void mexFunction(
         int nlhs,       mxArray *plhs[],
         int nrhs, const mxArray *prhs[]
         )
{
    /* Declare variable */
    mwSize M, N, K, K2;
    mwSize nOmega1, nOmega2, nOmega;
    mwIndex i,j,k;
    double *U, *Vt, *output, *omega;
    double *omegaX, *omegaY;
    int *omegaI, *omegaJ;
    int SPARSE = false;
    
    /* Check for proper number of input and output arguments */    
    if ( (nrhs < 3) || (nrhs > 4) ) {
        printUsage();
    mexErrMsgTxt("Three (or four) input argument required.");
    } 
    if(nlhs > 1){
        printUsage();
    mexErrMsgTxt("Too many output arguments.");
    }
    
    /* Check data type of input argument  */
    if (!(mxIsDouble(prhs[0])) || !((mxIsDouble(prhs[1]))) ){
        printUsage();
    mexErrMsgTxt("Input arguments wrong data-type (must be doubles).");
    }   

    /* Get the size and pointers to input data */
    M  = mxGetM(prhs[0]);
    K  = mxGetN(prhs[0]);
    N  = mxGetM(prhs[1]);
    K2  = mxGetN(prhs[1]);
    if ( K != K2 ) {
        printUsage();
        mexErrMsgTxt("Inner dimension of U and V' must agree.");
    }
    nOmega1 = mxGetM( prhs[2] );
    nOmega2 = mxGetN( prhs[2] );
    if ( (nOmega1 != 1) && (nOmega2 != 1) ) {
/*         printUsage(); */
/*         mexErrMsgTxt("Omega must be a vector"); */
        /* Update:
         * if this happens, we assume Omega is really a sparse matrix
         * and everything is OK */
        if ( ( nOmega1 != M ) || ( nOmega2 != N ) || (!mxIsSparse( prhs[2] )) ){
            printUsage();
            mexErrMsgTxt("Omega must be a vector or a sparse matrix");
        }
        nOmega = mxGetNzmax( prhs[2] );
        SPARSE = true;
    } else {
        nOmega = nOmega1 < nOmega2 ? nOmega2 : nOmega1;
        if ( nOmega > N*M ) {
            printUsage();
            mexErrMsgTxt("Omega must have M*N or fewer entries");
        }
    }

    U = mxGetPr(prhs[0]);
    Vt = mxGetPr(prhs[1]);
    plhs[0] = mxCreateDoubleMatrix(nOmega, 1, mxREAL);
    output = mxGetPr(plhs[0]);

    if (( nrhs < 4 ) && (!SPARSE) ){
        /* omega is a vector of linear indices */
        int USE_BLAS = false;
        omega = mxGetPr(prhs[2]);
        if ( nOmega == (M*N) ) {
            /* in this case, we want to use level-3 BLAS, unless
             * omega isn't sorted (in which case, BLAS would give wrong
             * answer, since it assumes omega is sorted and in column-major
             * order.  So, find out if omega is sorted: */
            USE_BLAS = true;
            for ( k = 0 ; k < nOmega-1 ; k++ ) {
                if (omega[k] > omega[k+1] ) {
                    USE_BLAS = false;
                    break;
                }
            }
        }
        if ( !USE_BLAS ) {
            /* by default, make output the same shape (i.e. row- or column-
             * vector) as the input "omega" */
            mxSetM( plhs[0], mxGetM( prhs[2] ) );       
            mxSetN( plhs[0], mxGetN( prhs[2] ) );
            
            for ( k=0 ; k < nOmega ; k++ ){
                /* don't forget that Matlab is 1-indexed, C is 0-indexed */
                i = (int) ( (int)(omega[k]-1) % M);
                j = (int) ( (int)(omega[k]-1)/ M);
    /*             mexPrintf("%2d %2d\n",i+1,j+1); */
#ifdef WINDOWS                
                output[k] = ddot( &K, U+i, &M, Vt+j, &N );
#else
                output[k] = ddot_( &K, U+i, &M, Vt+j, &N );
#endif
            }
        } else {
            /* we need to compute A itself, so use level-3 BLAS */
            char transA = 'N', transB = 'T';
            double alpha = 1.0, beta = 0.0;
            int LDA = M, LDB = N;
#ifdef WINDOWS
dgemm(&transA,&transB,&M,&N,&K,&alpha,U,&LDA,Vt,&LDB,&beta,output,&M );
#else
dgemm_(&transA,&transB,&M,&N,&K,&alpha,U,&LDA,Vt,&LDB,&beta,output,&M );
#endif
        }

    } else {
        if (SPARSE) {
            /* sparse array indices in Matlab are rather confusing;
             * see mxSetJc help file to get started.  The Ir index
             * is straightforward: it contains rows indices of nonzeros,
             * in column-major order.  But the Jc index is tricky... 
             * Basically, Jc (which has N+1 entries, not nnz entries like Ir)
             * tells you which Ir entries correspond to the jth row, thus fully 
             * specifying the indices.  Ir[ Jc[j]:Jc[J+1] ] are the rows
             * that correspond to column j. This works because Ir is
             * in column-major order.   For this to work (and match A(omega)),
             * we need omega to be sorted!  */
            omegaI = mxGetIr( prhs[2] );
            omegaJ = mxGetJc( prhs[2] );
            for ( j=0 ; j < N ; j++ ){
                for ( k = omegaJ[j] ; k < omegaJ[j+1] ; k++ ) {
                    i = (int) omegaI[k];
/*                     mexPrintf("%2d %2d\n",i+1,j+1);  */
#ifdef WINDOWS
                    output[k] = ddot( &K, U+i, &M, Vt+j, &N );
#else
                    output[k] = ddot_( &K, U+i, &M, Vt+j, &N );
#endif
                }
            }
        } else {
            /* we have omegaX and omegaY, the row and column indices */
            nOmega1 = mxGetM( prhs[3] );
            nOmega2 = mxGetN( prhs[3] );
            if ( (nOmega1 != 1) && (nOmega2 != 1) ) {
                printUsage();
                mexErrMsgTxt("OmegaY must be a vector");
            }
            nOmega1 = nOmega1 < nOmega2 ? nOmega2 : nOmega1;
            if ( nOmega1 != nOmega ) {
                printUsage();
    mexErrMsgTxt("In subscript index format, subscript vectors must be same length.");
            }
            omegaX = mxGetPr(prhs[2]);
            omegaY = mxGetPr(prhs[3]);

            for ( k=0 ; k < nOmega ; k++ ){
                i = (int) omegaX[k] - 1;
                j = (int) omegaY[k] - 1;
#ifdef WINDOWS
                output[k] = ddot( &K, U+i, &M, Vt+j, &N );
#else
                output[k] = ddot_( &K, U+i, &M, Vt+j, &N );
#endif
            }
        }
    }

}
