/**********************************************************************
stats.c
A program to generate:
     - confidence intervals
     - average
     - min, max
     - standard deviation
     - linear regression fits for two fielded inputs.
It takes a stream of numbers from standard input.  You can specify
what field you want from the incoming stream.

Note: In order to speed up processing, the standard deviation is
calculated "on the fly."  This may cause the sum of the squares mto
overflow the type double. To be sure this does not happen, we check
for an overflow.

To do:

- Make it to read in environment variable for command line flags.

Copyright Notice
================
This software, stats, is copyright 1994, 1998 by Mark Claypool.

Permission to use, copy, and distribute stats in its entirety, for 
non-commercial purposes, is hereby granted without fee, provided that
the copyright notice appear in all copies.

The software may be modified for your own purposes, but modified
versions may NOT be distributed without prior consent of the author.

This software is provided 'as-is', without any express or implied
warranty.  In no event will the author be held liable for any damages
arising from the use of this software.

If you would like to do something with stats that this copyright
prohibits (such as distributing it with a commercial product, 
using portions of the source in some other program, etc.), please
contact the author (preferably via email).  Arrangements can
probably be worked out.

The author may be contacted via email at claypool@cs.wpi.edu.

***********************************************************************/
/* Include Files */
#include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <math.h>   /* need to -lm link in math library for log() */
#include <string.h> /* for strcpy() */

/********************************************************************/
/* Program Information */

#define VERSION 3.00  /* See Changes file for updates for each version */
#define EMAIL "claypool@cs.wpi.edu"

#define TRUE 1
#define FALSE 0

/* Command line flags */
#define HELP_FLAG "-h"
#define HELP2_FLAG "-help"
#define FIELD_FLAG "-f"
#define BATCH_FLAG "-b"
#define EXPONENTIAL_FLAG "-e"
#define CONFIDENCE_FLAG "-c"
#define LINES_FLAG "-l"
#define BOTH_FLAG "-xy"
#define GNUPLOT_FLAG "-gp"
#define GNUPLOT_ONLY_FLAG "-GP"
#define VERBOSE_FLAG "-v"
#define FUTURE_SAMPLES_FLAG "-s"

/* default values */
#define CONF_DEFAULT 0.0
#define F_DEFAULT 1
#define FUTURE_SAMPLES_DEFAULT 1.0

/* for parsing input */
#define TAB '\t'
#define NEWLINE '\n'
#define EOL '\0'
#define SPACE ' '
#define MAX_LINE_SIZE 160
#define COMMENT_CHAR '#'

/* return values for GetField() */
#define NO_FIELD -1
#define COMMENT -2
#define OK 1

/********************************************************************/
/* Function Prototypes */
void error(char *fmt, ...);
char *GetFlag(int *argc, char *argv[], char *flag);
void Usage(char *name);
int GetField(char line[], char word[], int field);
int GetData(char *line, int field, double *num, double *sum_sq, 
	    double *old_sum_sq, 
	    double *sum, double *min, double *max, int *first_time, 
	    double count, int do_exponential);
void PrintResults(int field, double count, double mean, double var, 
		  double std, double sum, double min, double max, 
		  double c1, double c2, double);
void CalculateData(double confidence, double sum, double sum_sq, double count, 
	      double *mean, double *var, double *std, double *c1, double *c2); 
double t_table(double count, double confidence);

/* Global Variables */
int warn = FALSE;		/* used to warn if sum_sq rolls over */
int gnuplot, gnuplot_only;
int exponential = FALSE;        /* if doing an exponential fit */
int batch_mode;			/* batch mode for samples */

/********************************************************************/
int 
main(int argc, char *argv[]) 
{
   int i;			/* loop index */
   double e;			/* error value */
   double z;			/* z table value */
   double future_samples;	/* 1/num of future samples for xy intrvls. */
   double coeff_of_determination;
   double ssr, sst;
   double A, B;			/* for linear regression fits */
   double correlation;
   double sum_x, sum_y, sum_xy, sum_esq;
   double sum_xsq, sum_ysq; 
   double c1_x, c2_x;
   double c1_y, c2_y;
   double c1_A, c2_A;
   double c1_B, c2_B;
   double std_A, std_B, std_e;
   double mean_x, mean_y, mean_xy, mean_esq;
   double var_x, var_y;
   double std_x, std_y;
   double min_x, min_y, max_x, max_y; 
   double old_sum_xsq, old_sum_ysq; /* To compare with sumsq, for overflow. */
   double x_temp, y_temp;	/* the input from the last parse */
   double confidence;		/* confidence level */
   double count_x, count_y;	/* count of the lines */
   int field_x;			/* field to take stats on */
   int field_y;			/* 2nd field if parsing (x,y) pairs */
   char line[MAX_LINE_SIZE];
   int first_time_x;		/* true to set min and max */
   int first_time_y;		/* true to set min and max */
   int verbose;
   char *f;			/* for parsing command line args */

   /* initialize system variables */
   confidence = CONF_DEFAULT;
   field_x = F_DEFAULT;
   field_y = FALSE;

   /* Parse command line */
   if (f=GetFlag(&argc, argv, HELP_FLAG))
      Usage(argv[0]);
   if (f=GetFlag(&argc, argv, HELP2_FLAG))
      Usage(argv[0]);
   if (f=GetFlag(&argc, argv, VERBOSE_FLAG))
      verbose = TRUE;
   else
      verbose = FALSE;
   if (f=GetFlag(&argc, argv, GNUPLOT_FLAG))
      gnuplot = TRUE;
   else
      gnuplot = FALSE;
   if (f=GetFlag(&argc, argv, BATCH_FLAG)) {
      batch_mode = atoi(f+2);
      if (batch_mode < 1) {
         fprintf(stderr,"batch number must be 1 or greater.\n");
	 Usage(argv[0]);
      }
   } else
      batch_mode = FALSE;
   if (f=GetFlag(&argc, argv, GNUPLOT_ONLY_FLAG))
      gnuplot_only = TRUE;
   else
      gnuplot_only = FALSE;
   if (f=GetFlag(&argc, argv, CONFIDENCE_FLAG)) {
      if (*(f+2) == '.') {
	 confidence = (atoi(f+3) / 100.0);
      } else {
	 confidence = (atoi(f+2) / 100.0);
      }
      if (confidence <= 0.5 || confidence > 1.0)
	  Usage(argv[0]);
   }
   if (f=GetFlag(&argc, argv, FUTURE_SAMPLES_FLAG)) {
      if (!*(f+2))
	 Usage(argv[0]);
      if (atof(f+2) == 0.0) {
	 future_samples = 0.0;
      } else 
	 future_samples = 1/atof(f+2);
   } else {
      future_samples = FUTURE_SAMPLES_DEFAULT;
   }
   if (f=GetFlag(&argc, argv, FIELD_FLAG)) {
      field_x = atoi(f+2);
      if (field_x <= 0)
	 Usage(argv[0]);
   }
   if (f=GetFlag(&argc, argv, FIELD_FLAG)) {
      field_y = atoi(f+2);
      if (field_y <= 0)
	 Usage(argv[0]);
   }
   if (f=GetFlag(&argc, argv, BOTH_FLAG)) {
      field_x = 1;
      field_y = 2;
   }
   if (f=GetFlag(&argc, argv, EXPONENTIAL_FLAG)) {
     if (field_y == FALSE) {
       fprintf(stderr,"Must specify 2nd field for exponential fit.\n");
	 Usage(argv[0]);
     }
     exponential = TRUE;
   }

   if (argc > 1) {
      fprintf(stderr,"Unknown options: ");
      for (i=1; i<argc; i++)
	 fprintf(stderr,"%s ", argv[i]);
      fprintf(stderr,"\n");
      Usage(argv[0]);
   }

   /* initialize variables */
   mean_x = mean_y = mean_xy = 0.0;
   count_x = count_y = 0.0;
   sum_xsq = sum_ysq = 0.0;
   old_sum_xsq = old_sum_ysq = 0.0;
   sum_x = sum_y = sum_xy = 0.0;
   c1_x = c2_x = c1_y = c2_y = 0.0;
   first_time_x = first_time_y = TRUE;

   if (exponential)  {
     fprintf(stderr, "Scaling y data to logscale for regression.\n");
   }

   /* Loop until end of the input stream. */
   while (fgets(line, MAX_LINE_SIZE, stdin) != NULL) {
      
      /* do the processing for each field */
      count_x += GetData(line, field_x, &x_temp, &sum_xsq, &old_sum_xsq, 
			 &sum_x, &min_x, &max_x, &first_time_x, count_x, 
                            FALSE);
      if (field_y) {
	 count_y += GetData(line, field_y, &y_temp, &sum_ysq, &old_sum_ysq, 
			    &sum_y, &min_y, &max_y, &first_time_y, count_y, 
                            exponential);
	 sum_xy += x_temp * y_temp;
      }

   }   

   if (count_x <= 0.0) 
      error("No input lines in field %d.", field_x);
   else 
      CalculateData(confidence, sum_x, sum_xsq, count_x, &mean_x, 
		    &var_x, &std_x, &c1_x, &c2_x);
   if (field_y) {
      if (count_y <= 0.0) 
	 error("No input lines in field %d.", field_y);
      else
	 CalculateData(confidence, sum_y, sum_ysq, count_y, &mean_y, 
		       &var_y, &std_y, &c1_y, &c2_y);
   }

   /* print out the results */
   PrintResults(field_x, count_x, mean_x, var_x, std_x, sum_x, 
		min_x, max_x, c1_x, c2_x, confidence);
   if (field_y) {
      if (!gnuplot_only) 
	 printf("\n");
      PrintResults(field_y, count_y, mean_y, var_y, std_y, 
		   sum_y, min_y, max_y, c1_y, c2_y, confidence);
   }

   if (field_y) {
      if (count_x != count_y) {
	 if (!gnuplot_only) {
	    printf("Error.  Lines in field %d to not equal lines in %d.\n",
		field_x, field_y);
	    printf("Not doing linear regression and correlation.\n");
	 }
      } else {

	 /* linear regression: y = Ax + B */
	 A = (sum_xy - count_x*mean_x*mean_y) / 
	     (sum_xsq - count_x*(mean_x*mean_x));
	 B = mean_y - A*mean_x;

	 if (count_y <= 2.0) {
	    fprintf(stderr,"Only %.0f lines of input.\n", count_y);
	    fprintf(stderr,"Unable to calculate errors on line fit.\n");
	 } else {	    
	    sum_esq = sum_ysq - B*sum_y - A*sum_xy;
	    sst = sum_ysq - (count_y * mean_y * mean_y);
	    ssr = sst - sum_esq;
	    coeff_of_determination = ssr / sst;
	    mean_esq = sum_esq / (count_y - 2);
	    std_e = sqrt(mean_esq);
	    
	    /* confidence intevals */
	    if (confidence > 0) {
	       std_A = std_e / sqrt(sum_xsq - count_x*mean_x*mean_x);
	       std_B = std_e * sqrt(1/count_x + (mean_x*mean_x) / 
				    (sum_xsq - count_x*mean_x*mean_x));
	       z = t_table(count_x - 2, confidence);
	       c1_A = A - z*std_A;
	       c2_A = A + z*std_A;
	       c1_B = B - z*std_B;
	       c2_B = B + z*std_B;
	    }


	    /* correlation */
	    correlation = (sum_xy - (sum_x*sum_y / count_x)) /
	                  (sqrt(sum_xsq - (sum_x*sum_x)/count_x) *
			   sqrt(sum_ysq - (sum_y*sum_y)/count_y));
	 }

	 /* print the results */
	 if (!gnuplot_only) {
	    printf("\n           line:  y = Ax + B\n");
	    printf("              A:  %.12f\n", A);
	    printf("              B:  %.12f\n", B);
	    if (count_y > 2.0) {
	       if (confidence > 0) {
		  printf("         left A:  %.12f\n", c1_A);
		  printf("        right A:  %.12f\n", c2_A);
		  printf("         left B:  %.12f\n", c1_B);
		  printf("        right B:  %.12f\n", c2_B);
	       }
	       printf("  error squared:  %.12f\n", sum_esq);
	       printf("            SSR:  %.12f\n", ssr);
	       printf("            SST:  %.12f\n", sst);
	       printf(" coeff. of det.:  %.12f\n", coeff_of_determination);
	       printf("    correlation:  %.12f\n", correlation);
	    }
	 }
      }

      if (exponential) {
	printf("\n");
	printf("exponential fit:  y = K r^x\n");
	printf("              K:  %.12f\n", pow(10, B));
	printf("              r:  %.12f\n", pow(10, A));
      }

      /* predicted response function with confidence intervals */
      /* first, assume a large number of observations with mean */
      if (gnuplot || gnuplot_only) {

	 if (!gnuplot_only)
	    printf("        gnuplot:\n");

	 if (confidence > 0) {
	    /* print the min and max lines with the line fit */
	    printf("(%.12f*x + %.12f) + ((%.12f*sqrt(%.12f + %.12f+(x-%.12f)*(x-%.12f)/(%.12f-%d*%.12f))))*%.12f title 'max fit' lines 2, \\\n", A, B, std_e, future_samples, 1/count_y, mean_x, mean_x, sum_xsq, (int) count_y, mean_x*mean_x, t_table(count_y - 2, confidence));
	    printf("%.12f*x + %.12f title 'best fit', \\\n", A, B);
	    printf("(%.12f*x + %.12f) - ((%.12f*sqrt(%.12f + %.12f+(x-%.12f)*(x-%.12f)/(%.12f-%d*%.12f))))*%.12f title 'min fit' lines 2\n", A, B, std_e, future_samples, 1/count_y, mean_x, mean_x, sum_xsq, (int) count_y, mean_x*mean_x, t_table(count_y - 2, confidence));

	 } else {
	    /* print only the line fit */
	    printf("%.12f*x + %.12f title 'best fit' \n", A, B);
	 }	    

      }

   }

   /* print out verbose information if it's needed */
   if (verbose == TRUE) {
      printf("        sum_xsq:  ");
      printf("%.12f\n", sum_xsq);
      if (field_y) {
	 printf("        sum_ysq:  ");
	 printf("%.12f\n", sum_ysq);
	 printf("         sum_xy:  ");
	 printf("%.12f\n", sum_xy);
      } 
   }

   /* Was there a problem with the size of the data? */
   if (warn) {
      fprintf(stderr,"Warning!  The data sample overflowed sum_sq .\n");
      fprintf(stderr,"  The standard deviation and confidence intervals\n");
      fprintf(stderr,"  may be inaccurate.\n");
   }

   /* all done */
   return 0;
}

/********************************************************************/
/* GetField- puts the nth field into the string. */
int 
GetField(char line[], char word[], int field)
{
   int count = 1, Lindex = 0, Windex = 0;

   /* see if this line should be ignored */
   if (line[0] == COMMENT_CHAR) 
      return COMMENT;

   /* Move to right field. */
   while (count < field) {

      /* remove initial whitespace */
      while (line[Lindex]==SPACE || line[Lindex]==TAB) 
	 Lindex++;
   
      /* remove word */
      while (line[Lindex]!=SPACE && line[Lindex]!=TAB &&
	     line[Lindex]!=NEWLINE && line[Lindex]!=EOL)
	 Lindex++;
  
      /* are there more words? */
      if (line[Lindex] == NEWLINE || line[Lindex] == EOL)
	 return NO_FIELD;

      count++;
   }

   /* remove initial whitespace */
   while (line[Lindex]==SPACE || line[Lindex]==TAB) 
      Lindex++;

   /* Copy in word. */
   while (line[Lindex]!=SPACE && line[Lindex]!=TAB &&
	  line[Lindex]!=NEWLINE && line[Lindex]!=EOL) {
      word[Windex] = line[Lindex];
      Windex++;
      Lindex++;
   }
   word[Windex] = EOL;

   return OK;
}


/********************************************************************/
/* GetFlag- Look for flag. Remove it from argv if found. Return pointer. */
char *
GetFlag(int *argc, char *argv[], char *flag)
{
   int i=1;
   char *found;
   char temp[100];
 
   found = NULL;
   while (i<*argc) {
      strcpy(temp, argv[i]);
      temp[strlen(flag)] = '\0';
      if (strncmp(temp, flag, strlen(flag)) == 0) {
         found = argv[i];
         while (i < *argc-1) {
            argv[i] = argv[i+1];
            i++;
         }
         *argc = *argc - 1;
      }
      i++;
   }
   return found;
}

/********************************************************************/
/* error: print an error message and die ... */
void 
error(char *fmt, ...)
{
   va_list args;                    /* a pointer to the arguments */

   va_start(args, fmt);             /* get started */
   fprintf(stderr, "error: ");      
   vfprintf(stderr, fmt, args);     /* printf with arguments */
   fprintf(stderr, "\n");      
   va_end(args);                    /* clean up */
   exit(1);
}

/********************************************************************/
/* Get the data from the right field.  Perform calculations.  If the field
 * does exist, return 1 else return 0.
 */
int
GetData(char *line, int field, double *num, double *sum_sq, 
	double *old_sum_sq, double *sum, double *min, double *max, 
	int *first_time, double count, int do_exponential) 
{
   int ern;			/* error number for parsing */
   short int ret;		/* return value */
   char word[MAX_LINE_SIZE];

   *num = 0.0;

   /* read the right field in. */
   ern = GetField(line, word, field);
   if (ern == NO_FIELD) {
      fprintf(stderr, "Warning! Field %d does not exist in line %.0f.\n", 
	      field, count);
      ret = 0;
   } else if (ern == COMMENT) {
      fprintf(stderr, "Line %.0f is a comment\n", count);
      ret = 0;
   } else {

      ern = sscanf(word, "%lf", num);
   
      /* print an error message if needed there is a non-digit. */
      if (ern != 1) {
	 fprintf(stderr, "Invalid field \"%s\" in line %f.\n", word, count);
	 ret = 0;
      } else { 

	 /* If exponential, scale by log. */
         if (do_exponential) 
	   *num = log(*num) / log(10);
	 
	 /* Add the number to the running total for stddev and avg */
	 if (batch_mode) 
	    *sum_sq += ((*num)*batch_mode) * ((*num)*batch_mode); 
	 else 
	    *sum_sq += (*num)*(*num);
	 *sum += *num;
	 
	 /* Check if we should report a warning if the sumsq overflowed. */
	 if (*sum_sq < *old_sum_sq || *sum_sq < 0)
	    warn = TRUE;
	 *old_sum_sq = *sum_sq;
	 
	 /* Compute min and max.  The first time through, initialize. */
	 if (*first_time) {
	    *min = *num;
	    *max = *num;
	    *first_time = FALSE;
	 }
	 if (*num > *max) 
	    *max = *num;
	 if (*num < *min)
	    *min = *num;

	 ret = 1;
      }
   }

   return ret;
}

/********************************************************************/
void 
CalculateData(double confidence, double sum, double sum_sq, double count, 
	      double *mean, double *var, double *std, double *c1, double *c2) 
{
   double z;			/* t table value */

   /* Calculate mean and variance  */
   if (batch_mode) {
      *mean = (sum / count) / batch_mode;
      if (count > 1)
	*var = (sum_sq - ((sum * sum) / count)) / 
	  ((batch_mode * count - 1) * batch_mode * batch_mode); 
      else
	*var = 0;
   } else {
      if (count > 1)
	*var = (sum_sq - ((sum * sum) / count)) / (count - 1); 
      else
	*var = 0;
      *mean = sum / count;
   }   

   /* Calculate standard deviation */
   if (count > 1 && var != 0)
     *std = sqrt(*var);
   else
     *std = 0.0;
          
   /* Calculate confidence intervals */
   if (confidence > 0.0) { 

     if (count > 1) {
       z = t_table(count  - 1, confidence);
       *c1 = *mean - z * (*std / sqrt(count));
       *c2 = *mean + z * (*std / sqrt(count));
       if (count < 30.0 && confidence != 0.95) {
	 fprintf(stderr,"Warning! 30 samples are needed for accurate ");
	 fprintf(stderr,"confidence intervals.\n");
       }
     } else {
       *c1 = 0;
       *c2 = 0;
     }
   }

}

/********************************************************************/
/* print the results */
void
PrintResults(int field, double count, double mean, double var, double std, 
	     double sum, double min, double max, double c1, double c2, 
	     double confidence)
{
   if (!gnuplot_only) {
      printf("          Field:  ");
      printf("%d\n", field);
      printf("          lines:  ");
      printf("%.0f\n", count);
      printf("           mean:  ");
      printf("%.12f\n", mean);
      printf("       variance:  ");
      printf("%.12f\n", var);
      printf("        std dev:  ");
      printf("%.12f\n", std);
      printf("std err of mean:  ");
      printf("%.12f\n", std/sqrt(count));
      if (!batch_mode) {
	 printf("            sum:  ");
	 printf("%.12f\n", sum);
	 printf("            min:  ");
	 printf("%.12f\n", min);
	 printf("            max:  ");
	 printf("%.12f\n", max);
      }
      if (confidence > 0) {
	 printf("     confidence:  ");
	 printf("%.0f%%\n", confidence*100.0);      
	 printf("  left endpoint:  ");
	 printf("%.12f\n", c1);
	 printf(" right endpoint:  ");
	 printf("%.12f\n", c2);
      }
   }
   if (gnuplot || gnuplot_only) {
      if (!gnuplot_only) {
	 printf("        gnuplot:\n");
      }
      if (confidence > 0) {
	 printf("%12f  %12f  %12f\n",  mean, c1, c2);
      } else {
	 printf("%12f\n",  mean);
      }
   }
}

/********************************************************************/
/* output t table value for given confidence and degrees of freedom */
double 
t_table(double freedom, double confidence)
{
   double num, denom, n;	/* misc variables for computation */
   double z;			/* t table value computed */
   static double t[] = {	/* .95 for when num samples < 30 */
      6.314, 2.920, 2.353, 2.132, 2.015, 1.943, 1.895, 1.860, 
      1.833, 1.812, 1.796, 1.782, 1.771, 1.761, 1.753, 1.746, 
      1.740, 1.734, 1.729, 1.725, 1.721, 1.717, 1.714, 1.711,
      1.708, 1.706, 1.703, 1.701, 1.699, 1.697};

   /* Compute z value. We can use an approximation formula if we
      have more than 30 samples.  If not, we have a table typed
      in for 95% confidence intervals.  If not, we should warn. */
   if (freedom < 30.0 && confidence == 0.95) {
      z = t[(int) freedom];
   } else {
      n = sqrt(log(1/((1-confidence)*(1-confidence)))); 
      num = (2.515517 + 0.802853 * n + 0.010328 * n*n);
      denom = (1 + 1.432788 * n + 0.189269 *n*n + 0.001308 * n*n*n);
      z = n - num/denom;
   }
   return z;
}

/********************************************************************/
/* Usage
 * print message and quit 
 */
void 
Usage(char *name)
{
   fprintf(stderr,"%s - a summary statistics program\n", name);
   fprintf(stderr,"version %.2f, by Mark Claypool\n", VERSION);
   fprintf(stderr,"Send bugs, suggestions to %s\n", EMAIL);
   fprintf(stderr,"usage: %s <flags>, where flags are:\n", name);
   fprintf(stderr,"\t%s or %s\tthis message\n", HELP_FLAG, HELP2_FLAG);
   fprintf(stderr,"\t%s#\tprobability confidence 0.5 < c <= 1.0\n", 
	   CONFIDENCE_FLAG);
   fprintf(stderr,"\t%s#\tfield (use twice for (x,y) pairs)\n", FIELD_FLAG);
   fprintf(stderr,"\t%s\tread (x,y) from field 1 and field 2\n", BOTH_FLAG);
   fprintf(stderr,"\t%s\tcompute exponential regresstion (requires 2 fields)\n", EXPONENTIAL_FLAG);
   fprintf(stderr,"\t%s#\tbatch mode (many small samples in one large)\n", 
	   BATCH_FLAG);
   fprintf(stderr,"\t%s\toutput lines for gnuplot\n", GNUPLOT_FLAG);
   fprintf(stderr,"\t%s\toutput only lines for gnuplot\n", GNUPLOT_ONLY_FLAG);
   fprintf(stderr,"\t%s#\tnumber of future samples (0 for inf.). Only w/gnuplot output\n", FUTURE_SAMPLES_FLAG);
   fprintf(stderr,"\t%s\tprint out extra information\n", VERBOSE_FLAG);
   fprintf(stderr,"\tDefaults are: %s%.2f %s%d %s%.0f\n", CONFIDENCE_FLAG, 
	   CONF_DEFAULT, FIELD_FLAG, F_DEFAULT, FUTURE_SAMPLES_FLAG, 
	   FUTURE_SAMPLES_DEFAULT);
   exit(1);
}

/********************************************************************
"Beware of bugs in the above code; I have only proved it correct, not
tried it."
	-- Donald E. Knuth
********************************************************************/
