/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* vulkan.c -- device targets for OpenMP trasformations/code generation */

#include <string.h>
#include "ast.h"
#include "sem.h"
#include "codetargs.h"
#include "ast_xformrules.h"
#include "ast_xform.h"
#include "x_parallel.h"
#include "x_assorted_vk.h"
#include "x_distribute_vk.h"
#include "x_for_vk.h"
#include "x_parallel_vk.h"
#include "x_target_vk.h"
#include "x_teams_vk.h"
#include "vulkan.h"


static xfr_t _vulkan_xfr[] =  {
	{ DCTARGET,                xform_target_vulkan,                XFR_ACTION_COMBINE },
	{ DCTARGETPARALLEL,        xform_targetparallel_vulkan,        XFR_ACTION_SPLIT },
	{ DCTARGETPARFOR,          xform_targparfor_vulkan,            XFR_ACTION_SPLIT },
	{ DCTARGETTEAMS,           xform_targetteams_vulkan,           XFR_ACTION_COMBINE },
	{ DCTARGETTEAMSDISTPARFOR, xform_targetteamsdistparfor_vulkan, XFR_ACTION_SPLIT },
	{ DCDISTRIBUTE,            xform_distribute_vulkan,            XFR_ACTION_NONE },
	{ DCDISTPARFOR,            xform_distparfor_vulkan,            XFR_ACTION_NONE },
	{ DCCRITICAL,              xform_critical_vulkan,              XFR_ACTION_NONE },
	{ DCTASK,                  xform_task_vulkan,                  XFR_ACTION_NONE },
	{ DCFOR_P,                 xform_for_vulkan,                   XFR_ACTION_NONE },
	{ DCFOR,                   xform_for_vulkan,                   XFR_ACTION_NONE },
	{ DCSINGLE,                xform_single_vulkan,                XFR_ACTION_NONE },
	{ DCPARALLEL,              xform_parallel_vulkan,              XFR_ACTION_NONE },
	{ DCERROR,                 xform_error_vulkan,                 XFR_ACTION_NONE },
	XFR_LASTRULE
};

/*
 * Options
 */
#define OPTNAME(opt)   "vulkan-" #opt
#define OPTNAME_V(opt) "Vvulkan-" #opt "="
#define OPTION(opt)    OPT_##opt

typedef enum {
	OPTION(unknown) = -1, /* unknown option */

	OPTION(dummy1),  /* Requirements */
	OPTION(require), OPTION(require_V),

	OPTION(lastoption)    /* dummy */
} option_t;

static char *optnames[] = {
	NULL,
	OPTNAME(require), OPTNAME_V(require),

	NULL
};

static bool _vk_requires_int8 = false, _vk_requires_int16 = false,
            _vk_requires_int64 = false, _vk_used_long_double = false;

static option_t optid(char *arg, char **val)
{
	int i;

	for (i = 0; i < OPTION(lastoption); i++)
	{
		if (!optnames[i])   /* Skip dummy options */
			continue;
		if (optnames[i][0] == 'V')     /* Option with value */
		{
			if (strncmp(optnames[i]+1, arg, strlen(optnames[i])-1) == 0)
			{
				*val = arg + strlen(optnames[i]) - 1;
				return ((option_t) i);
			}
		}
		else
			if (strcmp(optnames[i], arg) == 0)
				return ((option_t) i);
	}
	return ( OPTION(unknown) );
}


/* When ompi.c sees an --vulkan-arg[=value] argument, it passes the
 * vulkan-arg[=value] part to this handler.
 */
int _vulkan_cmdarg_handler(char *arg)
{
	char *val;
	switch ( optid(arg, &val) )
	{
		case OPTION(require):
			fprintf(stderr, "[OMPi error]: expected value for require option.\n");
			return (1);
		case OPTION(require_V):
			if (strcmp(val, "int8") == 0)
				_vk_requires_int8 = true;
			else 
				if (strcmp(val, "int16") == 0)
					_vk_requires_int16 = true;
				else 
					if (strcmp(val, "int64") == 0)
						_vk_requires_int64 = true;
					else
					{
						fprintf(stderr, "[OMPi error]: unknown require option '%s' (try 'int8'/'int16'/'int64').\n", 
						                val);
						return (1);
					}
			break;
		default:
			fprintf(stderr, "[OMPi error]: unknown option '--%s'.\n", arg);
			return (1);
	}
	return (0);
}


int CODETARGID(vulkan) = -1;   /* We will get an id automatically */

static char *_vulkan_kernel_header =
	"\n"
	"/* <repl:devpart.comp> */\n"
	"\n"
;


static char *skipped_funcs[] = {
	/* math */
	"cos", "sin", "pow", "fabs", "sqrt", "exp", 
	"acos", "acosh", "acospi", "asin", "asinh", "asinpi", "atan", "atan2",
	"atanh", "atanpi", "atan2pi", "cbrt", "ceil", "copysign", "cos", "cosh",
	"cospi", "erfc", "erf", "exp", "exp2", "exp10", "expm1", "fabs", "fdim",
	"floor", "fma", "fmax", "fmin", "fabsf", "fdimf", "floorf", "fmaf", "fmaxf",
	"fminf", "fmodf", "frexpf",	"hypotf", "expf", "exp2f", "logf", "log2f", "log10f", 
	"powf", "sinf", "cosf", "tanf", "asinf", "acosf", "atanf", "atan2f", "ceilf",
	"truncf", "roundf", "copysignf", "sqrtf",
	/* others */
	"malloc", "memcpy", "_ort_execute_teams", "_ort_execute_parallel",
	"printf", 
	NULL
};

aststmt _vulkan_filterfunc(symbol fsym)
{
	int i;

	for (i = 0; skipped_funcs[i]; i++)
		if (strcmp(skipped_funcs[i], fsym->name) == 0)
			return NULL;

	/* All other cases: just return a copy of the prototype */
	return xform_clone_funcdecl(fsym);
}


void _vulkan_shared_adjust(aststmt t)
{
	ast_stmt_declordef_addspec(t, Usertype(Symbol("shared")));
}

/* Some signs/sizes/types are not supported by Vulkan shaders; we convert
 * them to shader-specific types.
 */
static void _check_and_convert_types(astspec s, void *state, int ignore)
{
	if (s->type == SPECLIST) /* e.g. unsigned char */
	{
		/* unsigned char | signed char */
		if (speclist_basetype(s) == CHAR_T) 
		{
			s->body = (speclist_sign(s) == UNSIGNED_T) ? 
				Usertype(Symbol("uint8_t")) : Usertype(Symbol("int8_t"));
			s->u.next = Usertype(Symbol("")); /* hack */
			_vk_requires_int8 = true;
		}
		/* long double (unsupported) */
		else if (speclist_basetype(s) == DOUBLE_T && speclist_size(s) == LONG_T) 
		{
			s->body  = Usertype(Symbol("double"));
			s->u.next = Usertype(Symbol(""));
			_vk_used_long_double = true;
		}
		/* _Bool */
		else if (speclist_basetype(s) == UBOOL_T) 
		{
			s->body  = Usertype(Symbol("bool"));
			s->u.next = Usertype(Symbol(""));
		}
		/* unsigned short | signed short */
		else if (speclist_size(s) == SHORT_T) 
		{
			s->body = (speclist_sign(s) == UNSIGNED_T) ? 
				Usertype(Symbol("uint16_t")) : Usertype(Symbol("int16_t"));
			s->u.next = Usertype(Symbol(""));
			_vk_requires_int16 = true;
		}
		/* unsigned long | signed long | unsigned long long | signed long long */
		else if (speclist_size(s) == LONG_T || speclist_size(s) == LONGLONG_T) 
		{
			s->body = (speclist_sign(s) == UNSIGNED_T) ? 
				Usertype(Symbol("uint64_t")) : Usertype(Symbol("int64_t"));
			s->u.next = Usertype(Symbol(""));
			_vk_requires_int64 = true;
		}
		/* unsigned int | signed int */
		else if (speclist_basetype(s) == INT_T) 
		{
			s->body = (speclist_sign(s) == UNSIGNED_T) ? 
				Usertype(Symbol("uint")) : Usertype(Symbol("int"));
			s->u.next = Usertype(Symbol(""));
		}
	}
	else if (s->type == SPEC)
	{
		switch (s->subtype)
		{
			/* unsigned -> uint */
			case SPEC_unsigned:
				*s = *(Usertype(Symbol("uint")));
				break;
			/* signed -> int */
			case SPEC_signed:
				*s = *(Declspec(SPEC_int));
				break;
			/* char -> int8_t */
			case SPEC_char:
				*s = *(Usertype(Symbol("int8_t")));
				_vk_requires_int8 = true;
				break;
			/* short -> int16_t */
			case SPEC_short:
				*s = *(Usertype(Symbol("int16_t")));
				_vk_requires_int16 = true;
				break;
			/* long -> int64_t */
			case SPEC_long:
				*s = *(Usertype(Symbol("int64_t")));
				_vk_requires_int64 = true;
				break;
			default:
				break;
		}
	}
}


static void _vk_typecheck_convert(aststmt t)
{
	travopts_t trops;
	travopts_init_noop(&trops);
	trops.when = PREVISIT;
	trops.specc.spec_c = _check_and_convert_types;
	trops.specc.speclist_c = _check_and_convert_types;
	ast_stmt_traverse(t, &trops);
}


void _vulkan_topcom_adjust(aststmt t, str s) 
{
	str_printf(s, "#version 450\n");
	_vk_typecheck_convert(t);
	if (_vk_used_long_double)
		warning("[warning]: Vulkan targets do not support long double types.\n");
	if (_vk_requires_int8)
	{
		str_printf(s, "#extension GL_EXT_shader_8bit_storage : require\n");
		str_printf(s, "#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n");
	}
	if (_vk_requires_int16)
		str_printf(s, "#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require\n");
	if (_vk_requires_int64)
		str_printf(s, "#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require\n");
}


/* This state is used for declvars */
static vk_bindingstate_t st = { 0, BINDING_DECLTARG }; 

/* Wrap globals with a buffer declaration */
void _vulkan_global_adjust(aststmt t) 
{
	static int current_kernelID = -1;
	assert(t->type == DECLARATION);
	astdecl decl = t->u.declaration.decl;
	symbol dsym = decl_getidentifier_symbol(decl);

	if (current_kernelID != xformingKernelID)
		st.binding_num = 0; /* Reset binding number, we are in another kernel file */
	current_kernelID = xformingKernelID;
	
	*t = *produce_shader_buffer_decl(dsym, &st);
}

static char *_vulkan_kbinsuffixes[] = { "-vulkan.spv", NULL };

/* This is called automatically */
void __codetarg_vulkan_init()
{
	/* Install command line option handler, and transformation rules */
	codetarg_set_cmdarg_handler(CODETARGID(vulkan), _vulkan_cmdarg_handler);
	codetarg_set_xformrules(CODETARGID(vulkan), _vulkan_xfr);
	codetarg_set_reduction_style(CODETARGID(vulkan), REDCODE_RTLIB);
	codetarg_set_filterfunc(CODETARGID(vulkan), _vulkan_filterfunc);
	codetarg_set_adjuster(CODETARGID(vulkan), 
	                      ADJ_SHARED_STRUCT, (void (*)(void)) _vulkan_shared_adjust);
	codetarg_set_adjuster(CODETARGID(vulkan), 
	                      ADJ_TOPCOMMENT, (void (*)(void)) _vulkan_topcom_adjust);
	codetarg_set_adjuster(CODETARGID(vulkan), 
	                      ADJ_GLOBALS, (void (*)(void)) _vulkan_global_adjust);
	
	/* Set kernel suffix/header */
	codetarg_set_kernelfiles_suffix(CODETARGID(vulkan), "-vulkan.comp");
	codetarg_set_kernelfiles_header(CODETARGID(vulkan), _vulkan_kernel_header);
	codetarg_set_kernelbins_suffixes(CODETARGID(vulkan), _vulkan_kbinsuffixes);
}
