#include "TePDIWaveletAtrous.hpp"

#include <TeAgnostic.h>
#include "TePDIUtils.hpp"
#include "TePDITypes.hpp"
#include "TePDIStatistic.hpp"
#include "TePDIPrincipalComponents.hpp"
#include "TeRasterRemap.h"
#include <TeMatrix.h>
#include <TeUtils.h>
#include <math.h>
#include <queue>

TePDIWaveletAtrous::TePDIWaveletAtrous()
{
}

TePDIWaveletAtrous::~TePDIWaveletAtrous()
{
}

void TePDIWaveletAtrous::ResetState(const TePDIParameters& params)
{
}

bool TePDIWaveletAtrous::CheckParameters(const TePDIParameters& parameters) const
{
	TeWaveletAtrousDirection direction;
	TEAGN_TRUE_OR_RETURN(parameters.GetParameter("direction", direction), "Missing parameter: direction");

	if (direction == DECOMPOSE)
	{
		TePDITypes::TePDIRasterPtrType input_raster;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("input_raster", input_raster), "Missing parameter: input_raster");
		TEAGN_TRUE_OR_RETURN(input_raster.isActive(), "Invalid parameter: input_raster inactive");
		TEAGN_TRUE_OR_RETURN(input_raster->params().status_ != TeRasterParams::TeNotReady, "Invalid parameter: input_raster not ready");
	
		int band;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("band", band), "Missing parameter: band");
		TEAGN_TRUE_OR_RETURN(band < input_raster->nBands(), "Invalid parameter: band number");
	
		int levels;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("levels", levels), "Missing parameter: levels");
	
		TeFilterBanks filterType;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("filterType", filterType), "Missing parameter: filterType");
	
		TePDITypes::TePDIRasterPtrType multi_raster;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("multi_raster", multi_raster), "Missing parameter: multi_raster");

		bool fit_histogram;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("fit_histogram", fit_histogram), "Missing parameter: fit_histogram");

		TePDITypes::TePDIRasterVectorType output_wavelets;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("output_wavelets", output_wavelets), "Missing parameter: output_wavelets");
		TEAGN_TRUE_OR_RETURN((int)output_wavelets.size() == (levels + 1), "Invalid output rasters number");
	
		for(unsigned int b = 0; b < output_wavelets.size(); b++)
		{
			TEAGN_TRUE_OR_RETURN(output_wavelets[b].isActive(), "Invalid parameter: output_wavelets " + Te2String(b) + " inactive");
			TEAGN_TRUE_OR_RETURN(output_wavelets[b]->params().status_ != TeRasterParams::TeNotReady, "Invalid parameter: output_wavelets " + Te2String(b) + " not ready");
		}
	}
	else if (direction == RECOMPOSE)
	{
		std::vector<TePDITypes::TePDIRasterVectorType> input_rasters_wavelets;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("input_rasters_wavelets", input_rasters_wavelets), "Missing parameter: input_rasters_wavelets");
		for(unsigned int w = 0; w < input_rasters_wavelets.size(); w++)
			for(unsigned int b = 0; b < input_rasters_wavelets[w].size(); b++)
			{
				TEAGN_TRUE_OR_RETURN(input_rasters_wavelets[w][b].isActive(), "Invalid parameter: input_rasters_wavelets inactive");
				TEAGN_TRUE_OR_RETURN(input_rasters_wavelets[w][b]->params().status_ != TeRasterParams::TeNotReady, "Invalid parameter: input_rasters_wavelets not ready");
			}
	
		int rasters_levels;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("rasters_levels", rasters_levels), "Missing parameter: rasters_levels");
		for(unsigned int w = 0; w < input_rasters_wavelets.size(); w++)
			TEAGN_TRUE_OR_RETURN((rasters_levels + 1) == (int)input_rasters_wavelets[w].size(), "Invalid parameter: rasters_levels not ready");
	
		std::vector<TePDITypes::TePDIRasterVectorType> reference_raster_wavelets;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("reference_raster_wavelets", reference_raster_wavelets), "Missing parameter: reference_raster_wavelets");
		for(unsigned int w = 0; w < reference_raster_wavelets.size(); w++)
			for(unsigned int b = 0; b < reference_raster_wavelets[w].size(); b++)
			{
				TEAGN_TRUE_OR_RETURN(reference_raster_wavelets[w][b].isActive(), "Invalid parameter: reference_raster_wavelets inactive");
				TEAGN_TRUE_OR_RETURN(reference_raster_wavelets[w][b]->params().status_ != TeRasterParams::TeNotReady, "Invalid parameter: reference_raster_wavelets not ready");
			}
	
		int reference_levels;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("reference_levels", reference_levels), "Missing parameter: reference_levels");
		for(unsigned int w = 0; w < reference_raster_wavelets.size(); w++)
			TEAGN_TRUE_OR_RETURN((reference_levels + 1) == (int)reference_raster_wavelets[w].size(), "Invalid parameter: reference_levels not ready");
	
		double channel_min_level;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("channel_min_level", channel_min_level), "Missing parameter: channel_min_level");
	
		double channel_max_level;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("channel_max_level", channel_max_level), "Missing parameter: channel_max_level");
	
		TePDITypes::TePDIRasterVectorType output_rasters;
		TEAGN_TRUE_OR_RETURN(parameters.GetParameter("output_rasters", output_rasters), "Missing parameter: output_rasters");
	
		for(unsigned int b = 0; b < output_rasters.size(); b++)
		{
			TEAGN_TRUE_OR_RETURN(output_rasters[b].isActive(), "Invalid parameter: output_rasters " + Te2String(b) + " inactive");
			TEAGN_TRUE_OR_RETURN(output_rasters[b]->params().status_ != TeRasterParams::TeNotReady, "Invalid parameter: output_rasters " + Te2String(b) + " not ready");
		}

	}
	else
		return false;

	return true;
}

bool filterBank(int filterBankType,TeMatrix &filter)
{	
	if (filterBankType == TePDIWaveletAtrous::B3SPLINE)
	{
		filter.Init(5, 5, 0.0);
		std::queue<double> maskWeights;
		maskWeights.push(1.0/256.0);maskWeights.push(1.0/64.0);maskWeights.push(3.0/128.0);maskWeights.push(1.0/64.0);maskWeights.push(1.0/256.0);
		maskWeights.push(1.0/64.0);maskWeights.push(1.0/16.0);maskWeights.push(3.0/32.0);maskWeights.push(1.0/16.0);maskWeights.push(1.0/64.0);
		maskWeights.push(3.0/128.0);maskWeights.push(3.0/32.0);maskWeights.push(9.0/64.0);maskWeights.push(3.0/32.0);maskWeights.push(3.0/128.0);
		maskWeights.push(1.0/64.0);maskWeights.push(1.0/16.0);maskWeights.push(3.0/32.0);maskWeights.push(1.0/16.0);maskWeights.push(1.0/64.0);
		maskWeights.push(1.0/256.0);maskWeights.push(1.0/64.0);maskWeights.push(3.0/128.0);maskWeights.push(1.0/64.0);maskWeights.push(1.0/256.0);

		for (int i = 0; i < 5; i++)
			for (int j = 0; j < 5; j++)
			{
				filter(i, j) = maskWeights.front();
				maskWeights.pop();
			}
	}
 	else if (filterBankType == TePDIWaveletAtrous::SMALLB3SPLINE)
	{
		filter.Init(3, 3, 0.0);
		std::queue<double> maskWeights;
		maskWeights.push(1.0/16.0);maskWeights.push(1.0/8.0);maskWeights.push(1.0/16.0);
		maskWeights.push(1.0/8.0);maskWeights.push(1.0/4.0);maskWeights.push(1.0/8.0);
		maskWeights.push(1.0/16.0);maskWeights.push(1.0/8.0);maskWeights.push(1.0/16.0);
	
		for (int i = 0; i < 3; i++)
			for (int j = 0; j < 3; j++)
			{
				filter(i, j) = maskWeights.front();
				maskWeights.pop();
			}
	}
	else
		return false;

	return true;
}

bool TePDIWaveletAtrous::RunImplementation_decompose()
{
	TePDITypes::TePDIRasterPtrType input_raster;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("input_raster", input_raster), "Missing parameter: input_raster");

	int band;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("band", band), "Missing parameter: band");

	int levels;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("levels", levels), "Missing parameter: levels");

	TeFilterBanks filterType;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("filterType", filterType), "Missing parameter: filterType");

	TePDITypes::TePDIRasterPtrType multi_raster;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("multi_raster", multi_raster), "Missing parameter: multi_raster");

	bool fit_histogram;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("fit_histogram", fit_histogram), "Missing parameter: fit_histogram");

	TePDITypes::TePDIRasterVectorType output_wavelets;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("output_wavelets", output_wavelets), "Missing parameter: output_wavelets");

	TeMatrix filter;
	TEAGN_TRUE_OR_THROW(filterBank(filterType, filter), "Filter Bank generation failed");

	/*for(unsigned int l = 1; l <= levels; l++)
	{
		TeRasterParams input_raster_params = input_raster->params();
		input_raster_params.nBands(waveletPlanes);
		if (input_raster_params.projection() != 0)
		{
			TeSharedPtr<TeProjection> proj(TeProjectionFactory::make(input_raster_params.projection()->params())); 
			input_raster_params.projection(proj.nakedPointer());
		}
		input_raster_params.boxResolution(input_raster_params.box().x1(), input_raster_params.box().y1(), input_raster_params.box().x2(), input_raster_params.box().y2(), input_raster_params.resx_, input_raster_params.resy_);
		input_raster_params.setPhotometric(TeRasterParams::TeMultiBand, -1);
		input_raster_params.setDataType(TeFLOAT, -1);
		TEAGN_TRUE_OR_RETURN(output_wavelets[l]->init(input_raster_params), "Output wavelets reset error " + Te2String(l));
	}*/

	int	l,
		multi,
		x,
		y,
		i,
		j,
		filter_dim = filter.Nrow(),
		offset = (int)(filter_dim / 2),
		k,
		m,
		lines = input_raster->params().nlines_,
		columns = input_raster->params().ncols_;
	double	fit_gain = 1.0,
		fit_offset = 0.0,
		p_ori,
		p_ant,
		p_new;
	TeMatrix	inputImageMatrix,
			std_matrix,
			mean_matrix;
	inputImageMatrix.Init(lines, columns, 0.0);

/* Computing statistics to fit the histograms */
	
	if (fit_histogram)
	{
		TePDIStatistic stat1;
		TePDIParameters stat1_pars;
		TePDITypes::TePDIRasterVectorType stat1_rasters;
		std::vector<int> stat1_bands;

		stat1_rasters.push_back(multi_raster);
		stat1_pars.SetParameter("rasters", stat1_rasters);
		stat1_bands.push_back(0);
		stat1_pars.SetParameter("bands", stat1_bands);
		TEAGN_TRUE_OR_RETURN(stat1.Reset(stat1_pars), "Unable to inialize the statistc module");
		stat1.ToggleProgInt(false);
		double	std1 = stat1.getStdDev(0),
			mean1 = stat1.getMean(0);
	
		TePDIStatistic stat2;
		TePDIParameters stat2_pars;
		TePDITypes::TePDIRasterVectorType stat2_rasters;
		std::vector<int> stat2_bands;

		stat2_rasters.push_back(input_raster);
		stat2_pars.SetParameter("rasters", stat2_rasters);
		stat2_bands.push_back(band);
		stat2_pars.SetParameter("bands", stat2_bands);
		TEAGN_TRUE_OR_RETURN(stat2.Reset(stat2_pars), "Unable to inialize the statistc module");
		stat2.ToggleProgInt(false);
		double	std2 = stat2.getStdDev(0),
			mean2 = stat2.getMean(0);

		fit_gain = std1 / std2;
		fit_offset = mean1 - (fit_gain * mean2);
	}

	double p_fit;
	for (j = 0; j < lines; j++)
	{
		for (i = 0; i < columns; i++)
		{
			input_raster->getElement(i, j, p_ant, band);
			p_fit = fit_gain * p_ant + fit_offset;
			inputImageMatrix(j, i) = p_fit;
			input_raster->setElement(i, j, p_fit, band);
		}
	}

	for(l = 1; l <= levels; l++)
	{
		multi = (int)pow(2., l-1);
		TePDIPIManager progress("Decomposing Wavelets", (int)(input_raster->params().nlines_), progress_enabled_);
		for (j = 0; j < lines; j++)
		{
			for (i = 0; i < columns; i++)
 			{
				p_ori = p_ant = p_new = 0.0;
				for (k = 0; k < filter_dim; k++)
				{
					for (m = 0; m < filter_dim; m++)
					{
						x = i+(k-offset)*multi;
						y = j+(m-offset)*multi;
						if (x < 0)
							x = columns + x;
						else if (x >= columns)
							x = x - columns;
						if (y < 0)
							y = lines + y;
						else if (y >= lines)
							y = y - lines;
						p_new += filter(k, m) * inputImageMatrix(y, x);
					}
				}
				output_wavelets[l]->setElement(i, j, p_new, 0);
				output_wavelets[l-1]->getElement(i, j, p_ori, band);
				output_wavelets[l]->setElement(i, j, p_ori-p_new, 1);
			}
			progress.Increment();
		}
	}

	inputImageMatrix.Clear();

	return true;
}

bool TePDIWaveletAtrous::RunImplementation_recompose()
{
	std::vector<TePDITypes::TePDIRasterVectorType> input_rasters_wavelets;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("input_rasters_wavelets", input_rasters_wavelets), "Missing parameter: input_rasters_wavelets");

	int rasters_levels;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("rasters_levels", rasters_levels), "Missing parameter: rasters_levels");

	std::vector<TePDITypes::TePDIRasterVectorType> reference_raster_wavelets;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("reference_raster_wavelets", reference_raster_wavelets), "Missing parameter: reference_raster_wavelets");

	int reference_levels;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("reference_levels", reference_levels), "Missing parameter: reference_levels");

	double channel_min_level;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("channel_min_level", channel_min_level), "Missing parameter: channel_min_level");

	double channel_max_level;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("channel_max_level", channel_max_level), "Missing parameter: channel_max_level");

	TePDITypes::TePDIRasterVectorType output_rasters;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("output_rasters", output_rasters), "Missing parameter: output_rasters");

	unsigned int b;
	int	i,
		j,
		l;
	double	pixel,
			p_pan;

	TePDIPIManager progress("Recomposing wavelets", input_rasters_wavelets.size()*input_rasters_wavelets[0][0]->params().nlines_, progress_enabled_);
	for(b = 0; b < input_rasters_wavelets.size(); b++)
	{
		int	lines = input_rasters_wavelets[b][0]->params().nlines_,
			columns = input_rasters_wavelets[b][0]->params().ncols_;
		for (j = 0; j < lines; j++)
		{
			for (i = 0; i < columns; i++)
			{
				input_rasters_wavelets[b][rasters_levels]->getElement(i, j, pixel, 0);
				for(l = 1; l <= reference_levels; l++)
				{
					reference_raster_wavelets[b][l]->getElement(i, j, p_pan, 1);
					pixel += p_pan;
				}
				pixel = (pixel > channel_max_level?channel_max_level:(pixel < channel_min_level?channel_min_level:pixel));
				output_rasters[b]->setElement(i, j, pixel, 0);
			}
			
			progress.Increment();
		}
	}

	return true;
}

bool TePDIWaveletAtrous::RunImplementation()
{
/* Getting parameters */

	int direction;
	TEAGN_TRUE_OR_RETURN(params_.GetParameter("direction", direction), "Missing parameter: direction");

	if (direction == DECOMPOSE)
		return RunImplementation_decompose();
	else if (direction == RECOMPOSE)
		return RunImplementation_recompose();

	return false;
}
