/******************************************************************************
* WAVELET.C                                                                   *
*                                                                             *
* Necessary functions for wavelet analysis.                                   *
*                                                                             *
* James Holliday                                                              *
* University of California - Davis                                            *
*                                                                             *
******************************************************************************/


/* Standard include files */
#include <math.h>
#include <complex.h>

/* Custom include files */
#include "fftw3.h"
#include "wavelet.h"

/* Custom functions and macros */
#define cabs2(Z) (cabs(Z)*cabs(Z))

/* Factoral function (for Paul wavelet) */
double factoral(int N)
{
  if ( N <= 1 ) return 1;
  else          return N * factoral (N-1);
}

/* Gamma function (for DOG wavelet).  Returns a log_e value. */
double gammaln(double xx)
{
  double x,y,tmp,ser;
  double cof[6]={76.18009172947146,-86.50532032941677,24.01409824083091,
		 -1.231739572450155,0.1208650973866179e-2,-0.5395239384953e-5};
  int j;
  y=x=xx;
  tmp=x+5.5;
  tmp -= (x+0.5)*log(tmp);
  ser=1.000000000190015;
  for (j=0;j<=5;j++) ser += cof[j]/++y;
  return -tmp+log(2.5066282746310005*ser/x);
}


/*****************************************************************************\
*                                                                             *
* CREATE_WAVELET                                                              *
*                                                                             *
*  This function creates a daughter wavelet (in frequency space) for a        *
*  specified wavelet function.  On completion, the wavelet vectors are        *
*  filled and the "Fourier Period" and "e-Folding Factor" are calculated.     *
*                                                                             *
*  Parameters:                                                                *
*    cdouble  wavelet[N] - Vector space for daughter wavelet            .     *
*    int      N          - Length of vector.                                  *
*    int      mother     - Type of mother wavelet to use:                     *
*                             0 = Morlet                                      *
*                             1 = Paul                                        *
*                             2 = DOG (derivative of Gaussian)                *
*    double   param      - Mother wavelet parameter.  If <0 default is used:  *
*                             Morlet = w0 (wave number)     [6]               *
*                             Paul   = m  (order)           [4]               *
*                             DOG    = m  (mth derivative)  [2]               *
*    double   scale      - Wavelet scale for construcing daughter wavelet.    *
*    double   dt         - Sampling rate.                                     *
*    double*  period     - Fourier period (calculated in units of dt).        *
*    double*  folding    - e-folding factor.                                  *
*                                                                             *
*  James Holliday                                                             *
*  University of California - Davis                                           *
*                                                                             *
\*****************************************************************************/

void create_wavelet(cdouble  wavelet[],
		    int      N,
		    int      mother,
		    double   param,
		    double   scale,
		    double   dt,
		    double*  period,
		    double*  folding)
{
  double w     = 0.0;  /* wave number           */
  cdouble norm = 0.0;  /* wavelet normalization */
  int    k     = 0;    /* iterating variable    */

  /* Morelet Wavelet ------------------------------------------------------- */
  if ( mother == MORLET )
  {
    /* Check the input parameter */
    if ( param < 0.0 )  param = 6.0;

    /* Calculate normalization factor */
    //norm = 1.88279252755 * sqrt(scale/dt);
    norm = sqrt(2.0 * M_PI * scale / dt) * pow(M_PI,-0.25);

    /* Morlet transform is only for positive frequencies */
    for (k=0; k<N/2; k++)
    {
      /* Calculate wave number w(k) */
      w = 2.0 * M_PI * k / (N * dt);

      /* Calculate daughter wavelet */
      wavelet[k] = norm * exp(-0.5 * (scale*w - param) * (scale*w - param) );
    }

    /* Zero out the negative frequencies */
    for (k=(N/2)+1; k<N; k++)  wavelet[k] = 0.0;

    /* Calculate the Fourier period */
    *period = 4.0 * M_PI * scale / (param + sqrt(2.0 + (param*param)));

    /* Calculate the e-Folding factor */
    *folding = 4.0 * M_PI / (param + sqrt(2.0 + (param*param))) / sqrt(2.0);
  }

  /* Paul Wavelet ---------------------------------------------------------- */
  if ( mother == PAUL )
  {
    /* Check the input parameter */
    if ( param < 0.0 )  param = 4.0;

    /* Calculate normalization factor */
    norm = sqrt(2.0 * M_PI * scale / dt) *
           pow(2.0,param) / sqrt(param * factoral(2*param - 1));

    /* Paul transform is only for positive frequencies */
    for (k=0; k<N/2; k++)
    {
      /* Calculate wave number w(k) */
      w = 2.0 * M_PI * k / (N * dt);

      /* Calculate daughter wavelet */
      wavelet[k] = norm * pow(scale*w , param) * exp(-scale * w);
    }

    /* Zero out the negative frequencies */
    for (k=(N/2)+1; k<N; k++)  wavelet[k] = 0.0;

    /* Calculate the Fourier period */
    *period = 4.0 * M_PI * scale / (2.0*param + 1.0);

    /* Calculate the e-Folding factor */
    *folding = 4.0 * M_PI *sqrt(2.0) / (2.0*param + 1.0);
  }

  /* DOG Wavelet ----------------------------------------------------------- */
  if ( mother == DOG )
  {
    /* Check the input parameter */
    if ( param < 0.0 )  param = 2.0;

    /* Calculate normalization factor */
    norm = -sqrt(2.0 * M_PI * scale / dt) *
           cpow(I,param) / sqrt(exp(gammaln(param+0.5)));

    /* DOG transform is over all frequencies */
    for (k=0; k<N; k++)
    {
      /* Calculate wave number w(k) */
      if (k <= N/2)  w =  2.0 * M_PI *    k  / (N * dt);
      else           w = -2.0 * M_PI * (N-k) / (N * dt);

      /* Calculate daughter wavelet */
      wavelet[k] = norm * pow(scale*w , param) * exp(-scale*scale*w*w/2.0);
    }

    /* Calculate the Fourier period */
    *period = 2.0 * M_PI * scale * sqrt(2.0 / (2.0*param + 1.0));

    /* Calculate the e-Folding factor */
    *folding = 2.0 * M_PI * sqrt(2.0 / (2.0*param + 1.0)) / sqrt(2.0);
  }

  return;
}


/*****************************************************************************\
*                                                                             *
* WAVELET_TRANSFORM                                                           *
*                                                                             *
*  This function performs a wavelet analysis on an input (real or complex)    *
*  time series based on the given parameters.  Input NULL vectors indicate    *
*  not to save data.  Note that the output matrix vectors POWER and PHASE     *
*  are filled in row-major format:  M[i][j] -> M[j + (Nj * i)].  The power    *
*  spectrum is normalized by the variance of the time series.                 *
*                                                                             *
*  Parameters:                                                                *
*    cdouble input[N]    - Input time series.                                 *
*    int     N           - Length of time series vectors.                     *
*    double  dt          - Amount of time between points (sampling time).     *
*    int     mother      - Type of mother wavelet to use:                     *
*                             0 = Morlet                                      *
*                             1 = Paul                                        *
*                             2 = DOG (derivative of Gaussian)                *
*    double  param       - Mother wavelet parameter.  If <0 default is used:  *
*                             Morlet = w0 (wave number)     [6]               *
*                             Paul   = m  (order)           [4]               *
*                             DOG    = m  (mth derivative)  [2]               *
*    int     J           - Number of wavelet scales to calculate.             *
*    double  s0          - Starting wavelet scale.                            *
*    double  ds          - Spaceing between discrete wavelet scales.          *
*    double  power[N][J] - Storage space for calculated wavelet power.        *
*    double  phase[N][J] - Storage space for calculated wavelet phase info.   *
*    double  spectrum[J] - Storage space for calculated wavelet spectrum.     *
*    double  sap[N]      - Storage space for calculated scale-average power.  *
*    double  period[J]   - Storage space for calculated Fourier period.       *
*    double  coi[N]      - Storage space for calculated cone-of-influence.    *
*                                                                             *
*  James Holliday                                                             *
*  University of California - Davis                                           *
*                                                                             *
\*****************************************************************************/

void wavelet_transform(cdouble input[],
		       int     length,
		       double  dt,
		       int     mother,
		       double  param,
		       int     J,
		       double  s0,
		       double  ds,
		       double  power[],
		       double  phase[],
		       double  spectrum[],
		       double  sap[],
		       double  period[],
		       double  coi[])
{
  int      N       = 0;     /* Length of vectors with padding  */
  cdouble  sum     = 0.0;   /* Summation variable for data     */
  double   var     = 0.0;   /* Variance of mean-removed input  */
  double   scale   = 0.0;   /* Wavelet scale                   */
  cdouble* wavelet = NULL;  /* Storage for RE daughter wavelet */
  double   fp      = 0.0;   /* Fourier period                  */
  double   folding = 0.0;   /* e-folding factor                */
  int      i       = 0;     /* Iterating variable              */
  int      j       = 0;     /* Iterating variable              */

  /* FFTW variables */
  fftw_complex *data = NULL;
  fftw_complex *transform = NULL;
  fftw_complex *convolution = NULL;
  fftw_plan forward;
  fftw_plan reverse;

  /* Allocate memory */
  while (N <= length)  N = pow(2 , i++);

  wavelet = fftw_malloc(sizeof(fftw_complex) * N);
  data = fftw_malloc(sizeof(fftw_complex) * N);
  transform = fftw_malloc(sizeof(fftw_complex) * N);
  convolution = fftw_malloc(sizeof(fftw_complex) * N);

  /* Initialize the arrays */
  for(i=0; i<N; i++)
  {
    data[i] = 0.0;
    transform[i] = 0.0;
    convolution[i] = 0.0;
  }

  /* Create FFTW plan files */
  forward = fftw_plan_dft_1d(N,data,transform,FFTW_FORWARD,FFTW_ESTIMATE);
  reverse = fftw_plan_dft_1d(N,convolution,data,FFTW_BACKWARD,FFTW_ESTIMATE);

  /* Remove the mean from the input time series and calculate the variance. */
  for(i=0; i<length; i++)
  {
    if ( input != NULL )  sum += input[i];
  }

  for(i=0; i<length; i++)
  {
    if ( input != NULL )  data[i] = input[i] - (sum / length);

    var += cabs2(data[i]);
  }

  var = (var / length);

  /* Forward transform the input data */
  fftw_execute(forward);

  /* Main wavelet loop ----------------------------------------------------- */
  for (j=0; j<J; j++)
  {
    /* Calculate the wavelet scale */
    scale = s0 * pow(2.0 , j*ds);

    /* Create the daughter wavelet */
    create_wavelet(wavelet,N,mother,param,scale,dt,&fp,&folding);

    /* Convolute and reverse transform */
    for (i=0; i<N; i++)  convolution[i] = transform[i] * wavelet[i] / N;
    fftw_execute(reverse);

    /* Save the power and phase data of the transform */
    /* Also calculate the global wavelet spectrum */
    for (i=0; i<length; i++)
    {
      if ( power != NULL )
	power[j + J*i] = cabs2(data[i]) / var;

      if ( phase != NULL )
	phase[j + J*i] = catan(data[i]);

      if ( spectrum != NULL )
	spectrum[j] += cabs2(data[i]) / (length*var);

      if (sap != NULL )
	sap[i] += cabs2(data[i]) / (J*var);
    }

    /* Save the Fourier period associated with wavelet scale */
    if ( period != NULL )  period[j] = fp;
  }

  /* Calculate the Cone-Of-Influence */
  if ( coi != NULL )
  {
    for (i=0; i<length/2; i++)  
    {
      coi[       i] = i * dt * folding;
      coi[length-i] = i * dt * folding;
    }
  }

  /* Deallocate memory */
  fftw_free(wavelet);
  fftw_free(data);
  fftw_free(transform);
  fftw_free(convolution);
  fftw_destroy_plan(forward);
  fftw_destroy_plan(reverse);

  return;
}


/*****************************************************************************\
*                                                                             *
* SAVE_DATA                                                                   *
*                                                                             *
*  This function takes in data vectors and creates a tab delimeted output     *
*  file.  The DATA matrix vector is assumed to be in row-major format and     *
*  can be one dimensional (in which case Ny = 1).  Setting Xscale and Yscale  *
*  to NULL results in the matrix index being used for the (x,y) location.     *
*  A value of 0 is returned if the file was opened correctly, -1 is returned  *
*  otherwise.                                                                 *
*                                                                             *
*  Parameters:                                                                *
*    char*  name        - Name of output file to create.                      *
*    double data[Nx*Ny] - Data matrix vector to output to file.               *
*    int    Nx          - Length of rows in matrix vector.                    *
*    int    Ny          - Length of columns in matrix vector (1 for 1D data). *
*    double Xscale[Nx]  - X scale values.                                     *
*    double Yscale[Ny]  - Y scale values.                                     *
*                                                                             *
*  James Holliday                                                             *
*  University of California - Davis                                           *
*                                                                             *
\*****************************************************************************/

int save_data(char*  name,
              double data[],
              int    Nx,
              int    Ny,
	      double Xscale[],
	      double Yscale[])
{
  FILE*  output = NULL;  /* Output file for data */
  double x      = 0.0;   /* X-scale value        */
  double y      = 0.0;   /* Y-scale value        */
  int    i      = 0;     /* Iterating variable   */
  int    j      = 0;     /* Iterating variable   */

  /* Open the file for writting */
  output = fopen(name,"w");
  if ( output == NULL )  return -1;

  /* Save the data */
  for (i=0; i<Nx; i++)
  {
    /* Get x scale values if specified */
    if ( Xscale != NULL )  x = Xscale[i];
    else                   x = i;

    /* Check for 1D data */
    if (Ny == 1) fprintf(output,"%lf\t%lf\n",x,data[i]);
    else
    {
      for (j=0; j<Ny; j++)
      {
	/* Get y scale values if specified */
	if ( Yscale != NULL )  y = Yscale[j];
	else                   y = j;

	/* Write to file */
        fprintf(output,"%lf\t%lf\t%lf\n",x,y,data[j + Ny*i]);
      }
    }
  }

  /* Close the output file */
  fclose(output);

  /* Return normally */
  return 0;
}
