// Choleski.java - translated from ncor c++ choleski
// quickly tested, seems ok but beware
// for a small run on mac, 1.7 seconds versus 2.3 for DoubleMatrix.solve

package ZS;

import VisualNumerics.*;
import zlib.*;


/** Cholesky Decomposition.
    <P>
    For a symmetric, positive definite matrix A, the Cholesky decomposition
    is an lower triangular matrix L so that A = L*L'.
    <P>
    If the matrix is not symmetric or positive definite, the constructor
    returns a partial decomposition and sets an internal flag that may
    be queried by the isSPD() method.
*/
public class Choleski
{

  public static boolean decomp(double[][] A, double[] p)
  {
    int n = A.length;
    for( int i = 0; i < n; i++ ) {
      for( int j = 0; j < n; j++ ) {
	double sum = A[i][j];
	for( int k=i-1; k >= 0; k-- )
	  sum -= (A[i][k] * A[j][k]);
	if (i == j) {
	  if (sum <= 0.) return false;	// not positive definite
	  p[i] = Math.sqrt(sum);
	}
	else
	  A[j][i] = sum / p[i];
      }
    }
    return true;
  } //decomp

  
  static void solve1(double[][] A, double[] p, double[] b, double[] x)
  {
    int n = A.length;

    for( int i = 0; i < n; i++ ) {
      double sum = b[i];
      for( int k=i-1; k >= 0; k-- )
	sum -= (A[i][k]*x[k]);
      x[i] = sum / p[i];
    }

    for( int i = n-1; i >= 0; i-- ) {
      double sum = x[i];
      for( int k = i+1; k < n; k++ )
	sum -= (A[k][i] * x[k]);
      x[i] = sum / p[i];
    }
  } //solve1


  public static boolean solve(double[][] A, double[] x, double[] b)
  {
    zliberror._assert(A.length == A[0].length);
    double[] p = new double[A.length];
    if (!decomp(A, p)) return false;
    solve1(A, p, b, x);
    return true;
  }

/****************************************************************
  // this based on jama, did not work yet
  /** Cholesky algorithm for symmetric and positive definite matrix.
      @param  A   Square, symmetric matrix.
      @return     Structure to access L and isspd flag.
  +/
  static void setup(const cordouble2d A, cordouble2d& L)
  {
    // Initialize.
    int n = A.yres();
    bool isspd = (A.xres()==n);

    // Main loop.
    for (int j = 0; j < n; j++) {
      //double[] Lrowj = L[j];
      double d = 0.0;
      for (int k = 0; k < j; k++) {
	//double[] Lrowk = L[k];
	double s = 0.0;
	for (int i = 0; i < k; i++) {
	  //s += Lrowk[i]*Lrowj[i];
	  s += L(k,i)*L(j,i);
	}
	L(j,k) = s = (A(j,k) - s)/L(k,k);
	d = d + s*s;
	isspd = isspd & (A(k,j) == A(j,k)); 
      }
      d = A(j,j) - d;
      isspd = isspd & (d > 0.0);
      double dpos = (d > 0.0) ? d : 0.0;
      L(j,j) = sqrt(d);
      for (int k = j+1; k < n; k++) {
	L(j,k) = 0.0;
      }
    }
    assert(isspd);
  }

  /** Solve A*X = B
      @param  B   A Matrix with as many rows as A and any number of columns.
      @return     X so that L*L'*X = B
      @exception  IllegalArgumentException  Matrix row dimensions must agree.
      @exception  RuntimeException  Matrix is not symmetric positive definite.
  +/

  static void solve1(cordouble2d& L, double *x, const double *B, int n)
  {
    for( int i = 0; i < n; i++ ) x[i] = 0.;

    // Solve L*Y = B;
    int k;
    for (k = 0; k < n; k++) {
      for (int i = k+1; i < n; i++) {
	x[i] -= x[k]*L(i,k);
      }
      x[k] /= L(k,k);
    }

    // Solve L'*X = Y;
    for (k = n-1; k >= 0; k--) {
      x[k] /= L(k,k);
      for (int i = 0; i < k; i++) {
	x[i] -= x[k]*L(k,i);
      }
    }
  } //solve1

  static void solve(cordouble2d& M, cordouble2d& Ltmp,
		    double *x, const double *b, int len)
  {
    assert(len == M.yres());
    assert(M.xres() == M.yres());
    setup(M, Ltmp);
    solve1(Ltmp, x, b, len);
  } //solve

****************************************************************/

//----------------------------------------------------------------

  // test if choleski is actually faster for 6x6 case
  public static void main(String[] cmdline)
  {
    System.out.println("choleski main");

    double[][] M = new double[6][6];
    double[] b = new double[6];
    double[] x = new double[6];

    for( int i = 0; i < 6; i++ )  b[i] = 0.1 * (i+1);

    for( int r=0; r < 6; r++ ) {
      for( int c = 0; c <= r; c++ ) {
	M[r][c] = 0.1*r + 0.1*c;
	if (r==c)
	  M[r][c] += ((r+1) + 0.1*r);
	else
	  M[c][r] = M[r][c];
      }
    }
    matrix.print("M",M);
    matrix.print("b=", b);

    for( int t = 0; t < 100000; t++ ) {

      // prevent optimization
      double[][] Ltmp = array.clone(M);

      if (true) {
	boolean ok = Choleski.solve(Ltmp, x,b);
	zliberror._assert(ok);
      }
      else {
	try {
	  x = VisualNumerics.math.DoubleMatrix.solve(Ltmp, b);
	}
	catch(Exception ex) {
	  System.err.println(ex);
	  //System.exit(1);
	}
      }

    }
    matrix.print("x=",x);

  } //main

} //Choleski

