/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_target_cu.c -- transform CUDA target constructs */

#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include "callgraph.h"
#include "ast_copy.h"
#include "ast_free.h"
#include "ast_print.h"
#include "ast_vars.h"
#include "ast_xform.h"
#include "ast_types.h"
#include "ast_assorted.h"
#include "ast_csource.h"
#include "ast_xformrules.h"
#include "ast_arith.h"
#include "x_target.h"
#include "x_target_cu.h"
#include "x_map.h"
#include "x_decltarg.h"
#include "x_clauses.h"
#include "x_task.h"
#include "x_kernels.h"
#include "x_parallel.h"
#include "x_teams.h"
#include "x_teams_cu.h"
#include "x_for.h"
#include "symtab.h"
#include "ompi.h"
#include "outline.h"
#include "str.h"
#include "builder.h"
#include "x_combine.h"
#include "cfg.h"

#ifdef DEVENV_DBG
#include "ast_show.h"
#endif

/* Quick but ugly flag to remember whether there is a defaultmap clause
 * in the current #target construct (it works since no nesting is allowed for
 * target constructs.
 * The best would be to have all *implicit*() functions accept a second
 * argument
 */
static bool hasdefaultmap = false;

void prepend_cuda_prologue(aststmt *t)
{
	aststmt initexpr = Block3(
	                    If( /* omp_get_thread_num() == 0 */
	                     BinaryOperator(
	                       BOP_eqeq,
	                       Call0_expr("omp_get_thread_num"),
	                       ZeroExpr()
	                     ),
	                     FuncCallStmt(IdentName("_cuda_dev_init"), numConstant(1)),
	                     NULL
	                    ),
						/* _cuda_dev_init_ctlblock(); _cuda_dev_syncthreads(); */
	                    FuncCallStmt(IdentName("_cuda_dev_init_ctlblock"), NULL),
	                    FuncCallStmt(IdentName("_cuda_dev_syncthreads"), NULL)
	                   );
	ast_stmt_prepend((*t)->u.omp->body, initexpr);
	XFORM_CURR_DIRECTIVE->iscombpar = 1;
}


void optimize_numthreads(astexpr *targetparams)
{
	/* (1) Assume it's already there and we're on a combined parallel region */
	astexpr optimal_nthr = (*targetparams)->right->left;
	
	/* (2) If # threads is missing due to the absence of combined parallel regions,
	* check if non-combined parallel region exist. If yes, use a fixed value; this 
	* could be also passed as "-2" to the runtime.
	* If not, this is a target region either with no parallel regions (supported),
	* or with calls to functions that include parallel regions (not supported)
	*/
	if (optimal_nthr == NULL)
		optimal_nthr = (XFORM_CURR_DIRECTIVE->nparallel > 0) ? 
			numConstant(DEVICETHREADS_FIXED) : ZeroExpr();
				
	/* (3) Add optimal_nthr to the target parameters */
	(*targetparams)->right->left = optimal_nthr;
	
}


void xform_targetparallel_cuda(aststmt *t)
{
	astexpr targetparams;
	targstats_t *ts = analyzeKernels 
		? cars_analyze_target((*t)->u.omp->body)
		: NULL;

	/* (1) Find all the other offloading parameters */
	targetparams = xc_ompcon_search_offload_params((*(t))->u.omp);
	optimize_numthreads(&targetparams);

	prepend_cuda_prologue(t);
	xform_ompcon_body((*t)->u.omp);
	xform_target_forcuda(t, ts, targetparams);
}


void xform_targparfor_cuda(aststmt *t)
{
	ccc_try_splitting(t);
	// xform_ompcon_body((*t)->u.omp);
	xform_targetteams_cuda(t);
}


void xform_target_cuda(aststmt *t)
{
	astexpr targetparams;
	int savecpl = cur_parallel_line;
	int savectgl = cur_taskgroup_line;
	cur_parallel_line = cur_taskgroup_line = 0;	

	/* (1) Find all the other offloading parameters */
	targetparams = xc_ompcon_search_offload_params((*(t))->u.omp);
	optimize_numthreads(&targetparams);
	
	/* (2) Apply the appropriate scheme, transform the body & 
	 * the directive, no need to find offload parameters.
	 */
	if (!search_nested_construct((*t)->u.omp, DCPARALLEL)
	&& !search_nested_construct((*t)->u.omp, DCDISTPARFOR))
	{
		#if PARALLEL_SCHEME == SCHEME_MASTERWORKER
		masterworker_stmt(&((*t)->u.omp->body));
		XFORM_CURR_DIRECTIVE->ismasterworker = 1;
		#else
		ifmaster_stmt(&((*t)->u.omp->body), false);
		#endif
	}
    else
		prepend_cuda_prologue(t);
		
	TARGET_PROLOGUE(t);
	xform_target_forcuda(t, ts, targetparams);

	cur_parallel_line = savecpl;
	cur_taskgroup_line = savectgl;
}


/* Produces an expression that calls the outlined function */
static astexpr offcuda_callsite_xtraargs;
astexpr offcuda_callsite_expr(symbol func, astexpr funcargs)
{
	/* Add the extra parameters */
	funcargs = funcargs ? CommaList(funcargs, offcuda_callsite_xtraargs) : 
	                      offcuda_callsite_xtraargs;
	return FunctionCall(
		       IdentName("_ort_offload_kernel"),
		       CommaList(Identifier(func), funcargs)
		     );
}

/* 
 * Functions that do the job 
 */
void xform_target_forcuda(aststmt *t, targstats_t *ts, astexpr params)
{
	astexpr    deviceexpr = NULL, ifexpr = NULL;
	aststmt    devicestm = NULL, *producedc, repstruct_pin, parent = (*t)->parent,
	           kernel_dims_decl;
	ompclause  c, deps;
	outcome_t  oc;
	bool       nowait = false, xformtask = false;
	target_list_t newtarget;
	set(vars)  devptrs = set_new(vars);
	setelem(cgfun) caf;
	stentry    e;
	setelem(xformrules) el;
	
	assert((el = set_get(ast_xfrules, XFRMOD(cuda))) != NULL); /* Sanity */

	/* 1) Preparations
	*/
	newtarget = (target_list_t) smalloc(sizeof(struct target_list_));
	newtarget->kernelfile = 
		(char *) smalloc((strlen(filename)+8+strlen("cuda")) * sizeof(char));

	/* <kernel_filename>_dXX */
	snprintf(newtarget->kernelfile, (strlen(filename) + 5), "%.*s_d%02d",
			(int)(strlen(filename) - 3), filename, el->value->vars->targetnum);

	A_str_truncate();
	str_printf(strA(), "\"%s\"", newtarget->kernelfile);

	/* Append suffix (-cuda.c) */
	strcat(newtarget->kernelfile, CUDA_KERNEL_SUFFIX);

	newtarget->ts = ts;                      /* For CARS */
	newtarget->calledfuncs = set_new(cgfun);

	/* Mark and store all the called functions.
	 * Notice that if e.g. the kernel has a #parallel or #task, then the outlined 
	 * function is not directly called from the kernel and thus it is not inclued
	 * here. However, it was included in the global symbol table by outline.c
	 */
	for (caf = cg_find_called_funcs(*t)->first; caf != NULL; caf = caf->next)
	{
		decltarg_add_calledfunc(caf->key);
		if ((e = symtab_get(stab, caf->key, FUNCNAME)) != NULL)
		{
			decltarg_bind_id(e);   /* Do it now in case it was previously analyzed */
			set_put(newtarget->calledfuncs, caf->key);
		}
	}
	
	/* 2) Check for device, if and other clauses
	 */
	if ((c = xc_ompcon_get_unique_clause((*t)->u.omp, OCDEVICE)) != NULL)
		deviceexpr = ast_expr_copy(c->u.expr);
	else
		deviceexpr = numConstant(AUTODEV_ID);
	if ((c = xc_ompcon_get_unique_clause((*t)->u.omp, OCIF)) != NULL)
		ifexpr = ast_expr_copy(c->u.expr);
	hasdefaultmap =
		(xc_ompcon_get_unique_clause((*t)->u.omp, OCDEFAULTMAP) != NULL);
	deps = xc_ompcon_get_every_clause((*t)->u.omp, OCDEPEND);
	nowait = (xc_ompcon_get_unique_clause((*t)->u.omp, OCNOWAIT) != NULL);
	get_and_check_device_ptrs((*t)->u.omp, devptrs);

	/* 3) Store device id in a variable to avoid re-evaluating the expression
	 */
	devicestm = device_statement(ifexpr, deviceexpr);
	deviceexpr = IdentName(currdevvarName);

	/* 4) Outline
	 */
	static outline_opts_t op =
	{
		/* structbased             */  true,                   
		/* functionName            */  "test",                 
		/* functionCall  (func)    */  offcuda_callsite_expr,  
		/* byvalue_type            */  BYVAL_bycopy,           
		/* byref_type              */  BYREF_pointer,          
		/* byref_copyptr (2 funcs) */  NULL, NULL,             
		/* global_byref_in_struct  */  true,                   
		/* structName              */  "__dev_struct",         
		/* structVariable          */  DEVENV_STRUCT_NAME,     
		/* structInitializer       */  NULL,                   
		/* implicitDefault (func)  */  xtarget_implicitDefault,
		/* deviceexpr              */  NULL,                   
		/* addComment              */  true,                   
		/* thestmt                 */  NULL,
		/* userType                */  NULL                      
	};

	sprintf(op.functionName, "_kernelFunc%d_cuda", el->value->vars->targetnum++);

	/* The NULL is replaced later with the declared variables struct */
	newtarget->decl_struct = NullExpr();

	//(void *) 0, "<kernelfilename>", <deviceexpr>
	offcuda_callsite_xtraargs = Comma4(
	                                newtarget->decl_struct,
	                                params,
	                                IdentName(A_str_string()),
	                                deviceexpr
	                            );
	//(struct __dev_struct *) _ort_devdata_alloc(sizeof(struct __dev_struct), <deviceexpr>)
	op.structInitializer =
	  CastedExpr(
	    Casttypename(
	      SUdecl(SPEC_struct, Symbol(op.structType), NULL, NULL),
	      AbstractDeclarator(Pointer(), NULL)
	    ),
	    FunctionCall(
	      IdentName("_ort_devdata_alloc"),
	      CommaList(
	        Sizeoftype(
	          Casttypename(
	            SUdecl(SPEC_struct, Symbol(op.structType), NULL, NULL),
	            NULL
	          )),
	        ast_expr_copy(deviceexpr)
	      )
	    )
	  );
	op.deviceexpr = deviceexpr;

	op.thestmt = *t;
	oc = outline_OpenMP(t, op);

	kernel_dims_decl = create_offloaddims_stmt();
	teamdims_list = thrdims_list = NULL; /* reset */

	if (oc.repl_befcall)
		ast_stmt_append(oc.repl_befcall, kernel_dims_decl);
	else
		ast_stmt_prepend(oc.repl_funcall, kernel_dims_decl);

	if (oc.func_struct)
		gpuize_struct(oc.func_struct, set_size(oc.usedvars[DCT_BYVALUE]));
	
	/* 5) Check if a struct was created and free it
	 *   -- do the same for the decldata struct (VVD)
	 */
	if (oc.func_struct)
		//_ort_devdata_free(DEVENV_STRUCT_NAME, <deviceexpr>);
		ast_stmt_append(oc.repl_aftcall ? oc.repl_aftcall : oc.repl_funcall,
		                 FuncCallStmt(
		                   IdentName("_ort_devdata_free"),
		                   CommaList(
		                     IdentName(op.structName),
		                     ast_expr_copy(deviceexpr)
		                   )
		                 )
		                );
	if (declvars_exist())
		//_ort_decldata_free(_decl_data, <deviceexpr>);
		ast_stmt_append(oc.repl_aftcall ? oc.repl_aftcall : oc.repl_funcall,
		                 FuncCallStmt(
		                   IdentName("_ort_decldata_free"),
		                   CommaList(
		                     Identifier(declstructVar),
		                     ast_expr_copy(deviceexpr)
		                   )
		                 )
		                );

	//In order to place it at the start of the generated code we have to go past
	//the commented directive and into the compound
	producedc = &oc.replacement->body->body;

	/* When there is no _dev_data struct, we need to remember where the
	 * offload statment is located so as to insert (possibly) the _decl_data
	 * struct just before it; in fact because xkn_produce_decl_var_code()
	 * places it right *after* rep_struct, we must actualy remember the
	 * statement right before the offload. Thus, when no _dev_data exists,
	 * we add an artificial comment to use as the spot after which the
	 * _decl_data struct will be placed, if needed.
	 */
	if (!oc.repl_struct)
		ast_stmt_prepend(*producedc, repstruct_pin = verbit("/* no_data_denv */"));
	else
		repstruct_pin = oc.repl_struct;

	/* 6) Create the code for the device data environment
	 */
#ifdef DEVENV_DBG
	fprintf(stderr, "[target env]:\n");
	ast_ompdir_show_stderr(op.thestmt->u.omp->directive);
#endif
	create_devdata_env(producedc, op.thestmt->u.omp,
	                   oc.usedvars, oc.usedvars[DCT_IGNORE], deviceexpr);

	prepare_gpu_wrapper(offcuda_callsite_xtraargs, oc.usedvars, oc.repl_struct,
	                    op.structType, devptrs, newtarget);

	/* 7) Prepare numargs and numteams declaration statements */
	if (numargs_declstmt)
	{
		if (oc.repl_befcall)
			ast_stmt_append(oc.repl_befcall, ast_stmt_copy(numargs_declstmt));
		else
			ast_stmt_prepend(oc.repl_funcall, ast_stmt_copy(numargs_declstmt));

		ast_stmt_free(numargs_declstmt);
	}

	if (argarray_declstmt)
	{
		if (oc.repl_befcall)
			ast_stmt_append(oc.repl_befcall, ast_stmt_copy(argarray_declstmt));
		else
			ast_stmt_prepend(oc.repl_funcall, ast_stmt_copy(argarray_declstmt));

		ast_stmt_free(argarray_declstmt);
	}

	/* 8) Now that clause xlitems were used, get rid of the OmpStmt
	 */
	ast_free(op.thestmt);          /* Get rid of the OmpStmt */

	/* 9) Prepare the task data environment, if needed.
	 */
	if (deps || nowait)
		xformtask = targettask(producedc, devicestm, deps, nowait, oc.usedvars);
	/* Insert the variable generated for the device id (and any tasking stuff) */
	ast_stmt_prepend(*producedc, devicestm);

	/* 10) Store the generated code
	 */
	xkn_kernel_add(&newtarget, "cuda");

	newtarget->rep_struct = repstruct_pin;
	newtarget->functionName = strdup(op.functionName);

	if (xformtask)
	{
		taskopt_e bak = taskoptLevel;

		taskoptLevel = OPT_NONE;
		ast_stmt_parent(parent, *t);
		targetTask = true;
		ast_stmt_xform(t);
		targetTask = false;
		taskoptLevel = bak;
	}
}