/*
 * FullTileMaker.hpp: This file is part of the Parametric Tiling project.
 * 
 * Parametric Tiling: A CLAST-to-CLAST parametric tiling software
 * 
 * Copyright (C) 2011 Sanket Tavargeri
 * 
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 * 
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 * 
 * Author:
 * Sanket Tavargeri <sanket.tavargeri@gmail.com>
 * 
 */


# include <ptile/PtileClasses.hpp>
# include <ptile/GenericUtility.hpp>
# include <ptile/FullTileMaker.hpp>

  s_past_node_t* FullTileMaker::FullTileMakerDriver(s_past_node_t* pointLoops, s_ptile_options_t* options)
  {

        vector<string>* iteratorNames = GeneralUtilityClass::CollectIteratorNames(pointLoops);
	s_past_node_t* fullTiles = BuildFullTiles(pointLoops, iteratorNames, NULL, NULL, options);

    return fullTiles;
  }//  struct ext_clast_guard* FullTileMakerDriver(struct ext_clast_stmt* point_loops)



/*
Input:
fullTiles - a list of 'if' conditional nodes. 
 point_loops - 'for' loop nest

Output:
  if ()
  {
	fullTile1;
  }
  else if ()
  {
	fullTile2;
  }
  else if ()
  {
	fullTile3;
  }
  else
  {
	point_loops
  }

*/
 s_past_node_t* FullTileMaker::BuildIfElseList (s_past_node_t* fullTiles, s_past_node_t* point_loops)
  {
	    s_past_node_t* task_body = fullTiles;	
	    s_past_node_t* temp = fullTiles;
	    
	    while (temp->next != NULL)
	      {
		PAST_DECLARE_TYPED(if, ti, temp);
		ti->else_clause = temp->next;
		temp->next = NULL;
		temp = ti->else_clause;
	      }// if (((struct ext_clast_stmt*) fullTiles)->next != NULL)
	    PAST_DECLARE_TYPED(if, ti2, temp);

	    ti2->else_clause = point_loops;
	    
		return task_body;
  }// s_past_node_t* BuildIfElseList (s_past_node_t* fullTiles, s_past_node_t* point_loops)
  

/*
Suppose the Input is the following and the current loop is i2
 for (i1 = lb1; i1 <= ub1; i1++)
     S1;
 for (i2 = lb2; i2 <= ub2; i2++)
     S2;
 for (i3 = lb3; i3 <= ub3; i3++)
     S3;

The codition to be returned is:
  ((it2 * T1i2) >= lb2) && ((it2* T1i2) <= ub2) && (lb2 >= (ub1 + 1) || ub2 <= (lb1 - 1)) && (lb2 >= (ub3 + 1) || ub2 <= (lb3 - 1))

*/
s_past_node_t* FullTileMaker::BuildExclusivityConditionsForCurrentLevelForLoopNest (s_past_node_t* pointLoops, s_past_node_t* currentLoopNode, vector<string>* iteratorNames)
{

/*
We want to build entry conditions for currentLoopNode and exlusion conditions for all the other loops
*/

  s_past_node_t* head = pointLoops;
  s_past_node_t* workingCondition = NULL;
  s_past_for_t* currentLoop = NULL;
  s_past_node_t* lb_currentLoop = NULL;
  s_past_node_t* ub_currentLoop = NULL;


  if (PAST_NODE_IS_A (currentLoopNode, past_for) ||  PAST_NODE_IS_A (currentLoopNode, past_parfor))
   {
	currentLoop = (s_past_for_t*) currentLoopNode;


	const char * iterator = (const char*)currentLoop->iterator->symbol->data;
	lb_currentLoop = GeneralUtilityClass::GetTileBoundForIterator(iterator, -1, 1, 0);
	ub_currentLoop = GeneralUtilityClass::GetTileBoundForIterator(iterator, 1, 1, 0);

   }// if (PAST_NODE_IS_A (currentLoopNode, past_for) ||  PAST_NODE_IS_A (currentLoopNode, past_parfor))
  else
   {
	assert (0); // This should have been a 'for' loop.
   }// else

    for ( ; pointLoops; pointLoops = pointLoops->next)
	{
		if (PAST_NODE_IS_A (pointLoops, past_for) || PAST_NODE_IS_A (pointLoops, past_parfor))
		{
			PAST_DECLARE_TYPED(for, tempLoop, pointLoops);

			s_past_node_t* totalCondition = NULL;
			if (tempLoop != currentLoop)
			{
				// This loop is not the 'currentLoop'. Build exclusivity condition
				// lb_currentLoop >= (ub + 1) || ub_currentLoop <= (lb - 1)

				// lb_currentLoop >= (ub + 1)
				PAST_DECLARE_TYPED(binary, ub, tempLoop->test);

				s_past_node_t* ub_clone = GeneralUtilityClass::GetBoundEliminatingPointLoopIterators(ub->rhs, 1, iteratorNames, 1, 0);
				s_past_node_t* ubPlusOne = past_node_binary_create (past_add, ub_clone, past_node_value_create_from_int (1));
				s_past_node_t* lb_currentLoopClone = past_clone (lb_currentLoop);
				s_past_node_t* greaterThanCondition = past_node_binary_create (past_geq, lb_currentLoopClone, ubPlusOne);


				// ub_currentLoop <= (lb - 1)
				PAST_DECLARE_TYPED(binary, lb, tempLoop->init);
				s_past_node_t* lb_clone = GeneralUtilityClass::GetBoundEliminatingPointLoopIterators(lb->rhs, -1, iteratorNames, 1, 0);
				s_past_node_t* lbMinusOne = past_node_binary_create (past_sub, lb_clone, past_node_value_create_from_int (1));
				s_past_node_t* ub_currentLoopClone = past_clone (ub_currentLoop);
				s_past_node_t* lessThanCondition = past_node_binary_create (past_leq, ub_currentLoopClone, lbMinusOne);

				// lb_currentLoop >= (ub + 1) || ub_currentLoop <= (lb - 1)
				totalCondition = past_node_binary_create (past_or, greaterThanCondition, lessThanCondition);



			}// if (tempLoop != currentLoop)
			else if (tempLoop == currentLoop)
			{
				// This loop is the 'currentLoop'. Build entry condition
				// lb_currentLoop >= lb && ub_currentLoop <= ub

				// lb_currentLoop >= lb
				PAST_DECLARE_TYPED(binary, lb, tempLoop->init);

				// Pass boundType = 1. i.e Get the maximum (lb) - to be conservative
				s_past_node_t* lb_clone = GeneralUtilityClass::GetBoundEliminatingPointLoopIterators(lb->rhs, 1, iteratorNames, 1, 0);
				s_past_node_t* lb_currentLoopClone = past_clone (lb_currentLoop);
				s_past_node_t* greaterThanCondition = past_node_binary_create (past_geq, lb_currentLoopClone, lb_clone);


				// ub_currentLoop <= ub
				PAST_DECLARE_TYPED(binary, ub, tempLoop->test);

				// Pass boundType = -1. i.e Get the minimum (lb) - to be conservative
				s_past_node_t* ub_clone = GeneralUtilityClass::GetBoundEliminatingPointLoopIterators(ub->rhs, -1, iteratorNames, 1, 0);
				s_past_node_t* ub_currentLoopClone = past_clone (ub_currentLoop);
				s_past_node_t* lessThanCondition = past_node_binary_create (past_leq, ub_currentLoopClone, ub_clone);

				// lb_currentLoop >= lb && ub_currentLoop <= ub
				totalCondition = past_node_binary_create (past_and, greaterThanCondition, lessThanCondition);


			}// else if (pointLoops == currentLoop)


				assert (totalCondition != NULL);

				if ( workingCondition == NULL)
				{
					workingCondition = totalCondition;
				}// if ( workingCondition == NULL)
				else
				{
					workingCondition = past_node_binary_create (past_and, workingCondition, totalCondition);
				}// else

		}// if (PAST_NODE_IS_A (pointLoops, past_for) || PAST_NODE_IS_A (pointLoops, past_parfor))
		else if (PAST_NODE_IS_A (pointLoops, past_affineguard))
		{
			PAST_DECLARE_TYPED(affineguard, tempIf, pointLoops);
		}// else if (PAST_NODE_IS_A (pointLoops, past_affineguard)

	} // for ( ; pointLoops; pointLoops = pointLoops->next)

	return workingCondition;
}// s_past_node_t* BuildExclusivityConditionsForCurrentLevelForLoopNest (s_past_node_t* pointLoops, s_past_node_t* currentLoop)


// The function returns (condition1 && condition2)
s_past_node_t* FullTileMaker::CombineConditions (s_past_node_t* condition1, s_past_node_t* condition2)
{

  if (condition1 == NULL)
  {
	return condition2;
  }//if (condition1 == NULL)
  else if (condition2 == NULL)
  {
	return condition1;
  }// else if (condition2 == NULL)
 else
  {
   	s_past_node_t* andCondition = past_node_binary_create (past_and, condition1, condition2);
	return andCondition;
  }// else
	   
}// s_past_node_t* CombineConditions (s_past_node_t* condition1, s_past_node_t* condition2)



/*
Supposing the currentLoop iterator is i:
return 
prefixBody->body = for (i = it*Ti; i <= it*Ti + Ti - 1; it++)

*/
s_past_node_t* FullTileMaker::GetFullTile (s_past_node_t* prefixBody, s_past_node_t* currentLoop, s_ptile_options_t* options)
{
	if (prefixBody != NULL)
	{
		prefixBody = past_clone (prefixBody);
	}// if (prefixBody != NULL)
	
	s_past_node_t* currentLoopClone = NULL;

	if (currentLoop != NULL)
	{
		if (PAST_NODE_IS_A (currentLoop, past_for) || PAST_NODE_IS_A (currentLoop, past_parfor)) 
		{
			// The following code generates 
			// for (i = it*Ti; i <= it*Ti + Ti - 1; i++)			
			PAST_DECLARE_TYPED(for, tempLoop, currentLoop);
			
		        char * iterator = (char*)tempLoop->iterator->symbol->data;			
			s_past_node_t* lb = GeneralUtilityClass::GetTileBoundForIterator(iterator, -1, 1, 0); // Generates it*Ti
          		s_past_node_t* ub = GeneralUtilityClass::GetTileBoundForIterator(iterator, 1, 1, 0); // Generates it*Ti + Ti - 1

			currentLoopClone = (s_past_node_t*) GeneralUtilityClass::CreateForLoop (iterator, lb, ub, options);			
	        	((s_past_for_t*)currentLoopClone)->type = e_past_tile_loop;
		}// if (PAST_NODE_IS_A (currentLoop, past_for) || PAST_NODE_IS_A (currentLoop, past_parfor)) 
	}// if (currentLoop != NULL)

	    if (prefixBody != NULL)
	    {
		s_past_node_t* innermostLoop = GeneralUtilityClass::FindInnermostForLoopInLoopNest (prefixBody);
		if (innermostLoop != NULL)
		{
			((s_past_for_t*) (innermostLoop))->body = currentLoopClone;
		}//if (innermostLoop != NULL)
		else
		{
			assert (0); // because prefixBody is supposed to be a 'for' loop nest but no 'for' loop is found.
		}// else
	    }// if (prefixBody != NULL)
	  else
	    {
		prefixBody = currentLoopClone;
	    }// else

		return prefixBody;	
}// s_past_node_t* FullTileMaker::GetFullTile (s_past_node_t* prefixBody, s_past_node_t* currentLoop)


// Return InnermostLoop (forLoopNest)->body = body;
s_past_node_t* FullTileMaker::GenerateFullTileLoopNest (s_past_node_t* forLoopNest, s_past_node_t* body)
{

	if (forLoopNest != NULL)
	{
		forLoopNest = past_clone (forLoopNest);
		s_past_node_t* innermostLoop = GeneralUtilityClass::FindInnermostForLoopInLoopNest (forLoopNest);
		if (innermostLoop != NULL)
		{
			((s_past_for_t*) (innermostLoop))->body = body;
		}// if (innermostLoop != NULL)
		else
		{
			assert (0);
		}// else
	}// if (forLoopNest != NULL)
	else
	{
		forLoopNest = body;
	}// else

	return forLoopNest;
}// s_past_node_t* GenerateFullTileLoopNest (s_past_node_t* forLoopNest, s_past_node_t* body)


/*
This function builds the full tiles from the given loop nest.
fullTileConditionPrefix - the full tile condition for the outer loops - NULL at the root level
fullTilePrefix - the full tile loop nest for the outer loop - NULL at the root level
*/
  s_past_node_t*
  FullTileMaker::BuildFullTiles(s_past_node_t* pointLoops, vector<string>* iteratorNames,
		 s_past_node_t* fullTileConditionPrefix,
		 s_past_node_t* fullTilePrefix, s_ptile_options_t* options)
  {
    
    s_past_node_t* pointLoopsIterator = pointLoops;
    s_past_node_t* head = NULL; // head of the full tile chain
    s_past_node_t* headOfCurrentLevelForLoopChain = pointLoops;

	// Traverse the AST
    for ( ; pointLoopsIterator; pointLoopsIterator = pointLoopsIterator->next)
	{
		if (PAST_NODE_IS_A (pointLoopsIterator, past_for) || PAST_NODE_IS_A (pointLoopsIterator, past_parfor))
		{
			PAST_DECLARE_TYPED(for, tempLoop, pointLoopsIterator);
			s_past_node_t* currentLevelFullTileCondition = BuildExclusivityConditionsForCurrentLevelForLoopNest (headOfCurrentLevelForLoopChain,
													(s_past_node_t*) tempLoop, iteratorNames);
	

			s_past_node_t* currentFullTileCondition = CombineConditions (fullTileConditionPrefix, currentLevelFullTileCondition);
			s_past_node_t* currentFullTileLoopNest = GetFullTile (fullTilePrefix, (s_past_node_t*) tempLoop, options);

		        if (tempLoop->body != NULL)
			{
				if (past_node_is_a (tempLoop->body, past_cloogstmt))
				{	


		/*{
			FILE *fp = fopen ("Conditions.c", "a");
			past_pprint (fp, currentFullTileCondition);
			fprintf (fp, "\n\n");
			fclose (fp);
		}*/
					// The cloogstmt is encountered. Need to build a full tile here

				  s_past_node_t* fullTileLoop = GenerateFullTileLoopNest (currentFullTileLoopNest, tempLoop->body);
				  if (past_node_is_a (fullTileLoop, past_for))
				    {
				      PAST_DECLARE_TYPED(for, pf, fullTileLoop);
				      pf->type = e_past_fulltile_loop;
				    }
					s_past_if_t* fullTileWithCondition = past_if_create (currentFullTileCondition, fullTileLoop, NULL);
				
					if (head == NULL)
					{
						head = (s_past_node_t*) fullTileWithCondition;
					}// if (head == NULL)
					else
					{
						(GeneralUtilityClass::GetLastNode(head))->next = (s_past_node_t*) fullTileWithCondition;
					}// else
						
				}// if (past_node_is_a (tempLoop->body, past_cloogstmt))
				else
				{

					s_past_node_t* childFullTiles = BuildFullTiles (tempLoop->body, iteratorNames, currentFullTileCondition, currentFullTileLoopNest, options);

					if (head == NULL)
					{
						head = childFullTiles;
					}// if (head == NULL)
					else
					{
						(GeneralUtilityClass::GetLastNode(head))->next = childFullTiles;
					}// else

				}// else

			}// if (tempLoop->body != NULL)				

		}// if (PAST_NODE_IS_A (pointLoops, past_for) || PAST_NODE_IS_A (pointLoopsIterator, past_parfor))
		else if (past_node_is_a (pointLoopsIterator, past_affineguard))
		{

			PAST_DECLARE_TYPED(affineguard, tempIf, pointLoopsIterator);
			s_past_node_t* currentFullTileCondition = CombineConditions (fullTileConditionPrefix, tempIf->condition);
			s_past_node_t* childFullTiles = BuildFullTiles (tempIf->then_clause, iteratorNames, currentFullTileCondition, fullTilePrefix, options);

					if (head == NULL)
					{
						head = childFullTiles;
					}// if (head == NULL)
					else
					{
						(GeneralUtilityClass::GetLastNode(head))->next = childFullTiles;
					}// else
		}// else if (past_node_is_a (pointLoopsIterator, past_affineguard))
	}// for ( ; pointLoopsIterator; pointLoopsIterator = pointLoopsIterator->next)


		/*{
			FILE *fp = fopen ("FullTiles.c", "a");
			past_pprint (fp, head);
			fprintf (fp, "\n");
			fclose (fp);
		}*/

	return head;
  }//   s_past_node_t*  FullTileMaker::BuildFullTiles(s_past_node_t* pointLoops, vector<string>* iteratorNames,	 s_past_node_t* fullTileConditionPrefix, s_past_node_t* fullTilePrefix, s_ptile_options_t* options)

