/*
 * pocc_driver.c: This file is part of the PTile project.
 *
 * PTile: A PAST-to-PAST parametric tiling software
 *
 * Copyright (C) 2011 Sanket Tavargeri
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 *
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 *
 * Author:
 * Louis-Noel Pouchet <pouchet@cse.ohio-state.edu>
 *
 */
#if HAVE_CONFIG_H
# include <ptile/config.h>
#endif

#include <ptile/common.h>
#include <ptile/pocc_driver.h>
#include <assert.h>

#include <candl/candl.h>
#include <candl/ddv.h>
#include <scoptools/past2scop.h>
#include <past/past_api.h>
#include <irconverter/past2scop.h>
#include <ptile/PTile.hpp>

#include <past/pprint.h>
#include <scoplib/scop.h>


struct s_subscop
{
  s_past_node_t* root;
  scoplib_scop_p scop;
  int		 is_parallel;
};
typedef struct s_subscop s_subscop_t;

/**
 * past_macro_stmt
 *
 *
 */

struct past_macro_stmt_t
{
  s_past_node_t		node;
  //
  s_past_node_t*	body;
};
typedef struct past_macro_stmt_t s_past_macro_stmt_t;
PAST_DECLARE_NODE_IN_HIERARCHY_HEADER(macro_stmt);


static void
past_macro_stmt_free (s_past_node_t* node)
{
  assert (past_node_is_a (node, past_macro_stmt));
  s_past_macro_stmt_t* pf = (s_past_macro_stmt_t*) node;
  past_deep_free (pf->body);
  XFREE(pf);
}
static void
past_macro_stmt_visitor (s_past_node_t* node,
			past_fun_t prefix,
			void* prefix_data,
			past_fun_t suffix,
			void* suffix_data)
{
  assert (past_node_is_a (node, past_macro_stmt));
  PAST_DECLARE_TYPED(macro_stmt, r, node);
  past_visitor (r->body, prefix, prefix_data, suffix, suffix_data);
}
PAST_DECLARE_NODE_IN_HIERARCHY_UNIT_1(macro_stmt);

s_past_macro_stmt_t* past_macro_stmt_create (s_past_node_t* body)
{
  s_past_macro_stmt_t* n = XMALLOC(s_past_macro_stmt_t, 1);
  n->node.type = past_macro_stmt;
  n->node.visitor = past_macro_stmt_visitor;
  n->node.parent = NULL;
  n->node.next = NULL;
  n->body = body;
  for (; body; body->parent = &(n->node), body = body->next)
    ;

  return n;
}

s_past_node_t* past_node_macro_stmt_create (s_past_node_t* body)
{
  return &past_macro_stmt_create (body)->node;
}


static
int
loops_are_nested (CandlProgram* cprogram, int l1, int l2)
{
  int i, j;

  for (i = 0; i < cprogram->nb_statements; ++i)
    {
      int nested = 0;
      for (j = 0; j < cprogram->statement[i]->depth; ++j)
	{
	  int lidx = cprogram->statement[i]->index[j];
	  if (lidx == l1 || lidx == l2)
	    ++nested;
	}
      if (nested == 2)
	return 1;
    }

  return 0;
}


struct s_process_data
{
  s_past_node_t*	fornode;
  int			forid;
  int			is_outer;
  int			is_parallel;
};
typedef struct s_process_data s_process_data_t;

static
void traverse_tree_index_for (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
    {
      int i;
      s_process_data_t* pd = (s_process_data_t*) data;
      for (i = 0; pd[i].fornode != NULL; ++i)
	;
      pd[i].fornode = node;
      pd[i].forid = i;
      pd[i].is_outer = past_is_outer_for_loop (node);
      pd[i].is_parallel = 0;
    }
  if (past_node_is_a (node, past_cloogstmt))
    {
      // Special case: statements not surrouded by any loop in the
      // tree that are surrounded by a fake loop in the scop representation.
      s_past_node_t* parent;
      for (parent = node->parent; parent && !past_node_is_a (parent, past_for);
	   parent = parent->parent)
	;
      if (!parent)
	{
	  int i;
	  s_process_data_t* pd = (s_process_data_t*) data;
	  for (i = 0; pd[i].fornode != NULL; ++i)
	    ;
	  pd[i].fornode = node;
	  pd[i].forid = i;
	  pd[i].is_outer = 0; // This ensures the 'loop' will be skipped.
	  pd[i].is_parallel = 0;
	}
    }
}


#ifndef max
# define max(a,b) (a < b ? b : a)
#endif

static
char** compute_iterator_list (s_past_node_t* root)
{
  int maxdepth = past_max_loop_depth (root);
  s_past_node_t* node = past_find_statement_at_depth (root, maxdepth);
  char** ret = XMALLOC(char*, maxdepth + 1);

  int pos = 0;
  for (; node; node = node->parent)
    {
      if (past_node_is_a(node, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf, node);
	  ret[pos++] = pf->iterator->symbol->data;
	  if (node == root)
	    break;
	}
    }
  ret[pos] = NULL;

  return ret;
}


static
s_past_node_t*
create_embedding_loop (s_past_node_t* body, s_past_for_t* refloop,
		       const char* iter, s_past_node_t* prevfor,
		       s_past_node_t* nextfor)
{
  s_symbol_t* itersymb = symbol_add_from_char (NULL, iter);
  s_past_node_t* bound;

  if (prevfor == NULL && nextfor == NULL)
    bound = past_node_value_create_from_int (0);
  else
    {
      if (past_node_is_a (nextfor, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf, nextfor);
	  PAST_DECLARE_TYPED(binary, pb, pf->init);
	  bound = past_clone (pb->rhs);
	  bound = past_node_binary_create
	    (past_sub, bound, past_node_value_create_from_int (1));
	}
      else if (past_node_is_a (prevfor, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf, prevfor);
	  PAST_DECLARE_TYPED(binary, pb, pf->test);
	  bound = past_clone (pb->rhs);
	  bound = past_node_binary_create
	    (past_add, bound, past_node_value_create_from_int (1));
	}
      else
	{
	  printf ("[PTile][WARNING] Unable to get a good OTL bound\n");
	  bound = past_node_value_create_from_int (0);
	}
    }

  s_past_node_t* init =
    past_node_binary_create (past_assign,
			     past_node_variable_create (itersymb),
			     bound);
  itersymb = symbol_add_from_char (NULL, iter);
  s_past_node_t* test =
    past_node_binary_create (past_leq,
			     past_node_variable_create (itersymb),
			     bound);
  itersymb = symbol_add_from_char (NULL, iter);
  s_past_node_t* increment =
    past_node_unary_create (past_inc_after,
			    past_node_variable_create (itersymb));
  itersymb = symbol_add_from_char (NULL, iter);
  s_past_variable_t* iterator = past_variable_create (itersymb);

  s_past_node_t* parent = body->parent;
  s_past_node_t* newfor =
    past_node_for_create (init, test, iterator,
			  increment, body, NULL);
  PAST_DECLARE_TYPED(for, pf, newfor);
  pf->type = e_past_otl_loop;
  newfor->parent = parent;

  // Set the metainfo to character string "otl". This information is required for further processing
 char *otl = (char*) malloc (sizeof (char) * (strlen ("otl") + 1));
 strcpy (otl, "otl");
 newfor->metainfo = otl;

  return newfor;
}

static
int
past_local_loop_depth (s_past_node_t* node, s_past_node_t* top)
{
  int depth = 0;

  while (node && node != top)
    {
      node = node->parent;
      if (past_node_is_a (node, past_for))
	++depth;
    }

  return depth;
}

static
int local_past_count_loops_with_iter (s_past_node_t* node, const char* iter)
{
  if (! node)
    return 0;
  int num_loops = 0;
  s_past_node_t** outer_loops = past_outer_loops (node);
  int i;
  for (i = 0; outer_loops && outer_loops[i]; ++i)
    {
      PAST_DECLARE_TYPED(for, pouf, outer_loops[i]);
      if (pouf->iterator->symbol->is_char_data)
	{
	  if (! strcmp (iter, pouf->iterator->symbol->data))
	    ++num_loops;
	}
      else
	{
	  if (iter == pouf->iterator->symbol->data)
	    ++num_loops;
	}
    }
  XFREE(outer_loops);

  return num_loops;
}

static
void traverse_create_uniform_embedding (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
    {
      PAST_DECLARE_TYPED(for, pf, node);
      s_past_node_t* top = ((void**)data)[0];
      int* maxdepth = ((void**)data)[1];
      char** iterators = ((void**)data)[2];
      int bodyiter = 0;
      char* loopiter = pf->iterator->symbol->data;
      while (*iterators && strcmp (*iterators, loopiter))
	++iterators;
      if (*iterators && (iterators != ((void**)data)[2]))
	--iterators;
      else
	return;
      s_past_node_t* cur;
      int num_loops = 0;
      int num_siblings = 0;
      for (cur = pf->body; cur; cur = cur->next, ++num_siblings)
	num_loops += local_past_count_loops_with_iter (cur, *iterators);
      int local_depth = past_local_loop_depth (pf->body, top);
      if ((num_loops && num_loops != num_siblings) ||
	  (! num_loops && (local_depth < *maxdepth )))
	{
	  s_past_node_t* prevfor = NULL;
	  s_past_node_t* nextfor = NULL;
	  int depth;
	  cur = pf->body;
	  s_past_node_t* curnext;
	  do
	    {
	      depth = 0;
	      if (cur)
		curnext = cur->next;
	      while (cur)
		{
		  // Dig into conditionals.
		  while (past_node_is_a (cur, past_affineguard))
		    {
		      PAST_DECLARE_TYPED(affineguard, pa, cur);
		      cur = pa->then_clause;
		      ++depth;
		    }
		  if (! local_past_count_loops_with_iter (cur, *iterators))
		    {
		      s_past_node_t** addr = past_node_get_addr (cur);
		      if (addr)
			{
			  s_past_node_t* next = cur->next;
			  s_past_node_t* prev = cur;
			  while (next &&
				 ! local_past_count_loops_with_iter
				 (next, *iterators))
			    {
			      prev = next;
			      next = next->next;
			    }
			  nextfor = next;
			  prev->next = NULL;
			  s_past_node_t* parent = cur->parent;
			  *addr = create_embedding_loop (cur, pf, *iterators,
							 prevfor, nextfor);
			  (*addr)->next = next;
			  cur = next;
			  prevfor = *addr;
			}
		      else
			{
			  prevfor = cur;
			  cur = cur->next;
			}
		    }
		  else
		    {
		      prevfor = cur;
		      cur = cur->next;
		    }
		}
	      if (depth)
		cur = curnext;
	    }
	  while (depth != 0);
	}
    }
}

static
void
create_uniform_embedding (s_past_node_t* node)
{
  // 1- Compute the maximal loop depth.
  void* data[4];
  data[0] = node;
  int max_depth = past_max_loop_depth (node);
  data[1] = &max_depth;
  data[2] = compute_iterator_list (node);
  past_visitor (node, traverse_create_uniform_embedding, (void*)data,
		NULL, NULL);
}

static void
traverse_mark_point_loop (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
    {
      PAST_DECLARE_TYPED(for, pf, node);
      pf->type = e_past_point_loop;
    }
}


static
s_subscop_t*
pocc_create_tilable_nests (scoplib_scop_p program,
			   s_past_node_t* root,
			   int data_is_char)
{
  // Create SCoP corresponding to the transformed code.
  scoplib_scop_p control_scop = program;
  CandlOptions* coptions = candl_options_malloc ();
  CandlProgram* cprogram = candl_program_convert_scop (control_scop, NULL);
  CandlDependence* cdeps = candl_dependence (cprogram, coptions);
  int num_for_loops = past_count_for_loops (root);
  int num_stmts = past_count_statements (root);
  // Oversize the data structure, to deal with fake iterators.
  s_subscop_t* ret = XMALLOC(s_subscop_t, num_for_loops + num_stmts + 1);
  s_process_data_t prog_loops[num_for_loops + num_stmts];
  int i, j;
  for (i = 0; i < num_for_loops + num_stmts; ++i)
    prog_loops[i].fornode = NULL;
  past_visitor (root, traverse_tree_index_for, (void*)prog_loops, NULL, NULL);

  // Recompute the number of actual for loops.
  for (num_for_loops = 0; prog_loops[num_for_loops].fornode; ++num_for_loops)
    ;
  int partid = 0;
  for (i = 0; i < num_for_loops; ++i)
    {
      if (prog_loops[i].is_outer)
	{
	  for (j = i + 1; j < num_for_loops; ++j)
	    if (loops_are_nested (cprogram, i, j) &&
		! candl_loops_are_permutable (cprogram, cdeps, i, j))
		break;
	  if (j == num_for_loops)
	    {
	      // Ensure there is more than one loop in the nest.
	      if (past_count_for_loops (prog_loops[i].fornode) <= 1)
	      	continue;

	      // Check if the outer loop is parallel.
	      CandlDependence* d;
	      for (d = cdeps; d; d = d->next)
		if (candl_dependence_is_loop_carried (cprogram, d, i))
		  break;
	      ret[partid].is_parallel = (d == NULL);

	      // All loops in the nest are permutable. Process it.
	      ret[partid].root = prog_loops[i].fornode;

	      // Do 'otl' on the loop nest.
	      s_past_node_t* next = ret[partid].root->next;
	      ret[partid].root->next = NULL;
	      // All loops in the nest become point loops.
	      past_visitor (ret[partid].root, traverse_mark_point_loop, NULL,
			    NULL, NULL);
	      
	      create_uniform_embedding (ret[partid].root);


	      // Recompute the scoplib representation, in case new
	      // loops (OTL) have been inserted.

	      // 1- Insert the surrounding guards, PTile needs it.
	      /// FIXME: temporary.
	      s_past_node_t* parent = ret[partid].root->parent;
	      s_past_node_t* temproot = ret[partid].root->parent;
	      s_past_node_t* curnode = ret[partid].root;
	      while (past_node_is_a (temproot, past_affineguard))
		{
		  PAST_DECLARE_TYPED(affineguard, pa, temproot);
		  s_past_node_t* cond =
		    past_node_affineguard_create (pa->condition,
						  curnode);
		  curnode = cond;
		  temproot = temproot->parent;
		}
	      // 2- Get the new scop
	      ret[partid].scop =
	      	past2scop_control_only (curnode, program, data_is_char);
	      // 3- Restore.
	      if (curnode != ret[partid].root)
		while (past_node_is_a (curnode, past_affineguard))
		  {
		    PAST_DECLARE_TYPED(affineguard, pa, curnode);
		    s_past_node_t* then_clause = pa->then_clause;
		    XFREE(pa);
		    curnode = then_clause;
		  }
	      ret[partid].root->next = next;
	      ret[partid].root->parent = parent;

	      ++partid;
	    }
	}
    }
  ret[partid].root = NULL;
  if (partid == 0)
    {
      printf ("[PTile][Warning] There is no fully permutable loop nest in the program\n");
    }
  // Be clean.
  candl_dependence_free (cdeps);
  candl_program_free (cprogram);

  return ret;
}

void
pocc_expand_macro_stmt (s_past_node_t* root)
{

}


struct s_tuple
{
  s_past_node_t*	stmt;
  int			old_number;
  s_past_node_t*	subs;
};
typedef struct s_tuple s_tuple_t;

static
void traverse_backup_stmt_number(s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_cloogstmt))
    {
      PAST_DECLARE_TYPED(cloogstmt, pc, node);
      int i;
      s_tuple_t* backup_stmt = ((void**)args)[0];
      int* stmt_id = ((void**)args)[1];
      for (i = 0; backup_stmt[i].stmt != NULL; ++i)
	;
      backup_stmt[i].stmt = node;
      backup_stmt[i].old_number = pc->stmt_number;
      backup_stmt[i].subs = pc->substitutions;

      // Collect surrounding loops.
      s_past_node_t* parent = node->parent;
      s_past_node_t* newsubs = NULL;
      while (parent)
	{
	  if (past_node_is_a (parent, past_for))
	    {
	      PAST_DECLARE_TYPED(for, pf, parent);
	      s_symbol_t* symb;
	      if (pf->iterator->symbol->is_char_data)
		symb = symbol_add_from_char (NULL, pf->iterator->symbol->data);
	      else
		symb = symbol_add_from_data (NULL, pf->iterator->symbol->data);
	      s_past_node_t* tmpsub = past_node_variable_create (symb);
	      tmpsub->next = newsubs;
	      newsubs = tmpsub;
	    }
	  parent = parent->parent;
	}
      pc->substitutions = newsubs;
      pc->stmt_number = ++(*stmt_id);
    }
}

static void traverse_to_mark_unvisited (s_past_node_t* node, void *args)
{
	int *visited = (int*) malloc (sizeof (int));
	(*visited) = 0; 
	node->metainfo = (void*) visited; 	
}// static void traverse_to_mark_unvisited (s_past_node_t* node, void *args)

static
void traverse_restore_stmt_number (s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_cloogstmt))
    {
      PAST_DECLARE_TYPED(cloogstmt, pc, node);
      int i;
      s_tuple_t* backup_stmt = args;
      for (i = 0; backup_stmt[i].stmt != node; ++i)
	;

    if (node->metainfo != NULL)
    {
       pc->stmt_number = backup_stmt[i].old_number;
       past_deep_free (pc->substitutions);
       pc->substitutions = backup_stmt[i].subs;

	free ((int*) node->metainfo);
	node->metainfo = NULL;

    }// if (node->metainfo != NULL)
    }// if (past_node_is_a (node, past_cloogstmt))
}

static
void traverse_get_tile_parameters (s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_variable))
    {
      PAST_DECLARE_TYPED(variable, pv, node);
      if (! pv->symbol->is_attached_to_table &&
	  pv->symbol->is_char_data && pv->symbol->data)
	if (((char*)pv->symbol->data)[0] == 'T')
	    {
	      char** tile_params = args;
	      int i;
	      for (i = 0; tile_params[i] &&
		     strcmp (tile_params[i], (char*)pv->symbol->data); ++i)
		;
	      if (! tile_params[i])
		tile_params[i] = (char*)pv->symbol->data;
	    }
    }
}

/**
 * Main entry point. Takes the scoplib representation of the program,
 * and a PAST tree to parametrically tile.
 *
 * Automatically finds permutable loop nests, and parametrically tile them.
 * The first parallel loop is marked 'past_parfor'.
 *
 */
void
ptile_pocc_driver (scoplib_scop_p program,
		   s_past_node_t* root,
		   s_ptile_options_t* ptopts)
{  
  if (ptopts->verbose_level > 2)
    printf ("[PTile] Start parametric tiling\n");

  // Set parent, just in case.
  past_set_parent (root);
  // Backup statement names.
  int nb_statements = past_count_statements (root);
  s_tuple_t* backup_stmts = XMALLOC(s_tuple_t, nb_statements);
  int i;
  for (i = 0; i < nb_statements; ++i)
    backup_stmts[i].stmt = NULL;
  int stmt_id = 0;
  void* args[2];
  args[0] = backup_stmts;
  args[1] = &stmt_id;
  past_visitor (root, traverse_backup_stmt_number, (void*)args, NULL, NULL);
  // Extract tileable components.
  s_subscop_t* tileable_comps =
    pocc_create_tilable_nests (program, root, ptopts->data_is_char);

  // Iterate on all tileable components, parametrically tile them.
  for (i = 0; tileable_comps[i].root; ++i)
    {
      s_past_node_t** addr = past_node_get_addr (tileable_comps[i].root);
      s_past_node_t* next = tileable_comps[i].root->next;
      tileable_comps[i].root->next = NULL;
      ptopts->RSFME = ! tileable_comps[i].is_parallel;
      if (ptopts->verbose_level > 0)
	printf ("[PTile] Tiling component #%d\n", i + 1);
      int nofuse = 1;
      if (ptopts->RSFME)
      	nofuse = 0;
      s_past_node_t* newpast =
	parametricallytile (tileable_comps[i].scop,
			    tileable_comps[i].root, ptopts, nofuse);
      // Set parent.
      past_set_parent (newpast);
      // Insert parallel loop information.
      s_past_node_t* n = newpast;
      if (ptopts->RSFME && past_node_is_a (newpast, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf, newpast);
	  n = pf->body;
	}
      if (past_node_is_a (n, past_for))
	{
	  if (n == newpast)
	    newpast = past_for_to_parfor (n);
	  else
	    {
	      s_past_node_t* newtree;
	      s_past_node_t* oldnext = n->next;
	      newtree = past_for_to_parfor (n);
	      newtree->next = oldnext;
	    }
	}
      // Restore the tree.
      newpast->next = next;
      assert (addr);
      *addr = newpast;
      pocc_expand_macro_stmt (newpast);
    }

  // Gather newly created tile parameters.
  int nb_var = past_count_nodetype (root, past_variable);
  char** tile_params = XMALLOC(char*, nb_var + 1);
  for (i = 0; i < nb_var + 1; ++i)
    tile_params[i] = NULL;
  past_visitor (root, traverse_get_tile_parameters, tile_params, NULL, NULL);
  if (tile_params[0] &&  ptopts->verbose_level > 0)
    {
      printf ("[PTile] Tile size parameters created: ");
      for (i = 0; tile_params[i]; ++i)
	printf ("%s ", tile_params[i]);
      printf ("\n");
    }
  XFREE(tile_params);

  // Update the symbol table.
  past_rebuild_symbol_table (root);


// A cloogstmt may appear in more than one location - inside a full tile and in a fallback loop nest. 
// Therefore mark the nodes as 'unvisited' and then when traversing to restore the stmt numbers, process only 'unvisited' nodes and mark them as 'visited'.

 // Mark unvisited
 past_visitor (root, traverse_to_mark_unvisited, NULL, NULL, NULL);

  // Restore the statement number.
  past_visitor (root, traverse_restore_stmt_number, backup_stmts, NULL, NULL);
  XFREE(backup_stmts);

  // Set parent, just in case.
  past_set_parent (root);

  // Be clean.
  for (i = 0; tileable_comps[i].root; ++i)
    scoplib_scop_shallow_free (tileable_comps[i].scop);
  XFREE(tileable_comps);

  if (ptopts->verbose_level > 2)
    printf ("[PTile] all done.\n");
}

