#include <algorithm>
#include <fstream>
#include <rumba/manifoldmatrix.h>
#include <rumba/matrixio.h>
#include <rumba/parse.h>

#include "design.h"

using RUMBA::ManifoldMatrix;
using RUMBA::manifold_generator;
using RUMBA::makeMatrix;
using RUMBA::writeManifoldMatrix;

using std::map;
using std::list;
using std::string;
using std::vector;
using std::ostream_iterator;
using std::cerr;

void tokenizeCsvLine( std::vector<std::string>& result, const std::string& line, char delim )
{
	std::string tmp;
	std::istringstream strin (line);
	while ( getline(strin,tmp,delim))
		result.push_back(tmp);
}

Treatment::Treatment(const vector<Factor>& F, const vector<string>& v )
	: ptrF(&F)
{
	int j;
	if ( ptrF->size() != v.size() )
	{
		throw RUMBA::Exception ("Malformed treatment vector:wrong size\n" );
	}

	for ( unsigned int i = 0; i < ptrF->size(); ++i )
	{
		j = find_index<std::string> ( (*ptrF)[i].levels, v[i] );
		if ( j == -1 )
		{
			copy ( v.begin(), v.end(), ostream_iterator<string>(cerr ,":\n"));
			throw RUMBA::Exception ( "Malformed treatment vector");
		}
		T.push_back(j);
	}
}





bool Treatment::isPrimitive () const
{
	bool b = false;
	for ( unsigned int i = 0; i < T.size(); ++i )
		if (T[i] != 0)
		{
			if (b) 
				return false;
			else
				b = true;
		}
	return true; // don't count A0 B0 C0 ....  as primitive
}

	
bool Treatment::isZero () const // are all components 0 ?
{
	for ( unsigned int i = 0; i < T.size(); ++i )
		if ( T[i] != 0 )
			return false;

	return true;
}

int Treatment::toColumn () const
{
	int scale = 1;
	int result = 0;
	int j;

	if ( ptrF->size() != T.size() )
		throw RUMBA::Exception ( "Malformed treatment vector: wrong size\n" );

	for ( unsigned int i = 0; i < ptrF->size(); ++i )
	{
		j = find_index ( (*ptrF)[i].levels, stringLevel(i)); // need string corr with T[i]
		if ( j == -1 )
			throw RUMBA::Exception ( "Malformed treatment vector\n" );
		result += scale * j;
		scale *= (*ptrF)[i].levels.size();
	}

	return (result - 1);
}

bool operator< (const Treatment& left, const Treatment& right )
{
	for ( int i = 0; i < left.length() && i < right.length(); ++i )
	{
		if (left.level(i) < right.level(i))
			return true;
		else if ( right.level(i) < left.level(i) )
			return false;
	}
	return false;
}

bool operator== (const Treatment& left, const Treatment& right )
{
	if ( left.length() != right.length() )
		return false;

	for ( int i = 0; i < left.length(); ++i )
		if ( left.stringLevel(i) != right.stringLevel(i) )
			return false;

	return true;
}


int count_treatments ( const vector<Factor>& factors )
{
	int x = 1;
	for ( unsigned int i = 0; i < factors.size(); ++i )
		x *= factors[i].levels.size();
	return x - 1;
}



int primitiveToColumn (int level, int factor_no, const vector<Factor>& factors)
{
	int scale = 1;
	for ( int i = 0; i < factor_no; ++i )
		scale *= factors[i].levels.size();

	return (level * scale - 1);
}

ManifoldMatrix getDesignMatrixRow ( Treatment t, const vector<Factor>& factors )
{
	int cols=count_treatments(factors);
	manifold_generator gen;	

	ManifoldMatrix M ( gen(1,cols));
	for ( int i = 0; i < cols; ++i )
		M.element(0,i) = 0;

	for ( unsigned int i = 0; i < factors.size(); ++i )
	{
		// if level for a factor is non-zero, put a 1 in the appropriate column
		if ( t.level(i) != 0 )
			M.element(0,primitiveToColumn ( t.level(i), i, factors )) = 1;
		else  
		// set all primitives in a factor to -1 if level for that factor is 0
			for ( unsigned int j = 1; j < factors[i].levels.size(); ++j )
				M.element(0, primitiveToColumn( j, i, factors )) = -1;
	}
	

	// now do the multiplications ... 
	// loop over columns of matrix

	Treatment u; int x;
	vector<string> vt;
	for ( int i = 0; i < M.cols(); ++i )
	{
		vt.clear();
		u = column_to_treatment ( i, factors );
		if ( !u.isPrimitive() )
		{
			x = 1;
			for ( unsigned int j = 0; j < factors.size(); ++j )
			{
				if ( u.level(j) != 0 )
				{
					x *= static_cast<int> ( 
				   		M.element(0,primitiveToColumn(u.level(j),j,factors))
						);
				}


			}

			M.element(0,i) = x;
		}
	}

	return M;
	
}

std::ostream& operator<< (std::ostream& left, const Treatment& right)
{
	for (int i =0; i < right.length(); ++i )
		left << right.stringLevel(i) << ' ';

	left << '(';
	for (int i =0; i < right.length(); ++i )
		left << right.level(i) << ", ";
	left << ')';

	return left;
}

std::ostream& operator<< (std::ostream& left, const Factor& right)
{
	left << right.name << " ";
	left << "column " << right.col_num << " " ;
	std::copy (right.levels.begin(),right.levels.end(),
		std::ostream_iterator<string>(left,", "));
	return left;
}

// what row numbers contain the desired fields ?
void process_header
(
 vector<Factor>& result, 
 string s, 
 const vector<string>& headers
 )
{
	vector<string> mylist;
	tokenizeCsvLine(mylist,s,'\t');	

	vector<string>::iterator it2;

	for (vector<string>::const_iterator it = headers.begin(); 
			it !=headers.end();++it)
	{
		it2=std::find(mylist.begin(),mylist.end(), *it);
		if ( it2 !=mylist.end())
			result.push_back( Factor(*it,it2-mylist.begin()) );
		else
			throw RUMBA::Exception(string("Lookup failed for field ")+ *it);
	}
}

void process_row
(
 vector<Factor>& result,
 string row
)
{
	vector<string> mylist;
	tokenizeCsvLine(mylist,row,'\t');

	for_each( mylist.begin(), mylist.end(), RUMBA::strip );

	
	vector<string>::iterator it;

	for ( unsigned int i = 0; i < result.size(); ++i )
	{
		assert ( result[i].col_num >= 0 && result[i].col_num < (int)mylist.size() );
		it = find (result[i].levels.begin(),result[i].levels.end(),
			mylist[result[i].col_num]);
		if ( it == result[i].levels.end() )
			result[i].levels.push_back(mylist[result[i].col_num]);
	}
}





/*
 * Returns the unique column for a given treatment. 
 * If -1 is returned, then the treatment corresponds
 * to the complement of all the other columns (ie put a -1 in
 * all other columns
 *
 * The item in the vector treatment[i] should correspond to some
 * level in factor[i]
 */
int treatment_to_column 
( const vector<Factor>& factors, const vector<string>& treatment)
{
	int scale = 1;
	int result = 0;
	int j;

	if ( factors.size() != treatment.size() )
		throw RUMBA::Exception ( "Malformed treatment vector: wrong size\n" );

	for ( unsigned int i = 0; i < factors.size(); ++i )
	{
		j = find_index ( factors[i].levels, treatment[i] );
		if ( j == -1 )
		{
			std::copy(treatment.begin(),treatment.end(),
					ostream_iterator<string>(cerr, ":  :"));
			cerr << std::endl;

			throw RUMBA::Exception ( "Malformed treatment vector\n" );
		}
		result += scale * j;
		scale *= factors[i].levels.size();
	}

	return (result - 1);
}

Treatment column_to_treatment 
( int col, const vector<Factor> & factors)
{
	vector<string> treatment;
//	int scale = 1;
	int sz;

	++col; // correct offset

	for ( unsigned int i = 0; i < factors.size(); ++i )
	{
		sz = factors[i].levels.size();
		treatment.push_back(factors[i].levels[col%sz]);
		col /= sz;
	}

	return Treatment ( factors, treatment );
}


std::vector<Treatment> readfile (std::ifstream& fin, vector<Factor>& factors)
{
	vector<string> tokens;
	vector<string> treatment;
	std::vector<Treatment> result;
	std::string row;

	while ( getline ( fin, row ) )
	{
		tokens.clear();
		treatment.clear();
		tokenizeCsvLine(tokens,row,'\t');
		for_each( tokens.begin(), tokens.end(), RUMBA::strip );

		for ( unsigned int i = 0; i < factors.size(); ++i )
		{
			assert ( factors[i].col_num < (int)tokens.size() && 
					factors[i].col_num >= 0 );
			treatment.push_back ( tokens[factors[i].col_num] );
		}
		
		result.push_back(Treatment(factors, treatment));
	}
	return result;
}

// new version of design_matrix() function
ManifoldMatrix design_matrix2
(const std::vector<Treatment>& treatments, vector<Factor>& factors)
{
	
	typedef std::vector<Treatment>::const_iterator tvec_iter;
	list<ManifoldMatrix> rows;

	for (tvec_iter it = treatments.begin();it!=treatments.end(); ++it )
		rows.push_back( getDesignMatrixRow( *it, factors ));
	
	ManifoldMatrix result = 
		makeMatrix(rows.size(),count_treatments(factors)+1);

	list<ManifoldMatrix>::iterator it = rows.begin();
	for ( int i = 0;  i < rows.size(); ++i, ++it )
	{
		assert (i<=result.rows()&&it->rows()==1&&it->cols()<=result.cols());
		result.put(i, 0, *it  );
	}

	for ( int i = 0; i < result.rows(); ++i )
		result.element( i, result.cols() - 1 ) = 1;


	return result;
}

/** Note that we append a column of 1s
 *
 */


ManifoldMatrix design_matrix 
(std::ifstream& fin, vector<Factor>& factors)
{
	
	vector<string> tokens;
	vector<string> treatment;
	std::string row;
	int count = 0;
	list<ManifoldMatrix> rows;
	manifold_generator gen;
//	int idx;

	while ( getline ( fin, row ) )
	{
		tokens.clear();
		treatment.clear();
		tokenizeCsvLine(tokens,row,'\t');
		for_each( tokens.begin(), tokens.end(), RUMBA::strip );

		for ( unsigned int i = 0; i < factors.size(); ++i )
		{
			assert ( factors[i].col_num < (int)tokens.size() && 
					factors[i].col_num >= 0 );
			treatment.push_back ( tokens[factors[i].col_num] );
		}
//		result.push_back ( treatment_to_column ( factors, treatment ) );
		
		rows.push_back(
				getDesignMatrixRow( Treatment(factors, treatment), factors ) 
				);
		++count;
	}
	ManifoldMatrix result (gen( count, count_treatments(factors) + 1   ));

	list<ManifoldMatrix>::iterator it = rows.begin();
	for ( int i = 0;  i < count; ++i, ++it )
	{
		assert ( i <= result.rows() && it->rows() == 1 && it->cols() <= result.cols() );
		result.put(i, 0, *it  );
	}

	for ( int i = 0; i < result.rows(); ++i )
		result.element( i, result.cols() - 1 ) = 1;


	return result;
}


RUMBA::ManifoldMatrix repeatRows(const RUMBA::ManifoldMatrix& M, int n)
{
	RUMBA::ManifoldMatrix R = RUMBA::makeMatrix(n*M.rows(), M.cols());
	assert(n>0);
	for ( int i = 0; i < M.rows(); ++i )
		for ( int j = 0; j < n; ++j )
			R.put(i*n+j,0,M.subMatrix(i,1,0,M.cols()));
	return R;
}


