/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* opencl.c -- device targets for OpenMP trasformations/code generation */


#include <string.h>
#include "codetargs.h"
#include "ast_xformrules.h"
#include "x_target_ocl.h"
#include "x_parallel_ocl.h"
#include "x_teams_ocl.h"
#include "x_sections_ocl.h"
#include "x_assorted_ocl.h"
#include "x_parallel.h"
#include "x_distribute_cu.h"  /* Use the same transformation as CUDA */
#include "ast_xform.h"
#include "sem.h"

static xfr_t _opencl_xfr[] = {
	{ DCTARGET,                xform_target_opencl,          XFR_ACTION_COMBINE },
	{ DCTARGETPARALLEL,        xform_targetparallel_opencl,  XFR_ACTION_NONE },
	{ DCTARGETPARFOR,          xform_targparfor_opencl,      XFR_ACTION_NONE },
	{ DCTARGETTEAMS,           xform_targetteams_opencl,     XFR_ACTION_COMBINE },
	{ DCTARGETTEAMSDIST,       xform_targetteamsdist_opencl, XFR_ACTION_COMBSPLIT },
	{ DCTARGETTEAMSDISTPARFOR, xform_targetteamsdistparfor_opencl, XFR_ACTION_SPLIT },
	{ DCDISTPARFOR,            xform_distparfor_cuda,        XFR_ACTION_NONE},
	{ DCPARFOR,                xform_parallel,               XFR_ACTION_SPLIT },
	{ DCPARALLEL,              xform_parallel_opencl,        XFR_ACTION_NONE },
	{ DCSECTIONS,              xform_sections_opencl,        XFR_ACTION_NONE },
	{ DCCRITICAL,              xform_critical_opencl,        XFR_ACTION_NONE },
	{ DCATOMIC,                xform_atomic_opencl,          XFR_ACTION_NONE },
	{ DCTASK,                  xform_task_opencl,            XFR_ACTION_NONE },
	{ DCERROR,                 xform_error_opencl,           XFR_ACTION_NONE },
	XFR_LASTRULE
};


/*
 * Command-line user options
 */

#define OPTNAME(opt)   "opencl-" #opt
#define OPTNAME_V(opt) "Vopencl-" #opt "="
#define OPTION(opt)    OPT_##opt

typedef enum {
	OPTION(unknown) = -1, /* unknown option */
	OPTION(lastoption)    /* dummy */
} option_t;

static char *optnames[] = {
	NULL
};

static option_t optid(char *arg, char **val)
{
	int i;

	for (i = 0; i < OPTION(lastoption); i++)
	{
		if (!optnames[i])   /* Skip dummy options */
			continue;
		if (optnames[i][0] == 'V')     /* Option with value */
		{
			if (strncmp(optnames[i]+1, arg, strlen(optnames[i])-1) == 0)
			{
				*val = arg + strlen(optnames[i]) - 1;
				return ((option_t) i);
			}
		}
		else
			if (strcmp(optnames[i], arg) == 0)
				return ((option_t) i);
	}
	return ( OPTION(unknown) );
}


/* When ompi.c sees an --opencl-arg[=value] argument, it passes the
 * opencl-arg[=value] part to this handler.
 */
int _opencl_cmdarg_handler(char *arg)
{
	char *val;
	switch ( optid(arg, &val) )
	{
		default:
			fprintf(stderr, "[OMPi error]: unknown option '--%s'.\n", arg);
			return (1);
	}
	return (0);
}


/* 
 * Code to check for dangerous types
 */
 
static bool _ocl_hasdouble, _ocl_haslonglong;

static void _typecheck_declaration(aststmt t, void *ignore, int vistime)
{
	/* Too lazy to also check the declarator part... */
	if (speclist_size(t->u.declaration.spec) == LONGLONG_T)
		_ocl_haslonglong= true;
	if (speclist_basetype(t->u.declaration.spec) == DOUBLE_T)
		_ocl_hasdouble = true;
}

static void _typecheck_cast(astexpr e, void *ignore, int vistime)
{
	/* Too lazy to also check the declarator part... */
	if (e->u.dtype && e->u.dtype->spec)
	if (speclist_size(e->u.dtype->spec) == LONGLONG_T)
		_ocl_haslonglong = true;
	if (speclist_basetype(e->u.dtype->spec) == DOUBLE_T)
		_ocl_hasdouble = true;
}

/* Check for declared variables of type long long / double 
 * t is assumed to be a function definition.
 */
static void _typecheck_ocl(aststmt t)
{
	static travopts_t *decltrops = NULL;

	if (!t || !t->body) return;
	t = t->body;
	if (decltrops == NULL)
	{
		travopts_init_noop(decltrops = (travopts_t *) smalloc(sizeof(travopts_t)));
		decltrops->stmtc.declaration_c = _typecheck_declaration;
		decltrops->exprc.castexpr_c = _typecheck_cast;
		decltrops->when = PREVISIT;
	}
	decltrops->starg = NULL;
	ast_stmt_traverse(t, decltrops);
}


/*
 * Declarator adjustments
 */

void _opencl_kernel_adjust(aststmt t)
{
	ast_stmt_declordef_addspec(t, Usertype(Symbol("__kernel")));
}

void _opencl_shared_adjust(aststmt t)
{
	//ast_stmt_declordef_addspec(t, Usertype(Symbol("__local")));
}

void _opencl_global_adjust(aststmt t)
{
	ast_stmt_declordef_addspec(t, Usertype(Symbol("__global")));
}

void _opencl_structfield_adjust(aststmt t) { 
	ast_stmt_declordef_addspec(t, Usertype(Symbol("__global")));
}

void _opencl_decl_adjust(astdecl t) { 
	ast_decl_addspec(t, Usertype(Symbol("__global")));
}


/* For OpenCL we need to check if unsupported types were used */
void _opencl_topcom_adjust(aststmt t, str s) 
{
	/* Because the wrapper was added @ transformation time, the actual 
	 * function definition is the first node in the blocklist
	 */                                        
	_typecheck_ocl(t->u.next);
	if (_ocl_haslonglong)
		warning("[warning]: OpenCL targets do not support long long types.\n");
	str_printf(s, "/* $OCL_info:%d,%d */\n", _ocl_hasdouble, _ocl_haslonglong);
}


/* Functions to skip when generating prototypes, due to possible 
 * conflicting declarations with built-in ones.
 */
static char *skipped_funcs[] = {
	/* math */
	"cos", "sin", "pow", "fabs", "sqrt", "exp", 
	"acos", "acosh", "acospi", "asin", "asinh", "asinpi", "atan", "atan2",
	"atanh", "atanpi", "atan2pi", "cbrt", "ceil", "copysign", "cos", "cosh",
	"cospi", "erfc", "erf", "exp", "exp2", "exp10", "expm1", "fabs", "fdim",
	"floor", "fma", "fmax", "fmin",
	/* others */
	"malloc", "memcpy", "_ort_execute_teams", "_ort_execute_parallel",
	"printf", 
	NULL
};

aststmt _opencl_filterfunc(symbol fsym)
{
	int i;

	for (i = 0; skipped_funcs[i]; i++)
		if (strcmp(skipped_funcs[i], fsym->name) == 0)
			return NULL;

	if (fsym == Symbol("_dev_med2dev_addr"))
		return verbit("__global char *_dev_med2dev_addr(__global void *, "
		                                               "unsigned long); ");
	if (fsym == Symbol("_ort_critical_begin"))
		return verbit("void _ort_critical_begin(__global int *); ");
	if (fsym == Symbol("_ort_critical_end"))
		return verbit("void _ort_critical_end(__global int *); ");
	if (fsym == Symbol("_ort_reduce_add"))
	  return verbit("void _ort_reduce_add(int, void *, __global void *, int); ");
	/* The standard prototype of memset is not good for us */
	if (fsym == Symbol("memset"))
		return verbit("void memset(char *, int, int); ");
	/* All other cases: just return a copy of the prototype */
	return xform_clone_funcdecl(fsym);
}


int CODETARGID(opencl) = -1;   /* We get an id automatically */
static char *_opencl_kernel_header =
	"void _ort_set_local_mem(__local void *);\n"
	"void _ort_set_xtrainfo(__global int *);\n"
	"#define _ocl_dev_shmem_push(x,y) _dev_shmem_push(x,y)\n"
	"#define _ocl_dev_shmem_pop(x,y)  _dev_shmem_pop(x,y)\n"
	"#define _ocl_dev_execute_parallel(a,b,c,d,e) a(b)\n"
;

static char *_opencl_kbinsuffixes[] = { "-opencl.out", NULL };

/* This is called automatically */
void __codetarg_opencl_init()
{
	codetarg_set_cmdarg_handler(CODETARGID(opencl), _opencl_cmdarg_handler);
	codetarg_set_xformrules(CODETARGID(opencl), _opencl_xfr);
	codetarg_set_reduction_style(CODETARGID(opencl), REDCODE_RTLIB);
	codetarg_set_filterfunc(CODETARGID(opencl), _opencl_filterfunc);
	codetarg_set_kernelfiles_header(CODETARGID(opencl), _opencl_kernel_header);
	codetarg_set_kernelfiles_suffix(CODETARGID(opencl), "-opencl.cl");
	codetarg_set_kernelbins_suffixes(CODETARGID(opencl), _opencl_kbinsuffixes);

	/* Install adjusters */
	codetarg_set_adjuster(CODETARGID(opencl), 
	                      ADJ_KERNEL_FUNC,   (void (*)(void)) _opencl_kernel_adjust);
	codetarg_set_adjuster(CODETARGID(opencl), 
	                      ADJ_SHARED_STRUCT, (void (*)(void)) _opencl_shared_adjust);
	codetarg_set_adjuster(CODETARGID(opencl), 
	                      ADJ_GLOBALS,       (void (*)(void)) _opencl_global_adjust);
	codetarg_set_adjuster(CODETARGID(opencl), 
	                      ADJ_VARDECL_STMT,  (void (*)(void)) _opencl_structfield_adjust);
	codetarg_set_adjuster(CODETARGID(opencl), 
	                      ADJ_DECLARATION,   (void (*)(void)) _opencl_decl_adjust);
	codetarg_set_adjuster(CODETARGID(opencl), 
	                      ADJ_TOPCOMMENT,    (void (*)(void)) _opencl_topcom_adjust);
}
