/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_target_vk.c -- transform Vulkan target constructs */
#include <string.h>
#include "ast_free.h"
#include "ast_xform.h"
#include "ast_show.h"
#include "ast_copy.h"
#include "x_target.h"
#include "x_target_vk.h"
#include "x_decltarg.h"
#include "x_clauses.h"
#include "outline.h"
#include "x_combine.h"
#include "codetargs.h"
#include "set.h"
#include "cfg.h"

#ifdef DEVENV_DBG
#include "ast_show.h"
#endif

static void ifmaster_stmt_vulkan(aststmt *t, bool infunc)
{
	aststmt mask;

	/* (1) Declare the thread mask variable */
	if (!infunc)
		mask = verbit("int _im_thrmask = int(_im_myid == 0);");
	else
		mask = verbit("int _im_thrmask = int((_im_myid == 0) || omp_in_parallel());");

	/* (2) Call the main if-master transformation function */
	ifmaster_stmt(t, infunc, mask, "barrier");
}


/* Produces something like:
 * layout(set = X, binding = Y) buffer _ompi_shader_buf_ABC {
 *    <origdecl>
 * };
 */
aststmt produce_shader_buffer_decl(symbol s, void *state)
{
	const int set_num = 0;
	vk_bindingstate_t *st = (vk_bindingstate_t *) state;
	aststmt bufferdecl, buffer_prologue;
	stentry e = symtab_get(stab, s, IDNAME);
	int binding_num = st->binding_num;

	assert(e != NULL);

	bufferdecl = Declaration(
	               (e->spec->type == SUE && e->spec->subtype == SPEC_struct) ?
	                  Usertype(e->spec->name) :
	                  ast_spec_copy(e->spec), 
	               Declarator(
	                 NULL, 
	                 (e->isarray || decl_ispointer(e->decl)) ?
	                   ArrayDecl(IdentifierDecl(s), NULL, NULL) 
	                 : IdentifierDecl(s)
	               )
	             );

	if ((st->from == BINDING_TARGET) && declvars_exist())
		binding_num += set_size(declare_variables);

	buffer_prologue = verbit("layout(set = %d, binding = %d) buffer _ompi_shader_buf_%s {", 
	                          set_num, binding_num, s->name);

	st->binding_num++;
	return Block3(buffer_prologue, bufferdecl, verbit("};\n"));
}

/* Produces something like:
 * layout(set = X, binding = Y) buffer _ompi_shader_off_ABC {
 *    unsigned long <varname>
 * };
 */
aststmt produce_shader_offset_decl(symbol s, void *state)
{
	const int set_num = 0;
	vk_bindingstate_t *st = (vk_bindingstate_t *) state;
	aststmt bufferdecl, buffer_prologue;
	int binding_num = st->binding_num;

	bufferdecl = Declaration(
	               Speclist_right(Declspec(SPEC_unsigned), Declspec(SPEC_long)),
	               Declarator(NULL, IdentifierDecl(s))
	             );

	if ((st->from == BINDING_TARGET) && declvars_exist())
		binding_num += set_size(declare_variables);

	buffer_prologue = verbit("layout(set = %d, binding = %d) buffer _ompi_shader_off_%s {", 
	                          set_num, binding_num, s->name);

	st->binding_num++;
	return Block3(buffer_prologue, bufferdecl, verbit("};\n"));
}


static
void _produce_all_shader_buffers(set(vars) vars, aststmt *bufstmt, void *state)
{
	aststmt tmpstmt;
	setelem(vars) e;
	vk_bindingstate_t *st = (vk_bindingstate_t *) state;
	for (e = vars->first; e; e = e->next)
	{
		tmpstmt = produce_shader_buffer_decl(e->key, st);
		
		*bufstmt = (*bufstmt) ? BlockList(*bufstmt, tmpstmt) : tmpstmt;
#if 0
		if (e->value.ismap) /* produce offset, too */
			*bufstmt = BlockList(*bufstmt, 
				produce_shader_offset_decl(targstruct_offsetname(e->key), st));
#endif
	}
}


/* Wrapper for the actual _kernelFuncXX_(), specific to Vulkan devices
 * 
 * void main() {
 *   <body> // Call actual _kernelFuncXX_() 
 * }
 */
 
aststmt produce_shader_wrapper(char *kfuncname, bool emptyde)
{
	aststmt wrapr;
	void (*kerneladjust)(aststmt) = 
	        codetarg_get_adjuster(xformingFor, ADJ_KERNEL_FUNC);
	
	wrapr = FuncDef(
	          Declspec(SPEC_void), 
	          Declarator(
	            NULL, FuncDecl(IdentifierDecl(Symbol("main")), NULL)
	          ),
	          NULL,
	          Compound(
	            BlockList(
	              FuncCallStmt("_kernel_globals_init", NULL), 
	              FuncCallStmt(kfuncname, NULL)
	            )
	          )
	        );

	if (kerneladjust)
		kerneladjust(wrapr);

	return wrapr;
}


/**
 * Wrapper function that simply handles defaultmap clauses and calls 
 * xtarget_implicitDefault for everything else.
 * @return The decided mapping attribute (i.e. the corresponding set to join).
 */
vartype_t xtarget_implicitDefault_vk(setelem(vars) s, void *arg)
{
	struct { bool hasdefaultmap; } *impdefargs = arg;
	if (impdefargs->hasdefaultmap)
		return DCT_MAPTOFROM;   /* all treated as map(tofrom:) */
	return xtarget_implicitDefault(s, arg);
}


/* 
 * Functions that do the job 
 */
static void _omp_target_vulkan(aststmt *t, targstats_t *ts)
{
	aststmt    gpu_wrapper, bufstmt = NULL;
	outcome_t  oc;
	bool       emptyde = false;
	ompcon     ompc = (*t)->u.omp;
	kernel_t  *kernel = codetargs_get_kernel_from_copy(ompc, xformingFor);
	bool hasdefaultmap =
		(xc_ompcon_get_clause((*t)->u.omp, OCDEFAULTMAP) != NULL);
	struct { bool hasdefaultmap; } impdefargs = { hasdefaultmap };
	vk_bindingstate_t state = { 0, BINDING_TARGET }; // from target

	/* 1) Outline
	 */
	static outline_opts_t op =
	{
		/* structbased             */  false,                   
		/* functionName            */  "test",                 
		/* functionCall  (func)    */  NULL,  
		/* byvalue_type            */  BYVAL_bycopy,           
		/* byref_type              */  BYREF_pointer,          
		/* byref_copyptr (2 funcs) */  NULL, NULL,             
		/* global_byref_in_struct  */  true,                   
		/* structName              */  "__dev_struct",         
		/* structVariable          */  DEVENV_STRUCT_NAME,     
		/* structInitializer       */  NULL,                   
		/* implicitDefault (func)  */  xtarget_implicitDefault_vk,
		/* implicitDefault (args)  */  NULL,
		/* deviceexpr              */  NULL,                   
		/* addComment              */  false,                   
		/* thestmt                 */  NULL,
		/* userType                */  NULL,
		/* usePointers             */  false,
		/* makeReplCode            */  false,
		/* makeWrapper             */  false,
		/* wrapperType             */  WRAPPER_none
	};
	sprintf(op.functionName, "_kernelFunc%d_vulkan", kernel->kid);
	op.structInitializer = NullExpr();
	op.deviceexpr = numConstant(DFLTDEV_ALIAS);  /* dummy, just != NULL */
	op.implicitDefault_args = &impdefargs;
	op.thestmt = *t;

	oc = outline_OpenMP(t, op);

	/* 2) Produce the shader wrapper
	 */
	gpu_wrapper = produce_shader_wrapper(op.functionName, emptyde);

	/* 3) Store the generated code
	 */
	ast_parentize(kernel->kfuncstmt[xformingFor]);
	analyze_pointerize_decltarg_varsfuncs(kernel->kfuncstmt[xformingFor]);

	_produce_all_shader_buffers(oc.usedvars[DCT_BYVALUE],   &bufstmt, &state);
	_produce_all_shader_buffers(oc.usedvars[DCT_MAPALLOC],  &bufstmt, &state);
	_produce_all_shader_buffers(oc.usedvars[DCT_MAPTO],     &bufstmt, &state);
	_produce_all_shader_buffers(oc.usedvars[DCT_MAPFROM],   &bufstmt, &state);
	_produce_all_shader_buffers(oc.usedvars[DCT_MAPTOFROM], &bufstmt, &state);
	_produce_all_shader_buffers(oc.usedvars[DCT_DDENV],     &bufstmt, &state);
	_produce_all_shader_buffers(oc.usedvars[DCT_ZLAS],      &bufstmt, &state);
	
	if (bufstmt)
		ast_stmt_in_place_prepend(kernel->kfuncstmt[xformingFor], bufstmt);
	kernel->kfuncstmt[xformingFor] = BlockList(kernel->kfuncstmt[xformingFor], gpu_wrapper);
	kernel->kfuncname[xformingFor] = strdup(op.functionName);
}


void xform_targparfor_vulkan(aststmt *t)
{
	ccc_try_splitting(t);
	// xform_ompcon_body((*t)->u.omp);
	xform_target_vulkan(t);
}

void xform_targetparallel_vulkan(aststmt *t)
{
	int savecpl = cur_parallel_line;
	int savectgl = cur_taskgroup_line;
	bool prev_cppLineNo = cppLineNo;
	cur_parallel_line = cur_taskgroup_line = 0;	

	cppLineNo = false;

	// ccc_try_splitting(t);
	/* This is called just to set the clause vars */
	xc_ompcon_search_offload_params((*(t))->u.omp);

	TARGET_PROLOGUE(t);
	_omp_target_vulkan(t, ts);

	cur_parallel_line = savecpl;
	cur_taskgroup_line = savectgl;
	cppLineNo = prev_cppLineNo;
}

void xform_target_vulkan(aststmt *t)
{
	int savecpl = cur_parallel_line;
	int savectgl = cur_taskgroup_line;
	bool prev_cppLineNo = cppLineNo;
	
	cur_parallel_line = cur_taskgroup_line = 0;	
	cppLineNo = false;

	/* This is called just to set the clause vars */
	xc_ompcon_search_offload_params((*(t))->u.omp);

	/* (1) Apply the appropriate scheme, transform the body & 
	 * the directive, no need to find offload parameters.
	 */
	if (!search_nested_construct((*t)->u.omp, DCPARALLEL) &&
	    !search_nested_construct((*t)->u.omp, DCDISTRIBUTE) &&
	    !search_nested_construct((*t)->u.omp, DCDISTPARFOR))
		ifmaster_stmt_vulkan(&((*t)->u.omp->body), false);

	TARGET_PROLOGUE(t);
	_omp_target_vulkan(t, ts);

	cur_parallel_line = savecpl;
	cur_taskgroup_line = savectgl;

	cppLineNo = prev_cppLineNo;
}
