/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_for_vk.c */

#include <string.h>
#include <assert.h>
#include "stddefs.h"
#include "x_for.h"
#include "x_clauses.h"
#include "x_reduction.h"
#include "ast_xform.h"
#include "ast_free.h"
#include "ast_copy.h"
#include "ast_print.h"
#include "ast_assorted.h"
#include "ast_arith.h"
#include "ast_types.h"
#include "str.h"
#include "ompi.h"
#include "ox_xform.h"

static astexpr Castorigtype(int spec_subtype, astexpr t)
{
	return FunctionCall(IdentName(SPEC_symbols[spec_subtype]), t);
}

/**
 * Produce standard declarations in all split-iterations schedules.
 * Always non-NULL.
 */
aststmt for_iterdecls_vk(fordata_t *loopinfo)
{
	aststmt s;
	int i;
	
	/* niters, fiter, liter */
	s = Declaration(/* Initialize because if a thread gets no iterations, the */
	      Usertype(Symbol("int")),  /* lastprivate check for iter==niters may succeed! */
	      DeclList(       /*  <specs> niters_=0,iter_=0,fiter_,liter_=0; */
	        DeclList(
	          DeclList(
	            InitDecl(
	              Declarator(NULL, 
	                      IdentifierDecl(Symbol(loopinfo->varname(LOOP_NITERS)))),
	              ZeroExpr()
	            ),
	            InitDecl(
	              Declarator(NULL, 
	                        IdentifierDecl(Symbol(loopinfo->varname(LOOP_ITER)))),
	              ZeroExpr()
	            )
	          ),
	          Declarator(NULL, 
	                     IdentifierDecl(Symbol(loopinfo->varname(LOOP_FITER))))
	        ),
	        InitDecl(
	          Declarator(NULL, 
	                     IdentifierDecl(Symbol(loopinfo->varname(LOOP_LITER)))),
	          ZeroExpr()
	        )
	      )
	    );

	/* We need vars for # iterations in complex loops */
	if (loopinfo->collapsenum > 1 || loopinfo->doacrossnum > 0)
	{
		int nestnum;
	
		if (loopinfo->collapsenum > 1)
			s = BlockList(
			      s,
			      Declaration(Usertype(Symbol("int")),
			                  InitDecl(
			                    Declarator(NULL, IdentifierDecl(Symbol("pp_"))),
			                    OneExpr()
			                  ))
			    );
			       
		nestnum = (loopinfo->doacrossnum > loopinfo->collapsenum) ? 
		          loopinfo->doacrossnum : loopinfo->collapsenum;
		for (i = 0; i < nestnum; i++)
			s = BlockList(
			      s,
			      Declaration(
			        Usertype(Symbol("int")),
			        InitDecl(
			          Declarator(NULL, IdentifierDecl(loopinfo->itersym[i])),
		            loop_iters(&loopinfo->forps[i])
			        )
			      )
			    );
	}

	/* Finally, we need the loop parameters for doacross loops */
	if (loopinfo->doacrossnum > 0)
	{
		astexpr elems;
		
		/* Form the initializer */
		elems = LongArray3Initer(loopinfo->forps[0].lb, loopinfo->forps[0].step,
		               loopinfo->forps[0].incrop, Identifier(loopinfo->itersym[0]));
		for (i = 1; i < loopinfo->doacrossnum; i++)
			elems = 
				CommaList(
					elems, 
					LongArray3Initer(loopinfo->forps[i].lb, loopinfo->forps[i].step, 
					          loopinfo->forps[i].incrop, Identifier(loopinfo->itersym[i]))
				);
		/* Declare and initialize _doacc_params_[][3] */
		s = BlockList(
				s, 
				Declaration(
					Declspec(SPEC_long),
					InitDecl(
						Declarator(
							NULL,
							ArrayDecl(
								ArrayDecl(IdentifierDecl(Symbol(DOACCPARAMS)),NULL,NULL),
								NULL,
								Constant("3")
							)
						),
						BracedInitializer(elems)
					)
				)
			);
	}
	
	return s;
}


/**
 * @brief Produce the main, normalized loop body
 * 
 * A single loop becomes:
 *   for (iter = fiter; iter < liter; iter++) {
 *     <var> = lb +/- iter*step
 *     <body>
 *   }
 * optimized as:
 *   for (iter = fiter, var = ...; iter < liter; iter++, var +/-= step) {
 *     <body>
 *   }
 * If there is an ordered clause, we insert "_ort_for_curriter(iter_)"
 * just before the body, to let the runtime know our current iteration.
 *
 * For a collapsed loop nest, the non-optimized version is output
 * and multiple <var>s are recovered.
 */
static aststmt for_std_mainpart_vk(fordata_t *loopinfo, aststmt origbody)
{
	int i;
	aststmt idx;                           /* needed only for loop nest */
	symbol var = loopinfo->forps[0].var;   /* needed only in 1 loop */
	stentry e = symtab_get(stab, var, IDNAME);

	if (loopinfo->collapsenum > 1)         /* Recover all indices */
	{
		idx = AssignStmt(IdentName("pp_"), OneExpr());
		for (i = loopinfo->collapsenum - 1; i >= 0; i--)
		{
			idx = BlockList(
			        idx,
			        AssignStmt(
			          Identifier(loopinfo->forps[i].var),
			          Castorigtype(e->spec->subtype, BinaryOperator(
			            loopinfo->forps[i].incrop, //BOP_add,
			            ast_expr_copy(loopinfo->forps[i].lb),
			            BinaryOperator(
			              BOP_mul,
			              ast_expr_copy(loopinfo->forps[i].step),
			              Parenthesis(
			                BinaryOperator(
			                  BOP_mod,
			                  Parenthesis(
			                    BinaryOperator(
			                      BOP_div,
			                      IdentName(loopinfo->varname(LOOP_ITER)),
			                      IdentName("pp_")
			                    )
			                  ),
			                  Identifier(loopinfo->itersym[i])
			                )
			              )
			             )
			           )
					  )
			        )
			      );
			if (i != 0)
				idx = BlockList(
				        idx,
				        Expression(Assignment(IdentName("pp_"), ASS_mul,
				                              Identifier(loopinfo->itersym[i]))
				                  )
				      );
		}
	}
	
#define ORTCURRITER Expression(FunctionCall(IdentName("_ort_for_curriter"), \
                               IdentName(loopinfo->varname(LOOP_ITER))))
	if (loopinfo->collapsenum > 1) 
		return
			loop_normalize(Symbol(loopinfo->varname(LOOP_ITER)), 
			          IdentName(loopinfo->varname(LOOP_FITER)), NULL, 
			          IdentName(loopinfo->varname(LOOP_LITER)), NULL, origbody,
			          (loopinfo->ordplain ? BlockList(idx, ORTCURRITER) : idx), NULL);
	else    /* Optimize original loop index recovery */
		return
			loop_normalize(Symbol(loopinfo->varname(LOOP_ITER)), 
			                 IdentName(loopinfo->varname(LOOP_FITER)), 
			                 Assignment(Identifier(var),
			                            ASS_eq,
			                            Castorigtype(e->spec->subtype, BinaryOperator(
			                              loopinfo->forps[0].incrop, 
			                              ast_expr_copy(loopinfo->forps[0].lb),
			                              BinaryOperator(BOP_mul,
			                                IdentName(loopinfo->varname(LOOP_FITER)),
			                                ast_expr_copy(loopinfo->forps[0].step))
			                              ))
			                 ),
			                 IdentName(loopinfo->varname(LOOP_LITER)), 
			                 Assignment(Identifier(var), 
			                            bop2assop(loopinfo->forps[0].incrop),
			                            ast_expr_copy(loopinfo->forps[0].step)),
			                 origbody, (loopinfo->ordplain ? ORTCURRITER:NULL), NULL);
#undef ORTCURRITER
}


void for_schedule_static_vk(fordata_t *loopinfo, foresult_t *code)
{
	code->decls = Block2(code->decls, for_iterdecls_vk(loopinfo));
	code->mainpart = 
		If(
		  parse_expression_string(
		    "_vulkan_get_static_default_chunk(%s, %s, %s) != 0",
		    loopinfo->varname(LOOP_NITERS),
		    loopinfo->varname(LOOP_FITER),
		    loopinfo->varname(LOOP_LITER)
		  ),
		  Compound(loopinfo->mainpart_func(loopinfo, code->mainpart)),
		  NULL
		);
}


void for_schedule_static_chunksize_vk(fordata_t *loopinfo,foresult_t *code)
{
	aststmt s = for_iterdecls_vk(loopinfo);
	char *chsize;
	
	/* May need a declaration for non-constant chunk sizes */
	if (loopinfo->schedchunk && loopinfo->schedchunk->type == CONSTVAL)
		chsize = loopinfo->schedchunk->u.str;
	else   /* non constant */
	{
		chsize = CHUNKSIZE;
		s = BlockList(         /* expr for chunk size */
		      s,
		      Declaration(
		        Usertype(Symbol("int")),
		        InitDecl(
		          Declarator(NULL, IdentifierDecl(Symbol(chsize))),
		          ast_expr_copy(loopinfo->schedchunk)
		        )
		      )
		    );
	}

	/* Declare 2 more vars */
	s = BlockList(
	    s,
	    Declaration( /* declare: int chid_, TN_=omp_get_num_threads(); */
	      Usertype(Symbol("int")),
	      DeclList(
	        Declarator(NULL, IdentifierDecl(Symbol("chid_"))),
	        InitDecl(
	          Declarator(NULL, IdentifierDecl(Symbol("TN_"))),
	          Call0_expr("omp_get_num_threads")
	        )
	      )
	    )
	   );
			       
	code->decls = Block2(code->decls, s);
  
	/* The loop */
	s = loopinfo->mainpart_func(loopinfo, code->mainpart);
	code->mainpart = For(
	                   parse_blocklist_string("chid_ = omp_get_thread_num();"),
	                   NULL,
	                   parse_expression_string("chid_ += TN_"),
	                   Compound(
	                     BlockList(
	                       parse_blocklist_string(
	                         "%s = chid_*(%s);"
	                         "if (%s >= %s) break;"
	                         "%s = %s + (%s);"
	                         "if (%s > %s) %s = %s;",
	                         loopinfo->varname(LOOP_FITER), chsize, 
	                         loopinfo->varname(LOOP_FITER), 
	                         loopinfo->varname(LOOP_NITERS), 
	                         loopinfo->varname(LOOP_LITER), 
	                         loopinfo->varname(LOOP_FITER), chsize,
	                         loopinfo->varname(LOOP_LITER), 
	                         loopinfo->varname(LOOP_NITERS),
	                         loopinfo->varname(LOOP_LITER), 
	                         loopinfo->varname(LOOP_NITERS)
	                       ),
	                       s
	                     )
	                   )
	                 );
}


/* Possible clauses:
 * private, firstprivate, lastprivate, reduction, nowait, ordered, schedule,
 * collapse.
 */
void xform_for_vulkan(aststmt *t)
{
	xform_ompcon_body((*t)->u.omp);
	aststmt   s = (*t)->u.omp->body, parent = (*t)->parent, 
	          lasts = NULL, stmp, embdcls = NULL, arrsecxvars = NULL;
	forparts_t forps[MAXLOOPS];
	astexpr   expr, elems;
	symbol    itersym[MAXLOOPS];
	int       schedtype = OC_static /* default */, modifer = OCM_none,
	          static_chunk = 0, i = 0, collapsenum = 1, doacrossnum = 0, nestnum;
	bool      ispfor = ((*t)->u.omp->type == DCFOR_P);
	bool      haslast, hasboth, hasred;
	astexpr   schedchunk = NULL;    /* the chunksize expression */
	char      *chsize = NULL,       /* the chunksize value or variable */
	          iterstr[128];
	ompclause nw  = xc_ompcon_get_clause((*t)->u.omp, OCNOWAIT),
	          sch = xc_ompcon_get_clause((*t)->u.omp, OCSCHEDULE),
	          ord = xc_ompcon_get_clause((*t)->u.omp, OCORDERED),
	          ordnum = xc_ompcon_get_clause((*t)->u.omp, OCORDEREDNUM),
	          col = xc_ompcon_get_clause((*t)->u.omp, OCCOLLAPSE);
	symtab    dvars;
	fordata_t info = { 0 };
	foresult_t code = { NULL };

	/*
	 * Preparations
	 */
	if (sch)
	{
		schedtype  = sch->subtype;      /* OC_static, OC_... */
		schedchunk = sch->u.expr;
		if (schedtype == OC_static && sch->subtype != OC_auto && schedchunk)
			static_chunk = 1;
		if (schedtype == OC_affinity && schedchunk)
			schedchunk = ast_expr_copy(schedchunk);
		/* Optimize: if schedchunk is a constant, don't use a variable for it */
		if (schedchunk && schedchunk->type == CONSTVAL)
			chsize = strdup(schedchunk->u.str);    /* memory leak */
		modifer = HasModifier(sch, OCM_nonmonotonic) ? OCM_nonmonotonic :
		          HasModifier(sch, OCM_monotonic) ? OCM_monotonic : OCM_none;
	}

	if (ord && modifer == OCM_nonmonotonic)
		exit_error(1, "(%s, line %d) openmp error:\n\t"
		     "nonmonotonic schedules are not allowed along with ordered clauses.\n",
		     (*t)->u.omp->directive->file->name, (*t)->u.omp->directive->l);
	
	if (ord && ordnum)
		exit_error(1, "(%s, line %d) openmp error:\n\t"
		     "plain ordered clauses are not allowed in doacross loops.\n",
		     (*t)->u.omp->directive->file->name, (*t)->u.omp->directive->l);

	if (col)
	{
		if ((collapsenum = col->subtype) >= MAXLOOPS)
			exit_error(1, "(%s, line %d) ompi error:\n\t"
				"cannot collapse more than %d FOR loops.\n",
				(*t)->u.omp->directive->file->name, (*t)->u.omp->directive->l,MAXLOOPS);
	}

	if (ordnum)
	{
		if ((doacrossnum = ordnum->subtype) >= MAXLOOPS)
			exit_error(1, "(%s, line %d) ompi error:\n\t"
				"doacross loop nests should have up to %d FOR loops.\n",
				(*t)->u.omp->directive->file->name, (*t)->u.omp->directive->l,MAXLOOPS);
		if (doacrossnum < collapsenum)
			exit_error(1, "(%s, line %d) ompi error:\n\t"
		             "doacross loop collapse number cannot be larger "
		             "than its ordered number.\n",
		             (*t)->u.omp->directive->file->name, (*t)->u.omp->directive->l);
	}
	
	/* Collect all data clause vars - we need to check if any vars
	 * are both firstprivate and lastprivate
	 */
	dvars = xc_validate_store_dataclause_vars((*t)->u.omp->directive);

	/* Analyze the loop(s) */
	nestnum = (doacrossnum > collapsenum) ? doacrossnum : collapsenum;
	loopnest_analyze(s, nestnum, collapsenum, forps, *t, dvars, &embdcls);
	
	/* Prepare the loop info */
	info.haslast = haslast;
	info.ordplain = (ord != NULL);
	info.collapsenum = collapsenum;
	info.doacrossnum = doacrossnum;
	info.schedtype = schedtype;
	info.schedchunk = schedchunk;
	info.forps = forps;
	info.itersym = itersym;
	info.mainpart_func = for_std_mainpart_vk;
	info.varname = for_varnames;
	info.monotonic = modifer == OCM_monotonic || ord ||
	                 (modifer == OCM_none && schedtype == OC_static);

	/* Remember the last loop and var; form normalized iteration variables */
	s = forps[collapsenum-1].s;
	for (i = 0; i < nestnum; i++)
	{
		sprintf(iterstr, "iters_%s_", forps[i].var->name);
		itersym[i] = Symbol(iterstr); /* Remember the normalized iteration index */
	}
	
	/*
	 * Declarations and initializations
	 */
	
	/* get possibly new variables for array section parameters */
	arrsecxvars = red_arrayexpr_simplify((*t)->u.omp->directive);

	/* declarations from the collected vars (not the clauses!) */
	code.decls = verbit("/* declarations (if any) */");
	stmp = xc_stored_vars_declarations(&haslast, &hasboth, &hasred);
	if (stmp)
		code.decls = Block2(code.decls, stmp);
	if (arrsecxvars)
		code.decls = Block2(arrsecxvars, code.decls);
	if (embdcls)
		code.decls = BlockList(code.decls, embdcls);

	/* initialization statements for firstprivate non-scalar vars */
	code.inits = verbit("/* initializations (if any) */");
	if ((stmp = xc_ompdir_fiparray_initializers((*t)->u.omp->directive)) != NULL)
		code.inits = Block2(code.inits, stmp);
	
	/* assignments for lastprivate vars */
	if (haslast)
		lasts = xc_ompdir_lastprivate_assignments((*t)->u.omp->directive);
	
	/*
	 * Prologue
	 */
	
	/* Append our new code: niters_ = ...; _ort_entering_for(...); */
	if (collapsenum == 1 && doacrossnum == 0)
		elems = Parenthesis(loop_iters(&forps[0]));
	else
		for (elems = Identifier(itersym[0]), i = 1; i < collapsenum; i++)
			elems = BinaryOperator(BOP_mul, elems, Identifier(itersym[i]));
	expr = elems;

	if (ordnum)               /* Need more info for doacross loops */
		stmp = Expression(      /* _ort_entering_doacross(nw,doacnum,collnum,...); */
	           FunctionCall(
	             IdentName("_ort_entering_doacross"),
	             Comma6(
	               numConstant(nw ? 1 : 0),
	               numConstant(doacrossnum),
	               numConstant(collapsenum),
	               numConstant(FOR_CLAUSE2SCHED(schedtype, static_chunk)),
	               schedchunk ? IdentName(chsize) : numConstant(-1),
	               IdentName(DOACCPARAMS)
	             )
	           )
	         );
	else
		stmp = Expression(      /* _ort_entering_for(nw,ord); */
	           FunctionCall(
	             IdentName("_ort_entering_for"),
	             Comma2(numConstant(nw ? 1 : 0), numConstant(ord ? 1 : 0))
	           )
	         );

	stmp = BlockList(
	         Expression(     /* niters_ = ... */
	           Assignment(IdentName(info.varname(LOOP_NITERS)), ASS_eq, expr)
	         ),
	         stmp
	       );
	if (hasboth)   /* a var is both fip & lap; this needs a barrier here :-( */
		stmp = BlockList(stmp, BarrierCall(OMPIBAR_OMPIADDED));
	
	code.prologue = stmp;    /* Guaranteed to be non-NULL */

	/*
	 * Main part
	 */
	
	/* Just leave the original body and let the schedules utilize it */
	code.mainpart = s->body;
	
	/*
	 * Epilogue
	 */
	
	/* Add a label that is used when canceling */
	code.epilogue = Expression(NULL);
	if (!ispfor || ord || ordnum)   /* Still need it if ordered clause exists */
		code.epilogue = BlockList(code.epilogue, Call0_stmt("_ort_leaving_for"));
	/* Add lastprivate assignments */
	if (lasts)
	{
		if (collapsenum > 1)
		{
			aststmt idx;
		
			idx = Expression(Assignment(Identifier(forps[0].var), 
			                            bop2assop(forps[0].incrop), 
			                            ast_expr_copy(forps[0].step)));
			for (i = 1; i < collapsenum; i++)
				idx = BlockList(
				        idx,
				        Expression(Assignment(Identifier(forps[i].var), 
				                              bop2assop(forps[i].incrop), 
				                              ast_expr_copy(forps[i].step))
				        )
				      );
			lasts = BlockList(idx, lasts);
		}

		code.epilogue = 
			BlockList(
			  code.epilogue,
			  If(
			    BinaryOperator(BOP_land,
			      IdentName(info.varname(LOOP_ITER)),
			      BinaryOperator(BOP_eqeq,
			        IdentName(info.varname(LOOP_ITER)),
			        IdentName(info.varname(LOOP_NITERS))
			      )
			    ),
			    lasts->type == STATEMENTLIST ?  Compound(lasts) : lasts,
			    NULL
			  )
			);
	}

	/*
	 * Get loop specific code and combine the parts
	 */
	
	/* schedule-specific actions */
	switch (schedtype)
	{
		case OC_static:
			if (schedchunk)
				for_schedule_static_chunksize_vk(&info, &code);
			else
				for_schedule_static_vk(&info, &code);
			break;
	}
	
	(*t)->u.omp->body = NULL;     /* Make it NULL so as to free it easily */
	ast_free(*t);                 /* Get rid of the OmpStmt */
	*t = Block5(code.decls, code.inits, code.prologue, code.mainpart, 
	            code.epilogue);
	*t = Compound(*t);
	(*t)->parent = parent;
}
