/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_parallel_cu.c -- transform CUDA parallel constructs */

#include <string.h>
#include <assert.h>
#include "ast_xform.h"
#include "ast_copy.h"
#include "x_parallel.h"
#include "x_parallel_cu.h"
#include "x_target.h"
#include "x_clauses.h"
#include "omp.h"
#include "ompi.h"
#include "outline.h"
#include "ast_show.h"
#include "autoscope.h"
#include "builder.h"

static autoshattr_t as;

static astexpr xpcuda_callsite_xtraargs;
astexpr xp_cuda_callsite_expr(symbol func, astexpr funcargs)
{
	/* Add the extra parameters */
	funcargs = funcargs ? CommaList(funcargs, xpcuda_callsite_xtraargs) : 
	                      xpcuda_callsite_xtraargs;
	return FunctionCall(
		       IdentName("_ort_execute_parallel"),
		       CommaList(Identifier(func), funcargs)
		     );
}


void xform_parallel_cuda_wrapper(aststmt *t)
{
	stackelem(ompdirs) targetelem;
	int savescope = closest_parallel_scope;
	int combined;
	int savecpl = cur_parallel_line;
	int savectgl = cur_taskgroup_line;

	cur_parallel_line = (*t)->l;
	cur_taskgroup_line = 0;

	if (enableAutoscope)
		/*******************************/
		/*          Agelos             */
		/*******************************/
		/* If it's already analyzed, do nothing.
		* Else perform the analysis */
		if (autoscope_parreg_get_results(cur_parallel_line) == NULL)
			autoscope(*t);
		/*******************************/
		/*                             */
		/*******************************/

	closest_parallel_scope = stab->scopelevel;

	if ((targetelem = find_target_parent()))
		targetelem->data->nparallel++;

	xform_ompcon_body((*t)->u.omp);

	if (targetelem && targetelem->data->ismasterworker)
		combined = 0;
	else 
		if (targetelem && targetelem->data->iscombpar)
			combined = 1;
		else
			combined = XFORM_CURR_DIRECTIVE->iscombpar;
		
	xform_parallel_cuda(t, combined);

	if (enableAutoscope)
		autoscope_parreg_remove(cur_parallel_line);

	closest_parallel_scope = savescope;
	cur_parallel_line = savecpl;
	cur_taskgroup_line = savectgl;
}


void xform_parallel_cuda(aststmt *t, int iscombined)
{
	static int thrnum_cuda = 0;

	astexpr    numthrexpr = NULL, ifexpr = NULL;
	aststmt    copyininit;
	ompclause  c, def = xc_ompcon_get_clause((*t)->u.omp, OCDEFAULT);
	int        procbind_type;
	outcome_t  op;
	char       clabel[22];

#if 1
	static outline_opts_t oo =
	{
		/* structbased             */  true,            
		/* functionName            */  "",              
		/* functionCall  (func)    */  xp_cuda_callsite_expr,
		/* byvalue_type            */  BYVAL_byname,    
		/* byref_type              */  BYREF_copyptr,   
		/* byref_copyptr (2 funcs) */  "_cuda_dev_shmem_push", "_cuda_dev_shmem_pop",
		/* global_byref_in_struct  */  false,             
		/* structName              */  "__shvt__",        
		/* structVariable          */  "_shvars",         
		/* structInitializer       */  NULL,              
		/* implicitDefault (func)  */  xp_implicitDefault,
		/* deviceexpr              */  NULL,              
		/* addComment              */  true,              
		/* thestmt                 */  NULL,
		/* userType                */  NULL                 
	};
#else	
	static outline_opts_t oo =
	{
		/* structbased             */  true,              
		/* functionName            */  "",                
		/* functionCall  (func)    */  xp_callsite_expr,  
		/* byvalue_type            */  BYVAL_byname,      
		/* byref_type              */  BYREF_pointer,     
		/* byref_copyptr (2 funcs) */  NULL, NULL,        
		/* global_byref_in_struct  */  false,             
		/* structName              */  "__shvt__",        
		/* structVariable          */  "_shvars",         
		/* structInitializer       */  NULL,              
		/* implicitDefault (func)  */  xp_implicitDefault,
		/* deviceexpr              */  NULL,              
		/* addComment              */  true,              
		/* thestmt                 */  NULL,
		/* userType                */  NULL                 
	};
#endif
	/* The name of the label used for canceling. We use line number to avoid
	 * conflicts (there shouldn't be any since the code is outlined but we
	 * use it anyway incase we inline the parallel code in the future)
	 */
	snprintf(clabel, 22, "CANCEL_parallel_%d", (*t)->l);

	if (enableAutoscope) /* Agelos */
	{
		assert(autoscope_parreg_get_results((*t)->l) != NULL);
		as = *autoscope_parreg_get_results((*t)->l);
		if (def && def->subtype == OC_auto)
			oo.implicitDefault = xp_implicitDefaultAuto;
	}

	/* (1) Check for if and num_threads clauses and keep a copy
	 */
	if ((c = xc_ompcon_get_unique_clause((*t)->u.omp, OCNUMTHREADS)) != NULL)
		numthrexpr = ast_expr_copy(c->u.expr);
	if ((c = xc_ompcon_get_unique_clause((*t)->u.omp, OCIF)) != NULL)
		ifexpr = ast_expr_copy(c->u.expr);

	/* (2) Retrieve bind type
	 */
	procbind_type = xp_procbind((*t)->u.omp);

	/* (3) Call outline_OpenMP
	 */
	sprintf(oo.functionName, "_cudathrFunc%d_", thrnum_cuda++);
	xpcuda_callsite_xtraargs = Comma3(  // <numthread>, combined, <procbind>
	                       numthrexpr ? numthrexpr : numConstant(NUMTHREADS_RUNTIME),
	                       numConstant(iscombined),
	                       numConstant(procbind_type)
	                     );
	// If we are in a target region add global shared variables to the struct
	if (in_target())
		oo.global_byref_in_struct = true;
	else
		oo.global_byref_in_struct = false;
	oo.thestmt = *t;

	op = outline_OpenMP(t, oo);

	/* (4) Add if clause
	 */
	if (ifexpr)                    /* Check if we have an if() clause */
	{
		aststmt parent = op.repl_funcall->parent;
		aststmt new = If(ifexpr, ast_stmt_copy(op.repl_funcall),
		                  /* _ort_execute_serial(thrFunc, (void *) __shvt__/0); */
		                  FuncCallStmt(
		                    IdentName("_ort_execute_serial"),
		                    CommaList(
		                      IdentName(oo.functionName),
		                      CastVoidStar(
		                        op.func_struct ?
		                        UOAddress(IdentName(oo.structName)) :
		                        ZeroExpr()
		                      )
		                    )
		                  )
		                 );
		*(op.repl_funcall) = *new;
		ast_stmt_parent(parent, op.repl_funcall);

		// Replace call with if(ifexpr) execute_parallel else execute_serial
		// WARNING replacing parent->body can cause problems, for example if
		// parent is IF or STATEMENTLIST we may not be the body of our parent
		// In this case if there are any shared/firstprivate variables the parent
		// is a STATEMENTLIST with us as body, else it is a COMPOUND
		// 		parent->body = If(ifexpr, op.repl_funcall,
		// 		                  /* _ort_execute_serial(thrFunc, (void *) __shvt__/0); */
		// 		                  FuncCallStmt(
		// 		                    IdentName("_ort_execute_serial"),
		// 		                    CommaList(
		// 		                      IdentName(oo.functionName),
		// 		                      CastVoidStar(
		// 		                        op.func_struct ?
		// 		                        UOAddress(IdentName(oo.structName)) :
		// 		                        ZeroExpr()
		// 		                      )
		// 		                    )
		// 		                  )
		// 		                 );
		// 		//Parentize
		// 		ast_stmt_parent(parent, parent->body);
	}

	/* (5) Handle copyin variables
	 */
	copyininit = xp_handle_copyin(op.usedvars[DCT_BYREF], oo.structName);
	if (copyininit)
		ast_stmt_prepend(
		  op.func_regcode,
		  Block3(
		    verbit("/* copyin initialization(s) */"), copyininit, BarrierCall()
		  ));

	/* (6) Place a barrier at the end of the function
	 */
	ast_stmt_prepend(
	  op.func_return,  /* _ort_taskwait(2);  */
	  Labeled(
	    Symbol(clabel), /* label used for cancel */
	    FuncCallStmt(
	      IdentName("_ort_taskwait"),
	      numConstant(2)
	    )
	  )
	);
}

