/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_loops.c -- utility functions related to OMP loops */

/*
 * 2021/05/8
 *   Created from x_for.c parts.
 */

#include <string.h>
#include <assert.h>
#include "stddefs.h"
#include "ompi.h"
#include "ast_free.h"
#include "ast_copy.h"
#include "x_loops.h"
#include "ast_arith.h"


/** 
 * Analyze a whole loop nest. Given a loop nest (s) like the following:
 *    for (i1=...)
 *      for (i2=...)
 *        ...
 *          for (iK=...)
 *            LoopBody
 * this analyzes all the K (=nestnum) loops and gathers their characteristics. 
 * There is a choice of collapsing a certain number of loops (collapsenum <= K).
 * To this end, we keep track of where the real body of the collapsed loop
 * starts (realbody) along with the iteration variable of the last collapsed 
 * loop (realvar). Because loop  normalization introduces new iteration 
 * variables, their names (symbols) will be stored to itersym[].
 * 
 * Finaly, if we are in OpenMP mode (omploop!=NULL), all iteration variables 
 * must be checked against data clauses. They are assumed to be gathered in 
 * the dvars table. If an index variable is not in the dvars table then it
 * will entered there and if it was declared in-place (e.g. as in 
 * for (int i = ...)), then its declaration is added to embdcls so as to be 
 * declared externally, where needed.
 *
 * @param s the outer loop statement
 * @param nestnum how many loops are in the loop nest 
 * @param collapsenum how many loops to collapse
 * @param forps the analyzed parts of each of the loops
 * @param ompcon the OpenMP statement from which the loop nest comes
 * @param dvars all variables from all data clauses of the OpenMP statement
 * @param embdcls declaration statements of all in-loop declared indices
 */
void loopnest_analyze(aststmt s, int nestnum, int collapsenum,
       forparts_t *forps, aststmt omploop, symtab dvars, aststmt *embdcls)
{
	int i;
	forparts_t *forp;
	stentry varentry;
	
	i = 0;
	do
	{
		loop_analyze(s, forp = &forps[i]);
		
		if (omploop)  /* OpenMP mode; need to check with the data clauses */
		{
			/* First check if the loop variable has been enlisted; if not,
			 * it is automatically considered private (v25) - so we make it
			 * appear as if there was a private(var) clause.
			 */
			if ((varentry = symtab_get(dvars, forp->var, IDNAME)) == NULL)
			{
				if (s->u.iteration.init->type == EXPRESSION)
					symtab_put(dvars, forp->var, IDNAME)->ival = OCPRIVATE;
				else
					*embdcls = (*embdcls) ?
						BlockList(
							*embdcls,
							Declaration( /* without the initializer */
								ast_spec_copy(s->u.iteration.init->u.declaration.spec),
								ast_decl_copy(s->u.iteration.init->u.declaration.decl->decl)
							)
						) :
						Declaration(
							ast_spec_copy(s->u.iteration.init->u.declaration.spec),
							ast_decl_copy(s->u.iteration.init->u.declaration.decl->decl)
						);
			}
			else
			{
				if (s->u.iteration.init->type != EXPRESSION)  /* a declaration */
					exit_error(1, "(%s, line %d) openmp error:\n\t"
						"iteration variable '%s' is declared within the FOR statement\n\t"
						"and thus it cannot appear in the directive's data clauses.\n",
						omploop->u.omp->directive->file->name, omploop->u.omp->directive->l, 
						forp->var->name);
				/* Remove the FIRSTPRIVATE attribute if any (there is no use for it) */
				/* Actually, v25 (p.64,l.23) specifies that the iteration variable
				 * can only appear in a PRIVATE or LASTPRIVATE clause, so we should
				 * emit at least a warning.
				 */
				if (varentry->ival==OCFIRSTPRIVATE ||varentry->ival==OCFIRSTLASTPRIVATE)
					warning("(%s, line %d) warning:\n\t"
						"iteration variable '%s' cannot appear in a FIRSTPRIVATE clause.\n\t"
						"  .. let's pretend it was in a PRIVATE clause.\n",
						omploop->u.omp->directive->file->name, omploop->u.omp->directive->l, 
						forp->var->name);
				if (varentry->ival == OCFIRSTPRIVATE)
					varentry->ival = OCPRIVATE;
				else
					if (varentry->ival == OCFIRSTLASTPRIVATE)
						varentry->ival = OCLASTPRIVATE;
			}
		}
		
		/* Fix null or constant-value step; step will be non-NULL from now on */
		if (forp->step == NULL || forp->step->type == CONSTVAL) /* ++/--/+= const */
			forp->step = (forp->step == NULL) ? OneExpr() : ast_expr_copy(forp->step);
		else /* step != NULL && general expression for step */
			forp->step = Parenthesis(ast_expr_copy(forp->step));   /* An expression */
		forp->lb = Parenthesis(ast_expr_copy(forp->lb));
		forp->ub = Parenthesis(
		             (forp->condop==BOP_leq || forp->condop==BOP_geq) ? /* fix ub */
		             BinaryOperator((forp->condop == BOP_leq) ? BOP_add : BOP_sub,
		                            Parenthesis(ast_expr_copy(forp->ub)),
		                            OneExpr()) :
		             ast_expr_copy(forp->ub)
		           );
		
		if (i < nestnum - 1)
		{
			s = s->body;
			if (s != NULL && s->type == COMPOUND && s->body != NULL &&
			    s->body->type == ITERATION && s->body->subtype == SFOR)
				s = s->body;  /* { For } -> For */
			if (s == NULL || s->type != ITERATION || s->subtype != SFOR)
				exit_error(1, "(%s, line %d) syntax error:\n\t"
					"%d perfectly nested FOR loops were expected.\n",
					omploop->u.omp->directive->file->name, omploop->u.omp->directive->l,
					nestnum);
		}
	}
	while ((++i) < nestnum);
}


/**
 * Get FOR loop nest indicies; their number is given by the number in the
 * ordered clause. This is a slimmed version of loop_analyze() applied 
 * to the whole loop nest.
 * 
 * @param s the body of the #for construct (i.e. the outer loop of the nest)
 * @param orderednum the number in the ordered() clause (i.e. the nest depth)
 * @return an array of symbols (freeable)
 */
symbol *loopnest_get_indices(aststmt s, int orderednum)
{
	aststmt init, tmp = s;
	int     lid = 0;
	symbol  *vars = smalloc(orderednum*sizeof(symbol));
	
	do 
	{
		assert(s != NULL && s->type == ITERATION && s->subtype == SFOR);
		init = s->u.iteration.init;
		if (init == NULL)
		{
		OMPFOR_ERROR:
			exit_error(1, "(%s, line %d) openmp error:\n\t"
				"non-conformant FOR statement\n", s->file->name, s->l);
		}

		/* Get var from the init part of the FOR */
		if (init->type == EXPRESSION)     /* assignment: var = lb */
		{
			astexpr e = init->u.expr;
			if (e == NULL || e->type != ASS || e->left->type != IDENT)
				goto OMPFOR_ERROR;
			vars[lid] = e->left->u.sym;
		}
		else
			if (init->type == DECLARATION)  /* declaration: type var = lb */
			{
				astdecl d = init->u.declaration.decl;

				if (d->type != DINIT)
					goto OMPFOR_ERROR;
				if (d->decl->type != DECLARATOR)
					goto OMPFOR_ERROR;
				if (d->decl->decl->type != DIDENT)
					goto OMPFOR_ERROR;
				vars[lid] = d->decl->decl->u.id;
			}
			else
				goto OMPFOR_ERROR;

		if (lid < orderednum - 1)
		{
			s = s->body;
			if (s != NULL && s->type == COMPOUND && s->body != NULL &&
			    s->body->type == ITERATION && s->body->subtype == SFOR)
				s = s->body;  /* { For } -> For */
			if (s == NULL || s->type != ITERATION || s->subtype != SFOR)
				exit_error(1, "(%s, line %d) openmp error:\n\t"
				      "an ordered(%d) clause requires %d perfectly nested FOR loops.\n",
				      tmp->u.omp->directive->file->name, tmp->u.omp->directive->l,
				      orderednum, orderednum);
		}
	}
	while ((++lid) < orderednum);
	return vars;
}


/**
 * Analyze a single for statement and determine conformance to 
 * OpenMP (canonicality) & other stuff that matter. It must be 
 * called repeatedly for for-loop nests.
 * NOTE: the three deduced expressions (lb, ub, step) are not freeable.
 * 
 * @param s the for statement
 * @param fp the structure to store the deduced loop parameters
 */
void loop_analyze(aststmt s, forparts_t *fp)
{
	aststmt init;
	astexpr cond, incr, tmp;
	int     rel;
	char    *xtramsg = NULL;

	assert(s != NULL && s->type == ITERATION && s->subtype == SFOR);
	init = s->u.iteration.init;
	cond = s->u.iteration.cond;
	incr = s->u.iteration.incr;
	if (init == NULL || cond == NULL || incr == NULL)
	{
	OMPFOR_ERROR:
		exit_error(1, "(%s, line %d) openmp error:\n\t"
		           "non-conformant FOR statement %s\n", s->file->name, s->l,
		           xtramsg ? xtramsg : "");
	}

	fp->s = s;    /* remember the statement */
	
	/* Get var and lb from the init part of the FOR
	 */
	xtramsg = "(first part of the for)";
	if (init->type == EXPRESSION)     /* assignment: var = lb */
	{
		astexpr e = init->u.expr;

		if (e == NULL || e->type != ASS || e->left->type != IDENT)
			goto OMPFOR_ERROR;
		fp->var = e->left->u.sym;
		fp->lb  = e->right;
	}
	else
		if (init->type == DECLARATION)  /* declaration: type var = lb */
		{
			astdecl d = init->u.declaration.decl;

			/* must check for integral type, too ... */

			if (d->type != DINIT)
				goto OMPFOR_ERROR;
			if (d->decl->type != DECLARATOR)
				goto OMPFOR_ERROR;
			if (d->decl->decl->type != DIDENT)
				goto OMPFOR_ERROR;
			fp->var = d->decl->decl->u.id;
			fp->lb  = d->u.expr;
		}
		else
			goto OMPFOR_ERROR;

	/* Get condition operator and ub from the cond part of the FOR
	 */
	xtramsg = "(condition operator)";
	if (cond->type != BOP) goto OMPFOR_ERROR;
	rel = cond->opid;
	if (rel != BOP_lt && rel != BOP_gt && rel != BOP_leq && rel != BOP_geq && 
	    rel != BOP_neq) /* OpenMP 5.0 */
		goto OMPFOR_ERROR;
	/* OpenMP 3.0 allows swapping the left & right sides */
	if (cond->left->type != IDENT || cond->left->u.sym != fp->var)
	{
		tmp = cond->left;
		cond->left = cond->right;
		cond->right = tmp;
		rel = (rel == BOP_lt) ? BOP_gt : (rel == BOP_leq) ? BOP_geq :
		      (rel == BOP_gt) ? BOP_lt : (rel == BOP_geq) ? BOP_leq : 
		      BOP_neq;   /* stays the same */
	}
	if (cond->left->type != IDENT || cond->left->u.sym != fp->var) /* sanity */
		goto OMPFOR_ERROR;
	fp->condop = rel;
	fp->ub = cond->right;

	/* Last part: get step and increment operator from the incr part of the FOR
	 */
	xtramsg = "(increment part)";
	if (incr->type != PREOP && incr->type != POSTOP && incr->type != ASS)
		goto OMPFOR_ERROR;
	if (incr->left->type != IDENT || incr->left->u.sym != fp->var) /* sanity */
		goto OMPFOR_ERROR;
	if (incr->type != ASS)
	{
		/* step is only needed for printing to a string; nothing else. Thus we can
		   leave it as is and "create" it when printing to the string */
		fp->step = NULL;  /* signal special case of pre/postop */
		fp->incrop = (incr->opid == UOP_inc) ? BOP_add : BOP_sub;
	}
	else   /* ASS */
	{
		if (incr->opid != ASS_eq && incr->opid != ASS_add && incr->opid != ASS_sub)
			goto OMPFOR_ERROR;
		if (incr->opid != ASS_eq)
		{
			fp->incrop = (incr->opid == ASS_add) ? BOP_add : BOP_sub;
			fp->step   = incr->right;
		}
		else
		{
			tmp = incr->right;
			fp->incrop = tmp->opid;
			if (tmp->type != BOP || (fp->incrop != BOP_add && fp->incrop != BOP_sub))
				goto OMPFOR_ERROR;
			if (fp->incrop == BOP_sub)      /* var = var - incr */
			{
				if (tmp->left->type != IDENT || tmp->left->u.sym != fp->var)
					goto OMPFOR_ERROR;
				fp->step = tmp->right;
			}
			else                         /* var = var + incr / incr + var */
			{
				if (tmp->left->type != IDENT || tmp->left->u.sym != fp->var)
				{
					/* var = incr + var */
					if (tmp->right->type != IDENT || tmp->right->u.sym != fp->var)
						goto OMPFOR_ERROR;
					fp->step = tmp->left;
				}
				else /* var = var + incr */
					fp->step = tmp->right;
			}
		}
		/* OpenMP 5.0: check that step is +/-1 if condop is != */
		if (fp->condop == BOP_neq)
		{
			int err;
			
			xtramsg = "('!=' condition requires unit increment)";
			rel = xar_calc_int_expr(fp->step, &err);
			if (err || (rel != 1 && rel != -1))
				goto OMPFOR_ERROR;
		}
	}
}


/**
 * Given the loop specifications (l, u, s), this produces a correct
 * (and possibly simplified) expression for the number of iterations,
 * irrespectively of the type of the index variable.
 * There are two cases, depending on the direction of the increment.
 * If stepdir is BOP_add, then the # iterations is given by:
 *   (u > l) ? ( st > 0 ? ( (u - l + s - 1) / s ) : 0 ) :
 *             ( st < 0 ? ( (l - u - s - 1) / (-s) ) : 0)
 * If stepdir is BOP_sub, then the # iterations is given by:
 *   (u > l) ? ( st < 0 ? ( (u - l - s - 1) / (-s) ) : 0 ) :
 *             ( st > 0 ? ( (l - u + s - 1) / s ) : 0)
 * 
 * In the usual case of st > 0, the above is simplified as:
 *   (u > l) ? ( (u - l + s - 1) / s ) : 0   (for BOP_add)
 *   (u > l) ? ( (l - u + s - 1) / s ) : 0   (for BOP_sub)
 * and if s==1,
 *   (u > l) ? (u - l) : 0       (for BOP_add)
 *   (u > l) ? 0 : (l - u)       (for BOP_sub)
 * 
 * @param l the lower bound of the loop
 * @param u the upper bound of the loop
 * @param s the step increment
 * @param stepdir the direction of the increment (BOP_add / BOP_sub)
 * @return an expression for the total number of iterations
 */
astexpr loop_iters(forparts_t *fp)
{
	astexpr l = fp->lb, u = fp->ub, s = fp->step;
	int stepdir = fp->incrop;
	
	/* Check if a case of constant-value step */
	if (s == NULL || s->type == CONSTVAL)
	{
		if (s == NULL || strcmp(s->u.str, "1") == 0) /* step = 1 */
			return
				ConditionalExpr(
					Parenthesis(BinaryOperator(BOP_geq,ast_expr_copy(u),ast_expr_copy(l))),
					(stepdir == BOP_add ? 
						Parenthesis(BinaryOperator(BOP_sub,ast_expr_copy(u),ast_expr_copy(l))) 
						: ZeroExpr()),
					(stepdir == BOP_sub ? 
						Parenthesis(BinaryOperator(BOP_sub,ast_expr_copy(l),ast_expr_copy(u)))
						: ZeroExpr())
				);
		else                                   /* step = positive constant */
			return
				ConditionalExpr(
					Parenthesis(BinaryOperator(BOP_geq,ast_expr_copy(u),ast_expr_copy(l))),
					(stepdir == BOP_add ? 
						Parenthesis(
							BinaryOperator(BOP_div,
								Parenthesis(
									BinaryOperator(BOP_add,
										BinaryOperator(BOP_sub, ast_expr_copy(u), ast_expr_copy(l)),
										BinaryOperator(BOP_sub, ast_expr_copy(s), OneExpr())
									)
								),
								ast_expr_copy(s)
							)
						) 
						: ZeroExpr()),
					(stepdir == BOP_sub ? 
						Parenthesis(
							BinaryOperator(BOP_div,
								Parenthesis(
									BinaryOperator(BOP_add,
										BinaryOperator(BOP_sub, ast_expr_copy(l), ast_expr_copy(u)),
										BinaryOperator(BOP_sub, ast_expr_copy(s), OneExpr())
									)
								),
								ast_expr_copy(s)
							)
						) 
						: ZeroExpr())
				);
	}
	
	/* General case */
	return
		ConditionalExpr(
			Parenthesis(BinaryOperator(BOP_geq, ast_expr_copy(u), ast_expr_copy(l))),
			Parenthesis(
				ConditionalExpr(
					BinaryOperator(stepdir == BOP_add ? BOP_gt : BOP_lt, 
						ast_expr_copy(s), 
						ZeroExpr()
					),
					Parenthesis(
						BinaryOperator(BOP_div,
							Parenthesis(
								BinaryOperator(BOP_sub,
									BinaryOperator(stepdir,
										BinaryOperator(BOP_sub, ast_expr_copy(u), ast_expr_copy(l)),
										ast_expr_copy(s)
									),
									OneExpr()
								)
							),
							stepdir == BOP_add ? 
								ast_expr_copy(s) : UnaryOperator(UOP_neg, ast_expr_copy(s))
						)
					),
					ZeroExpr()
				)
			),
			Parenthesis(
				ConditionalExpr(
					BinaryOperator(stepdir == BOP_add ? BOP_lt : BOP_gt, 
						ast_expr_copy(s), 
						ZeroExpr()
					),
					Parenthesis(
						BinaryOperator(BOP_div,
							Parenthesis(
								BinaryOperator(BOP_sub,
									BinaryOperator(stepdir == BOP_add ? BOP_sub : BOP_add,
										BinaryOperator(BOP_sub, ast_expr_copy(l), ast_expr_copy(u)),
										ast_expr_copy(s)
									),
									OneExpr()
								)
							),
							stepdir == BOP_add ? 
								UnaryOperator(UOP_neg, ast_expr_copy(s)) : ast_expr_copy(s)
						)
					),
					ZeroExpr()
				)
			)
		);
}


/**
 * @brief Normalize a loop, producing an equivallent one which iterates
 *        over a positive integer counter with unit steps and < condition. 
 *
 * The new loop will be as follows: 
 *    for (idx = fiter [, initexpr]; idx < liter; idx++ [, stepexpr]) {
 *      [prebody]
 *      loopbody
        [postbody]
 *    }
 * where _idx_ is the new loop index symbol, _fiter_ is an expression
 * giving the first iteration (defaults to 0 if _fiter_ is NULL), _liter_
 * is the last+1 iteration, and _loopbody_ is the original loop body.
 * The entites in square brackets are optional and provide handy hooks:
 * - _prebody_ and _postbody_ statement hooks are provided for wrapping
 *   _loopbody_; _prebody_ normaly contains a statement that recovers the
 *   value of the original loop index variable. If both are NULL, 
 *   _loopbody_ will not be enclosed in a compound.
 * - _initexpr_, if not NULL, will be comma-separated from `idx = fiter`
 * - _stepexpr_, if not NULL, will be comma-separated from `idx++`
 *      ...
 * @param idx       the new loop index identifier
 * @param fiter     initializer
 * @param initexpr  initializer hook
 * @param liter     upper bound
 * @param stepexpr  step part hook
 * @param loopbody  the original loop body
 * @param prebody   statement to precede loop body with
 * @param postbody  statement proceeding loop body
 * @returns the normalized loop AST
 */
aststmt loop_normalize(symbol idx, astexpr fiter, astexpr initexpr, 
                       astexpr liter, astexpr stepexpr, aststmt loopbody, 
                       aststmt prebody, aststmt postbody)
{
	if (initexpr)     /* initialization part hook */
		initexpr = CommaList(Assignment(Identifier(idx), ASS_eq, fiter), initexpr);
	else              /* Simpler */
		initexpr = Assignment(Identifier(idx), ASS_eq, fiter);

	if (stepexpr)     /* step part hook */
		stepexpr = CommaList(PostOperator(Identifier(idx), UOP_inc), stepexpr);
	else              /* Simpler */
		stepexpr = PostOperator(Identifier(idx), UOP_inc);
		
	if (prebody)
		loopbody = Block2(prebody, loopbody);
	if (postbody)
		loopbody = Block2(loopbody, postbody);
	if (prebody || postbody)
		loopbody = Compound(loopbody);

	return 
		For(Expression(initexpr),
		    BinaryOperator(BOP_lt, Identifier(idx), liter),
		    stepexpr,
		    loopbody
		   );
}
