/*
 * past2scop.c: This file is part of the IR-Converter project.
 *
 * IR-Converter: a library to convert PAST to ScopLib
 *
 * Copyright (C) 2011 Louis-Noel Pouchet
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 *
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 *
 * Author:
 * Louis-Noel Pouchet <pouchet@cse.ohio-state.edu>
 *
 */
#if HAVE_CONFIG_H
# include <irconverter/config.h>
#endif

#include <irconverter/common.h>
#include <irconverter/past2scop.h>
#include <irconverter/stack.h>
#include <irconverter/past_parser.h>

#include <assert.h>

static
void collect_symbols (s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_variable))
    {
      PAST_DECLARE_TYPED(variable, pv, node);
      s_stack_t** stack = args;
      stack_push (stack, pv->symbol);
    }
}


static
void
collect_parameters (s_stack_t** param_list, s_past_node_t* expr,
		    s_past_node_t* top)
{
  // 1- Collect surrounding loop iterators.
  s_past_node_t* node = expr->parent;
  s_stack_t* iterators = NULL;
  while (node)
    {
      if (past_node_is_a (node, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf, node);
	  stack_push (&iterators, pf->iterator->symbol);
	}
      if (node == top)
	break;
      node = node->parent;
    }

  // 2- Collect symbols in the expression.
  s_stack_t* symbols = NULL;
  past_visitor (expr, collect_symbols, (void*)&symbols, NULL, NULL);

  // 3- Parameters = symbols - iterators.
  s_stack_t* tmp;
  for (tmp = symbols; tmp; tmp = tmp->next)
    if (! stack_contains (&iterators, tmp->data))
      if (! stack_contains (param_list, tmp->data))
	stack_push (param_list, tmp->data);

  stack_free (&symbols);
  stack_free (&iterators);
}


static
char* mydup (char* elt, int data_is_char)
{
  if (data_is_char)
    return strdup (elt);
  else
    return elt;
}



static
scoplib_matrix_p
build_domain (s_stack_t** cstack, scoplib_scop_p newscop, int data_is_char,
	      s_past_node_t* subs)
{
  int has_fake_iter = 0;
  char** parameters = newscop->parameters;
  char fkiter[] = "fk";
  // 1- Compute the number of iterators.
  s_stack_t* tmp;
  s_stack_t* iters = NULL;
  s_stack_t* ctrl = NULL;
  for (tmp = *cstack; tmp; tmp = tmp->next)
    {
      if (past_node_is_a (tmp->data, past_for))
	stack_push (&iters, tmp->data);
      stack_push (&ctrl, tmp);
    }

  if (stack_size (&iters) == 0)
    has_fake_iter = 1;
  int nb_iters = stack_size (&iters) + has_fake_iter;
  char* iterators[nb_iters + 1];
  int i;
  for (i = 0; iters; ++i)
    {
      s_past_for_t* fornode = stack_pop (&iters);
      if (fornode)
	iterators[i] = fornode->iterator->symbol->data;
    }
  if (has_fake_iter)
    iterators[i++] = fkiter;
  iterators[i] = NULL;

  // 2- Allocate matrix.
  scoplib_matrix_p mat = NULL;
  // 3- Fill in.
  for (tmp = stack_pop (&ctrl); tmp; tmp = stack_pop (&ctrl))
    {
      if (past_node_is_a (tmp->data, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf, tmp->data);
	  scoplib_matrix_p matlb = past_parser (pf->init, iterators, parameters,
						data_is_char);
	  scoplib_matrix_p matub = past_parser (pf->test, iterators, parameters,
					       data_is_char);
	  scoplib_matrix_p matf = scoplib_matrix_concat (matlb, matub);
	  scoplib_matrix_free (matlb);
	  scoplib_matrix_free (matub);
	  scoplib_matrix_p mat2 = scoplib_matrix_concat (mat, matf);
	  if (mat)
	    scoplib_matrix_free (mat);
	  scoplib_matrix_free (matf);
	  mat = mat2;
	}
      else if (past_node_is_a (tmp->data, past_affineguard))
	{
	  PAST_DECLARE_TYPED(affineguard, pa, tmp->data);
	  if (past_node_is_a (pa->condition, past_equal))
	    {
	      PAST_DECLARE_TYPED(binary, pb, pa->condition);
	      if (past_node_is_a (pb->lhs, past_mod) &&
		  past_node_is_a (pb->rhs, past_value))
		{
		  // Ignore the 'xxx % y == 0' conditionals
		  continue;
		}
	    }
	  else if (past_node_is_a (pa->condition, past_leq) ||
		   past_node_is_a (pa->condition, past_lt))
	    {
	      PAST_DECLARE_TYPED(binary, pb, pa->condition);
	      if (past_node_is_a (pb->lhs, past_mod) &&
		  past_node_is_a (pb->rhs, past_value))
		{
		  // Ignore the 'xxx % y <= z' conditionals
		  continue;
		}

	    }
	  scoplib_matrix_p matif =
	    past_parser (pa->condition, iterators, parameters, data_is_char);
	  scoplib_matrix_p mat2 = scoplib_matrix_concat (mat, matif);
	  scoplib_matrix_free (matif);
	  if (mat)
	    scoplib_matrix_free (mat);
	  mat = mat2;
	}
    }

  if (has_fake_iter)
    {
      int nb_param = newscop->nb_parameters;
      scoplib_matrix_p fkctrl =
  	scoplib_matrix_malloc (2, 3 + nb_param);
      SCOPVAL_set_si(fkctrl->p[0][0], 1);
      SCOPVAL_set_si(fkctrl->p[0][1], 1);
      SCOPVAL_set_si(fkctrl->p[1][0], 1);
      SCOPVAL_set_si(fkctrl->p[1][1], -1);
      scoplib_matrix_p mat2 = scoplib_matrix_concat (mat, fkctrl);
      if (mat)
  	scoplib_matrix_free (mat);
      scoplib_matrix_free (fkctrl);
      mat = mat2;
    }

  return mat;
}

static
scoplib_statement_p
find_statement (scoplib_scop_p scop, int num)
{
  int i;
  scoplib_statement_p stm;
  for (stm = scop->statement, i = 1; stm && i != num; ++i, stm = stm->next)
    ;

  return stm;
}

struct s_tuple
{
  int			id;
  s_past_node_t*	node;
};
typedef struct s_tuple s_tuple_t;

static
void traverse_schedule (s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_for) ||
      past_node_is_a (node, past_cloogstmt) ||
      past_node_is_a (node, past_statement))
    {
      int* pos = ((void**)args)[0];
      s_stack_t** schedule = ((void**)args)[1];
      (*pos)++;
      s_tuple_t* t = XMALLOC(s_tuple_t, 1);
      t->id = *pos;
      t->node = node;
      stack_push (schedule, t);
    }
}

static
void compute_full_schedule (s_past_node_t* root, s_stack_t** schedule)
{
  void* args[2];
  int pos = 0;
  args[0] = &pos;
  args[1] = schedule;
  past_visitor (root, traverse_schedule, (void*)args, NULL, NULL);
}


static
scoplib_matrix_p
compute_schedule (s_past_node_t* node,
		  s_stack_t** full_schedule,
		  s_stack_t** control,
		  int nb_param)
{
  s_stack_t* tmp;
  s_stack_t* tmp2;
  int pos = 0;
  s_stack_t* cleaned_ctrl = NULL;
  for (tmp = *control; tmp; tmp = tmp->next)
    if (past_node_is_a (tmp->data, past_for))
      stack_push (&cleaned_ctrl, tmp->data);

  int beta_size = 1 + stack_size (&cleaned_ctrl);
  // Deal with fake iterator.
  if (! cleaned_ctrl)
    ++beta_size;

  int beta[beta_size];
  if (cleaned_ctrl)
    {
      while (cleaned_ctrl)
	{
	  s_past_node_t* fornode = stack_pop (&cleaned_ctrl);
	  for (tmp = *full_schedule; tmp &&
		 ((s_tuple_t*)tmp->data)->node != fornode; tmp = tmp->next)
	    ;
	  beta[pos++] = ((s_tuple_t*)tmp->data)->id;
	}
      for (tmp = *full_schedule; tmp &&
	     ((s_tuple_t*)tmp->data)->node != node; tmp = tmp->next)
	;
      beta[pos++] = ((s_tuple_t*)tmp->data)->id;
      assert (pos <= beta_size);
    }
  else
    {
      // Deal with fake iterator.
      for (tmp = *full_schedule; tmp &&
	     ((s_tuple_t*)tmp->data)->node != node; tmp = tmp->next)
	;
      beta[pos++] = ((s_tuple_t*)tmp->data)->id;
      beta[pos++] = 0;
    }

  scoplib_matrix_p sched = scoplib_matrix_malloc (pos * 2 - 1,
						  pos - 1 + nb_param + 2);
  int i;
  for (i = 0, pos = 0; i < sched->NbRows; ++i)
    {
      if (i % 2 == 0)
	SCOPVAL_set_si (sched->p[i][sched->NbColumns - 1], beta[pos++]);
      else
	{
	  assert (i/2 + 1 < sched->NbColumns);
	  SCOPVAL_set_si (sched->p[i][i/2 + 1], 1);
	}
    }

  return sched;
}

static
void
traverse_scop_pref (s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_for) ||
      past_node_is_a (node, past_affineguard))
    {
      s_stack_t** control_stack = ((void**)args)[0];
      stack_push (control_stack, node);
    }
  else if (past_node_is_a (node, past_cloogstmt))
    {
      PAST_DECLARE_TYPED(cloogstmt, pc, node);
      scoplib_statement_p newstmt = scoplib_statement_malloc ();
      // Fill-in domain.
      scoplib_matrix_list_p ml = scoplib_matrix_list_malloc ();
      s_stack_t** control_stack = ((void**)args)[0];
      scoplib_scop_p newscop = ((void**)args)[2];
      scoplib_scop_p oldscop = ((void**)args)[3];
      int data_is_char = (int)(long)(((void**)args)[4]);
      s_stack_t** full_schedule_stack = ((void**)args)[5];
      scoplib_matrix_p domain = build_domain (control_stack, newscop,
      					      data_is_char, pc->substitutions);
      ml->elt = domain;
      newstmt->domain = ml;

      // Fill-in iterator list.
      int nb_param = newscop->nb_parameters;
      newstmt->nb_iterators = domain->NbColumns - nb_param - 2;
      newstmt->iterators = XMALLOC(char*, newstmt->nb_iterators + 1);
      s_stack_t* tmp;
      s_stack_t* iters = NULL;
      for (tmp = *control_stack; tmp; tmp = tmp->next)
	{
	  if (past_node_is_a (tmp->data, past_for))
	    stack_push (&iters, tmp->data);
	}
      int i, j, k;
      for (i = 0; iters; ++i)
	{
	  s_past_for_t* fornode = stack_pop (&iters);
	  if (fornode->iterator->symbol->is_char_data)
	    newstmt->iterators[i] = strdup (fornode->iterator->symbol->data);
	  else
	    newstmt->iterators[i] = fornode->iterator->symbol->data;
	}
      // Deal with the fake iterator.
      int has_fake_iter = 0;
      if (i == 0)
	{
	  newstmt->iterators[i++] = strdup ("fk");
	  has_fake_iter = 1;
	}
      newstmt->iterators[i] = NULL;
      // Body remains NULL.

      scoplib_statement_p oldstmt = find_statement (oldscop, pc->stmt_number);
      assert (oldstmt);

      // Compute substitution matrix.
      s_past_node_t* sub;
      scoplib_matrix_p subs = NULL;
      assert (pc->substitutions);
      for (sub = pc->substitutions; sub; sub = sub->next)
      	{
      	  scoplib_matrix_p mat = past_parser (sub, newstmt->iterators,
      	  				      newscop->parameters,
      	  				      data_is_char);
      	  assert (mat->NbColumns == newstmt->nb_iterators + newscop->nb_parameters + 2);
      	  scoplib_matrix_p tmp = scoplib_matrix_concat (subs, mat);
      	  if (subs != NULL)
	    scoplib_matrix_free (subs);
      	  subs = tmp;
      	}

      // Fill-in read matrix.
      scoplib_matrix_p read = scoplib_matrix_malloc (oldstmt->read->NbRows,
						     domain->NbColumns);
      for (i = 0; i < read->NbRows; ++i)
      	{
      	  scoplib_int_t tmp;
      	  SCOPVAL_init(tmp);
      	  SCOPVAL_assign(read->p[i][0], oldstmt->read->p[i][0]);
	  int offset = 0;
      	  for (j = 0; j < subs->NbRows; ++j)
      	    {
	      // Skip the substitution that corresponds to newly
	      // created iterators (eg, tile iterators).
	      if (j  < oldstmt->nb_iterators &&
		  oldstmt->iterators[j][0] == 'f' &&
		  oldstmt->iterators[j][1] == 'k')
		{
		  ++offset;
		  continue;
		}
	      // Process the substitution.
      	      for (k = 1; k < read->NbColumns; ++k)
      	      	{
      	      	  SCOPVAL_multo(tmp, oldstmt->read->p[i][j + 1 - offset],
      	      	  		subs->p[j][k]);
      	      	  SCOPVAL_addto(read->p[i][k], read->p[i][k], tmp);
      	      	}
	    }
	  // Reinsert the base values.
	  for (j = 0; j < nb_param + 1; ++j)
	    SCOPVAL_addto(read->p[i][read->NbColumns - 1 - j],
			  read->p[i][read->NbColumns - 1 - j],
			  oldstmt->read->p[i][oldstmt->read->NbColumns - 1 -j]);

      	}
      newstmt->read = read;

      // Fill-in written matrix.
      scoplib_matrix_p write = scoplib_matrix_malloc (oldstmt->write->NbRows,
      						      domain->NbColumns);
      for (i = 0; i < write->NbRows; ++i)
      	{
      	  scoplib_int_t tmp;
      	  SCOPVAL_init(tmp);
      	  SCOPVAL_assign(write->p[i][0], oldstmt->write->p[i][0]);
	  int offset = 0;
	  for (j = 0; j < subs->NbRows; ++j)
      	    {
	      // Skip the substitution that corresponds to newly
	      // created iterators (eg, tile iterators).
	      if (j  < oldstmt->nb_iterators &&
		  oldstmt->iterators[j][0] == 'f' &&
		  oldstmt->iterators[j][1] == 'k')
		{
		  ++offset;
		  continue;
		}
	      // Process the substitution.
      	      for (k = 1; k < write->NbColumns; ++k)
      		{
      		  SCOPVAL_multo(tmp, oldstmt->write->p[i][j + 1 - offset],
				subs->p[j][k]);
      		  SCOPVAL_addto(write->p[i][k], write->p[i][k], tmp);
      		}
      	    }
	  // Reinsert the base values.
	  for (j = 0; j < nb_param + 1; ++j)
	      SCOPVAL_addto(write->p[i][write->NbColumns - 1 - j],
			    write->p[i][write->NbColumns - 1 - j],
			    oldstmt->write->p[i]
			    [oldstmt->write->NbColumns - 1 -j]);

      	}
      newstmt->write = write;

      // Compute schedule.
      newstmt->schedule =
      	compute_schedule (node, full_schedule_stack, control_stack, nb_param);

/*       if (data_is_char) */
/* 	newstmt->body = */
/* 	  strdup (find_statement (oldscop, pc->stmt_number)->body); */
/*       else */
/* 	newstmt->body = find_statement (oldscop, pc->stmt_number)->body; */

      // Store the new statement.
      if (newscop->statement == NULL)
	newscop->statement = newstmt;
      else
	{
	  scoplib_statement_p stm = newscop->statement;
	  for (; stm->next; stm = stm->next)
	    ;
	  stm->next = newstmt;
	}

      // Be clean.
      if (subs)
      	scoplib_matrix_free (subs);
    }
}


static
void
traverse_scop_post (s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_for) ||
      past_node_is_a (node, past_affineguard))
    {
      s_stack_t** control_stack = ((void**)args)[0];
      stack_pop (control_stack);
    }
}


/**
 * Returns a scoplib that describes the control in the past
 * tree. Statement bodies are not filled, meaning the scoplib cannot
 * be used for code generation.
 *
 *
 */
scoplib_scop_p
past2scop_control_only (s_past_node_t* root, scoplib_scop_p orig_scop,
			int data_is_char)
{
  if (! root)
    return NULL;

  int i;
  if (orig_scop)
    {
      scoplib_scop_p ret = scoplib_scop_malloc ();
      ret->parameters = XMALLOC(char*, orig_scop->nb_parameters + 1);
      for (i = 0; i < orig_scop->nb_parameters; ++i)
	ret->parameters[i] = mydup (orig_scop->parameters[i], data_is_char);
      ret->parameters[i] = NULL;
      ret->nb_parameters = orig_scop->nb_parameters;

      ret->arrays = XMALLOC(char*, orig_scop->nb_arrays + 1);
      for (i = 0; i < orig_scop->nb_arrays; ++i)
	ret->arrays[i] = mydup (orig_scop->arrays[i], data_is_char);
      ret->arrays[i] = NULL;
      ret->nb_arrays = orig_scop->nb_arrays;

      ret->context = scoplib_matrix_copy (orig_scop->context);
      if (orig_scop->optiontags)
	ret->optiontags = strdup (orig_scop->optiontags);

      void* args[6];
      s_stack_t* control_stack = NULL;
      s_stack_t* full_schedule_stack = NULL;
      args[0] = &control_stack;
      args[2] = ret;
      args[3] = orig_scop;
      args[4] = (void*)((long)data_is_char);
      args[5] = &full_schedule_stack;
      compute_full_schedule (root, &full_schedule_stack);
      past_visitor (root, traverse_scop_pref, args, traverse_scop_post, args);
      scoplib_scop_normalize_schedule (ret);

      return ret;
    }
  else
    return NULL; // Will be implemented later.
}
