//
// WaveletTransformFilter.java
//
// This filter takes an ordinary image and computes
//  its simple wavelet transform
//
// Copyright (c) 1997, Benjamin Nason Lipchak
//

package benj.awt.image;

import java.awt.image.*;

public class WaveletTransformFilter extends ImageFilter {
    private static ColorModel defaultRGB = ColorModel.getRGBdefault();
    private int width, height;
    private int rasterOrig[];
    private int rasterSend[];
    private int jLevel;

    public WaveletTransformFilter() {
	jLevel = -1;
    }

    public WaveletTransformFilter(int level) {
	if (jLevel >= 0)
	    jLevel = level;
    }

    public void setDimensions(int width, int height) {
	this.width = width;
	this.height = height;
	rasterOrig = new int[width * height];
	rasterSend = new int[width * height];

	//Limit jLevel to greatest possible value
	int n = logBase2(width);
	if (jLevel > n)
	    jLevel = n;

	if (jLevel == -1)
	    consumer.setDimensions(width, height);
	else
	    consumer.setDimensions(width / pow2(jLevel), height / pow2(jLevel));
    }

    public void setColorModel(ColorModel model) {
	consumer.setColorModel(defaultRGB);
    }

    public void setHints(int hintflags) {
	consumer.setHints(TOPDOWNLEFTRIGHT |
	    COMPLETESCANLINES | 
	    SINGLEPASS |
	    (hintflags & SINGLEFRAME));
    }

    public void setPixels(int x, int y, int w, int h, ColorModel model,
	byte pixels[], int off, int scansize)
    {
	int srcoff = off;
	int dstoff = y * width + x;
	for (int yc = 0; yc < h; yc++) {
	    for (int xc = 0; xc < w; xc++) {
		rasterOrig[dstoff++] = model.getRGB(pixels[srcoff++] & 0xff); 
	    }
	    srcoff += (scansize - w);
	    dstoff += (width - w);
	}
    }

    public void setPixels(int x, int y, int w, int h, ColorModel model,
	int pixels[], int off, int scansize)
    {
	int srcoff = off;
	int dstoff = y * width + x;
	if (model == defaultRGB) {
	    for (int yc = 0; yc < h; yc++) {
		System.arraycopy(pixels, srcoff, rasterOrig, dstoff, w);
		srcoff += scansize;
		dstoff += width;
	    }
	} else {
	    for (int yc = 0; yc < h; yc++) {
		for (int xc = 0; xc < w; xc++) {
		    rasterOrig[dstoff++] = model.getRGB(pixels[srcoff++]); 
		}
		srcoff += (scansize - w);
		dstoff += (width - w);
	    }
	}
    }

    public void imageComplete(int status) {
	if (status == IMAGEERROR || status == IMAGEABORTED) {
	    consumer.imageComplete(status);
	    return;
	}

	//Check for width (and height) of 1... make sure this gets copied
	int extent = 0;
	if (width == 1 || jLevel == 0) {
	    rasterSend = rasterOrig;
	} else {
	    //Transform data n times, where (2^n == width == height)
	    int pixelA, pixelB, pixelC, pixelD;
	    int average;
	    int n;

	    if (jLevel == -1)
		n = logBase2(width);
	    else
		n = jLevel;
	    for (int j = 0; j < n; j++) {
		extent = width / pow2(j);
		for (int dy = 0; dy < extent; dy += 2) {
		    for (int dx = 0; dx < extent; dx += 2) {
			pixelA = rasterOrig[(dy * width) + dx];
			pixelB = rasterOrig[(dy * width) + dx + 1];
			pixelC = rasterOrig[((dy + 1) * width) + dx];
			pixelD = rasterOrig[((dy + 1) * width) + dx + 1];

			average = averageColor(pixelA, pixelB, pixelC, pixelD);
			rasterSend[(dy / 2) * width + (dx / 2)] = average;
			rasterOrig[(dy / 2) * width + (dx / 2)] = average;
			rasterSend[(dy / 2) * width + ((dx + extent) / 2)] = pixelB;
			rasterSend[((dy + extent) / 2) * width + (dx / 2)] = pixelC;
			rasterSend[((dy + extent) / 2) * width + ((dx + extent) / 2)] = pixelD;
		    }
		}    
	    }
	}

	//Red rover, red rover, send data right over
	if (jLevel == -1 || jLevel == 0) {
	    int pixels[] = new int[width];
	    for (int dy = 0; dy < height; dy++) {
		for (int dx = 0; dx < width; dx++) {
		    pixels[dx] = rasterSend[(dy * width) + dx];
		}
		consumer.setPixels(0, dy, width, 1, defaultRGB, pixels, 0, width);
	    }
	} else {
	    extent /= 2;
	    int pixels[] = new int[extent];
	    for (int dy = 0; dy < extent; dy++) {
		for (int dx = 0; dx < extent; dx++) {
		    pixels[dx] = rasterSend[(dy * width) + dx];
		}
		consumer.setPixels(0, dy, extent, 1, defaultRGB, pixels, 0, extent);
	    }
	}

	consumer.imageComplete(status);
    }

    int pow2(int degree) {
	if (degree < 0)
	    return -1;

	int result = 1;
	while (degree > 0) {
	    result *= 2;
	    degree--;
	}

	return result;
    }

    int logBase2(int number) {
	int result = 0;
	while (pow2(result) != number) {
	    result++;
	    if (result >= number)
		return -1;
	}
	return result;
    }

    int averageColor(int colorA, int colorB, int colorC, int colorD) {
	int redA, redB, redC, redD, redAverage;
	int greenA, greenB, greenC, greenD, greenAverage;
	int blueA, blueB, blueC, blueD, blueAverage;
	int alphaA, alphaB, alphaC, alphaD, alphaAverage;

	redA = (colorA & 0xff);
	redB = (colorB & 0xff);
	redC = (colorC & 0xff);
	redD = (colorD & 0xff);
	redAverage = (redA + redB + redC + redD) / 4;

	greenA = (colorA & 0xff00) >> 8;
	greenB = (colorB & 0xff00) >> 8;
	greenC = (colorC & 0xff00) >> 8;
	greenD = (colorD & 0xff00) >> 8;
	greenAverage = (greenA + greenB + greenC + greenD) / 4;

	blueA = (colorA & 0xff0000) >> 16;
	blueB = (colorB & 0xff0000) >> 16;
	blueC = (colorC & 0xff0000) >> 16;
	blueD = (colorD & 0xff0000) >> 16;
	blueAverage = (blueA + blueB + blueC + blueD) / 4;

	alphaA = (colorA & 0xff000000) >> 24;
	alphaB = (colorB & 0xff000000) >> 24;
	alphaC = (colorC & 0xff000000) >> 24;
	alphaD = (colorD & 0xff000000) >> 24;
	alphaAverage = (alphaA + alphaB + alphaC + alphaD) / 4;

	return ((alphaAverage << 24) | (blueAverage << 16) |
	    (greenAverage << 8) | redAverage);
    }
}