/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_distribute.c */

#include <string.h>
#include <assert.h>
#include "stddefs.h"
#include "x_distribute.h"
#include "x_for.h"
#include "x_loops.h"
#include "x_clauses.h"
#include "x_reduction.h"
#include "ast_types.h"
#include "ast_xform.h"
#include "ast_free.h"
#include "ast_copy.h"
#include "ast_print.h"
#include "ast_assorted.h"
#include "ast_arith.h"
#include "str.h"
#include "ompi.h"
#include "codetargs.h"


static char *dist_varname(schedvartype_e svt)
{
	switch (svt)
	{
		case LOOP_PREFIX:
			return "dist_";
		case LOOP_ITER:
			return "dist_iter_";
		case LOOP_NITERS:
			return "dist_niters_";
		case LOOP_FITER:
			return "dist_fiter_";
		case LOOP_LITER:
			return "dist_liter_";
		default:
			exit_error(1, "[dist_varname]: unknown variable type (%d)", svt);
	}
	return (char *) NULL;
}


/* When producing normalized loops, an optimized version can be 
 * gererated when only a single loop is involved. The optimization 
 * however generates a loop with complex first and third parts which may 
 * not be always desired.
 * Hence the flag to activate/deactivate the optimization at will.
 */
static int distribute_optimize_single_loop = true;


/**
 * @brief Produce the main, normalized loop body
 * 
 * A single loop becomes:
 *   for (iter = fiter; iter < liter; iter++) {
 *     <var> = lb +/- iter*step
 *     <body>
 *   }
 * optimized as:
 *   for (iter = fiter, var = ...; iter < liter; iter++, var +/-= step) {
 *     <body>
 *   }
 * If there is an ordered clause, we insert "_ort_for_curriter(iter_)"
 * just before the body, to let the runtime know our current iteration.
 *
 * For a collapsed loop nest, the non-optimized version is output
 * and multiple <var>s are recovered.
 */
aststmt dist_mainpart(fordata_t *loopinfo, aststmt origbody)
{
	int i;
	aststmt idx;                           /* needed for loop nest */
	symbol var = loopinfo->forps[0].var;   /* needed only in 1 loop */
	aststmt origbodycopy = ast_stmt_copy(origbody);
	ast_parentize(origbodycopy);

	if (loopinfo->collapsenum > 1)         /* Recover all indices */
	{
		idx = AssignStmt(IdentName("pp_"), OneExpr());
		for (i = loopinfo->collapsenum - 1; i >= 0; i--)
		{
			idx = BlockList(
			        idx,
			        AssignStmt(
			          Identifier(loopinfo->forps[i].var),
			          BinaryOperator(
			            loopinfo->forps[i].incrop, //BOP_add,
			            ast_expr_copy(loopinfo->forps[i].lb),
			            BinaryOperator(
			              BOP_mul,
			              ast_expr_copy(loopinfo->forps[i].step),
			              Parenthesis(
			                BinaryOperator(
			                  BOP_mod,
			                  Parenthesis(
			                    BinaryOperator(
			                      BOP_div,
			                      IdentName(loopinfo->varname(LOOP_ITER)),
			                      IdentName("pp_")
			                    )
			                  ),
			                  Identifier(loopinfo->itersym[i])
			                )
			              )
			            )
			          )
			        )
			      );

			if (i != 0)
				idx = BlockList(
				        idx,
				        Expression(Assignment(IdentName("pp_"), ASS_mul,
				                              Identifier(loopinfo->itersym[i]))
				                  )
				      );
		}
	}
	else  /* The following is needed only if !distribute_optimize_single_loop */
		idx = AssignStmt(
						Identifier(loopinfo->forps[0].var),
						BinaryOperator(
							loopinfo->forps[0].incrop, //BOP_add,
							ast_expr_copy(loopinfo->forps[i].lb),
							BinaryOperator(
								BOP_mul,
								ast_expr_copy(loopinfo->forps[0].step),
								IdentName(loopinfo->varname(LOOP_ITER)))
							)
						);

	if (loopinfo->collapsenum > 1 || !distribute_optimize_single_loop)
		return
			loop_normalize(Symbol(loopinfo->varname(LOOP_ITER)), 
			                      IdentName(loopinfo->varname(LOOP_FITER)), NULL, 
			                      IdentName(loopinfo->varname(LOOP_LITER)), NULL, 
			                      origbody, idx, NULL);
	else    /* Optimize original loop index recovery */
		return
			loop_normalize(Symbol(loopinfo->varname(LOOP_ITER)), 
			               IdentName(loopinfo->varname(LOOP_FITER)), 
			               Assignment(
			                 Identifier(var),
			                 ASS_eq,
			                 BinaryOperator(loopinfo->forps[0].incrop, 
			                   ast_expr_copy(loopinfo->forps[0].lb),
			                   BinaryOperator(BOP_mul,
			                     IdentName(loopinfo->varname(LOOP_FITER)),
			                     ast_expr_copy(loopinfo->forps[0].step)
			                   )
			               	)
			               ),
			               IdentName(loopinfo->varname(LOOP_LITER)), 
			               Assignment(
			                 Identifier(var), 
			                 bop2assop(loopinfo->forps[0].incrop),
			                 ast_expr_copy(loopinfo->forps[0].step)
			               ),
			               origbody, NULL, NULL);
}


/* A copy of the #distribute parallel for clauses that go to #parallel for */
static ompclause _dpf_pfclauses;

/* Dress up normal distribute mainpart with a parallel for */
static aststmt distparfor_mainpart(fordata_t *loopinfo, aststmt origbody)
{
	int i, bak = distribute_optimize_single_loop;
	aststmt t;
	astdecl private, firstpriv;
	ompclause clauses;

	distribute_optimize_single_loop = false; 
	t = dist_mainpart(loopinfo, origbody);
	distribute_optimize_single_loop = bak; 

	/* Privatize all temporary variables accordingly in two clauses */
	firstpriv = IdList(IdentifierDecl(Symbol(dist_varname(LOOP_FITER))),
	                   IdentifierDecl(Symbol(dist_varname(LOOP_LITER))));
	if (loopinfo->collapsenum == 1)
		private = IdentifierDecl(loopinfo->forps[0].var);
	else
	{
		private = IdentifierDecl(Symbol("pp_"));
		for (i = 0; i < loopinfo->collapsenum; i++)
		{
			private = IdList(private, IdentifierDecl(loopinfo->forps[i].var));
			firstpriv = IdList(firstpriv, IdentifierDecl(loopinfo->itersym[i]));
		}
	}
	
	/* Add to the other clauses */
	clauses = _dpf_pfclauses ? 
	            OmpClauseList(_dpf_pfclauses, VarlistClause(OCPRIVATE, private)) : 
	            VarlistClause(OCPRIVATE, private);
	clauses = OmpClauseList(clauses, VarlistClause(OCFIRSTPRIVATE, firstpriv));
	/* Create the #parallel for statament*/
	t = OmpStmt(OmpConstruct(DCPARFOR, OmpDirective(DCPARFOR, clauses), t));
	ast_parentize(t);
	return t;
}


void dist_schedule_static(fordata_t *loopinfo, foresult_t *code)
{
	code->decls = Block2(code->decls, for_iterdecls(loopinfo));
	code->mainpart = 
		If(
		  parse_expression_string("_ort_get_distribute_chunk(%s, &%s, &%s)",
		    loopinfo->varname(LOOP_NITERS), loopinfo->varname(LOOP_FITER),
		    loopinfo->varname(LOOP_LITER)
		  ),
		  Compound(loopinfo->mainpart_func(loopinfo, code->mainpart)),
		  NULL
		);
}


void dist_schedule_static_with_chunksize(fordata_t *loopinfo, foresult_t *code)
{
	aststmt s = for_iterdecls(loopinfo);
	char *chsize;
	
	/* May need a declaration for non-constant chunk sizes */
	if (loopinfo->schedchunk && loopinfo->schedchunk->type == CONSTVAL)
		chsize = loopinfo->schedchunk->u.str;
	else   /* non constant */
	{
		chsize = CHUNKSIZE;
		s = BlockList(         /* expr for chunk size */
		      s,
		      Declaration(
		        (xformingFor == CODETARGID(vulkan)) ? Usertype(Symbol("int")) : ITERCNT_SPECS,
		        InitDecl(
		          Declarator(NULL, IdentifierDecl(Symbol(chsize))),
		          ast_expr_copy(loopinfo->schedchunk)
		        )
		      )
		    );
	}

	/* Declare 2 more vars */
	s = BlockList(
	    s,
	    Declaration( /* declare: int dist_chid_, dist_TN_=omp_get_num_teams(); */
	      Declspec(SPEC_int),
	      DeclList(
	        Declarator(NULL, IdentifierDecl(Symbol("dist_chid_"))),
	        InitDecl(
	          Declarator(NULL, IdentifierDecl(Symbol("dist_TN_"))),
	          Call0_expr("omp_get_num_teams")
	        )
	      )
	    )
	   );

	code->decls = Block2(code->decls, s);

	/* The loop */
	s = loopinfo->mainpart_func(loopinfo, code->mainpart);
	code->mainpart = For(
	                   parse_blocklist_string("dist_chid_ = omp_get_team_num();"),
	                   NULL,
	                   parse_expression_string("dist_chid_ += dist_TN_"),
	                   Compound(
	                     BlockList(
	                       parse_blocklist_string(
	                         "%s = dist_chid_*(%s);"
	                         "if (%s >= %s) break;"
	                         "%s = %s + (%s);"
	                         "if (%s > %s) %s = %s;",
	                         loopinfo->varname(LOOP_FITER), chsize, 
	                         loopinfo->varname(LOOP_FITER), 
	                         loopinfo->varname(LOOP_NITERS), 
	                         loopinfo->varname(LOOP_LITER), 
	                         loopinfo->varname(LOOP_FITER), chsize,
	                         loopinfo->varname(LOOP_LITER), 
	                         loopinfo->varname(LOOP_NITERS),
	                         loopinfo->varname(LOOP_LITER), 
	                         loopinfo->varname(LOOP_NITERS)
	                       ),
	                       s
	                     )
	                   )
	                 );
}


/* Possible clauses:
 * private, firstprivate, lastprivate, collapse, dist_schedule.
 */

void _do_distribute(aststmt *t, aststmt (*fmainpart)(fordata_t*, aststmt))
{
	aststmt   s = (*t)->u.omp->body, parent = (*t)->parent, v, 
	          lasts = NULL, stmp, embdcls = NULL;
	forparts_t forps[MAXLOOPS];
	astexpr   expr, elems;
	symbol    itersym[MAXLOOPS];
	int       i = 0, collapsenum = 1,  nestnum;
	bool      haslast, hasboth, hasred;
	astexpr   dist_schedchunk = NULL;  /* the chunksize expression */
	char      iterstr[128];
	ompclause sch = xc_ompcon_get_clause((*t)->u.omp, OCDISTSCHEDULE),
	          col = xc_ompcon_get_clause((*t)->u.omp, OCCOLLAPSE);
	symtab    dvars;
	stentry   ste;
	fordata_t info = { 0 };
	foresult_t code = { NULL };

	v = ompdir_commented((*t)->u.omp->directive); /* Put directive in comments */
	
	/*
	 * Preparations
	 */

	if (sch)
	{
		assert(sch->subtype == OC_static);  /* sanity */
		dist_schedchunk = sch->u.expr;
	}

	if (col)
	{
		if ((collapsenum = col->subtype) >= MAXLOOPS)
			exit_error(1, "(%s, line %d) ompi error:\n\t"
				"cannot collapse more than %d FOR loops.\n",
				(*t)->u.omp->directive->file->name, (*t)->u.omp->directive->l,MAXLOOPS);
	}

	/* Collect all data clause vars - we need to check if any vars
	 * are both firstprivate and lastprivate; notice that if on a 
	 * composite #distribute parallel for construct, there can also 
	 * exist reduction() clauses; they should not be takes into account 
	 * as #distributed does not accept reduction clauses---they will be 
	 * handled by the #parallel for part.
	 */
	dvars = xc_validate_store_dataclause_vars((*t)->u.omp->directive);
	for (ste = dvars->top; ste; )
		if (ste->ival == OCREDUCTION)
		{
			stentry next = ste->stacknext;
			symtab_remove(dvars, ste->key, IDNAME);
			ste = next;
		}
		else
			ste = ste->stacknext;

	/* Analyze the loop(s) */
	nestnum = collapsenum;
	loopnest_analyze(s, nestnum, collapsenum, forps, *t, dvars, &embdcls);
	
	/* Prepare the loop info */
	info.haslast = haslast;
	info.ordplain = false;
	info.collapsenum = collapsenum;
	info.doacrossnum = 0;
	info.schedtype = OC_static;
	info.schedchunk = dist_schedchunk;
	info.forps = forps;
	info.itersym = itersym;
	info.mainpart_func = fmainpart;
	info.varname = dist_varname;

	/* Remember the last loop and var; form normalized iteration variables */
	s = forps[collapsenum-1].s;
	for (i = 0; i < nestnum; i++)
	{
		sprintf(iterstr, "%siters_%s_", 
		                 info.varname(LOOP_PREFIX), forps[i].var->name);
		itersym[i] = Symbol(iterstr); /* Remember the normalized iteration index */
	}

	/*
	 * Declarations and initializations
	 */
	
	/* declarations from the collected vars (not the clauses!) */
	code.decls = verbit("/* declarations (if any) */");
	stmp = xc_stored_vars_declarations(&haslast, &hasboth, &hasred);
	
	if (stmp)
		code.decls = Block2(code.decls, stmp);
	if (embdcls)
		code.decls = BlockList(code.decls, embdcls);

	/* initialization statements for firstprivate non-scalar vars */
	code.inits = verbit("/* initializations (if any) */");
	if ((stmp = xc_ompdir_fiparray_initializers((*t)->u.omp->directive)) != NULL)
		code.inits = Block2(code.inits, stmp);
	
	/* assignments for lastprivate vars */
	if (haslast)
		lasts = xc_ompdir_lastprivate_assignments((*t)->u.omp->directive);

	/*
	 * Prologue
	 */
	
	/* Append our new code: niters_ = ...; _ort_entering_for(...); */
	if (collapsenum == 1)
		elems = CastLong(loop_iters(&forps[0]));
	else
		for (elems = Identifier(itersym[0]), i = 1; i < collapsenum; i++)
			elems = BinaryOperator(BOP_mul, elems, Identifier(itersym[i]));
	expr = elems;

	stmp = Expression(     /* niters_ = ... */
	           Assignment(IdentName(dist_varname(LOOP_NITERS)), ASS_eq, expr)
	         );

	if (hasboth)   /* a var is both fip & lap; this needs a barrier here :-( */
		stmp = BlockList(stmp, BarrierCall(OMPIBAR_OMPIADDED));
	
	code.prologue = stmp;    /* Guaranteed to be non-NULL */

	/*
	 * Main part
	 */
	
	/* Just leave the original body and let the schedules utilize it */
	code.mainpart = s->body;
	
	/*
	 * Epilogue
	 */
	
	/* Add a label that is used when canceling */
	code.epilogue = Expression(NULL);

	/* Add lastprivate assignments */
	if (lasts)
	{
		if (collapsenum > 1)
		{
			aststmt idx;
		
			idx = Expression(Assignment(Identifier(forps[0].var), 
			                            bop2assop(forps[0].incrop), 
			                            ast_expr_copy(forps[0].step)));
			for (i = 1; i < collapsenum; i++)
				idx = BlockList(
				        idx,
				        Expression(Assignment(Identifier(forps[i].var), 
				                              bop2assop(forps[i].incrop), 
				                              ast_expr_copy(forps[i].step))
				        )
				      );
			lasts = BlockList(idx, lasts);
		}

		code.epilogue = 
		  BlockList(
		    code.epilogue,
		    If(
		      BinaryOperator(BOP_land,
		        IdentName(info.varname(LOOP_ITER)),
		        BinaryOperator(BOP_eqeq,
		          IdentName(info.varname(LOOP_ITER)),
		          IdentName(info.varname(LOOP_NITERS))
		        )
		      ),
		      lasts->type == STATEMENTLIST ?  Compound(lasts) : lasts,
		      NULL
		    )
		  );
	}

	/*
	 * Get loop specific code and combine the parts
	 */
	
	/* schedule-specific actions */
	if (dist_schedchunk)
		dist_schedule_static_with_chunksize(&info, &code);
	else
		dist_schedule_static(&info, &code);
	
	(*t)->u.omp->body = NULL;     /* Make it NULL so as to free it easily */
	ast_free(*t);                 /* Get rid of the OmpStmt */
	*t = Block6(v, code.decls, code.inits, code.prologue, code.mainpart, 
	            code.epilogue);
	*t = Compound(*t);
	(*t)->parent = parent;
}


void xform_distribute(aststmt *t)
{
	_do_distribute(t, dist_mainpart);
	dist_combined = false;
}


/* Given the clauses of a #distribute parallel for, return a copy of the 
 * clauses that must be given to the #parallel for construct
 */
static ompclause parfor_clauses(ompclause all)
{
	ompclause cl = NULL;

	if (all == NULL) { return NULL; }
	if (all->type == OCLIST)
	{
		cl = parfor_clauses(all->u.list.next);
		all = all->u.list.elem;
		assert(all != NULL && all->type != OCLIST);
		/* Otherwise we would do: 
		if (all->type == OCLIST)
		{
			ompclause cl2 = parfor_clauses(all);
			return cl == NULL ? cl2 : 
			       cl2 == NULL ? cl : OmpClauseList(cl, cl2);
		}
		*/
	}
	
	switch (all->type)
	{
		case OCIF:           /* Only #parallel clauses */
		case OCNUMTHREADS:
		case OCPROCBIND:
		case OCSHARED:
		case OCCOPYIN:
		case OCDEFAULT:
		case OCORDERED:      /* Only #for clauses */
		case OCSCHEDULE:
		case OCNOWAIT:
		case OCREDUCTION:    /* for only; not for parallel */
		/* OCLINEAR, OCORDER */
		case OCFIRSTPRIVATE:
		case OCLASTPRIVATE:
		case OCPRIVATE:      /* normally this should go EXCLUSIVELY to #for */
			cl = cl ? OmpClauseList(cl, ast_ompclause_copy(all)) : 
			          ast_ompclause_copy(all);
			break;
		default:
			break;
	}
	return cl;
}


/* The way we do this is transform it like it is a #distribute construct,
 * and then wrap the inner generated loop around a #parallel for construct.
 * This is not very elegant or fast if the #distribute schedule has a 
 * chunksize (a team of threads will be created for each chunk).
 */
void xform_distparfor(aststmt *t)
{
	/* Keep all the clauses that are for the #parallel for */
	_dpf_pfclauses = parfor_clauses((*t)->u.omp->directive->clauses);

	/* This transforms like a plain #distribute and wraps the inner loop within 
	 * a #parallel for. Notice that all the clauses are left intact so 
	 * _do_distribute must be careful not to be tricked by them.
	 * The ideal thing would be to remove the clauses that will be given to
	 * the #parallel for (_dpf_pfclauses), but I am too lazy to do it now.
	 */
	_do_distribute(t, distparfor_mainpart);
	dist_combined = false;
	
	/* Re-transform the whole thing */
	ast_stmt_xform(t);
}


void xform_distsimd(aststmt *t)
{
	fprintf(stderr, "#pragma omp distribute simd: not supported yet...\n");
}

void xform_distparforsimd(aststmt *t)
{
	fprintf(stderr, 
	        "#pragma omp distribute parallel for simd: not supported yet...\n");
}
