// RbfThinPlate.java jplewis

// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Library General Public
// License as published by the Free Software Foundation; either
// version 2 of the License, or (at your option) any later version.
// 
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Library General Public License for more details.
// 
// You should have received a copy of the GNU Library General Public
// License along with this library; if not, write to the
// Free Software Foundation, Inc., 59 Temple Place - Suite 330,
// Boston, MA  02111-1307, USA.
//
// contact info:  zilla@computer.org

// implemented from Bookstein pami

package ZS.Grid;

import java.io.*;

import VisualNumerics.math.DoubleMatrix;

import ZS.matrix;
import zlib.*;

/**
 * <pre>
 * minimize int[ (d^2/dx^2)^2 + 2(d^2f/dxy)^2 + (d^2f/dy^2)^2 ]dxy
 * v=(x,y) warped to v'=(x',y')
 *
 * f(v) = a1 + ax*x + ay*y + sum w_j U(|v-v_j|)
 *
 * U(|v-v_j|) = |v-v_j|^2 log(|v-v_j|^2)
 * 	"solution to biharmonic eq  \Delta^2 u = 0"
 * 
 * ( K  P ) (w) = (v)
 * ( P' 0 ) (a)   (0)
 *
 * K=U(r,c), nxn, is zero on the diagonal
 * P is nx3, contains locations 1, x_k, y_k
 * 0 is 3x3
 * w|a has weights followed by a1, ax, ay
 * v are values to interpolate.
 *
 * ith row of D=(1,x_k,y_k)
 * a=affine part of warp, b = nonaffine,
 *  a=[a1,a2,a3]
 *  b=[b1..bn]
 *
 * Understanding lambda:
 * form a different system where the affine fit has been
 * subtracted out. Now have Kw = d,  d are the deviations
 * from affine, K is the upper nxn part of the matrix.
 * Say that lambda is huge so K is e.g. 100*I.
 *   100I w = d, so w = (1/100)I * d,
 * so the non affine weights w will be very small. 
 */
final public class RbfThinPlate implements ScatterInterpSparse2
{
  int		_verbose = 1;
  PrintWriter	_stdout = new PrintWriter(System.out);

  int		_npts;
  double[][]	_pts;
  double[]	_w;
  double	_a1;
  double	_ax;
  double	_ay;

  /**
   */
  public RbfThinPlate(double[][] pts, double[] values)
  {
    setup(pts, values, 0.);
  } //constructor


  /**
   * @param lambda  regularization, high values cause
   * the warp to approach affine.
   */
  public RbfThinPlate(double[][] pts, double[] values, double lambda)
  {
    setup(pts, values, lambda);
  } //constructor


  /**
   */
  public RbfThinPlate(float[][] fpts, float[] fvalues)
  {
    double[][] pts = zlib.toDouble(fpts);
    double[] values = zlib.toDouble(fvalues);
    setup(pts, values, 0.);
  } //constructor

  /**
   * TODO believe it is better if x,y locations are first mapped
   * to near unity before solving this.
   */
  private void setup(double[][] pts, double[] values, double lambda)
  {
    _npts = pts.length;
    zliberror.assert(_npts == values.length);
    _pts = zlib.cloneArray(pts);
    //_values = (double[])values.clone();

    double[][] L = new double[_npts+3][_npts+3];
    // assume that it is set to zero
    
    for( int r=0; r < _npts; r++ ) {
      for( int c=0; c < _npts; c++ ) {
	if (r != c) {
	  double r2 = dist2(r, c, pts);
	  L[r][c] = r2 * Math.log(r2);
	}
	else
	  L[r][c] = lambda;
      }
    }

    // P
    for( int r=0; r < _npts; r++ ) {
      L[r][_npts] = 1.;
      L[r][_npts+1] = pts[r][0];
      L[r][_npts+2] = pts[r][1];
    }

    // P^T
    for( int c=0; c < _npts; c++ ) {
      L[_npts][c] = 1.;
      L[_npts+1][c] = pts[c][0];
      L[_npts+2][c] = pts[c][1];
    }

    if (_verbose > 1) matrix.print(_stdout, "L=", L);

    double[] rhs = new double[_npts+3];	// assuming that it zeros
    for( int i=0; i < _npts; i++ )  rhs[i] = values[i];

    try {
      // TODO: warn about determinant
      _w = DoubleMatrix.solve(L, rhs);
    }
    catch(Exception x) { zliberror.die(x); }

    _a1 = _w[_npts];
    _ax = _w[_npts+1];
    _ay = _w[_npts+2];

    if (_verbose > 0) {
      for( int i=0; i < _npts; i++ ) {
	double v = interp((float)pts[i][0], (float)pts[i][1]);
	System.out.println("wanted " + values[i] +", got " + v);
      }
    } 
  } //setup

  //----------------------------------------------------------------

  /**
   * r = | pt - pt[k] |
   * rbf = a1 + a2 x + a_3 y + sum w_k*r^2*log(r^2)
   */
  public final float interp(float xf, float yf)
  {
    double sum = 0.f;
    double x = xf;
    double y = yf;
    int npts = _npts;

    for( int i=0; i < npts; i++ ) {
      double dx = x - _pts[i][0];
      double dy = y - _pts[i][1];
      double r2 = dx*dx + dy*dy;
      if (r2 != 0.0) {
	sum += (_w[i] * (r2 * Math.log(r2)));
	//System.out.println(" + w="+_w[i] +" * "+(r2*Math.log(r2)));
      }
    }

    sum += (_a1 + _ax*x + _ay*y);
    //System.out.println(" + "+_a1 +" "+_ax*x+" "+_ay*y);

    return (float)sum;
  } //interp


  //----------------------------------------------------------------


  /** return distance between points a,b
   */
  private static final double dist2(int a, int b, double[][] pts)
  {
    double dx = pts[a][0] - pts[b][0];
    double dy = pts[a][1] - pts[b][1];
    return dx*dx + dy*dy;
  }


  //----------------------------------------------------------------

  // test
  public static void main(String[] cmdline)
  {
    matrix._nf.setMinimumFractionDigits(5);
    matrix._nf.setMaximumFractionDigits(5);

    // 0          1
    //      .5
    // 1          ?

    float[][] pts = new float[][]
    {{0.f,0.f}, {1.f,0.f}, {0.f,1.f}, {0.5f,0.5f}};
    float[] vals = new float[]
    {  0.f,     1.f,      1.f,    0.5f };

    /****************
    double[][] pts = new double[][]
    {{0.,0.}, {1.,0}, {0.,1.}, {0.5,0.5}};
    double[] vals = new double[]
    {  0.,     1.,      1.,    0.5 };
    ****************/

    ScatterInterpSparse2 interp = new RbfThinPlate(pts, vals);

    // walk along the diagonal.
    for( int i=0; i < 21; i++ ) {
      float frac = (float)i / 20.f;
      float x = frac;
      float y = frac;
      float v = interp.interp(x,y);
      System.out.println(v);
    }
    
  } //main
  
} //RbfThinPlate
