/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_loops.c -- utility functions related to OMP loops */

/*
 * 2021/05/8
 *   Created from x_for.c parts.
 */

#include <string.h>
#include <assert.h>
#include "stddefs.h"
#include "ompi.h"
#include "ast_free.h"
#include "ast_copy.h"
#include "x_loops.h"
#include "ast_arith.h"


/**
 * Get FOR loop nest indicies; their number is given by the number in the
 * ordered clause. This is a slimmed version of analyze_omp_for() applied 
 * to the whole loop nest.
 * 
 * @param s the body of the #for construct
 * @param orderednum the number in the ordered() clause (i.e. the nest depth)
 * @return an array of symbols (freeable)
 */
symbol *ompfor_get_indices(aststmt s, int orderednum)
{
	aststmt init, tmp = s;
	int     lid = 0;
	symbol  *vars = smalloc(orderednum*sizeof(symbol));
	
	do 
	{
		assert(s != NULL && s->type == ITERATION && s->subtype == SFOR);
		init = s->u.iteration.init;
		if (init == NULL)
		{
		OMPFOR_ERROR:
			exit_error(1, "(%s, line %d) openmp error:\n\t"
				"non-conformant FOR statement\n", s->file->name, s->l);
		}

		/* Get var from the init part of the FOR */
		if (init->type == EXPRESSION)     /* assignment: var = lb */
		{
			astexpr e = init->u.expr;
			if (e == NULL || e->type != ASS || e->left->type != IDENT)
				goto OMPFOR_ERROR;
			vars[lid] = e->left->u.sym;
		}
		else
			if (init->type == DECLARATION)  /* declaration: type var = lb */
			{
				astdecl d = init->u.declaration.decl;

				if (d->type != DINIT)
					goto OMPFOR_ERROR;
				if (d->decl->type != DECLARATOR)
					goto OMPFOR_ERROR;
				if (d->decl->decl->type != DIDENT)
					goto OMPFOR_ERROR;
				vars[lid] = d->decl->decl->u.id;
			}
			else
				goto OMPFOR_ERROR;

		if (lid < orderednum - 1)
		{
			s = s->body;
			if (s != NULL && s->type == COMPOUND && s->body != NULL &&
			    s->body->type == ITERATION && s->body->subtype == SFOR)
				s = s->body;  /* { For } -> For */
			if (s == NULL || s->type != ITERATION || s->subtype != SFOR)
				exit_error(1, "(%s, line %d) openmp error:\n\t"
				      "an ordered(%d) clause requires %d perfectly nested FOR loops.\n",
				      tmp->u.omp->directive->file->name, tmp->u.omp->directive->l,
				      orderednum, orderednum);
		}
	}
	while ((++lid) < orderednum);
	return vars;
}


/**
 * Analyze a single for statement and determine conformance to 
 * OpenMP (canonicality) & other stuff that matter. It must be 
 * called repeatedly for for loop nests.
 * NOTE: the three deduced expressions (lb, b, step) are not freeable.
 * 
 * @param s the for statement
 * @param var    the deduced iteration index symbol
 * @param lb     the deduced lower bound (first iteration)
 * @param n      the deduced upper bound (last iteration)
 * @param step   the deduced step
 * @param condop the deduced condition operator in the second for part
 * @param incrop the step increment operator (BOP_add or BOP_sub)
 */
void analyze_omp_for(aststmt s,
                     symbol *var, astexpr *lb, astexpr *b, astexpr *step,
                     int *condop, int *incrop)
{
	aststmt init;
	astexpr cond, incr, tmp;
	int     rel;
	char    *xtramsg = NULL;

	assert(s != NULL && s->type == ITERATION && s->subtype == SFOR);
	init = s->u.iteration.init;
	cond = s->u.iteration.cond;
	incr = s->u.iteration.incr;
	if (init == NULL || cond == NULL || incr == NULL)
	{
	OMPFOR_ERROR:
		exit_error(1, "(%s, line %d) openmp error:\n\t"
		           "non-conformant FOR statement %s\n", s->file->name, s->l,
		           xtramsg ? xtramsg : "");
	}

	/* Get var and lb from the init part of the FOR
	 */
	xtramsg = "(first part of the for)";
	if (init->type == EXPRESSION)     /* assignment: var = lb */
	{
		astexpr e = init->u.expr;

		if (e == NULL || e->type != ASS || e->left->type != IDENT)
			goto OMPFOR_ERROR;
		*var = e->left->u.sym;
		*lb  = e->right;
	}
	else
		if (init->type == DECLARATION)  /* declaration: type var = lb */
		{
			astdecl d = init->u.declaration.decl;

			/* must check for integral type, too ... */

			if (d->type != DINIT)
				goto OMPFOR_ERROR;
			if (d->decl->type != DECLARATOR)
				goto OMPFOR_ERROR;
			if (d->decl->decl->type != DIDENT)
				goto OMPFOR_ERROR;
			*var = d->decl->decl->u.id;
			*lb  = d->u.expr;
		}
		else
			goto OMPFOR_ERROR;

	/* Get condition operator and ub from the cond part of the FOR
	 */
	xtramsg = "(condition operator)";
	if (cond->type != BOP) goto OMPFOR_ERROR;
	rel = cond->opid;
	if (rel != BOP_lt && rel != BOP_gt && rel != BOP_leq && rel != BOP_geq && 
	    rel != BOP_neq) /* OpenMP 5.0 */
		goto OMPFOR_ERROR;
	/* OpenMP 3.0 allows swapping the left & right sides */
	if (cond->left->type != IDENT || cond->left->u.sym != *var)
	{
		tmp = cond->left;
		cond->left = cond->right;
		cond->right = tmp;
		rel = (rel == BOP_lt) ? BOP_gt : (rel == BOP_leq) ? BOP_geq :
		      (rel == BOP_gt) ? BOP_lt : (rel == BOP_geq) ? BOP_leq : 
		      BOP_neq;   /* stays the same */
	}
	if (cond->left->type != IDENT || cond->left->u.sym != *var) /* sanity check */
		goto OMPFOR_ERROR;
	*condop = rel;
	*b = cond->right;

	/* Last part: get step and increment operator from the incr part of the FOR
	 */
	xtramsg = "(increment part)";
	if (incr->type != PREOP && incr->type != POSTOP && incr->type != ASS)
		goto OMPFOR_ERROR;
	if (incr->left->type != IDENT || incr->left->u.sym != *var) /* sanity check */
		goto OMPFOR_ERROR;
	if (incr->type != ASS)
	{
		/* step is only needed for printing to a string; nothing else. Thus we can
		   leave it as is and "create" it when printing to the string */
		*step = NULL;  /* signal special case of pre/postop */
		*incrop = (incr->opid == UOP_inc) ? BOP_add : BOP_sub;
	}
	else   /* ASS */
	{
		if (incr->opid != ASS_eq && incr->opid != ASS_add && incr->opid != ASS_sub)
			goto OMPFOR_ERROR;
		if (incr->opid != ASS_eq)
		{
			*incrop = (incr->opid == ASS_add) ? BOP_add : BOP_sub;
			*step   = incr->right;
		}
		else
		{
			tmp = incr->right;
			*incrop = tmp->opid;
			if (tmp->type != BOP || (*incrop != BOP_add && *incrop != BOP_sub))
				goto OMPFOR_ERROR;
			if (*incrop == BOP_sub)      /* var = var - incr */
			{
				if (tmp->left->type != IDENT || tmp->left->u.sym != *var)
					goto OMPFOR_ERROR;
				*step = tmp->right;
			}
			else                         /* var = var + incr / incr + var */
			{
				if (tmp->left->type != IDENT || tmp->left->u.sym != *var)
				{
					/* var = incr + var */
					if (tmp->right->type != IDENT || tmp->right->u.sym != *var)
						goto OMPFOR_ERROR;
					*step = tmp->left;
				}
				else /* var = var + incr */
					*step = tmp->right;
			}
		}
		/* OpenMP 5.0: check that step is +/-1 if condop is != */
		if (*condop == BOP_neq)
		{
			int err;
			
			xtramsg = "('!=' condition requires unit increment)";
			rel = xar_calc_int_expr(*step, &err);
			if (err || (rel != 1 && rel != -1))
				goto OMPFOR_ERROR;
		}
	}
}


/**
 * Given the loop specifications (l, u, s), this produces a correct
 * (and possibly simplified) expression for the number of iterations,
 * irrespectively of the type of the index variable.
 * There are two cases, depending on the direction of the increment.
 * If stepdir is BOP_add, then the # iterations is given by:
 *   (u > l) ? ( st > 0 ? ( (u - l + s - 1) / s ) : 0 ) :
 *             ( st < 0 ? ( (l - u - s - 1) / (-s) ) : 0)
 * If stepdir is BOP_sub, then the # iterations is given by:
 *   (u > l) ? ( st < 0 ? ( (u - l - s - 1) / (-s) ) : 0 ) :
 *             ( st > 0 ? ( (l - u + s - 1) / s ) : 0)
 * 
 * In the usual case of st > 0, the above is simplified as:
 *   (u > l) ? ( (u - l + s - 1) / s ) : 0   (for BOP_add)
 *   (u > l) ? ( (l - u + s - 1) / s ) : 0   (for BOP_sub)
 * and if s==1,
 *   (u > l) ? (u - l) : 0       (for BOP_add)
 *   (u > l) ? 0 : (l - u)       (for BOP_sub)
 * 
 * @param l the lower bound of the loop
 * @param u the upper bound of the loop
 * @param s the step increment
 * @param stepdir the direction of the increment (BOP_add / BOP_sub)
 * @param plainstep is 0 if the step increment is a full expression, 
 *                  2 if it is a constant, or 1 if it is equal to 1.
 * @return an expression for the total number of iterations
 */
astexpr specs2iters(astexpr l, astexpr u, astexpr s, int stepdir, int plainstep)
{
	if (plainstep)
	{
		if (plainstep == 1)       /* step = 1 */
			return
				ConditionalExpr(
					Parenthesis(BinaryOperator(BOP_geq, ast_expr_copy(u), ast_expr_copy(l))),
					(stepdir == BOP_add ? 
						Parenthesis(BinaryOperator(BOP_sub,ast_expr_copy(u),ast_expr_copy(l))) 
						: ZeroExpr()),
					(stepdir == BOP_sub ? 
						Parenthesis(BinaryOperator(BOP_sub,ast_expr_copy(l),ast_expr_copy(u)))
						: ZeroExpr())
				);
		else                      /* step = positive constant */
			return
				ConditionalExpr(
					Parenthesis(BinaryOperator(BOP_geq,ast_expr_copy(u),ast_expr_copy(l))),
					(stepdir == BOP_add ? 
						Parenthesis(
							BinaryOperator(BOP_div,
								Parenthesis(
									BinaryOperator(BOP_add,
										BinaryOperator(BOP_sub, ast_expr_copy(u), ast_expr_copy(l)),
										BinaryOperator(BOP_sub, ast_expr_copy(s), OneExpr())
									)
								),
								ast_expr_copy(s)
							)
						) 
						: ZeroExpr()),
					(stepdir == BOP_sub ? 
						Parenthesis(
							BinaryOperator(BOP_div,
								Parenthesis(
									BinaryOperator(BOP_add,
										BinaryOperator(BOP_sub, ast_expr_copy(l), ast_expr_copy(u)),
										BinaryOperator(BOP_sub, ast_expr_copy(s), OneExpr())
									)
								),
								ast_expr_copy(s)
							)
						) 
						: ZeroExpr())
				);
	}
	
	/* General case */
	return
		ConditionalExpr(
			Parenthesis(BinaryOperator(BOP_geq, ast_expr_copy(u), ast_expr_copy(l))),
			Parenthesis(
				ConditionalExpr(
					BinaryOperator(stepdir == BOP_add ? BOP_gt : BOP_lt, 
						ast_expr_copy(s), 
						ZeroExpr()
					),
					Parenthesis(
						BinaryOperator(BOP_div,
							Parenthesis(
								BinaryOperator(BOP_sub,
									BinaryOperator(stepdir,
										BinaryOperator(BOP_sub, ast_expr_copy(u), ast_expr_copy(l)),
										ast_expr_copy(s)
									),
									OneExpr()
								)
							),
							stepdir == BOP_add ? 
								ast_expr_copy(s) : UnaryOperator(UOP_neg, ast_expr_copy(s))
						)
					),
					ZeroExpr()
				)
			),
			Parenthesis(
				ConditionalExpr(
					BinaryOperator(stepdir == BOP_add ? BOP_lt : BOP_gt, 
						ast_expr_copy(s), 
						ZeroExpr()
					),
					Parenthesis(
						BinaryOperator(BOP_div,
							Parenthesis(
								BinaryOperator(BOP_sub,
									BinaryOperator(stepdir == BOP_add ? BOP_sub : BOP_add,
										BinaryOperator(BOP_sub, ast_expr_copy(l), ast_expr_copy(u)),
										ast_expr_copy(s)
									),
									OneExpr()
								)
							),
							stepdir == BOP_add ? 
								UnaryOperator(UOP_neg, ast_expr_copy(s)) : ast_expr_copy(s)
						)
					),
					ZeroExpr()
				)
			)
		);
}

