/*
 * AstGenerator.cpp: This file is part of the Parametric Tiling project.
 *
 * Parametric Tiling: A CLAST-to-CLAST parametric tiling software
 *
 * Copyright (C) 2011 Sanket Tavargeri
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 *
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 *
 * Author:
 * Sanket Tavargeri <sanket.tavargeri@gmail.com>
 *
 */

#if HAVE_CONFIG_H
# include <ptile/config.h>
#endif


#include <stdio.h>
#include <stdlib.h>
#include<string>
#include<iostream>
#include<vector>
#include <set>
#include<sstream>

# include <ptile/AstGenerator.hpp>

   s_past_node_t*
  ASTGenerator::GenerateTileLoops(vector<Expression*> *exprs,
		    s_ptile_options_t* options)
  {
    /* Here the assumption is that the lower bound and upper bound of an iterator occur together
       e.g it >= 0 and it <= N occur consecutively*/

    s_past_node_t* loopNestRoot = NULL;
    s_past_for_t* lastForLoop = NULL;
    s_past_for_t* currentForLoop = NULL;

    // The expression pair: The first expression is the lower bound and the second upper bound
    // e.g exprs[0] - Lo, exprs[1] - Uo
    for (int i = 0; i < exprs->size(); i = i + 2)
      {
	Expression *Lo = exprs->at(i);
	Expression *Uo = exprs->at(i + 1);
	currentForLoop = EmitForLoop(Lo, Uo, options);

	if (loopNestRoot == NULL)
	  {
	    // This becomes the outermost for loop
	    loopNestRoot = (s_past_node_t*) currentForLoop;
	  }// if (loopNestRoot == NULL)
	else
	  {
	    // The current for loop becomes the body of the previous for loop
	    lastForLoop->body = (s_past_node_t*) currentForLoop;
	  }//else

	lastForLoop = currentForLoop;
      }// for (int i = 0; i < exprs->size(); i++)

    return loopNestRoot;
  }// void GenerateTileLoops()

   s_past_for_t* ASTGenerator::EmitForLoop(Expression *Lo, Expression *Uo,
				   s_ptile_options_t* options)
  {
    /*
     *Expression = term1 + term2 + term3
     * Term = coefficient * (polynomial1 + polynomial2 + polynomial3)/(polynomial1 + polynomial2 + polynomial3)
     */

    int index = ExpressionLibrary::Find_Index_Of_RightmostLoopIterator(*Lo);
    if (ExpressionLibrary::Find_Index_Of_RightmostLoopIterator(*Uo) != index)
      {
	// This is a problematic case : error out
	//cout<<"Problem! Lower bound and upper bound do not correspond to the same iterator"<<endl;
	assert(0);
      }//if (ExpressionLibrary::Find_Index_Of_RightmostLoopIterator(*Uo) != index)

    s_past_node_t* LoExpr = FormClastExprFromExpression(Lo, index, options);
    s_past_node_t* UoExpr = FormClastExprFromExpression(Uo, index, options);


    // Init.
    s_symbol_t* iterSymb =
      symbol_add_from_char (options->symboltable,
			    (Lo->terms[index].name).c_str());
    s_past_node_t* init =
      &past_binary_create (past_assign,
			   &past_variable_create(iterSymb)->node, LoExpr)->node;

    // Test.
    iterSymb = symbol_add_from_char (options->symboltable,
				     (Lo->terms[index].name).c_str());
    s_past_node_t* test =
      &past_binary_create (past_leq,
			   &past_variable_create(iterSymb)->node, UoExpr)->node;

    // Increment.
    iterSymb = symbol_add_from_char (options->symboltable,
				     (Lo->terms[index].name).c_str());
    s_past_node_t* increment =
      &past_unary_create (past_inc_before,
			  &past_variable_create(iterSymb)->node)->node;

    // Iterator.
    s_past_variable_t* iter = past_variable_create
      (symbol_add_from_char (options->symboltable,
			     (Lo->terms[index].name).c_str()));
    s_past_for_t* pf = past_for_create (init, test, iter, increment, NULL, NULL);
    pf->type = e_past_tile_loop;
    return pf;
  }// void EmitForLoop(Expression *Lo, Expression *Uo)

   s_past_node_t*
  ASTGenerator::FormClastExprFromExpression(Expression *expr, int index,
			      s_ptile_options_t* options)
  {
    // For each term
    // terms is an array of pointers pointing to the terms
    s_past_node_t** terms = XMALLOC(s_past_node_t*, expr->size);
    for (int i = 0; i < expr->size; i++)
      {
	if (i != index)
	  {
	    // If the term is not the right most iterator. i.e the lower/upper bound of that iterator
            s_past_node_t* temp = FormClastExprFromTerm(&(expr->terms[i]),
							options);
            terms[i] = temp;
	  }//if (i != index)
      }// for (int i = 0; i < expr->size; i++)

    s_past_node_t* resultExpr = NULL;

    if (expr->size > 0)
      {
	for (int i = 0; i < expr->size; i++)
	  {
	    if (i != index && terms[i] != NULL)
	      {
		if (resultExpr == NULL)
		  {
		    resultExpr = terms[i];
		  }//if (resultExpr == NULL)
		else
		  {
		    resultExpr =
		      past_node_binary_create(past_add, resultExpr, terms[i]);
		  }//else
	      }//if (i != index)
	  }// for (int i = 1; i < expr->size; i++)
      }// else if (expr->size > 0)

    if (resultExpr == NULL)
      {
	// Make the expression 0
	resultExpr = past_node_value_create_from_int(0);
      }// if (resultExpr == NULL)

    // round(expression)
    return past_node_unary_create(past_round, resultExpr);
  }//  struct ext_clast_expr* FormClastExprFromExpression(Expression *expr, int index)

   s_past_node_t*
  ASTGenerator::FormClastExprFromTerm(Term *term,
			s_ptile_options_t* options)
  {
    // For each term
    // terms is an array of pointers pointing to the terms
    s_past_node_t** num_polys =
      XMALLOC(s_past_node_t*, term->num_polynomials_in_numerator);
    s_past_node_t** denom_polys =
      XMALLOC(s_past_node_t*, term->num_polynomials_in_denominator);
    s_past_node_t* num_expr = NULL;
    s_past_node_t* denom_expr = NULL;
    s_past_node_t* resultExpr = NULL;


    if (term->coefficient != 0)
      {
        // Numerator
        for (int i = 0; i < term->num_polynomials_in_numerator; i++)
	  {
            s_past_node_t* temp = FormClastExprFromPolynomial(&(term->numerator[i]), options);
            num_polys[i] = temp;
	  }// for (int i = 0; i < term->num_polynomials_in_numerator; i++)

        // Form sum of polynomials
        if (term->num_polynomials_in_numerator > 0)
	  {
            for (int i = 0; i < term->num_polynomials_in_numerator; i++)
	      {
                if (num_polys[i] != NULL)
		  {
                    if (num_expr == NULL)
		      {
                        num_expr = num_polys[i];
		      }// if (num_expr == NULL)
                    else
		      {
                        num_expr = &past_binary_create
			  (past_add, num_expr, num_polys[i])->node;
		      }// else
		  }// if (num_polys[i] != NULL)
	      }// for (int i = 1; i < expr->size; i++)
	  }// if (term->num_polynomials_in_numerator > 0)

	// Denominator
        for (int i = 0; i < term->num_polynomials_in_denominator; i++)
	  {
            s_past_node_t* temp = FormClastExprFromPolynomial(&(term->denominator[i]), options);
            denom_polys[i] = temp;
	  }// for (int i = 0; i < term->num_polynomials_in_numerator; i++)


        // Form sum of polynomials
        if (term->num_polynomials_in_denominator > 0)
	  {
            for (int i = 0; i < term->num_polynomials_in_denominator; i++)
	      {
                if (denom_polys[i] != NULL)
		  {
                    if (denom_expr == NULL)
		      {
                        denom_expr = denom_polys[i];
		      }// if (denom_expr == NULL)
                    else
		      {
                        denom_expr = &past_binary_create
			  (past_add, denom_expr, denom_polys[i])->node;
		      }//else
		  }// if (denom_polys[i] != NULL)
	      }// for (int i = 1; i < expr->size; i++)
	  }// if (term->num_polynomials_in_denominator > 0)


        // Numerator / Denominator
        if (num_expr != NULL && denom_expr != NULL)
	  {
	    resultExpr = &past_binary_create
	      (past_div, num_expr, denom_expr)->node;
	  }// if (num_expr != NULL && denom_expr != NULL)
        else if (num_expr != NULL)
	  {
            // Only numerator is present
            resultExpr = num_expr;
	  }//else if (num_expr != NULL)
        else if (denom_expr != NULL)
	  {
            // Only denominator is present
            // Make numerator 1. So the Term will be 1/Denominator
	    u_past_value_data_t val;
	    val.intval = 1;
	    num_expr = &past_value_create(e_past_value_int, val)->node;

            resultExpr = &past_binary_create
	      (past_div, num_expr, denom_expr)->node;
	  }//else if (denom_expr != NULL)


        // Multiply with the name of the iterator/parameter
        if (term->type == parameter || term->type == loop_iterator || term->type == interTile_loop_iterator
	    || term->type == intraTile_loop_iterator)
	  {
	    s_symbol_t* symb = symbol_add_from_char(options->symboltable,
						    (term->name).c_str());
	    s_past_node_t* name = &past_variable_create(symb)->node;
	    if (resultExpr == NULL)
	      {
		resultExpr = name;
	      }
	    else
	      {
		resultExpr =
		  &past_binary_create(past_mul, name, resultExpr)->node;
	      }
	  }//if the term is a parameter/iterator

        // Multiply with the coefficient
	u_past_value_data_t val;
	val.intval = term->coefficient;
	s_past_node_t* value =
	  &past_value_create(e_past_value_int, val)->node;
        if (resultExpr != NULL)
	  {
            if (term->coefficient != 1)
	      {
		resultExpr = &past_binary_create
		  (past_mul, resultExpr, value)->node;
	      }//if (term->coefficient != 1)
	  }// if (resultExpr != NULL)
        else
	  {
            // The coefficient is not zero because it has been checked earlier
            resultExpr = value;
	  }//else
      }//if (term->coefficient != 0)

    return resultExpr;
  }//struct ext_clast_expr* FormClastExprFromTerm(Term term)

   s_past_node_t*
  ASTGenerator::FormClastExprFromPolynomial(Polynomial *poly,
			      s_ptile_options_t* options)
  {
    s_past_node_t* resultExpr = NULL;

	// If coefficient is zero, return NULL
	if (poly->coefficient == 0)
	{
		return NULL;
	}// if (poly->coefficient == 0)

     int count = 0;
    for (int i = 0; i < poly->size; i++)
      {
	for (int j = 0; j < poly->exponents[i]; j++)
	  {
	    // If for Ti the exponent is 3, then have Ti*Ti*Ti
	    // TODO : Need to see what happens if all of the exponents are 0.
	    s_symbol_t* symb = symbol_add_from_char(options->symboltable,
						    poly->names[i].c_str());
	    s_past_node_t* var_name = &past_variable_create(symb)->node;
	    if (resultExpr == NULL)
	      {
		resultExpr = var_name;
	      }// if (resultExpr == NULL)
	    else
	      {
                resultExpr = &past_binary_create(past_mul,
						 resultExpr, var_name)->node;
	      }//else
	  }//for (int j = 0; j < exponents[i]; j++)
      }//for (int i = 0; i < poly->size; i++)


    if (resultExpr != NULL && poly->coefficient != 1)
      {
	// Multiply with the coefficient
	u_past_value_data_t val;
	val.intval = poly->coefficient;
	s_past_node_t* value =
	  &past_value_create(e_past_value_int, val)->node;
	resultExpr = &past_binary_create
	  (past_mul, resultExpr, value)->node;
      }//else

    return resultExpr;
  }// struct ext_clast_expr* FormClastExprFromPolynomial(Polynomial *poly)



