/*
 * past_api.c: This file is part of the PAST project.
 *
 * PAST: the PoCC Abstract Syntax Tree
 *
 * Copyright (C) 2011 Louis-Noel Pouchet
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 *
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 *
 * Author:
 * Louis-Noel Pouchet <pouchet@cse.ohio-state.edu>
 *
 */
#if HAVE_CONFIG_H
# include <past/config.h>
#endif

#include <assert.h>
#include <past/common.h>
#include <past/past_api.h>




static
void traverse_count_nodetype (s_past_node_t* node, void* args)
{
  cs_past_node_type_t* type = (cs_past_node_type_t*) ((void**)args)[0];
  if (past_node_is_a (node, type))
    {
      int* count = ((void**)args)[1];
      (*count)++;
    }
}

/**
 * Count the number of nodes of type 'type' in a tree.
 *
 */
int
past_count_nodetype (s_past_node_t* root, cs_past_node_type_t* type)
{
  int num_nodes = 0;
  if (root)
    {
      void* args[2];
      args[0] = (void*)type;
      args[1] = &num_nodes;
      s_past_node_t* next = root->next;
      root->next = NULL;
      past_visitor (root, traverse_count_nodetype, (void*)args, NULL, NULL);
      root->next = next;
    }

  return num_nodes;
}

/**
 * Count the number of for loops in a tree.
 *
 */
int
past_count_for_loops (s_past_node_t* root)
{
  return past_count_nodetype (root, past_for);
}

/**
 * Count the number of statements (past_cloogstmt and past_statement)
 * in a tree.
 *
 */
int
past_count_statements (s_past_node_t* root)
{
  return past_count_nodetype (root, past_cloogstmt) +
    past_count_nodetype (root, past_statement);
}

static
int
past_loop_depth_sub (s_past_node_t* root, s_past_node_t* top)
{
  if (!root || root == top)
    return 0;

  s_past_node_t* node;
  int depth = 0;
  do
    {
      if (past_node_is_a (root, past_for))
	++depth;
      if (root == top)
	break;
      root = root->parent;
    }
  while (root);

  return depth;
}

static
void traverse_max_loop_depth (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_cloogstmt) ||
      past_node_is_a (node, past_statement))
    {
      s_past_node_t* top = ((void**)data)[0];
      int* maxdepth = ((void**)data)[1];
      int depth = past_loop_depth_sub (node, top);
      if (depth > *maxdepth)
	*maxdepth = depth;
    }
}

/**
 * Return the maximal loop depth.
 *
 */
int
past_max_loop_depth (s_past_node_t* root)
{
  int maxdepth = 0;
  void* args[2];
  args[0] = root;
  args[1] = &maxdepth;
  s_past_node_t* next = root->next;
  root->next = NULL;
  past_visitor (root, traverse_max_loop_depth, (void*) args, NULL, NULL);
  root->next = next;

  return maxdepth;
}


/**
 * Return the loop depth of a node.
 *
 */
int
past_loop_depth (s_past_node_t* root)
{
  return past_loop_depth_sub (root, NULL);
}

static
void traverse_find_statement_at_depth (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_cloogstmt) ||
      past_node_is_a (node, past_statement))
    {
      s_past_node_t* found = ((void**)data)[2];
      if (! found)
	{
	  s_past_node_t* top = ((void**)data)[0];
	  int* maxdepth = ((void**)data)[1];
	  int depth = past_loop_depth_sub (node, top);
	  if (depth == *maxdepth)
	    ((void**)data)[2] = node;
	}
    }
}

/**
 * Return a statement/cloogstmt surrounded by 'depth' loop.
 *
 */
s_past_node_t*
past_find_statement_at_depth (s_past_node_t* root, int depth)
{
  void* args[3];
  args[0] = root;
  args[1] = &depth;
  args[2] = NULL;
  s_past_node_t* next = root->next;
  root->next = NULL;
  past_visitor (root, traverse_find_statement_at_depth, (void*)args,
		NULL, NULL);
  root->next = next;

  return args[2];
}

static
void traverse_past_outer_loops (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
    {
      int i;
      s_past_node_t* top = ((void**)data)[0];
      s_past_node_t** outer = ((void**)data)[1];
      s_past_node_t* n;
      int is_outer = 1;
      if (node != top)
	for (n = node->parent; n; n = n->parent)
	  {
	    if (past_node_is_a (n, past_for))
	      {
		is_outer = 0;
		break;
	      }
	    if (n == top)
	      break;
	  }
      if (is_outer)
	{
	  for (i = 0; outer[i]; ++i)
	    ;
	  outer[i] = node;
	}
    }
}

static
void traverse_past_inner_loops (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
    if (past_is_inner_for_loop (node))
      {
	int i;
	void** args = (void**)data;
	s_past_node_t** inner = args[1];
	for (i = 0; inner[i]; ++i)
	  ;
	inner[i] = node;
      }
}

static
s_past_node_t**
past_innerouter_loops (s_past_node_t* root, past_fun_t prefixfun)
{
  int num_for = past_count_for_loops (root);
  s_past_node_t** loops = NULL;
  if (root)
    {
      loops = XMALLOC(s_past_node_t*, num_for + 1);
      int i;
      for (i = 0; i <= num_for; ++i)
	loops[i] = NULL;
      void* args[2];
      args[0] = root;
      args[1] = loops;
      s_past_node_t* next = root->next;
      root->next = NULL;
      past_visitor (root, prefixfun, (void*)args, NULL, NULL);
      root->next = next;
      for (i = 0; loops[i]; ++i)
	;
      if (i < num_for)
	{
	  s_past_node_t** nloops = XMALLOC(s_past_node_t*, i + 1);
	  for (i = 0; loops[i]; ++i)
	    nloops[i] = loops[i];
	  nloops[i] = NULL;
	  XFREE(loops);
	  loops = nloops;
	}
    }
  else
    {
      loops = XMALLOC(s_past_node_t*, 1);
      loops[0] = NULL;
    }

  return loops;
}

/**
 * Return a NULL-terminated array of the inner loops.
 *
 */
s_past_node_t**
past_inner_loops (s_past_node_t* root)
{
  return past_innerouter_loops (root, traverse_past_inner_loops);
}


/**
 * Return a NULL-terminated array of the outer loops.
 *
 */
s_past_node_t**
past_outer_loops (s_past_node_t* root)
{
  return past_innerouter_loops (root, traverse_past_outer_loops);
}


static
void traverse_past_contain_loop (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
      *((int*)data) = 1;
}

/**
 * Return true if the tree contains a loop.
 *
 */
int
past_contain_loop (s_past_node_t* root)
{
  int has_loop = 0;
  if (root)
    {
      s_past_node_t* next = root->next;
      root->next = NULL;
      past_visitor (root, traverse_past_contain_loop, &has_loop, NULL, NULL);
      root->next = next;
    }

  return has_loop;
}


/**
 * Return true if the node is an outer loop.
 *
 */
int
past_is_outer_for_loop (s_past_node_t* node)
{
  if (! past_node_is_a (node, past_for))
    return 0;
  for (node = node->parent; node; node = node->parent)
    if (past_node_is_a (node, past_for))
      return 0;

  return 1;
}


/**
 * Return true if the node is an inner loop.
 *
 */
int
past_is_inner_for_loop (s_past_node_t* node)
{
  if (! past_node_is_a (node, past_for))
    return 0;
  PAST_DECLARE_TYPED(for, pf, node);
  int num_for = 0;
  s_past_node_t* tmp;
  for (tmp = pf->body; tmp; tmp = tmp->next)
    num_for += past_count_for_loops (tmp);
  return num_for == 0;
}


static
void traverse_rebuild_symt (s_past_node_t* node, void* args)
{
  if (past_node_is_a (node, past_variable))
    {
      PAST_DECLARE_TYPED(variable, pv, node);
      s_symbol_table_t* table = args;
      if (pv->symbol->is_char_data)
	{
	  s_symbol_t* sym = symbol_add_from_char (table, pv->symbol->data);
	  if (sym != pv->symbol)
	    {
	      symbol_free (pv->symbol);
	      pv->symbol = sym;
	    }
	}
    }
  else if (past_node_is_a (node, past_for))
    {
      PAST_DECLARE_TYPED(for, pf, node);
      s_symbol_table_t* table = args;
      if (pf->iterator->symbol->is_char_data)
	{
	  s_symbol_t* sym =
	    symbol_add_from_char (table, pf->iterator->symbol->data);
	  if (sym != pf->iterator->symbol)
	    {
	      symbol_free (pf->iterator->symbol);
	      pf->iterator->symbol = sym;
	    }
	}
    }
}
/**
 * Rebuild the symbol table, and attach it to the root node.
 *
 */
void
past_rebuild_symbol_table (s_past_node_t* node)
{
  if (! past_node_is_a (node, past_root))
    return;
  PAST_DECLARE_TYPED(root, pr, node);
  if (pr->symboltable == NULL)
    pr->symboltable = symbol_table_malloc ();
  s_symbol_table_t* table = pr->symboltable;
  past_visitor (pr->body, traverse_rebuild_symt, table, NULL, NULL);
}

/**
 * Loop bound hoisting routines. See down for the entry point.
 *
 */
static
s_symbol_t* find_available_symbol (s_past_node_t* root,
				   s_past_node_t*** symbols)
{
  int i;
  char buffer[512];
  for (i = 0; symbols[i]; ++i)
    ;
  sprintf (buffer, "__loop_bound_%d", i);
  s_symbol_t* ret = symbol_add_from_char (NULL, buffer);

  return ret;
}
static
void traverse_collect_symbs (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_variable))
    {
      int i;
      s_symbol_t** syms = data;
      PAST_DECLARE_TYPED(variable, pv, node);
      for (i = 0; syms[i] && !symbol_equal (syms[i], pv->symbol); ++i)
	;
      syms[i] = pv->symbol;
    }
}
static
void process_hoist_conditional (s_past_node_t* root,
				s_past_node_t*** insert_at,
				s_past_node_t** loopexpr,
				s_past_node_t* node)
{
  // 1- Create a temporary variable.
  s_symbol_t* newvars = find_available_symbol (root, insert_at);
  s_past_variable_t* vard = past_variable_create (newvars);
  s_past_variable_t* type =
    past_variable_create (symbol_add_from_char (NULL, "int"));
  s_past_node_t* decl = past_node_vardecl_create (vard, type);
  s_past_node_t* newass = past_node_statement_create
    (past_node_binary_create (past_assign, decl, *loopexpr));
  s_symbol_t* s2 = symbol_add_from_char (NULL, newvars->data);
  s_past_node_t* var = past_node_variable_create (s2);
  s_past_node_t* lbexpr = *loopexpr;
  *loopexpr = var;
  // 2- Find the best place to put it.
  int nb_syms = past_count_nodetype (lbexpr, past_variable);
  s_symbol_t** syms = XMALLOC(s_symbol_t*, nb_syms + 1);
  int i;
  for (i = 0; i < nb_syms + 1; ++i)
    syms[i] = NULL;
  past_visitor (lbexpr, traverse_collect_symbs, syms, NULL, NULL);

  s_past_node_t* parent;
  s_past_node_t* before = node;
  for (parent = node->parent; parent; parent = parent->parent)
    {
      if (past_node_is_a (parent, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf2, parent);
	  s_symbol_t* ls = pf2->iterator->symbol;
	  for (i = 0; syms[i] && !symbol_equal (syms[i], ls); ++i)
	    ;
	  if (! syms[i])
	    before = parent;
	  else
	    break;
	}
    }

  // 3- Store the insert_before.
  for (i = 0; insert_at[i]; ++i)
    ;
  insert_at[i] = XMALLOC(s_past_node_t*, 2);
  insert_at[i][0] = before;
  insert_at[i][1] = newass;
}
static
void traverse_is_complex_bound (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_min) ||
      past_node_is_a (node, past_max) ||
      past_node_is_a (node, past_ceil) ||
      past_node_is_a (node, past_floor) ||
      past_node_is_a (node, past_funcall) ||
      past_node_is_a (node, past_ceild) ||
      past_node_is_a (node, past_floord) ||
      past_node_is_a (node, past_round) ||
      past_node_is_a (node, past_mul) ||
      past_node_is_a (node, past_add) ||
      past_node_is_a (node, past_sub) ||
      past_node_is_a (node, past_div) ||
      past_node_is_a (node, past_sqrt))
    {
      int* has_complex = data;
      *has_complex = 1;
    }
}
static
int is_complex_bound (s_past_node_t* node)
{
  int has_complex = 0;
  past_visitor (node, traverse_is_complex_bound, &has_complex, NULL, NULL);

  return has_complex;
}
static
void traverse_optimize_loop_bound (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
    {
      PAST_DECLARE_TYPED(for, pf, node);
      if (pf->init)
	{
	  PAST_DECLARE_TYPED(binary, pinit, pf->init);
	  if (! past_node_is_a (pinit->rhs, past_variable) &&
	      ! past_node_is_a (pinit->rhs, past_value))
	    {
	      // Hoist only bounds that contain min, max, floor, ceil, round.
	      if (is_complex_bound (pinit->rhs))
		{
		  void** args = (void**)data;
		  process_hoist_conditional (args[0], args[1], &(pinit->rhs),
					     node);
		}
	    }
	}
      if (pf->test)
	{
	  PAST_DECLARE_TYPED(binary, ptest, pf->test);
	  // Process only simple bounds like 'i <= blabla'.
	  if (past_node_is_a (ptest->lhs, past_variable))
	    {
	      PAST_DECLARE_TYPED(variable, pv, ptest->lhs);
	      if (! symbol_equal (pv->symbol, pf->iterator->symbol))
		return;

	      // Hoist only 'complex' loop bounds.
	      if (past_node_is_a (ptest->rhs, past_variable) ||
		  past_node_is_a (ptest->rhs, past_value))
		return;
	      if (is_complex_bound (ptest->rhs))
		{
		  void** args = (void**)data;
		  process_hoist_conditional (args[0], args[1], &(ptest->rhs),
					     node);
		}
	    }
	}
    }
}
/**
 * Optimize 'for' loop bounds, by hoisting them as much as possible.
 *
 */
void
past_optimize_loop_bounds (s_past_node_t* root)
{
  int num_for_loops = past_count_for_loops (root);
  s_past_node_t*** insert_at = XMALLOC(s_past_node_t**, 2 * num_for_loops + 1);
  int i;
  for (i = 0; i < 2 * num_for_loops + 1; ++i)
    insert_at[i] = NULL;
  void* args[2]; args[0] = root; args[1] = insert_at;

  past_visitor (root, traverse_optimize_loop_bound, (void*)args, NULL, NULL);

  past_set_parent (root);
  for (i = 0; insert_at[i]; ++i)
    {
      insert_at[i][1]->next = insert_at[i][0];
      past_replace_node (insert_at[i][0], insert_at[i][1]);
    }
}
