/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_decltarg.c */

#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include "ast_copy.h"
#include "ast_free.h"
#include "ast_print.h"
#include "ast_vars.h"
#include "ast_xform.h"
#include "ast_assorted.h"
#include "ast_types.h"
#include "symtab.h"
#include "ompi.h"
#include "str.h"
#include "x_decltarg.h"
#include "x_target.h"
#include "x_clauses.h"
#include "x_kernels.h"

// #define DBG

#ifdef DBG
	#define DBGPRN(s) fprintf s;
#else
	#define DBGPRN(s) 
#endif
#undef DBG

/* This is the struct with all #declare-target variables as fields.
 * It is generated by decltarg_struct_code().
 */
aststmt structdecl = NULL;

/* In the following set we use the value ptr to point to the definition */
set(vars) declare_funcproto = NULL; /* Declared function prototypes */
set(vars) declare_variables = NULL; /* Declared variables */


/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                               *
 *     DISCOVER ALL IDS IN ALL DECALRE TARGET DIRECTIVES         *
 *                                                               *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */


/* The set of all #declare target identifiers; this includes ids from 
 * all v4.5 clauses as well as all discovered from v4.0 constructs.
 * The value fields are used as follows:
 *   ->clause to mark the way (OCLINK/OCTO)
 *   ->clsubt to mark a var, a func or unknown (OC_auto/OC_plus/OC_DontCare)
 */
static set(vars) decltarg_ids;


/**
 * Checks if a symbol has been explicitely #declared.
 */
bool decltarg_id_isknown(symbol s)
{ 
	return (set_get(decltarg_ids, s) != NULL); 
}


/**
 * Returns the clause type used for the #declare variable.
 */
ompclt_e decltarg_id_clause(symbol s)
{ 
	return set_get(decltarg_ids, s)->value.clause;
}


/*
 * Next 2 functions discover all ids within an old-style (v4.0) construct
 */


static 
void decltarg40_decl_ids(astdecl t)
{
	symbol        s; 
	setelem(vars) v;
	
	if (!t) return;
	switch (t->type)
	{
		case DIDENT:
			s = t->u.id;
			v = set_get(decltarg_ids, s);
			if (v)
				if (v->value.clause != OCTO)    /* OpenMP v4.5, 2.10.6, p. 112 */
					exit_error(1, "#declare target variable '%s' was previously used "
					              "in a LINK clause", s->name);
			if (!v)
			{
				v = set_put(decltarg_ids, s);
				v->value.clause = OCTO;
				v->value.clsubt = OC_auto;  /* OC_auto to mark a variable (unused :-) */
			}
			break;
		case DFUNC:
			s = t->decl->u.id;
			v = set_get(decltarg_ids, s);
			if (v)
				if (v->value.clause != OCTO)    /* OpenMP v4.5, 2.10.6, p. 112 */
					exit_error(1, "#declare target function '%s' was previously used "
					              "in a LINK clause", s->name);
			if (!v)
			{
				v = set_put(decltarg_ids, s);
				v->value.clause = OCTO;
				v->value.clsubt = OC_plus;  /* OC_plus to mark a function (unused :-) */
			}
			break;
		case DLIST:
			decltarg40_decl_ids(t->u.next);
		default:
			decltarg40_decl_ids(t->decl);
	}
}


static 
void decltarg40_body_ids(aststmt t)
{
	if (!t) return;
	switch (t->type)
	{
		case DECLARATION:
			if (t->u.declaration.decl &&      /* avoid user types */
			    !speclist_getspec(t->u.declaration.spec, STCLASSSPEC, SPEC_typedef))
				decltarg40_decl_ids(t->u.declaration.decl);
			break;
		case FUNCDEF:
		{
			symbol        s = decl_getidentifier_symbol(t->u.declaration.decl);
			setelem(vars) v = set_get(decltarg_ids, s);
			
			if (v)
				if (v->value.clause != OCTO)    /* OpenMP v4.5, 2.10.6, p. 112 */
					exit_error(1, "#declare target function '%s' was previously used "
					              "in a LINK clause", s->name);
			if (!v)
			{
				v = set_put(decltarg_ids, s);
				v->value.clause = OCTO;
				v->value.clsubt = OC_plus;  /* OC_plus to mark a function (unused :-) */
			}
			break;
		}
		case STATEMENTLIST:
			decltarg40_body_ids(t->u.next);
			decltarg40_body_ids(t->body);
			break;
		default:
			break;
	}
}


/*
 * Next 2 functions discover all ids within a newer-style (v4.5) construct
 */


/**
 * Given the xlist of a clause, add each id to the #declare target ids set
 * @param ctype the clause type
 * @param xl    the list
 */
static 
void decltarg45_item_ids(ompclt_e ctype, ompxli xl)
{
	setelem(vars) v;
	
	for (; xl; xl = xl->next)
	{
		/* TODO: ctype is either OCTO or OCLINK; we treat both as TO for now.. */
		if (xl->xlitype != OXLI_IDENT)
			warning("array sections are not yet handled in #declare target items; "
			        "using the whole array..\n");
		if ((v = set_get(decltarg_ids, xl->id)) != NULL)
			if (v->value.clause != ctype)    /* OpenMP v4.5, 2.10.6, p. 112 */
				exit_error(1, "Illegal re-declaration of #declare target variable '%s'",
				              xl->id->name);
		if (!v)
		{
			v = set_put(decltarg_ids, xl->id);
			v->value.clause = ctype;
			v->value.clsubt = OC_DontCare;  /* Don't know if it is var or func */
		}
	}
}


static 
void decltarg45_clause_ids(ompclause t)
{
	if (t->type == OCLIST)
	{
		if (t->u.list.next != NULL)
			decltarg45_clause_ids(t->u.list.next);
		t = t->u.list.elem;
		assert(t != NULL);
	}
	assert(t->type == OCTO || t->type == OCLINK);  /* sanity check */
	decltarg45_item_ids(t->type, t->u.xlist);
}


/**
 * Discovers all #declare target diretives and forms a set of all 
 * #declared identifiers; called before the transformation of the AST.
 * @param t the root of the AST 
 */
void decltarg_find_all_directives(aststmt t)
{
	/* Handy spot to initialize some sets which will be used later on */
	if (!declare_variables) declare_variables = set_new(vars);
	if (!declare_funcproto) declare_funcproto = set_new(vars);
	if (!decltarg_ids)      decltarg_ids      = set_new(vars);

	/* No better place found for these initializations... */
	declstructVar  = Symbol("_decl_data");
	declstructArg  = Symbol("__decl_data");
	declstructType = Symbol("__decl_struct");

	if (!t) return;
	switch (t->type)
	{
		case OMPSTMT:
			if (t->u.omp->type != DCDECLTARGET)
				break;
			if (t->u.omp->directive->clauses)  /* OpenMP v4.5 style */
				decltarg45_clause_ids(t->u.omp->directive->clauses);
			else
				decltarg40_body_ids(t->u.omp->body);
			break;
		case STATEMENTLIST:
			decltarg_find_all_directives(t->u.next);
			decltarg_find_all_directives(t->body);
			break;
		default:
			break;
	}
}


/* OpenMP v5.0: add any function identifier referenced within a kernel
 * This function has no actual meaning... It is the decltarg_bind_id()
 * which does the actual job. Here we merely record the symbol for 
 * completeness (we never use it).
 */
void decltarg_add_calledfunc(symbol s)
{
	setelem(vars) v;
	
	if ((v = set_get(decltarg_ids, s)) != NULL)   /* already known */
		return;
	v = set_put(decltarg_ids, s);
	v->value.clause = OCTO;
	v->value.clsubt = OC_plus;  /* OC_plus to mark a function (unused :-) */
}


/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                               *
 *     TRANSFORMATION FUNCTIONS                                  *
 *                                                               *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */



/* VVD--add totally global var (e.g. reduction locks) */
void decltarg_inject_newglobal(symbol s)
{
	setelem(vars) v;
	
	v = set_put(decltarg_ids, s);
	v->value.clause = OCTO;
	v->value.clsubt = OC_auto;
	set_put(declare_variables, s);
}


/**
 * Binds a known #declare target id with its actual declaration in the 
 * progam's symbol table and puts it in the appropriate set (vars/funcs).
 * @param e the symbol table entry
 */
void decltarg_bind_id(stentry e)
{
	if (e->space == IDNAME)
	{
		if (e->isthrpriv)  /* OpenMP v4.5, 2.10.6, p. 112 */
			exit_error(1,"threadprivate variables disallowed in #declare target\n");
		set_put_unique(declare_variables, e->key);
	}
	else
	{
		if (!set_get(declare_funcproto, e->key))
		{
			DBGPRN((stderr, "[decltarg] binding function %s\n", e->key->name));
			set_put(declare_funcproto, e->key)->value.ptr = NULL;
		}
	}
	e->isindevenv = due2DECLTARG;
}


/* Not much to do ... */
void xform_declaretarget(aststmt *t)
{
	aststmt v;
	
	if ((*t)->u.omp->directive->clauses)      /* Newer (v4.5) style */
		v = ompdir_commented((*t)->u.omp->directive);
	else
	  v = (*t)->u.omp->body;

	(*t)->u.omp->body = NULL;     /* Make it NULL so as to free it easily */
	ast_free(*t);                 /* Get rid of the OmpStmt */
	*t = v;
}


symbol declstructVar, declstructArg, declstructType;


/**
 * This generates the struct and all related code for the #declare target
 * variables.
 * @param initvars   declaration statments of extra vars used as initializers
 * @param regstmts   statements registering the declated variables
 * @param structinit the allocation and initialization statements of the struct
 */
void decltarg_struct_code(aststmt *initvars, aststmt *regstmts,
                          aststmt *structinit)
{
	setelem(vars) e;
	stentry       orig;
	aststmt       st, fieldinits = NULL;
	astdecl       dcl, fields = NULL;
	
	*initvars = NULL;
	*regstmts = NULL;
	for (e = declare_variables->first; e; e = e->next)
	{
		orig = symtab_get(stab, e->key, IDNAME);

		/* If the variable had an initializer, copy it as is in the initializer
		 * function we make it static const and add "init_" to it's name
		 */
		if (orig->idecl)
		{
			A_str_truncate();
			str_printf(strA(), "init_%s", e->key->name);

			/* static const init_<original name> = <original initializer> */
			st = Declaration(
			        Speclist_right(
			          Speclist_right(
			            StClassSpec(SPEC_static),
			            StClassSpec(SPEC_const)),
			          ast_spec_copy_nosc(orig->spec)
			        ),
			        decl_rename(ast_decl_copy(orig->idecl), Symbol(A_str_string()))
			      );
			*initvars = *initvars ? BlockList(*initvars, st) : st;
		}

		/* Produce registration statements:
		 * ort_decltarg_register(&<var>,sizeof(<var>),[&init_<var>|(void *)0],lnk);
		 */
		st = FuncCallStmt(
		        IdentName("ort_decltarg_register"),
		        Comma4(
		          UOAddress(Identifier(e->key)),
		          Sizeof(Identifier(e->key)),
		          (orig->idecl ? UOAddress(IdentName(A_str_string())) : NullExpr()),
		          numConstant(decltarg_id_clause(e->key) == OCLINK)
		        )
		      );
		*regstmts = *regstmts ? BlockList(*regstmts, st) : st;

		/* Create the fields for the struct used to pass declared variables values.
		 * -- need care here and use xform_clone_declaration() (VVD)
		 */
		dcl = StructfieldDecl(ast_spec_copy_nosc(orig->spec),
		                      decl_topointer(xform_clone_declonly(orig)));
		fields = fields ? StructfieldList(fields, dcl) : dcl;

		/* Field initialization statements
		 * _decl_data-><var> = (<type> *) 
		 *      ort_decltarg_host2med_addr(&<var>, __ompi_devID); 
		 */
		dcl = decl_topointer(ast_decl_copy(orig->decl));
		st = AssignStmt(
		        PtrField(Identifier(declstructVar), e->key),
		        CastedExpr(
		          Casttypename(
		            ast_spec_copy_nosc(orig->spec),
		            ast_xt_concrete_to_abstract_declarator(dcl)
		          ),
		          FunctionCall(
		            IdentName("ort_decltarg_host2med_addr"),
		            CommaList(
		              UOAddress(Identifier(orig->key)),
		              IdentName("__ompi_devID")
		            )
		          )
		        )
		      );
		free(dcl);
		fieldinits = fieldinits ? BlockList(fieldinits, st) : st;
	}
	
	/* struct __decl_struct {
	 *   <fields>
	 * } *_decl_data;
	 */
	structdecl = Declaration(
	               SUdecl(SPEC_struct, declstructType, fields, NULL),
	               Declarator(Pointer(), IdentifierDecl(declstructVar))
	             );
	
	/* _decl_data = (struct __decl_struct *) 
	 *     ort_devdata_alloc(sizeof(struct ___decl_struct), __ompi_devID);
	 */
	st = AssignStmt(
	        Identifier(declstructVar),
	        CastedExpr(
	          Casttypename(
	            SUdecl(SPEC_struct, declstructType, NULL, NULL),
	            AbstractDeclarator(Pointer(), NULL)
	          ),
	          FunctionCall(
	            IdentName("ort_devdata_alloc"),
	            CommaList(
	              Sizeoftype(
	                Casttypename(
	                  SUdecl(SPEC_struct, declstructType, NULL, NULL),
	                  NULL
	                )),
	              IdentName("__ompi_devID")
	            )
	          )
	        )
	      );
	*structinit = fieldinits ? BlockList(st, fieldinits) : st;
}

/**
 * Generates declarations for declare target variables (pointerized).
 */
aststmt decltarg_kernel_globals()
{
	setelem(vars) e;
	aststmt       var, globalvars = NULL;
	
	globalvars = NULL;
	for (e = declare_variables->first; e; e = e->next)
	{
		/* Copy the #declare variable declarations and turn them into pointers */
		var = xform_clone_declaration(e->key, NULL, true, NULL);
		DEVSPECit(var, DEVSPECQUAL);
		globalvars = globalvars ? BlockList(globalvars, var) : var;
	}
	return (globalvars);
}


/**
 * Generates the binding struct-related code for declare target variables
 */
aststmt decltarg_kernel_struct_code()
{
	setelem(vars) e;
	astexpr       tmpexpr;
	aststmt       inits = NULL, kstructinit, tmp;
	
	for (e = declare_variables->first; e; e = e->next)
	{
		/* Initialize the variable pointers:
		 *  var = devpart_med2dev_addr(_decl_data->var, sizeof(*(_decl_data->var)))
		 */
		tmpexpr = PtrField(Identifier(declstructVar), e->key);
		tmp = AssignStmt(Identifier(e->key),
		        FunctionCall(
		          IdentName("devpart_med2dev_addr"),
		          CommaList(
		            tmpexpr,
		            Sizeof(DerefParen(ast_expr_copy(tmpexpr)))
		          )
		        )
		      );
		inits = inits ? BlockList(inits, tmp) : tmp;
	}
	
	/* Cast the wrapper function parameter into the struct 
	 *   _decl_data = (struct __decl_struct  *) __decl_data;
	 */
	kstructinit = AssignStmt(
	                Identifier(declstructVar),
	                CastedExpr(
	                  Casttypename(
	                    SUdecl(SPEC_struct, declstructType, NULL, NULL),
	                    AbstractDeclarator(Pointer(), NULL)
	                  ),
	                  Identifier(declstructArg)
	                )
	              );
	if (inits)
		kstructinit = BlockList(kstructinit, inits);
	return (kstructinit);
}


/**
 * Generates binding statements for declare target variables in gpu kernels
 */
aststmt decltarg_gpu_kernel_varinits()
{
	setelem(vars) e;
	aststmt       inits = NULL, tmp;
	static char   argname[128];
  stentry       orig;
	

	for (e = declare_variables->first; e; e = e->next)
	{
		/* Initialize the variable pointers:
		 *  var = devpart_med2dev_addr(argname, sizeof(*var));
		 */
		orig = symtab_get(stab, e->key, IDNAME);
		sprintf(argname, "_dt_%s", e->key->name);
		tmp = AssignStmt(Identifier(e->key),
		        CastedExpr(
		          Casttypename(
		            Speclist_right(
		              Usertype(Symbol(DEVSPEC)), ast_spec_copy_nosc(orig->spec)
		            ),
		            ast_xt_concrete_to_abstract_declarator(
		                                   decl_topointer(ast_decl_copy(orig->decl))
		           )
		          ),
		          FunctionCall(
		            IdentName("devpart_med2dev_addr"),
		            CommaList(
		              IdentName(argname),
		              Sizeof(DerefParen(ast_expr_copy(Identifier(e->key))))
		            )
		          )
		        )
		      );
		inits = inits ? BlockList(inits, tmp) : tmp;
	}
	return inits;
}


/**
 * Turn declare target variables into arguments for the offloading
 */
astexpr decltarg_offload_arguments_withsize()
{
	setelem(vars) e;
	astexpr       args = NULL, ptrfield = NULL;

	args = numConstant(set_size(declare_variables));
	for (e = declare_variables->first; e; e = e->next)
	{
		ptrfield = PtrField(Identifier(declstructVar), e->key);
		args = CommaList(args, ptrfield);
	}
	return args;
}


astexpr decltarg_offload_arguments()
{
	setelem(vars) e;
	astexpr       args = NULL, ptrfield = NULL;

	for (e = declare_variables->first; e; e = e->next)
	{
		ptrfield = PtrField(Identifier(declstructVar), e->key);
		args = args ? CommaList(args, ptrfield) : ptrfield;
	}
	
	return args;
}


astexpr decltarg_num_offload_arguments()
{
	astexpr num_args = NULL;

	num_args = numConstant(set_size(declare_variables));
	return num_args;
}


/**
 * Generates parameters of gpu kernel wrapper out of declare target variables
 */
astdecl decltarg_gpu_kernel_parameters()
{
	setelem(vars) e;
	astdecl       params = NULL, tmp;
	static char   argname[128];
  stentry       orig;
	
	for (e = declare_variables->first; e; e = e->next)
	{
		orig = symtab_get(stab, e->key, IDNAME);
		sprintf(argname, "_dt_%s", e->key->name);
		tmp = ParamDecl(
		        Speclist_right(
		          Usertype(Symbol(DEVSPEC)),
		          ast_spec_copy_nosc(orig->spec)
		        ), 
		        decl_topointer(
		          decl_rename(ast_decl_copy(orig->decl), Symbol(argname))
		        )
		      );

		params = params ? DeclList(params, tmp) : tmp;
	}
	return params;
}
