/*
 * GenericUtility.cpp: This file is part of the Parametric Tiling project.
 * 
 * Parametric Tiling: A CLAST-to-CLAST parametric tiling software
 * 
 * Copyright (C) 2011 Sanket Tavargeri
 * 
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 * 
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 * 
 * Author:
 * Sanket Tavargeri <sanket.tavargeri@gmail.com>
 * 
 */

#if HAVE_CONFIG_H
# include <ptile/config.h>
#endif

# include <ptile/GenericUtility.hpp>
#include <stdio.h>
#include <stdlib.h>
#include<string>
#include<iostream>
#include<vector>
#include <set>
#include<sstream>

using namespace std;



    bool GeneralUtilityClass::StringComparator::operator()(string* s1, string* s2) const
    {
      return s1->compare(*s2) < 0;
    }


    string GeneralUtilityClass::GetIntegerKeyword()
  {
    return "int";
  }//  string GetIntegerKeyword()
	

    string GeneralUtilityClass::GetArrayAppendixString()
  {
    return "[" + GetNumberOfPartitionsString() + "]";
  }//    string GetArrayAppendixString()

    string GeneralUtilityClass::GetPartitionString()
  {
    return "partition";
  }//   string GetTempString()

    string GeneralUtilityClass::GetTempString()
  {
    return "temp";
  }//   string GetTempString()

    string GeneralUtilityClass::GetIdString()
  {
    return "id";
  }//   string GetTempString()

    string GeneralUtilityClass::GetTileString()
  {
    return "tile";
  }//   string GetTileString()

    string GeneralUtilityClass::GetRangeString()
  {
    return "_range";
  }//   string GetRangeString()

    char* GeneralUtilityClass::StringToChar(string s)
  {
    char *c = (char*) malloc(sizeof(char) * (s.length() + 1));
    strcpy(c, s.c_str());
    return c;
  }//   char* StringToChar(string s)

    string GeneralUtilityClass::GetEstimatedTilesString()
  {
    return "estimatedTiles";
  }//   string GetEstimatedTilesString()

    string GeneralUtilityClass::GetNumberOfPartitionsString()
  {
    return "NUMBER_OF_PARTITIONS";
  }//   string GetNumberOfPartitionsString()

    string GeneralUtilityClass::GetIteratorString()
  {
    return "i";
  }//   string GetIteratorString()

    CloogStride* GeneralUtilityClass::GetDefaultStride()
  { 
    cloog_int_t stride;
    cloog_int_init(stride);
    cloog_int_set_si(stride, 1);

    cloog_int_t offset;
    cloog_int_init(offset);
    cloog_int_set_si(offset, 0);

    CloogStride* cloogStride = cloog_stride_alloc(stride, offset);
    return cloogStride;
  }//   CloogStride* GetDefaultStride()

    string GeneralUtilityClass::GetMaxNumberOfPartitionsString()
  {
    return "NUMBER_OF_PARTITIONS";
  }//   string GetMaxNumberOfPartitionsString()

    string GeneralUtilityClass::GetTotalEstimatedTilesString()
  {
    return "totalEstimatedTiles";
  }//   string GetTotalEstimatedTilesString()



    string GeneralUtilityClass::GetStrideString()
  {
    string s = "stride";
    return s;
  }//   string GetStrideString()


    s_past_node_t* GeneralUtilityClass::GetLastNode(s_past_node_t* head)
  {
    if (head == NULL)
      {
	return NULL;
      }// if (head == NULL)

    while (head->next != NULL)
      {
	head = head->next;
      }// while (head->next != NULL)

    return head;
  }//  struct ext_clast_stmt* GetLastNode(struct ext_clast_stmt* head)

    string GeneralUtilityClass::GetGlobalEstimateWavefrontName()
  {
    return string("wPTile");
  }//   string GetGlobalEstimateWavefrontName()


    string GeneralUtilityClass::GetCurrentEstimatePrefix()
  {
    return string("current_");
  }//   string GetCurrentEstimatePrefix()

    string GeneralUtilityClass::GetGlobalEstimatePrefix()
  {
    return string("global_");
  }//   string GetPrefix()

    string GeneralUtilityClass::GetLocalEstimatePrefix()
  {
    return string("local_");
  }//   string GetLocalEstimatePrefix()

    string GeneralUtilityClass::GetGlobalEstimateMin_suffix()
  {
    return string("_min");
  }//    string GetMin_suffix()

    string GeneralUtilityClass::GetGlobalEstimateMax_suffix()
  {
    return string("_max");
  }//       string GetMax_suffix()

s_past_for_t* GeneralUtilityClass::CreateForLoop (const char* iterator, s_past_node_t* lb, s_past_node_t* ub, s_ptile_options_t* options)
{

   // Init.
    s_symbol_t* iterSymb =
      symbol_add_from_char (options->symboltable, iterator);

    s_past_node_t* init =
      &past_binary_create (past_assign,
			   &past_variable_create(iterSymb)->node, lb)->node;

    // Test.
    iterSymb = symbol_add_from_char (options->symboltable, iterator);

    s_past_node_t* test =
      &past_binary_create (past_leq,
			   &past_variable_create(iterSymb)->node, ub)->node;

    // Increment.
    iterSymb = symbol_add_from_char (options->symboltable, iterator);

    s_past_node_t* increment =
      &past_unary_create (past_inc_before,
			  &past_variable_create(iterSymb)->node)->node;

    // Iterator.
    s_past_variable_t* iter = past_variable_create
      (symbol_add_from_char (options->symboltable, iterator));
    s_past_for_t* pf = past_for_create (init, test, iter, increment, NULL, NULL);
    return pf;
} // s_past_for_t* CreateForLoop (const char* iterator, s_past_node* lb, s_past_node* ub, s_ptile_options_t* options)

char* GeneralUtilityClass::GetTileIteratorFromIterator (char *iterator)
{
	string suffix = ExpressionLibrary::GetTileIteratorSuffix(1, 1);
	char* tileIterator = (char*) malloc (sizeof (char) * (strlen (iterator) + suffix.length() + 1));
	strcpy (tileIterator, iterator);
	strcat (tileIterator, suffix.c_str());
	return tileIterator;
}// const char* GetTileIteratorFromIterator (const char *iterator)


// Returns it*Ti when boundType = -1
// Returns it*Ti + Ti - 1 when boundType = 1
    s_past_node_t*
  GeneralUtilityClass::GetTileBoundForIterator(const char *iteratorName,
			  int boundType, int purpose, int prefix)
  {
    int fullTileSeparation = 1;
    int parallelGlobalEstimate = 2;	
    int globalPrefix = 1;
    int localPrefix = 2;
	
    //printf("In GetTileBoundForIterator()\n");
    //printf("boundType = %d, purpose = %d\n", boundType, purpose);

    if (purpose == fullTileSeparation)
      {
	string tileIterator = string(iteratorName);
	tileIterator += ExpressionLibrary::GetTileIteratorSuffix(1, 1);
	string tileSize = string(iteratorName);
	tileSize = ExpressionLibrary::GetTileSizePrefix(1, 1) + tileSize;

	s_symbol_t* tisymb = symbol_add_from_char(NULL, tileIterator.c_str());
	s_symbol_t* tssymb = symbol_add_from_char(NULL, tileSize.c_str());
	
	if (boundType == -1)
	  {
	    // Constructing the fullTileLB
	    s_past_node_t* tileIteratorClastName =
	      &past_variable_create(tisymb)->node;
	    s_past_node_t* tileSizeClastName =
	      &past_variable_create(tssymb)->node;
	    
	    return &past_binary_create(past_mul, tileIteratorClastName,
				       tileSizeClastName)->node;
	  }// if (boundType == -1)				
	else if (boundType == 1)
	  {
	    // Constructing the fullTileUB                
	    s_past_node_t* tileIteratorClastName2 =
	      &past_variable_create(tisymb)->node;
	    s_past_node_t* tileSizeClastName2 =
	      &past_variable_create(tssymb)->node;
	    
	    tisymb = symbol_add_from_char(NULL, tileIterator.c_str());
	    tssymb = symbol_add_from_char(NULL, tileSize.c_str());

	    s_past_node_t* tileIteratorClastName3 =
	      &past_variable_create(tisymb)->node;
	    s_past_node_t* tileSizeClastName3 =
	      &past_variable_create(tssymb)->node;

	    s_past_node_t* fullTileUBComponent1 =
	      &past_binary_create(past_mul, tileIteratorClastName2,
				  tileSizeClastName2)->node;

	    u_past_value_data_t val;
	    val.intval = -1;
	    s_past_node_t* minusOne =
	      &past_value_create(e_past_value_int, val)->node;
	    
	    val.intval = 1;
	    s_past_node_t* one =
	      &past_value_create(e_past_value_int, val)->node;

	    s_past_node_t* fullTileUBComponent2 =
	      &past_binary_create(past_add,
				  tileSizeClastName3, minusOne)->node;

	    s_past_node_t* fullTileUB =
	      &past_binary_create(past_add, fullTileUBComponent1,
				  fullTileUBComponent2)->node;

	    return fullTileUB;
	  }// else if (boundType == 1)
					
      }// if (purpose == fullTileSeparation)	
    else if (purpose == parallelGlobalEstimate)
      {

	string iteratorNameString(iteratorName);
	//printf("In Parallelizer::GetTileBoundForIterator\n");
	//printf("boundType = %d\n", boundType);
	//printf("prefix = %d\n", prefix);

	if (boundType == -1) // lower bound
	  {
	    string minIteratorName = iteratorNameString + GetGlobalEstimateMin_suffix();
			
	    if (prefix == globalPrefix)
	      {
		minIteratorName = GetGlobalEstimatePrefix() + minIteratorName;
	      }// if (prefix == globalPrefix)
	    else if (prefix == localPrefix)
	      {
		minIteratorName = GetLocalEstimatePrefix() + minIteratorName;
	      }// else if (prefix == localPrefix)

	    s_symbol_t* misymb = symbol_add_from_char(NULL,
						      minIteratorName.c_str());
	    s_past_node_t* minIteratorNameExpr =
	      &past_variable_create(misymb)->node;
	    return minIteratorNameExpr;
	  }// if (boundType == -1) // lower bound
	else if (boundType == 1) //upder bound
	  {
	    string maxIteratorName = iteratorNameString + GetGlobalEstimateMax_suffix();

	    if (prefix == globalPrefix)
	      {
		maxIteratorName = GetGlobalEstimatePrefix() + maxIteratorName;
	      }// if (prefix == globalPrefix)
	    else if (prefix == localPrefix)
	      {
		maxIteratorName = GetLocalEstimatePrefix() + maxIteratorName;
	      }// else if (prefix == localPrefix)

	    s_symbol_t* maisymb = symbol_add_from_char(NULL,
						       maxIteratorName.c_str());
	    
	    s_past_node_t* maxIteratorNameExpr =
	      &past_variable_create(maisymb)->node;
	    return maxIteratorNameExpr;
	  }// else if (boundType == 1) //upder bound
						
	return NULL;

      }// else if (purpose == parallelGlobalEstimate)

    return NULL;
  }//   struct ext_clast_expr* GetTileBoundForIterator(const char *iteratorName, int boundType, int purpose)

    bool GeneralUtilityClass::IsPresentInNames(vector<string>* iteratorNames, const char *name)
  {
    int i;
    //printf("Searching for %s", name);
    for (i = 0; i < iteratorNames->size(); i++)
      {		
	if (strcmp(iteratorNames->at(i).c_str(), name) == 0)
	  {
	    //printf(" Found\n");
	    return true;
	  }// if (strcmp(iteratorNames->at(i), name) == 0)
      }// for (i = 0; i < iteratorNames->size(); i++)

    //printf(" NOT found\n");
    return false;
  }// bool IsPresentInNames(vector<string>* iteratorNames, const char *name)


    void GeneralUtilityClass::test()
  {
    //printf("If this doesn't print, it will go down as the biggest mystery in mankind\n");
  }//      test()
  
  
  struct GeneralUtilityClass::gu_args
  {
    vector<string>* iteratorNames;
    int boundType;
    int purpose;
    int prefix;
  };
  
   
  void GeneralUtilityClass::traverse_tree_replace_iter(s_past_node_t* node, void* data)
  {
    struct gu_args* arg = (struct gu_args*)data;
    if (PAST_NODE_IS_A(node, past_variable))
      {
	PAST_DECLARE_TYPED(variable, pv, node);
	char* symval = (char*) pv->symbol->data;
	// printf ("Searching for symbol = %s\n", symval);
	if (IsPresentInNames(arg->iteratorNames, symval))
	  {
		// printf ("Found it in the iterator names\n");
	    s_past_node_t* newvar = 
	      GetTileBoundForIterator(symval, arg->boundType, arg->purpose,
				      arg->prefix);
	    past_replace_node(node, newvar);
	  }

      }
  }
  
    s_past_node_t*
  GeneralUtilityClass::GetBoundEliminatingPointLoopIterators(s_past_node_t* e, int boundType,
					vector<string>* iteratorNames, 
					int purpose, int prefix)
  {
    // Traverse the expression tree and if an iterator is found, replace it by its tile lowerbound or upperbound expression
    //   struct ext_clast_expr* (*BoundGenerator)(const char *iteratorName, int boundType);
    // BoundGenerator =  BoundGeneratorInstance;

    if (e == NULL)
      return NULL;
    
    struct gu_args arg;
    arg.iteratorNames = iteratorNames;
    arg.boundType = boundType;
    arg.purpose = purpose;
    arg.prefix = prefix;


    if (PAST_NODE_IS_A(e, past_variable))
      {
	PAST_DECLARE_TYPED(variable, pv, e);
	char* symval = (char*) pv->symbol->data;
	if (IsPresentInNames(iteratorNames, symval))
	{
		return GetTileBoundForIterator((const char*)pv->symbol->data,
				       boundType, purpose, prefix);
	}
	else 
	{
		return e;
	}
      }
    
    s_past_node_t* e_clone = past_clone (e);
    past_set_parent (e_clone);

		/*{
			FILE *fp = fopen ("Replacement.c", "a");
			fprintf (fp, "Before replacing\n");
			past_pprint (fp, e_clone);
			fprintf (fp, "\n");
			fclose (fp);
		}*/

    past_visitor(e_clone, traverse_tree_replace_iter, (void*)&arg, NULL, NULL);

		/*{
			FILE *fp = fopen ("Replacement.c", "a");
			fprintf (fp, "After replacing\n");
			past_pprint (fp, e_clone);
			fprintf (fp, "\n");
			fclose (fp);
		}*/

    return e_clone;
  }//   struct ext_clast_expr* GetBoundEliminatingPointLoopIterators(struct ext_clast_expr* e, int boundType, vector<string>* iteratorNames, int purpose, int prefix)
  
  
    void
  GeneralUtilityClass::traverse_collect_iterator_names(s_past_node_t* inputNode, void* data)
  {
   
   if (PAST_NODE_IS_A (inputNode, past_for) ||  PAST_NODE_IS_A (inputNode, past_parfor))
   {
    s_past_variable_t* node = ((s_past_for_t*) inputNode)->iterator;

	vector<string>* iteratorNames = (vector<string>*)data;
	iteratorNames->push_back(string((const char*)node->symbol->data));
   }
  }
  
    vector<string>* GeneralUtilityClass::CollectIteratorNames(s_past_node_t* pointLoops)
  {
    //printf("In CollectIteratorNames()\n");
    vector<string>* iteratorNames = new vector<string>();
    
    past_visitor(pointLoops, traverse_collect_iterator_names,
		 (void*) iteratorNames, NULL, NULL);
    

/*
	cout<<"The iterators collected are:"<<endl;
    for (int i = 0; i < iteratorNames->size(); i++)
	{
		cout<<iteratorNames->at (i)<<", ";
	}

	cout<<endl;
*/

    // Since all the statements are embedded in a common convex hull and by implication all the statements would 
    // have the same number of for loops around them, it's sufficient to collect the iterator names from any one statment loop nest
	
    //printf("Returning from CollectIteratorNames()\n");
    return iteratorNames;
  }// vector<string>* CollectIteratorNames(struct ext_clast_for* pointLoops)

    s_past_node_t* GeneralUtilityClass::FindInnermostForLoopInLoopNest(s_past_node_t* tileLoops)
  {
    if (! PAST_NODE_IS_A(tileLoops, past_for))
      return NULL;
    while (tileLoops)
      {
	if (PAST_NODE_IS_A(tileLoops, past_for))
	  {
	    PAST_DECLARE_TYPED(for, pf, tileLoops);
	    if (PAST_NODE_IS_A(pf->body, past_for))
	      tileLoops = pf->body;
	    else
	      return tileLoops;
	  }
      }
    return tileLoops;
  }//struct ext_clast_for* FindInnermostForLoop(struct ext_clast_for* tileLoops)

    int GeneralUtilityClass::getIndex2D(int i1, int i2, int n2)
  {
    return i1*n2 + i2;
  }// getIndex()

    int GeneralUtilityClass::getIndex3D(int i1, int i2, int n2, int i3, int n3)
  {
    return i1*n2*n3 + i2*n3 + i3;
  }// getIndex()

    int GeneralUtilityClass::getIndex4D(int i1, int i2, int n2, int i3, int n3, int i4, int n4)
  {
    return i1*n2*n3*n4 + i2*n3*n4 + i3*n4 + i4;
  }// getIndex()


    vector<Expression*>* GeneralUtilityClass::FormExpressions(int *StatementDomains, int len, int DomainRows, int DomainColumns,
					      int NumberOfIterators, int NumberOfParameters,
					      int ParamNamesSet, string* ParamNames, int IteratorNamesSet, string* IteratorNames)
  {
    vector<Expression*>* exprs = new vector<Expression*>();

    for (int i = 0; i < DomainRows; i++)
      {				

	// The context - (like N >= 3 and T >= 1) is mixed with the domain. So if the current row does not have any of the iterators and is an expression only in parameters and context, don't form a term for this.				
	bool IteratorPresent = false;
				
	Expression *expr = new Expression;
	expr->terms = new Term[DomainColumns - 1];
	expr->size = DomainColumns - 1;

	string *exponentNames = new string[NumberOfIterators + NumberOfParameters];
	// For iterators, set the tiled iterators as "T" + iterator-name
	// For parameters, set them as they are
	for (int l = 0; l < NumberOfIterators; l++)
	  {
	    if (IteratorNamesSet)
	      {
		exponentNames[l] =  ExpressionLibrary::GetTileSizePrefix(1, 1) + IteratorNames[l];
	      }// if (IteratorNamesSet)
	  }// for (int l = 0; l < NumberOfIterators; l++)

	for (int l = 0; l < NumberOfParameters; l++)
	  {
	    exponentNames[l + NumberOfIterators] = "";
	    if (ParamNamesSet)
	      {
		exponentNames[l + NumberOfIterators] = ParamNames[l];
	      }// if (ParamNamesSet)
	  }// for (int l = 0; l < NumberOfParameters; l++)

	// j = 0 is the equality/inequality column. We assume that it's always 1 because we are considering the
	// inequalities
	for (int j = 1; j < DomainColumns; j++)
	  {
	    (expr->terms[j - 1]).coefficient = StatementDomains[GeneralUtilityClass::getIndex2D(i, j, DomainColumns)];
	    //Add the polynomial
	    (expr->terms[j - 1]).SetIdentityNumeratorDenominator(NumberOfIterators + NumberOfParameters, exponentNames);


	    if (j <= NumberOfIterators)
	      {
		expr->terms[j - 1].type = loop_iterator;
		if (IteratorNamesSet)
		  {
		    expr->terms[j - 1].name = IteratorNames[j - 1];							
		  }

		// The context - (like N >= 3 and T >= 1) is mixed with the domain. So if the current row does not have any of the iterators and is an expression only in parameters and context, don't form a term for this.
		if ((expr->terms[j - 1]).coefficient != 0)
		  {
		    IteratorPresent = true;
		  }// if ((expr->terms[j - 1]).coefficient != 0)
	      }//if (j <= NumberOfIterators)
	    else if (j <= (NumberOfIterators + NumberOfParameters))
	      {
		expr->terms[j - 1].type = parameter;
		if (ParamNamesSet)
		  {
		    expr->terms[j - 1].name = ParamNames[j - 1 - NumberOfIterators];
		  }//if (ParamNamesSet)

	      }//else
	    else
	      {
		//constants
		expr->terms[j - 1].type = constant;
	      }//else

	  }//for (int j = 0; j < DomainColumns; j++)
					

	// The context - (like N >= 3 and T >= 1) is mixed with the domain. So if the current row does not have any of the iterators and is an expression only in parameters and context, don't form a term for this.
	if (IteratorPresent == true)
	  {
	    exprs->push_back(expr);
	  }// if (IteratorPresent == true)

      }//for (int i = 0; i < DomainRows; i++)

    return exprs;
  }//FormExpressions

string GeneralUtilityClass::NumberToString ( int Number )
{
	stringstream ss;
	ss << Number;
	return ss.str();
}// string GeneralUtilityClass::NumberToString ( int Number )

int GeneralUtilityClass::NumberForLoopHierarchy (s_past_node_t* stmt, int id, int skip_otl)
{ 
  int local_id = id;
    for ( ; stmt; stmt = stmt->next)
      {
	if (PAST_NODE_IS_A(stmt, past_for) || PAST_NODE_IS_A(stmt, past_parfor))
	  {

	    // If skip_otl is set and if this is a one-time-loop then don't assign a number. Otherwise, assign a number
	     if (!(skip_otl == 1 && stmt->metainfo != NULL && strcmp ((char*) stmt->metainfo, "otl") == 0))
	     {
		int* temp = (int*) malloc(sizeof(int));
		(*temp) = local_id;
		local_id++;
		stmt->metainfo = temp;
	     }// if (!(skip_otl == 1 && stmt->metainfo != NULL && strcmp ((char*) stmt->metainfo, "otl") == 0))
	    else
	     {
		// Make metainfo NULL. Otherwise, cannot distinguish if this is a number or an "otl" string
		stmt->metainfo = NULL;
	     }// else		


		// Call NumberForLoopHierarchy() recursively
		local_id = NumberForLoopHierarchy(((struct past_for_t*) stmt)->body, local_id, skip_otl);

	  }// if (PAST_NODE_IS_A(stmt, past_for))
	else if (PAST_NODE_IS_A(stmt, past_affineguard))
	{
		local_id = NumberForLoopHierarchy(((struct past_affineguard_t*) stmt)->then_clause, local_id, skip_otl);
	}// else if (PAST_NODE_IS_A(stmt, past_affineguard))
      }// for ( ; stmt; stmt = stmt->next)

    return local_id;
}// int NumberForLoopHierarchy(s_past_node_t* point_loops, int id)

// Make the metainfo NULL again
void GeneralUtilityClass::DeNumberForLoopHierarchy (s_past_node_t* stmt)
{ 
    for ( ; stmt; stmt = stmt->next)
      {
	if (PAST_NODE_IS_A(stmt, past_for) || PAST_NODE_IS_A(stmt, past_parfor))
	  {
		stmt->metainfo = NULL;		
		DeNumberForLoopHierarchy(((struct past_for_t*) stmt)->body);
	  }// if (PAST_NODE_IS_A(stmt, past_for))
	else if (PAST_NODE_IS_A(stmt, past_affineguard))
	  {
		DeNumberForLoopHierarchy(((struct past_affineguard_t*) stmt)->then_clause);
	  }// else if (PAST_NODE_IS_A(stmt, past_affineguard))
      }// for ( ; stmt; stmt = stmt->next)
}// int NumberForLoopHierarchy(s_past_node_t* point_loops, int id)
