/*
 * transform.c: this file is part of the Vectorizer project.
 *
 * Vectorizer, a vectorization module.
 *
 * Copyright (C) 2010 Louis-Noel Pouchet
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 3
 * of the License, or (at your option) any later version.
 *
 * The complete GNU Lesser General Public Licence Notice can be found
 * as the `COPYING.LESSER' file in the root directory.
 *
 * Author:
 * Louis-Noel Pouchet <pouchet@cse.ohio-state.edu>
 */
#if HAVE_CONFIG_H
# include <vectorizer/config.h>
#endif

#include <vectorizer/common.h>
#include <vectorizer/transform.h>
#include <candl/options.h>
#include <candl/violation.h>
#include <candl/candl.h>
#include <vectorizer/list.h>
#include <clasttools/clastext.h>
#include <clasttools/utils.h>

// We manually set the maximum loop depth to 128.
#define VECTORIZER_MAX_LOOP_DEPTH 128

struct vector_metrics
{
  struct clast_parfor* node;
  int cur_level;
  int* group;
  int to_sink;
  int tripcount;
  int reuse_distance;
  int alignment;
  // Composite cost.
  int cost;
};
typedef struct vector_metrics s_vector_metrics_t;

static
s_vector_metrics_t*
vector_metrics_alloc (struct clast_parfor* f, int loop_level, int* loop_nest)
{
  s_vector_metrics_t* v = XMALLOC(s_vector_metrics_t, 1);
  v->node = f;
  v->cur_level = loop_level;
  v->group = XMALLOC(int, loop_level);
  int i;
  for (i = 0; i < loop_level; ++i)
    v->group[i] = loop_nest[i];
  v->to_sink = 0;
  v->tripcount;
  v->reuse_distance;
  v->alignment;
  v->cost;

  return v;
}

static
void
print_beta (int* beta, int size)
{
  int i;
  printf ("beta: ");
  for (i = 0; i < size; ++i)
    printf ("%d ", beta[i]);
  printf ("\n");
}


static
void
traverse_clast_parfor_collect (struct clast_stmt* s,
			       s_vector_metrics_t** metrics,
			       int loop_level,
			       int* loop_nest)
{
  int i;

  // Traverse the clast.
  for ( ; s; s = s->next)
    {
      if (CLAST_STMT_IS_A(s, stmt_parfor))
	{
	  for (i = 0; metrics[i]; ++i)
	    ;
	  loop_nest[loop_level] += 1;
	  metrics[i] = vector_metrics_alloc ((struct clast_parfor*) s,
					     loop_level, loop_nest);
	  traverse_clast_parfor_collect (((struct clast_parfor*)s)->body,
					 metrics,
					 loop_level + 1,
					 loop_nest);
	  for (i = loop_level + 1; i < VECTORIZER_MAX_LOOP_DEPTH; ++i)
	    loop_nest[i] = 0;
	}
      else if (CLAST_STMT_IS_A(s, stmt_for) ||
	       CLAST_STMT_IS_A(s, stmt_vectorfor))
	{
	  loop_nest[loop_level] += 1;
	  traverse_clast_parfor_collect (((struct clast_vectorfor*)s)->body,
					 metrics, loop_level + 1, loop_nest);
	  for (i = loop_level + 1; i < VECTORIZER_MAX_LOOP_DEPTH; ++i)
	    loop_nest[i] = 0;
	}
      else if (CLAST_STMT_IS_A(s, stmt_guard))
	traverse_clast_parfor_collect (((struct clast_guard*)s)->then,
				       metrics, loop_level, loop_nest);
      else if (CLAST_STMT_IS_A(s, stmt_block))
	traverse_clast_parfor_collect (((struct clast_block*)s)->body,
				       metrics, loop_level, loop_nest);
    }
}

static
int
traverse_clast_expr_has_iter (struct clast_expr* e, char* iter)
{
  if (!e)
    return 0;
  switch (e->type)
    {
    case clast_expr_name:
      {
	if (! strcmp (((struct clast_name*)e)->name, iter))
	  return 1;
	return 0;
      }
    case clast_expr_term:
      {
	return traverse_clast_expr_has_iter
	  (((struct clast_term*)e)->var, iter);
      }
    case clast_expr_red:
      {
	int i;
	int val = 0;
	struct clast_reduction* r = (struct clast_reduction*) e;
	for (i = 0; i < r->n; ++i)
	  val += traverse_clast_expr_has_iter (r->elts[i], iter);
	return val;
      }
    case clast_expr_bin:
      return traverse_clast_expr_has_iter
	(((struct clast_binary*)e)->LHS, iter);
    default:
      assert (0);
    }
}

static
void
traverse_clast_has_iter_in_bounds (struct clast_stmt* s,
				   char* iter,
				   int *reads_iter)
{
  // Traverse the clast.
  for ( ; s; s = s->next)
    {
      if (CLAST_STMT_IS_A(s, stmt_parfor))
	{
	  struct clast_parfor* f = (struct clast_parfor*) s;
	  if (traverse_clast_expr_has_iter (f->LB, iter) ||
	      traverse_clast_expr_has_iter (f->UB, iter))
	    {
	      *reads_iter = 1;
	      return;
	    }
	  else
	    traverse_clast_has_iter_in_bounds (f->body, iter, reads_iter);
	}
      else if (CLAST_STMT_IS_A(s, stmt_for))
	{
	  struct clast_for* f = (struct clast_for*) s;
	  if (traverse_clast_expr_has_iter (f->LB, iter) ||
	      traverse_clast_expr_has_iter (f->UB, iter))
	    {
	      *reads_iter = 1;
	      return;
	    }
	  else
	    traverse_clast_has_iter_in_bounds (f->body, iter, reads_iter);
	}
      else if (CLAST_STMT_IS_A(s, stmt_guard))
	traverse_clast_has_iter_in_bounds (((struct clast_guard*)s)->then,
					   iter, reads_iter);
      else if (CLAST_STMT_IS_A(s, stmt_block))
	traverse_clast_has_iter_in_bounds (((struct clast_block*)s)->body,
					   iter, reads_iter);
    }
}

static
int
is_sinkable (struct clast_parfor* f)
{
  // Check if the loop iterator of the current loop is in any enclosed
  // loop domain.
  int reads_iter = 0;
  traverse_clast_has_iter_in_bounds (f->body, (char*) (f->iterator),
				     &reads_iter);

  // If loop is parallel, then it is always sinkable to inner level.
  /// FIXME: Double-check it is true!

  return ! reads_iter;
}

/**
 * Return true if the loop is the outer-most loop, or the 2nd
 * outer-most loop and not the inner-most loop.
 *
 */
static
int
is_outer_par (s_vector_metrics_t* m)
{
  if (m->cur_level == 0)
    return 1;
  /// FIXME: this is here to deal with 1st seq loop then 1 par loop,
  /// and to preserve the next-outer par. Must be optionized.
  else if (m->cur_level == 1 && 0)
    {
      struct clast_stmt* s = m->node->body;
      for ( ; s; s = s->next)
	if (CLAST_STMT_IS_A(s, stmt_for) ||
	    CLAST_STMT_IS_A(s, stmt_parfor) ||
	    CLAST_STMT_IS_A(s, stmt_vectorfor))
	  return 1;
      return 0;
    }
  return 0;
}

static
void
compute_reuse_distance (s_vector_metrics_t** group)
{
  int i;

  for (i = 0; group[i]; ++i)
    group[i]->reuse_distance = 0;
}

static
void
compute_alignment (s_vector_metrics_t** group)
{
  int i;

  for (i = 0; group[i]; ++i)
    group[i]->alignment = 0;
}

static
void
compute_vector_cost (s_vector_metrics_t** group)
{
  int i;

   /* This should be machine-specific. */
  for (i = 0; group[i]; ++i)
    group[i]->cost = group[i]->reuse_distance * (1 + group[i]->alignment);
}


/**
 * Sink a parallel loop to the inner-most level.
 *
 * @FIXME implement a better vectorizer!
 *
 */
s_vectorizer_metrics_t*
vectorizer_transform_sinkparallel (struct clast_stmt* root,
				   scoplib_scop_p scop,
				   CandlDependence* deps,
				   s_vectorizer_options_t* options)
{
  int i, j;

  // 1- Collect all parfor loops, store them in an array.
  int size = clasttools_utils_number_parfor (root);
  s_vector_metrics_t** parfor_nodes = XMALLOC(s_vector_metrics_t*, size + 1);
  for (i = 0; i < size + 1; ++i)
    parfor_nodes[i] = NULL;
  int beta_string[VECTORIZER_MAX_LOOP_DEPTH];
  for (i = 0; i < VECTORIZER_MAX_LOOP_DEPTH; ++i)
    beta_string[i] = 0;
  traverse_clast_parfor_collect (root, parfor_nodes, 0, beta_string);

  // 2- Remove all non-sinkable loops (ie, tile loops).
  int count = 0;
  s_vector_metrics_t** sink_nodes = XMALLOC(s_vector_metrics_t*, size + 1);
  for (i = 0; i < size; ++i)
    if (is_sinkable (parfor_nodes[i]->node))
      sink_nodes[count++] = parfor_nodes[i];
    else
      XFREE(parfor_nodes[i]);
  sink_nodes[count] = NULL;
  XFREE(parfor_nodes);
  printf ("[Vectorizer] CLAST has %d parfor nodes, %d sinkable\n", size, count);
  size = count;

  // 3- Remove all first outer parallel loops, if coarse-grain
  // parallelization is to be preserved.
  s_vector_metrics_t** vect_nodes = sink_nodes;
  if (options->keep_outer_parallel)
    {
      count = 0;
      vect_nodes = XMALLOC(s_vector_metrics_t*, size + 1);
      for (i = 0; i < size; ++i)
	if (! is_outer_par (sink_nodes[i]))
	  vect_nodes[count++] = sink_nodes[i];
	else
	  XFREE(sink_nodes[i]);
      vect_nodes[count] = NULL;
      XFREE(sink_nodes);
      printf ("[Vectorizer] CLAST has %d sinkable nodes, %d non outer\n",
	      size, count);
      size = count;
    }

  // 4- For a given loop nest, select the best loop to sink.
  s_vector_metrics_t** group = XMALLOC(s_vector_metrics_t*, size + 1);

  // Iterate on all loop nest at depth 'level'.
  int mlev = 0;
  for (i = 0; i < size; ++i)
    mlev = vect_nodes[i]->cur_level < mlev ? mlev : vect_nodes[i]->cur_level;

  int level = 0;
  if (options->keep_outer_parallel)
    level = 1;
  for (; level <= mlev; ++level)
    {
      for (i = 0; i < size; )
	{
	  count = 0;
	  int offset = 0;
	  // Seek the first loop nest with enough loop depth.
	  for (; i < size && vect_nodes[i]->cur_level < level; ++i, ++offset)
	    ;
	  if (i == size)
	    break;
	  int gid = vect_nodes[i]->group[level];
	  group[count++] = vect_nodes[i];
	  for (j = i + 1; j < size; ++j)
	    {
	      int k;
	      for (k = 0; k <= level; ++k)
		if (vect_nodes[j]->group[k] != gid)
		  break;
	      if (k == level + 1)
		group[count++] = vect_nodes[j];
	    }
	  group[count] = NULL;
	  compute_reuse_distance (group);
	  compute_alignment (group);
	  compute_vector_cost (group);
	  int min = -1;
	  for (j = 0; j < count; ++j)
	    if (min == -1)
	      min = group[j]->cost;
	    else
	      min = min < group[j]->cost ? min : group[j]->cost;
	  // Select the inner-most loop with minimal cost for sinking.
	  for (j = count - 1; j >= 0 && group[j]->cost != min; --j)
	    ;
	  group[j]->to_sink = 1;
	  // If a loop in the same loop nest was already selected for
	  // sinking, reset its sink bit.
	  /// FIXME: It's ad-hoc, must do better.
	  int k;
	  if (! options->sink_all_candidates)
	    {
	      for (k = 0; k < size; ++k)
		if (vect_nodes[k]->to_sink && vect_nodes[k] != group[j])
		  {
		    int l;
		    for (l = 0; l < vect_nodes[k]->cur_level
			   && l < group[j]->cur_level; ++l)
		      if (vect_nodes[k]->group[l] != group[j]->group[l])
			break;
		    if (l == vect_nodes[k]->cur_level ||
			l == group[j]->cur_level)
		      vect_nodes[k]->to_sink = 0;
		  }
	    }
	  i += count + offset;
	}
    }
  s_vectorizer_metrics_t* output_metrics = vectorizer_metrics_malloc ();

  // 5- Sink the loops to sink.
  for (i = 0; i < size; ++i)
    if (vect_nodes[i]->to_sink)
      {
	printf ("[Vectorizer] Sinking vectorizable loop #%d (iterator:%s)\n",
		i, vect_nodes[i]->node->iterator);

	struct clast_parfor* f = vect_nodes[i]->node;

	// Retrieve the pointer node.
	struct clast_stmt** parent =
	  clasttools_utils_get_pointer (root, (struct clast_stmt*) f);
	// Retrieve all inner-most loops.
	struct clast_stmt** inner_loops =
	  clasttools_utils_get_all_inner_loops (f->body);

	if (inner_loops[0] != NULL)
	  {
	    // Parent node now points to the sink loop body.
	    *parent = f->body;
	    struct clast_stmt* body_last = f->body;
	    while (body_last->next != NULL)
	      body_last = body_last->next;
	    body_last->next = ((struct clast_stmt*) f)->next;

	    // Inner loops now point to a new loop, mirroring the sink loop.
	    for (j = 0; inner_loops[j]; ++j)
	      {
		struct clast_parfor* newf =
		  new_clast_parfor (strdup (f->iterator),
				    clasttools_utils_dup_expr (f->LB),
				    clasttools_utils_dup_expr (f->UB),
				    f->stride);
		newf->body = ((struct clast_for*)(inner_loops[j]))->body;
		((struct clast_for*)(inner_loops[j]))->body =
		  (struct clast_stmt*) newf;
	      }
	  }
	XFREE(inner_loops);
      }

  // Be clean.
  XFREE(vect_nodes);

  return output_metrics;
}
