/* dft.c
   Michael A. Gennert
   14 Oct 2002

   Take dft of an image

   Input parameters:
   input = input file
   output = output DFT file
   log_output = log of DFT
   orig_output = (hopefully) the IDFT of the DFT = input

   Compile String:
     cc dft.c img.o complex.o -o dft -O -lm
*/

#include <stdio.h>
#include <math.h>
#include "img.h"
#include "cimg.h"


#define DFT_FORWARD -1.0
#define DFT_INVERSE +1.0


// Function prototypes
void print_help_and_exit ();
cimg *checkerboard (char *, cimg *);
cimg *dft (char *, cimg *);
cimg *idft (char *, cimg *);
dimg *log_img (char *, dimg *);
double min_dimg (dimg *);
double max_dimg (dimg *);
unsigned char max_img (img *);
unsigned char min_img (img *);
cimg *dft_both (char *, cimg *, double, double);


int main (int argc, char** argv) {
  char *prog_name = argv[0];
  img *input_img, *output_img, *log_output_img, *orig_output_img;
  cimg *cinput_img, *dft_img, *coutput_img;
  dimg *mag_dft_img, *log_mag_dft_img;

  if (argc != 5) { print_help_and_exit (); }

  /* read input */
  input_img = read_img (prog_name, argv[1]);

  /* make sure dimensions are even and 1 color */
  if ( (input_img->cols % 2) != 0 ) {
    fprintf (stderr, "%s: image %s columns %d not an even number\n",
	     prog_name, argv[1], input_img->cols);
    exit (-1);
  }
  if ( (input_img->rows % 2) != 0 ) {
    fprintf (stderr, "%s: image %s rows %d not an even number\n",
	     prog_name, argv[1], input_img->rows);
    exit (-1);
  }
  if (input_img->colors != 1) {
    fprintf (stderr, "%s: image %s has %d colors, must be greyscale\n",
	     prog_name, argv[1], input_img->colors);
    exit (-1);
  }
    
  printf ("inp min = %d, max = %d\n",
	  min_img (input_img), max_img (input_img));

  /* Make complex input */
  cinput_img = img2cimg (prog_name, input_img);

  /* Take DFT */
  dft_img = dft (prog_name, cinput_img);

  /* Create output & log output */
  mag_dft_img = mg_part (prog_name, dft_img);

  printf ("mag min = %f, max = %f\n",
	  min_dimg (mag_dft_img), max_dimg (mag_dft_img));

  output_img = dimg2img_scale (prog_name, mag_dft_img); 

  printf ("mag img min = %d, max = %d\n",
	  min_img (output_img), max_img (output_img));

  log_mag_dft_img = log_img (prog_name, mag_dft_img); 

  printf ("log min = %f, max = %f\n",
	  min_dimg (log_mag_dft_img), max_dimg (log_mag_dft_img));

  log_output_img = dimg2img (prog_name, log_mag_dft_img);

  printf ("log img min = %d, max = %d\n",
	  min_img (log_output_img), max_img (log_output_img));


  /* Compute IDFT and take real part*/
  coutput_img = idft (prog_name, dft_img);
  orig_output_img = dimg2img (prog_name, re_part (prog_name, coutput_img));

  /* write outputs */
  write_img (prog_name, output_img, argv[2]);
  write_img (prog_name, log_output_img, argv[3]);
  write_img (prog_name, orig_output_img, argv[4]);
  
  exit (0);
} // END main


void print_help_and_exit () {
  printf ("Usage: dft input output log_output orig_output\n");
  exit (0);
} // END print_help_and_exit


cimg *dft (char *prog_name, cimg *in) {
  cimg *check_in = checkerboard (prog_name, in);

  return dft_both (prog_name, check_in,
		  1.0 / (check_in->rows * check_in->cols), DFT_FORWARD);
} // END dft_img


cimg *idft (char *prog_name, cimg *in) {
  cimg *out = dft_both (prog_name, in, 1.0, DFT_INVERSE);

  return checkerboard (prog_name, out);
} // END dft_img


cimg *dft_both (char *prog_name, cimg *in, double scale, double for_inv) {
  long i, j, u, v;
  double re, im;

  long rows = in->rows;
  long cols = in->cols;
  complex *in_data = in->image;
  complex *in_pixel;

  double two_pi_u_over_m, two_pi_v_over_n;
  double cs, sn;

  cimg *temp_img = create_cimg (prog_name, cols, rows);
  complex *temp_data = temp_img->image;

  cimg *result = create_cimg (prog_name, cols, rows);
  complex *result_data = result->image;

  printf ("ready to xform cols\n");

  /* transform along the column dimension */
  for (j = 0; j < rows; j++) {
    for (u = 0; u < cols; u++) {
      re = 0.0;
      im = 0.0;
      two_pi_u_over_m = for_inv * 2.0 * M_PI * u / cols;
      for (i = 0; i < cols; i++) {
	in_pixel = in_data + j * cols + i;
	cs = cos (two_pi_u_over_m * i);
	sn = sin (two_pi_u_over_m * i);
	re += in_pixel->r * cs - in_pixel->i * sn;
	im += in_pixel->i * cs + in_pixel->r * sn;
      }
      temp_data[j * cols + u].r = re;
      temp_data[j * cols + u].i = im;
    }
  }

  printf ("ready to xform rows\n");

  /* transform along the row dimension */
  for (u = 0; u < cols; u++) {
    for (v = 0; v < rows; v++) {
      re = 0.0;
      im = 0.0;
      two_pi_v_over_n = for_inv * 2.0 * M_PI * v / rows;
      for (j = 0; j < rows; j++) {
	in_pixel = temp_data + j * cols + u;
	cs = cos (two_pi_v_over_n * j);
	sn = sin (two_pi_v_over_n * j);
	re += in_pixel->r * cs - in_pixel->i * sn;
	im += in_pixel->i * cs + in_pixel->r * sn;
      }
      result_data[v * cols + u].r = scale * re;
      result_data[v * cols + u].i = scale * im;
    }
  }
  printf ("xforms done\n");

  return result;
} // END dft_both


cimg *checkerboard (char *prog_name, cimg *in) {
  cimg *out;
  long i, j;
  complex *pout;

  out = copy_cimg (prog_name, in);
  
  pout = out->image;

  for (i = 0; i < out->rows; i++) {
    for (j = 0; j < out->cols; j++) {
      // if i+j is odd, scale by -1
      if ( ((i + j) % 2) == 1) {
	pout->r = - pout->r;
	pout->i = - pout->i;
      }
      pout++;
    }
  }
  return out;
} // END checkerboard


dimg *log_img (char *prog_name, dimg *in) {
  dimg *out;
  double *pin, *pout;
  double cons;
  long pixel;
  
  out = create_dimg (prog_name, in->cols, in->rows);

  pin = in->image;
  pout = out->image;

  cons = 255.0 / log (256.0);

  for (pixel = 0; pixel < in->size; pixel ++) {
    *(pout++) =  (unsigned char) (cons * log (*(pin++) + 1.0));
  }

  return out;
} // END log_img


double min_dimg (dimg *img) {
  long i;
  double *pin = img->image;
  double min = *pin;

  for (i = 0; i < img->size; i++) {
    if (*pin < min) { min = *pin; }
    pin++;
  }

  return min;
} // END min_dimg


double max_dimg (dimg *img) {
  long i;
  double *pin = img->image;
  double max = *pin;

  for (i = 0; i < img->size; i++) {
    if (*pin > max) { max = *pin; }
    pin++;
  }

  return max;
} // END max_dimg


unsigned char min_img (img *img) {
  long i;
  unsigned char *pin = img->image;
  unsigned char min = *pin;

  for (i = 0; i < img->size; i++) {
    if (*pin < min) { min = *pin; }
    pin++;
  }

  return min;
} // END min_img


unsigned char max_img (img *img) {
  long i;
  unsigned char *pin = img->image;
  unsigned char max = *pin;

  for (i = 0; i < img->size; i++) {
    if (*pin > max) { max = *pin; }
    pin++;
  }

  return max;
} // END max_img


