/*****************************************************************************
* mexschur.c : C mex file to compute  
*          
*    mexschur(blk,Avec,nzlistA1,nzlistA2,permA,U,V,colend,type,schur);  
*
*    schur(I,J) = schur(I,J) + Trace(Ai U Aj V),
*    where I = permA[i], J = permA[j],   1<=i,j<=colend. 
* 
*   input: blk  = 1x2 a cell array describing the block structure of A.
*          Avec =  
*          nzlistA = 
*          permA = a permutation vector.  
*          U,V  = real symmetric matrices.
*          type = 0, compute Trace(Ai*(U Aj V + V Aj U)/2)) = Trace(Ai*(U Aj V))
*                = 1, compute Trace(Ai*(U Aj U)).
*
* SDPT3: version 3.0
* Copyright (c) 1997 by
* K.C. Toh, M.J. Todd, R.H. Tutuncu
* Last Modified: 2 Feb 01   
****************************************************************************/

#include <mex.h>
#include <math.h>
#include <matrix.h>

#if !defined(MX_API_VER) || ( MX_API_VER < 0x07030000 )
typedef int mwIndex;
typedef int mwSize;
#endif

#if !defined(MAX)
#define  MAX(A, B)   ((A) > (B) ? (A) : (B))
#endif

#if !defined(r2)
#define  r2   1.41421356237309504880      /* sqrt(2) */
#endif

#if !defined(ir2)
#define  ir2  0.70710678118654752440      /* 1/sqrt(2) */
#endif

/*********************************************************
*
*
********************************************************/
void setvec(int n, double *x, double alpha)

{ int k;
  for (k=0; k<=n; ++k) { x[k] = alpha; }
return; 
}

/**********************************************************
*  compute  Trace(B U*A*U)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U  is assumed to be real,dense,symmetric. 
**********************************************************/
void schurij1( int n, 
               double *Avec, int *idxstart, int *nzlistAi, int *nzlistAj,
               double *U, int col, double *schurcol)

{ int    i, ra, ca, rb, cb, rbn, cbn, l, k, kstart, kend, lstart, lend; 
  double tmp1, tmp2, tmp3, tmp4; 

  lstart = idxstart[col]; lend = idxstart[col+1]; 

for (i=0; i<=col; i++) {
   if (schurcol[i] != 0) {
     kstart = idxstart[i]; kend = idxstart[i+1]; 
     tmp1 = 0; tmp2 = 0;  
     for (l=lstart; l<lend; ++l) { 
        rb = nzlistAi[l];    
        cb = nzlistAj[l];
        if (rb > cb) { mexErrMsgTxt("mexschur: nzlistA2 is incorrect"); }
        rbn = rb*n; cbn = cb*n;   
        tmp3 = 0; tmp4 = 0;
        for (k=kstart; k<kend; ++k) { 
           ra = nzlistAi[k];
           ca = nzlistAj[k];             
           if (ra<ca) {  
              tmp3 += Avec[k] * (U[ra+rbn]*U[ca+cbn]+U[ra+cbn]*U[ca+rbn]); }
           else { 
              tmp4 += Avec[k] * (U[ra+rbn]*U[ca+cbn]); }
	}
        if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3 + tmp4); }
        else       { tmp2 += Avec[l]*(ir2*tmp3 + tmp4); } 
     }
     schurcol[i] = r2*tmp1+tmp2; 
   }
}
return;
}
/**********************************************************
*  compute  Trace(B (U*A*V + V*A*U)/2) = Trace(B U*A*V)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U,V are assumed to be real,dense,symmetric. 
**********************************************************/
void schurij3(int n,  
              double *Avec, int *idxstart, int *nzlistAi, int *nzlistAj,
              double *U, double *V, int col, double *schurcol)

{ int    ra, ca, rb, cb, rbn, cbn, l, k, idx1, idx2, idx3, idx4;
  int    i, kstart, kend, lstart, lend; 
  double tmp1, tmp2, tmp3, tmp4; 

  lstart = idxstart[col]; lend = idxstart[col+1]; 

for (i=0; i<=col; i++) {
   if (schurcol[i] != 0) {
     kstart = idxstart[i]; kend = idxstart[i+1]; 
     tmp1 = 0; tmp2 = 0;  
     for (l=lstart; l<lend; ++l) { 
        rb = nzlistAi[l];    
        cb = nzlistAj[l];
        if (rb > cb) { mexErrMsgTxt("mexschur: nzlistA2 is incorrect"); }
        rbn = rb*n; cbn = cb*n;   
        tmp3 = 0; tmp4 = 0; 
        for (k=kstart; k<kend; ++k) { 
           ra = nzlistAi[k];
           ca = nzlistAj[k];
           idx1 = ra+rbn; idx2 = ca+cbn;
           if (ra<ca) { 
              idx3 = ra+cbn; idx4 = ca+rbn; 
	      tmp3 += Avec[k] *(U[idx1]*V[idx2]+U[idx2]*V[idx1] \
                                +U[idx3]*V[idx4]+U[idx4]*V[idx3]);  }
           else {
	      tmp4 += Avec[k] * (U[idx1]*V[idx2]+U[idx2]*V[idx1]);  }
	}
        if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3+tmp4); }
        else       { tmp2 += Avec[l]*(ir2*tmp3+tmp4); } 
     }
     schurcol[i] = ir2*tmp1+tmp2/2; 
   }
}
return; 
}
/**********************************************************
* stack multiple blocks into a long column vector
**********************************************************/
void vec(int numblk, int *cumblksize, int *blknnz, 
         double *A, mwIndex *irA, mwIndex *jcA, double *B) 

{  int idx0, idx, i, j, l, jstart, jend, istart, blksize;
   int k, kstart, kend; 
   
      for (l=0; l<numblk; l++) { 
  	  jstart = cumblksize[l]; 
  	  jend   = cumblksize[l+1];  
          blksize = jend-jstart; 
          istart = jstart;
          idx0 = blknnz[l]; 
          for (j=jstart; j<jend; j++) { 
              idx = idx0 + (j-jstart)*blksize; 
              kstart = jcA[j]; kend = jcA[j+1]; 
              for (k=kstart; k<kend; k++) { 
                  i = irA[k];
                  B[idx+i-istart] = A[k]; }
          }
      }  
return;
}
/**********************************************************
*  compute  Trace(B U*A*U)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U  is assumed to be real,sparse,symmetric. 
**********************************************************/
void schurij2( double *Avec, 
               int *idxstart, int *nzlistAi, int *nzlistAj, double *Utmp,
               int *nzlistAr, int *nzlistAc, int *cumblksize, 
               int *blkidx, int col, double *schurcol)

{ int    r, ra, ca, rb, cb, l, k, kstart, kend, kstartnew, lstart, lend;
  int    colcb1, idxrb, idxcb, idx1, idx2, idx3, idx4;
  int    i, cblk, calk, firstime; 
  double tmp0, tmp1, tmp2, tmp3, tmp4; 

  lstart = idxstart[col]; lend = idxstart[col+1]; 

for (i=0; i<=col; i++) { 
   if (schurcol[i] != 0) {
      kstart = idxstart[i]; kend = idxstart[i+1]; 
      kstartnew = kstart;
      tmp1 = 0; tmp2 = 0; 
      for (l=lstart; l<lend; ++l) { 
          rb = nzlistAi[l];    
          cb = nzlistAj[l];
          cblk = blkidx[cb];  
          idxcb = nzlistAc[l]; 
          idxrb = nzlistAr[l];
          tmp3 = 0; tmp4 = 0; firstime = 1; 
          for (k=kstart; k<kend; ++k) { 
              ca = nzlistAj[k];
              calk = blkidx[ca]; 
              if (calk==cblk) {
                 ra = nzlistAi[k];
                 idx1 = ra+idxrb; idx2 = ca+idxcb; 
                 if (ra<ca) {  
                    idx3 = ra+idxcb; idx4 = ca+idxrb; 
                    tmp3 += Avec[k] * (Utmp[idx1]*Utmp[idx2]+Utmp[idx3]*Utmp[idx4]); }
                 else {
	            tmp4 += Avec[k] * (Utmp[idx1]*Utmp[idx2]); }
                 if (firstime) { kstartnew = k; firstime = 0; } 
	      }
              else if (calk > cblk) {
	         break;
              }
	  }
          kstart = kstartnew; 
          if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3 + tmp4); }
          else       { tmp2 += Avec[l]*(ir2*tmp3 + tmp4); } 
      }
      tmp0 = r2*tmp1+tmp2; 
      schurcol[i] = tmp0; 
   }
}
return;
}
/**********************************************************
*  compute  Trace(B (U*A*V + V*A*U)/2) = Trace(B U*A*V)
*
*  A,B are assumed to be real,sparse,symmetric.
*  U,V are assumed to be real,sparse,symmetric. 
**********************************************************/
void schurij4( double *Avec, 
               int *idxstart, int *nzlistAi, int *nzlistAj,
               double *Utmp, double *Vtmp, 
               int *nzlistAr, int *nzlistAc, int *cumblksize, 
               int *blkidx, int col, double *schurcol)

{ int    r, ra, ca, rb, cb, l, k, kstart, kend, kstartnew, lstart, lend;
  int    colcb1, idxrb, idxcb, idx1, idx2, idx3, idx4; 
  int    i, cblk, calk, firstime;
  double tmp0, tmp1, tmp2, tmp3, tmp4;
  double hlf=0.5;  

  lstart = idxstart[col]; lend = idxstart[col+1];

for (i=0; i<=col; i++) {
   if (schurcol[i] != 0) {
      kstart = idxstart[i]; kend = idxstart[i+1]; 
      kstartnew = kstart;
      tmp1 = 0; tmp2 = 0;  
      for (l=lstart; l<lend; ++l) { 
          rb = nzlistAi[l];    
          cb = nzlistAj[l];
          cblk = blkidx[cb];  
          idxcb = nzlistAc[l]; 
          idxrb = nzlistAr[l];
          tmp3 = 0; tmp4 = 0; firstime = 1; 
          for (k=kstart; k<kend; ++k) { 
              ca = nzlistAj[k];
              calk = blkidx[ca]; 
              if (calk == cblk) { 
                 ra = nzlistAi[k];
                 idx1 = ra+idxrb; idx2 = ca+idxcb; 
                 if (ra<ca) {
                    idx3 = ra+idxcb; idx4 = ca+idxrb; 
	            tmp3 += Avec[k] * (Utmp[idx1]*Vtmp[idx2] +Utmp[idx2]*Vtmp[idx1] \
                                   +Utmp[idx3]*Vtmp[idx4] +Utmp[idx4]*Vtmp[idx3]); 
                 } else {
	            tmp4 += Avec[k] * (Utmp[idx1]*Vtmp[idx2] +Utmp[idx2]*Vtmp[idx1]); 
		 }
                 if (firstime) { kstartnew = k; firstime = 0; }  
	      }
              else if (calk > cblk) {
	         break;
	      }
	  }
          kstart = kstartnew; 
          if (rb<cb) { tmp1 += Avec[l]*(ir2*tmp3 + tmp4); }
          else       { tmp2 += Avec[l]*(ir2*tmp3 + tmp4); } 
      }
      tmp0 = ir2*tmp1+hlf*tmp2;
      schurcol[i] = tmp0;
   }
}
return; 
}
/**********************************************************/
void mexFunction(int nlhs,   mxArray  *plhs[], 
                 int nrhs,   const mxArray  *prhs[] )

{    mxArray  *blk_cell_pr;  
     double   *Avec, *idxstarttmp, *nzlistAtmp, *permAtmp, *U, *V, *schur;
     double   *blksizetmp, *Utmp, *Vtmp, *schurcol, *nzschur, *P;  
     mwIndex  *irP, *jcP, *irU, *jcU, *irV, *jcV;
     int      *idxstart, *colm, *permA, *nzlistAr, *nzlistAc;
     int      *nzlistAi, *nzlistAj, *blksize, *cumblksize, *blknnz, *blkidx; 

     mwIndex  subs[2];
     mwSize   nsubs=2; 
     int      index, colend, type, isspU, isspV, numblk, nzP, existP; 
     int      len, row, col, nU, nV, n, m, m1, idx1, idx2, l, k, nsub, n1, n2, opt, opt2;
     int      kstart, kend, rb, cb, cblk, colcb, count; 
     double   tmp; 

/* CHECK THE DIMENSIONS */

    if (nrhs < 10) {
       mexErrMsgTxt(" mexschur: must have at least 10 inputs"); }
    if (!mxIsCell(prhs[0])) {
       mexErrMsgTxt("mexschur: 1ST input must be the cell array blk"); }  
    if (mxGetM(prhs[0])>1) {
       mexErrMsgTxt("mexschur: blk can have only 1 row"); }  
    subs[0] = 0; 
    subs[1] = 1;
    index = mxCalcSingleSubscript(prhs[0],nsubs,subs); 
    blk_cell_pr = mxGetCell(prhs[0],index); 
    numblk  = mxGetN(blk_cell_pr);
    blksizetmp = mxGetPr(blk_cell_pr); 
    blksize = mxCalloc(numblk,sizeof(int)); 
    for (k=0; k<numblk; k++) { 
        blksize[k] = (int)blksizetmp[k];
    }
/**** get pointers ****/    

    Avec = mxGetPr(prhs[1]); 
    if (!mxIsSparse(prhs[1])) { 
       mexErrMsgTxt("mexschur: Avec must be sparse"); }
    idxstarttmp = mxGetPr(prhs[2]);  
    len = MAX(mxGetM(prhs[2]),mxGetN(prhs[2])); 
    idxstart = mxCalloc(len,sizeof(int)); 
    for (k=0; k<len; k++) { 
        idxstart[k] = (int)idxstarttmp[k]; 
    }
    nzlistAtmp = mxGetPr(prhs[3]); 
    len = mxGetM(prhs[3]);
    nzlistAi = mxCalloc(len,sizeof(int)); 
    nzlistAj = mxCalloc(len,sizeof(int)); 
    for (k=0; k<len; k++) { 
        nzlistAi[k] = (int)nzlistAtmp[k] -1; /* -1 to adjust for matlab index */
        nzlistAj[k] = (int)nzlistAtmp[k+len] -1; 
    }
    permAtmp = mxGetPr(prhs[4]); 
    m1 = mxGetN(prhs[4]); 
    permA = mxCalloc(m1,sizeof(int)); 
    for (k=0; k<m1; k++) {
        permA[k] = (int)permAtmp[k]-1; /* -1 to adjust for matlab index */
    }
    U = mxGetPr(prhs[5]);  nU = mxGetM(prhs[5]); 
    isspU = mxIsSparse(prhs[5]); 
    if (isspU) { irU = mxGetIr(prhs[5]); jcU = mxGetJc(prhs[5]); }
    V = mxGetPr(prhs[6]);  nV = mxGetM(prhs[6]); 
    isspV = mxIsSparse(prhs[6]);
    if (isspV) { irV = mxGetIr(prhs[6]); jcV = mxGetJc(prhs[6]); }
    if ((isspU & !isspV) || (!isspU & isspV)) { 
       mexErrMsgTxt("mexschur: U,V must be both dense or both sparse"); 
    }
    colend = (int)*mxGetPr(prhs[7]); 
    type   = (int)*mxGetPr(prhs[8]); 

    schur = mxGetPr(prhs[9]); 
    m = mxGetM(prhs[9]);    
    if (m!= m1) {
       mexErrMsgTxt("mexschur: schur and permA are not compatible"); }
    if (nrhs == 11) {
       P=mxGetPr(prhs[10]); irP=mxGetIr(prhs[10]); jcP=mxGetJc(prhs[10]);    
       existP = 1;
    } else {
       existP = 0; 
    }
/************************************
* output 
************************************/

    plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL); 
    nzschur = mxGetPr(plhs[0]); 
    if (nlhs==2) {
       nzP = (int) (0.2*m*m+5); 
       plhs[1] = mxCreateSparse(m,colend,nzP,mxREAL); 
       P=mxGetPr(plhs[1]); irP=mxGetIr(plhs[1]); jcP=mxGetJc(plhs[1]); 
       jcP[0] = 0; 
    }
/************************************
* initialization 
************************************/
    if (isspU & isspV) { 
       cumblksize = mxCalloc(numblk+1,sizeof(int)); 
       blknnz = mxCalloc(numblk+1,sizeof(int)); 
       cumblksize[0] = 0; blknnz[0] = 0; 
       n1 = 0; n2 = 0; 
       for (k=0; k<numblk; ++k) {
           nsub = blksize[k];
           n1 += nsub;  
           n2 += nsub*nsub;  
           cumblksize[k+1] = n1; 
           blknnz[k+1] = n2;  }
       if (nU != n1 || nV != n1) { 
          mexErrMsgTxt("mexschur: blk and dimension of U not compatible"); }
       Utmp = mxCalloc(n2,sizeof(double)); 
       vec(numblk,cumblksize,blknnz,U,irU,jcU,Utmp); 
       Vtmp = mxCalloc(n2,sizeof(double)); 
       vec(numblk,cumblksize,blknnz,V,irV,jcV,Vtmp); 
       blkidx = mxCalloc(nU,sizeof(int));
       for (l=0; l<numblk; l++) {  
 	   kstart=cumblksize[l]; kend=cumblksize[l+1];
           for (k=kstart; k<kend; k++) { blkidx[k] = l; }           
       }
       nzlistAc = mxCalloc(len,sizeof(int)); 
       nzlistAr = mxCalloc(len,sizeof(int)); 
       for (k=0; k<len; k++) {
          rb = nzlistAi[k]; 
	  cb = nzlistAj[k]; 
	  cblk = blkidx[cb]; colcb = cumblksize[cblk];             
          nzlistAc[k] = blknnz[cblk]+(cb-colcb)*blksize[cblk]-colcb;
          nzlistAr[k] = blknnz[cblk]+(rb-colcb)*blksize[cblk]-colcb;  
       }
    }
/************************************
* compute schur(i,j)
************************************/

    colm = mxCalloc(colend,sizeof(int));     
    for (k=0; k<colend; k++) { colm[k] = permA[k]*m; } 

    n = nU; 
    if      (type==1 & !isspU)  { opt=1; }
    else if (type==0 & !isspU)  { opt=3; } 
    else if (type==1 &  isspU)  { opt=2; }
    else if (type==0 &  isspU)  { opt=4; }

    /*************************************/
    schurcol = mxCalloc(colend,sizeof(double)); 
    count = 0;
 
    for (col=0; col<colend; col++) { 
	if (existP) {
	   setvec(col,schurcol,0.0); 
           for (k=jcP[col]; k<jcP[col+1]; k++) { schurcol[irP[k]]=1.0;}
	} else {
	   setvec(col,schurcol,1.0); 
	}
        if (opt==1) { 
 	   schurij1(n,Avec,idxstart,nzlistAi,nzlistAj,U,col,schurcol); 
        } else if (opt==3) { 
           schurij3(n,Avec,idxstart,nzlistAi,nzlistAj,U,V,col,schurcol);
	} else if (opt==2) {
           schurij2(Avec,idxstart,nzlistAi,nzlistAj,Utmp, \
		    nzlistAr,nzlistAc,cumblksize,blkidx,col,schurcol); 
	} else if (opt==4) {
           schurij4(Avec,idxstart,nzlistAi,nzlistAj,Utmp,Vtmp, \
	            nzlistAr,nzlistAc,cumblksize,blkidx,col,schurcol);  
	}
        for (row=0; row<=col; row++) {
	    if (schurcol[row] != 0) {
	       if (count<nzP & nlhs==2) { jcP[col+1]=count+1; irP[count]=row; P[count]=1; }
	       count++; 
   	       idx1 = permA[row]+colm[col]; 
               idx2 = permA[col]+colm[row]; 
               schur[idx1] += schurcol[row];
               schur[idx2] = schur[idx1]; 
            }
	}
    }
    
    nzschur[0] = count;

    mxFree(blksize); mxFree(nzlistAi); mxFree(nzlistAj); 
    mxFree(permA);   mxFree(idxstart); mxFree(schurcol); 
    if (isspU) { 
       mxFree(Utmp);     mxFree(Vtmp); 
       mxFree(nzlistAc); mxFree(nzlistAr); 
       mxFree(blknnz); mxFree(cumblksize); mxFree(blkidx); 
    } 
return;
}
/**********************************************************/



