/*
 * vectorize.c: This file is part of the PAST-vectorizer project.
 *
 * Pvectorizer: a library to increase SIMD capability of PAST trees.
 *
 * Copyright (C) 2011 Louis-Noel Pouchet
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 *
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 *
 * Author:
 * Louis-Noel Pouchet <pouchet@cse.ohio-state.edu>
 *
 */
#if HAVE_CONFIG_H
# include <pvectorizer/config.h>
#endif

#include <assert.h>

#include <pvectorizer/common.h>
#include <pvectorizer/vectorize.h>

#include <candl/candl.h>
#include <irconverter/past2scop.h>

struct s_process_data
{
  s_past_node_t*	node;
  int			id;
  int			is_processed;
  int			is_loop;
};
typedef struct s_process_data s_process_data_t;

static
void traverse_tree_index_for (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_for))
    {
      int i;
      int count = 0;
      s_process_data_t* pd = (s_process_data_t*) data;
      for (i = 0; pd[i].node != NULL; ++i)
	if (pd[i].is_loop)
	  ++count;
      pd[i].node = node;
      pd[i].id = count;
      pd[i].is_processed = 0;
      pd[i].is_loop = 1;
    }
  if (past_node_is_a (node, past_cloogstmt))
    {
      // Special case: statements not surrouded by any loop in the
      // tree that are surrounded by a fake loop in the scop representation.
      s_past_node_t* parent;
      for (parent = node->parent; parent && !past_node_is_a (parent, past_for);
	   parent = parent->parent)
	;
      int i;
      int count_l = 0;
      int count_s = 0;
      s_process_data_t* pd = (s_process_data_t*) data;
      for (i = 0; pd[i].node != NULL; ++i)
	if (pd[i].is_loop)
	  ++count_l;
	else
	  ++count_s;
      if (!parent)
	{
	  pd[i].node = node;
	  pd[i].id = -count_l;
	  pd[i].is_processed = 0;
	  pd[i].is_loop = 1;
	  ++i;
	}
      pd[i].node = node;
      pd[i].id = count_s;
      pd[i].is_processed = 0;
      pd[i].is_loop = 0;
    }
}

static
int is_tile_loop (s_past_node_t* n)
{
  assert (past_node_is_a (n, past_parfor));
  PAST_DECLARE_TYPED(parfor, pf, n);
  return pf->type == e_past_tile_loop;
}

static
int is_perfectly_nested (s_past_node_t* n)
{
  s_past_node_t* next = n->next;
  n->next = NULL;
  if (past_node_is_a (n, past_for))
    {
      PAST_DECLARE_TYPED(for, pf, n);
      s_past_node_t* body = pf->body;
      while (body)
	{
	  if (body->next)
	    {
	      if (past_count_for_loops (body))
		{
		  n->next = next;
		  return 0;
		}
	      else
		{
		  n->next = next;
		  return 1;
		  return 0;
		}
	    }
	  if (past_node_is_a (body, past_for))
	    {
	      PAST_DECLARE_TYPED(for, pf, body);
	      body = pf->body;
	    }
	  else if (past_node_is_a (body, past_block))
	    {
	      PAST_DECLARE_TYPED(for, pb, body);
	      body = pb->body;
	    }
	  else if (past_node_is_a (body, past_affineguard))
	    {
	      PAST_DECLARE_TYPED(affineguard, pa, body);
	      body = pa->then_clause;
	    }
	  else
	    {
	      n->next = next;
	      return 1;
	    }
	}
    }

  n->next = next;
  return 0;
}

static
void traverse_change_symbol (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_variable))
    {
      void** args = (void**)data;
      PAST_DECLARE_TYPED(variable, pv, node);
      if (pv->symbol == args[0])
	pv->symbol = args[1];
    }
}

static
void sink_loop_inner_most (s_past_node_t* node,
			   int keep_outer_par)
{
  // Find the inner-most loop.
  assert (past_node_is_a (node, past_parfor));
  PAST_DECLARE_TYPED(parfor, pf, node);
  s_past_node_t* f = node;
  s_past_node_t* cur = f;
  do
    {
      if (past_node_is_a (f, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf2, f);
	  cur = pf2->body;
	}
      else if (past_node_is_a (f, past_block))
	{
	  PAST_DECLARE_TYPED(block, pb, f);
	  cur = pb->body;
	}
      if (past_node_is_a (cur, past_for) ||
	  past_node_is_a (cur, past_block))
	f = cur;
      else
	break;
    }
  while (1);

  // Already in good position.
  if (f == node)
    return;
  assert (past_node_is_a (f, past_for));

  // If we are to keep the outer parallel loop in a 2d-loop nest whose
  // other loop is sequential, skip the sinking.
  if (keep_outer_par)
    if (past_count_for_loops (node) == 2 && !past_node_is_a (f, past_parfor))
      return;

  // Sink the loop 'node' after loop 'f'.
  s_past_node_t* old_parent = node->parent;

  assert (pf->body->next == NULL);
  s_past_node_t* tmp;
  for (tmp = pf->body; tmp && tmp->next; tmp = tmp->next)
    ;
  tmp->next = node->next;
  node->next = NULL;
  if (pf->body)
    past_replace_node (node, pf->body);
  PAST_DECLARE_TYPED(for, pfs, f);
  pf->body = pfs->body;
  tmp = pfs->body->next;
  pfs->body = node;
  node->next = NULL;
  past_set_parent (old_parent);
}


static
void vectorize_loop_nest (s_past_node_t* node, s_process_data_t* prog,
			  CandlProgram* cprogram, int keep_outer_par)
{
  int i, j, k, l, m, n;
  // 1- Collect all inner parallel loops (candidate vector loops).
  int num_for_loops = past_count_for_loops (node);
  int num_stmts = past_count_statements (node);
  // Oversize the data structure, to deal with fake iterators.
  s_past_node_t* inloops[num_for_loops + 2 * num_stmts + 1];
  s_past_node_t* stmts[num_stmts + 1];
  int count_l = 0;
  int count_s = 0;
  PAST_DECLARE_TYPED(for, pf, node);
  s_past_node_t* body = pf->body;

  inloops[count_l++] = node;
  while (body)
    {
      if (past_node_is_a (body, past_for))
	{
	  PAST_DECLARE_TYPED(for, pf, body);
	  inloops[count_l++] = body;
	  body = pf->body;
	}
      else if (past_node_is_a (body, past_block))
	{
	  PAST_DECLARE_TYPED(for, pb, body);
	  body = pb->body;
	}
      else
	{
	  while (body)
	    {
	      stmts[count_s++] = body;
	      body = body->next;
	    }
	}
    }
  inloops[count_l] = NULL;
  stmts[count_s] = NULL;

  int metric_1[count_l + 1];
  int metric_2[count_l + 1];
  for (i = 0; i <= count_l; ++i)
    metric_1[i] = metric_2[i] = 0;

  // Iterate on all loops.
  for (l = 0; inloops[l]; ++l)
    {
      for (j = 0; prog[j].node && prog[j].node != inloops[l]; ++j)
	;
      // Skip non-parallel loops.
      if (! past_node_is_a (prog[j].node, past_parfor))
	{
	  metric_1[l] = -1;
	  continue;
	}
      int loop_id = prog[j].id;
      s_process_data_t* loopdata = &(prog[j]);
      int loop_pos = -1;

      // Iterate on all statements.
      for (i = 0; stmts[i]; ++i)
	for (j = 0; prog[j].node; ++j)
	  if (prog[j].node == stmts[i] && prog[j].is_loop == 0)
	    {
	      CandlStatement* stmt = cprogram->statement[prog[j].id];
	      // Get the loop pos index.
	      if (loop_pos == -1)
		for (loop_pos = 0; loop_pos < stmt->depth &&
		       stmt->index[loop_pos] != loop_id; ++loop_pos)
		  ;

	      // 1- Metric (a): count the number of ref to the loop iterator
	      // that is not in the fastest varying dimension.
	      // 2- Metric (b): count the number of ref to the loop iterator
	      // that is in the fastest varying dimension.
	      CandlMatrix* mat;
	      for (n = 0, mat = stmt->read; n < 2; ++n, mat = stmt->written)
		for (k = 0; k < mat->NbRows; ++k)
		  {
		    int fvd = k + 1;
		    while (fvd < mat->NbRows &&
			   CANDL_get_si(mat->p[fvd][0]) == 0)
		      ++fvd;
		    if (fvd > 0 && CANDL_get_si(mat->p[fvd - 1][0]) == 0)
		      {
			for (m = k; m < fvd - 1; ++m)
			  if (CANDL_get_si(mat->p[m][loop_pos + 1]) != 0)
			    metric_1[l]++;
			if (CANDL_get_si(mat->p[fvd - 1][loop_pos + 1]) != 0)
			  metric_2[l]++;
		      }
		  }
	    }
      // Discard loops which are not referenced (ie, kind-a otl/tile loops)
      if (metric_2[l] == 0)
	metric_1[l] = -1;

      // Mark the loop as processed.
      loopdata->is_processed = 1;
    }

  // Move the first found loop with lowest metric inward.
  int min_metric = -1;
  for (l = 0; inloops[l]; ++l)
    if (metric_1[l] != -1)
      {
	if (min_metric == -1)
	  min_metric = metric_1[l];
	else
	  min_metric = min_metric < metric_1[l] ? min_metric : metric_1[l];
      }
  for (l = 0; metric_1[l]; ++l)
    if (metric_1[l] == min_metric)
      break;

  sink_loop_inner_most (inloops[l], keep_outer_par);
}


static
void traverse_is_triangular_used (s_past_node_t* node, void* data)
{
  if (past_node_is_a (node, past_variable))
    {
      // Eliminate statements.
      // Eliminate inner-most affine guards.
      s_past_node_t* parent;
      for (parent = node->parent; parent; parent = parent->parent)
	if (past_node_is_a (parent, past_cloogstmt) ||
	    past_node_is_a (parent, past_affineguard))
	  return;
	else if (past_node_is_a (parent, past_for))
	  break;

      PAST_DECLARE_TYPED(variable, pv, node);
      void** args = (void**)data;
      if (pv->symbol == args[0])
	{
	  int* is_used = args[1];
	  *is_used = 1;
	}
    }
}

static
int is_triangular_used (s_past_node_t* node)
{
  assert (past_node_is_a (node, past_for));

  // Traverse the body, ensure the loop iterator is not used in loop
  // bound or conditional.
  void* args[2];
  PAST_DECLARE_TYPED(for, pf, node);
  args[0] = pf->iterator->symbol;
  int is_used = 0;
  args[1] = &is_used;
  s_past_node_t* next = pf->body->next;
  pf->body->next = NULL;
  past_visitor (pf->body, traverse_is_triangular_used, (void*)args, NULL, NULL);
  pf->body->next = next;

  return is_used;
}



void
pvectorizer_vectorize (scoplib_scop_p program, s_past_node_t* root,
		       int keep_outer_par)
{
  // 1- Extract all point loop nests, perfectly nested
  int num_for_loops = past_count_for_loops (root);
  int num_stmts = past_count_statements (root);
  // Oversize the data structure, to deal with fake iterators.
  s_process_data_t prog[num_for_loops + 2 * num_stmts + 1];
  int i, j;
  for (i = 0; i < num_for_loops + 2 * num_stmts; ++i)
    prog[i].node = NULL;
  past_visitor (root, traverse_tree_index_for, (void*)prog, NULL, NULL);

  // 2- Get the associated scop.
  scoplib_scop_p scop =
    past2scop_control_only (root, program, 1);
  CandlProgram* cprogram = candl_program_convert_scop (scop, NULL);

  // 2- Iterate on all loops.
  for (i = 0; prog[i].node; ++i)
    {
      // Skip statements loops.
      if (! prog[i].is_loop)
	continue;

      // Skip already processed loops.
      if (prog[i].is_processed)
	continue;

      // Skip fake loops.
      if (prog[i].id < 0)
	continue;

      // Skip non-parallel loops.
      if (! past_node_is_a (prog[i].node, past_parfor))
	continue;

      // Skip tile loops.
      if (is_tile_loop (prog[i].node))
	continue;

      // Skip non-perfectly nested loop nests.
      if (! is_perfectly_nested (prog[i].node))
	continue;

      // Skip triangular loops.
      if (is_triangular_used (prog[i].node))
	continue;

      // Process the loop nest.
      vectorize_loop_nest (prog[i].node, prog, cprogram, keep_outer_par);
    }

  // Be clean.
  candl_program_free (cprogram);
  scoplib_scop_shallow_free (scop);
}
