/*
 * past_parser.c: This file is part of the IR-Converter project.
 *
 * IR-Converter: a library to convert PAST to ScopLib
 *
 * Copyright (C) 2011 Louis-Noel Pouchet
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 *
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 *
 * Author:
 * Louis-Noel Pouchet <pouchet@cse.ohio-state.edu>
 *
 */
#if HAVE_CONFIG_H
# include <irconverter/config.h>
#endif

#include <irconverter/common.h>
#include <irconverter/past_parser.h>

#include <past/pprint.h>

#include <assert.h>


/******************************************************************************/
/************************ Affine expression processing ************************/
/******************************************************************************/
static
scoplib_matrix_p
parseBinExpr (s_past_node_t* lhs, int lhsmult,
	      s_past_node_t* rhs, int rhsmult,
	      int offset, char** iterators, char** parameters,
	      int data_is_char);
static
scoplib_matrix_p
parseLinearExpression (s_past_node_t* llhs,
		       int lhsmult,
		       s_past_node_t* lrhs,
		       int rhsmult,
		       int offset,
		       char** iterators,
		       char** parameters,
		       int data_is_char);
static
int
convertLinearExpression(s_past_node_t* e,
			int mult,
			char** iterators,
			char** parameters,
			scoplib_matrix_p matrix);


/**
 * Generic expression parser.
 *
 * Parse conditionals and access functions.
 *
 *
 */
scoplib_matrix_p
past_parser(s_past_node_t* expr,
	    char** iterators, char** parameters, int data_is_char)
{
  // conjunction: a && b
  if (past_node_is_a (expr, past_and))
    {
      PAST_DECLARE_TYPED(binary, andop, expr);
      scoplib_matrix_p mat1 =
	past_parser (andop->lhs, iterators, parameters, data_is_char);
      scoplib_matrix_p mat2 =
	past_parser (andop->rhs, iterators, parameters, data_is_char);
      scoplib_matrix_p mat = scoplib_matrix_concat (mat1, mat2);
      scoplib_matrix_free(mat1);
      scoplib_matrix_free(mat2);

      return mat;
    }
  else
    {
      scoplib_matrix_p mat = NULL;
      int is_linexp = 0;
      if (past_node_is_a (expr, past_add) ||
	  past_node_is_a (expr, past_sub) ||
	  past_node_is_a (expr, past_mul))
	is_linexp = 1;

      if (! is_linexp && past_node_is_a (expr, past_binary))
	{
	  PAST_DECLARE_TYPED(binary, binop, expr);
	  s_past_node_t* lhs = binop->lhs;
	  s_past_node_t* rhs = binop->rhs;
	  if (past_node_is_a (expr, past_gt))
	    mat = parseBinExpr (lhs, 1, rhs, -1, -1, iterators, parameters,
				data_is_char);
	  else if (past_node_is_a (expr, past_lt))
	    mat = parseBinExpr (lhs, -1, rhs, 1, -1, iterators, parameters,
				data_is_char);
	  else if (past_node_is_a (expr, past_geq))
	    mat = parseBinExpr (lhs, 1, rhs, -1, 0, iterators, parameters,
				data_is_char);
	  else if (past_node_is_a (expr, past_leq))
	    mat = parseBinExpr (lhs, -1, rhs, 1, 0, iterators, parameters,
				data_is_char);
	  else if (past_node_is_a (expr, past_equal))
	    {
	      scoplib_matrix_p mat1 =
		parseBinExpr (lhs, -1, rhs, 1, 0, iterators, parameters,
			      data_is_char);
	      scoplib_matrix_p mat2 =
		parseBinExpr (lhs, 1, rhs, -1, 0, iterators, parameters,
			      data_is_char);
	      mat = scoplib_matrix_concat (mat1, mat2);
	      scoplib_matrix_free (mat1);
	      scoplib_matrix_free (mat2);
	    }
	  else if (past_node_is_a (expr, past_assign))
	    mat = parseBinExpr (lhs, 1, rhs, -1, 0, iterators, parameters,
				data_is_char);
	  else if (past_node_is_a (expr, past_ceild) ||
		   past_node_is_a (expr, past_div) ||
		   past_node_is_a (expr, past_floord))
	    {
	      PAST_DECLARE_TYPED(binary, pc, expr);
	      if (past_node_is_a (pc->rhs, past_value))
		{
		  PAST_DECLARE_TYPED(value, pv, pc->rhs);
		  int val = pv->value.intval;
		  mat = parseBinExpr(pc->lhs, val, NULL, 0, 0, iterators,
				     parameters, data_is_char);
		}
	      else
		{
		  fprintf (stderr, "[IR-converter] ERROR: Cannot safely manage ceild/floord node\n");
		  exit (1);
		}
	    }
	  else
	    {
	      fprintf (stderr, "[IR-converter] ERROR: Unkown binop\n");
	      exit (1);
	    }
	}
      else
	{
	  mat = parseBinExpr (expr, 1, NULL, 0, 0, iterators, parameters,
			      data_is_char);
	}

      return mat;
    }
}


/**
 * Parse a demangled binary expression, eg a '<' b. To conform scoplib
 * representation, a < b is represented as b - a - 1 >= 0. So, a is
 * given a mult of -1, b a mult of 1, and the offset is -1.
 *
 */
static
scoplib_matrix_p
parseBinExpr (s_past_node_t* lhs, int lhsmult,
	      s_past_node_t* rhs, int rhsmult,
	      int offset, char** iterators, char** parameters,
	      int data_is_char)
{
  s_past_node_t* mlhs = NULL;
  s_past_node_t* mrhs = NULL;

  // Over-approximation: ignore conditionals with modulo inside
  // (over-approximated to always true.
  if (past_count_nodetype (lhs, past_mod) > 0 ||
      past_count_nodetype (rhs, past_mod) > 0)
    {
      printf ("[IrConverter][WARNING] %% expression unexpected. Ingoring it...\n");
      return parseLinearExpression (NULL, 1, NULL, 1, 0,
				    iterators, parameters, data_is_char);
    }

  // Standard cases.
  if (past_node_is_a (lhs, past_ceild) || past_node_is_a (lhs, past_floord))
    {
      PAST_DECLARE_TYPED(binary, pb, lhs);
      mlhs = pb->lhs;
      mrhs = pb->rhs;
      if (past_node_is_a (mrhs, past_value))
	{
	  PAST_DECLARE_TYPED(value, pv, mrhs);
	  rhsmult *= pv->value.intval;
	  scoplib_matrix_p res =
	    parseLinearExpression (mlhs, lhsmult, rhs, rhsmult, offset,
				   iterators, parameters, data_is_char);
	  return res;
	}
      else
	assert (0);
    }
  else if (past_node_is_a (rhs, past_ceild) ||
	   past_node_is_a (rhs, past_floord))
    {
      PAST_DECLARE_TYPED(binary, pb, rhs);
      mlhs = pb->lhs;
      mrhs = pb->rhs;
      if (past_node_is_a (mrhs, past_value))
	{
	  PAST_DECLARE_TYPED(value, pv, mrhs);
	  lhsmult *= pv->value.intval;
	  scoplib_matrix_p res =
	    parseBinExpr (lhs, lhsmult, mlhs, rhsmult, offset,
			  iterators, parameters, data_is_char);
	  return res;
	}
      else
	assert (0);
    }
  else if (past_node_is_a (lhs, past_min) || past_node_is_a (lhs, past_max))
    {
      PAST_DECLARE_TYPED(binary, pb, lhs);
      mlhs = pb->lhs;
      mrhs = pb->rhs;
      scoplib_matrix_p mat1 = parseBinExpr (mlhs, lhsmult,
					    rhs, rhsmult,
					    offset,
					    iterators, parameters,
					    data_is_char);
      scoplib_matrix_p mat2 = parseBinExpr (mrhs, lhsmult,
					    rhs, rhsmult,
					    offset,
					    iterators, parameters,
					    data_is_char);
      scoplib_matrix_p res = scoplib_matrix_concat (mat1, mat2);
      scoplib_matrix_free (mat1);
      scoplib_matrix_free (mat2);
      return res;
    }
  else if (past_node_is_a (rhs, past_min) || past_node_is_a (rhs, past_max))
    {
      PAST_DECLARE_TYPED(binary, pb, rhs);
      mlhs = pb->lhs;
      mrhs = pb->rhs;
      scoplib_matrix_p mat1 = parseBinExpr (lhs, lhsmult,
					    mlhs, rhsmult,
					    offset,
					    iterators, parameters,
					    data_is_char);
      scoplib_matrix_p mat2 = parseBinExpr (lhs, lhsmult,
					    mrhs, rhsmult,
					    offset,
					    iterators, parameters,
					    data_is_char);
      scoplib_matrix_p res = scoplib_matrix_concat (mat1, mat2);
      scoplib_matrix_free (mat1);
      scoplib_matrix_free (mat2);
      return res;
    }
  else
    {
      // Special case: add/sub w/ complex expression.
      if (past_node_is_a (lhs, past_add) || past_node_is_a (lhs, past_sub))
	{
	  PAST_DECLARE_TYPED(binary, pb, lhs);
	  if (past_node_is_a (pb->rhs, past_value))
	    {
	      PAST_DECLARE_TYPED(value, pv, pb->rhs);
	      int val = pv->value.intval * lhsmult;
	      if (past_node_is_a (lhs, past_sub))
		val *= -1;
	      return parseBinExpr (pb->lhs, lhsmult, rhs, rhsmult,
				   offset + val, iterators,
				   parameters, data_is_char);

	    }
	}
      else if (past_node_is_a (rhs, past_add) || past_node_is_a (rhs, past_sub))
	{
	  PAST_DECLARE_TYPED(binary, pb, rhs);
	  if (past_node_is_a (pb->rhs, past_value))
	    {
	      PAST_DECLARE_TYPED(value, pv, pb->rhs);
	      int val = pv->value.intval * rhsmult;
	      if (past_node_is_a (rhs, past_sub))
		val *= -1;
	      return parseBinExpr (lhs, lhsmult, pb->lhs, rhsmult,
				   offset + val, iterators,
				   parameters, data_is_char);

	    }
	}
      // Default fall-back.
      return parseLinearExpression (lhs, lhsmult, rhs, rhsmult, offset,
				    iterators, parameters, data_is_char);
    }
}


/**
 * Parse a linear affine expression, made of a lhs and a rhs (possibly
 * NULL). return the scoplib_matrix representation (1 row always).
 *
 */
static
scoplib_matrix_p
parseLinearExpression (s_past_node_t* llhs,
		       int lhsmult,
		       s_past_node_t* lrhs,
		       int rhsmult,
		       int offset,
		       char** iterators,
		       char** parameters,
		       int data_is_char)
{
  int nb_iter;
  for (nb_iter = 0; iterators[nb_iter]; ++nb_iter)
    ;
  int nb_params;
  for (nb_params = 0; parameters[nb_params]; ++nb_params)
    ;
  assert (nb_iter > 0);
  scoplib_matrix_p res = scoplib_matrix_malloc(1, nb_iter + nb_params + 2);
  SCOPVAL_set_si(res->p[0][0], 1);

  convertLinearExpression (llhs, lhsmult, iterators, parameters, res);
  convertLinearExpression (lrhs, rhsmult, iterators, parameters, res);
  if (offset != 0)
    {
      scoplib_int_t val; SCOPVAL_init(val); SCOPVAL_set_si(val, offset);
      /// DEBUG:
      assert(nb_iter + nb_params + 1 < res->NbColumns);
      SCOPVAL_addto(res->p[0][nb_iter + nb_params + 1],
		    res->p[0][nb_iter + nb_params + 1], val);
      SCOPVAL_clear(val);
    }

  return res;
}


/**
 * Internal Helper. Compute the coefficient associated to a symbol in
 * an affine expression.
 *
 */
static
scoplib_int_t
computeCoefficient(s_past_node_t* e,
		   char* symb)
{
  scoplib_int_t res;
  SCOPVAL_init(res);

  if (past_node_is_a (e, past_value))
    {
      PAST_DECLARE_TYPED(value, pv, e);
      SCOPVAL_set_si(res, pv->value.intval);
    }
  else if (past_node_is_a (e, past_variable))
    {
      PAST_DECLARE_TYPED(variable, pv, e);
      if (pv->symbol->is_char_data)
	{
	  if (symb && ! strcmp (pv->symbol->data, symb))
	    SCOPVAL_set_si(res, 1);
	  else
	    SCOPVAL_set_si(res, 0);
	}
      else
	{
	  if (symb && pv->symbol->data == symb)
	    SCOPVAL_set_si(res, 1);
	  else
	    SCOPVAL_set_si(res, 0);
	}
    }
  else if (past_node_is_a (e, past_binary))
    {
      PAST_DECLARE_TYPED(binary, pb, e);
      scoplib_int_t lhs = computeCoefficient (pb->lhs, symb);
      scoplib_int_t rhs = computeCoefficient (pb->rhs, symb);
      if (past_node_is_a (e, past_add))
	SCOPVAL_addto(res, lhs, rhs);
      else if (past_node_is_a (e, past_sub))
	SCOPVAL_subtract(res, lhs, rhs);
      else if (past_node_is_a (e, past_mul))
	SCOPVAL_multo(res, lhs, rhs);
      else if (past_node_is_a (e, past_ceild) ||
	       past_node_is_a (e, past_floord))
	{
	  if (SCOPVAL_notzero_p(rhs))
	    SCOPVAL_divto(res, lhs, rhs);
	  else
	    {
	      printf ("[IrConverter][WARNING] Divide by 0 in parser, replace by 0\n");
	      SCOPVAL_set_si(res, 0);
	    }
	}
      else
	{
	  printf ("[IrConverter][ERROR] Unsupported node: ");
	  past_pprint (stdout, e);
	  printf ("\n");
	  assert(0);
	}
      SCOPVAL_clear(lhs);
      SCOPVAL_clear(rhs);
    }
  else if (past_node_is_a (e, past_unary))
    {
      PAST_DECLARE_TYPED(unary, pu, e);
      scoplib_int_t val = computeCoefficient (pu->expr, symb);
      SCOPVAL_assign(res, val);
    }

  return res;
}


/**
 * Helper.
 *
 */
static
int
convertLinearExpression(s_past_node_t* e,
			int mult,
			char** iterators,
			char** parameters,
			scoplib_matrix_p matrix)
{
  if (e == NULL)
    return 1;
  assert (iterators);
  assert (parameters);
  scoplib_int_t zmult;
  SCOPVAL_init(zmult);
  SCOPVAL_set_si(zmult, mult);
  scoplib_int_t coeff;
  SCOPVAL_init(coeff);

  // 1- Compute the coefficient of the scalar constant.
  scoplib_int_t scalar;
  SCOPVAL_init(scalar);
  scalar = computeCoefficient(e, NULL);
  SCOPVAL_multo(coeff, scalar, mult);
  /// DEBUG:
  assert(matrix->NbRows > 0);
  SCOPVAL_addto(matrix->p[0][matrix->NbColumns - 1],
		matrix->p[0][matrix->NbColumns - 1], coeff);

  // 2- Iterate on all iterator symbols.
  int i;
  for (i = 0; iterators[i]; ++i)
    {
      /// DEBUG:
      assert(1 + i < matrix->NbColumns);
      coeff = computeCoefficient(e, iterators[i]);
      SCOPVAL_subtract(coeff, coeff, scalar);
      SCOPVAL_multo(coeff, coeff, zmult);
      SCOPVAL_addto(matrix->p[0][1 + i],
		    matrix->p[0][1 + i], coeff);
    }

  // 3- Iterate on all parameter symbols.
  int j;
  for (j = 0; parameters[j]; ++j)
    {
      /// DEBUG:
      assert(1 + i + j < matrix->NbColumns);
      coeff = computeCoefficient(e, parameters[j]);
      SCOPVAL_subtract(coeff, coeff, scalar);
      SCOPVAL_multo(coeff, coeff, zmult);
      SCOPVAL_addto(matrix->p[0][1 + i + j],
		    matrix->p[0][1 + i + j], coeff);
    }

  SCOPVAL_clear(coeff);
  SCOPVAL_clear(zmult);
  SCOPVAL_clear(scalar);

  return 1;
}


