/*
  OMPi OpenMP Compiler
  == Copyright since 2001, the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* KERNELS.C
 * Module and kernels support for ompicc
 */

/* 
 * May 2019:
 *   Created out of code in ompicc.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <ctype.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/time.h>
#include <errno.h>
#include <libgen.h>
#include "config.h"
#include "str.h"
#include "ompicc.h"
#include "mapper.h"
#include "assorted.h"
#include "set.h"
#ifdef OMPI_REMOTE_OFFLOADING
	#include "roff_config.h"
#endif

static
char current_file[PATHSIZE], cwd[PATHSIZE];  /* Current file and working dir */


/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                           *
 *        MODULES                                            *
 *                                                           *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */


char **modulenames;  /* Pointers to module names */
int  nmodules;

#if defined(OMPI_REMOTE_OFFLOADING)

static int is_module_supported(char *modulename)
{
	int i;

	for (i = 0; i < nmodules; i++)
		if (!strcmp(modulenames[i], modulename))
			return 1;
	
	return 0;
}

#endif


void modules_employ(char *modstring)
{
	char *s;
	int  i;

	/* Pass 1: find # modules */
	if (!modstring)
	{
		nmodules = 0;
		return;
	}
	else if (strcmp(modstring, "all") == 0)
		modstring = strdup(MODULES_CONFIG);

	for (; isspace(*modstring) || *modstring == ','; modstring++)
		;
	if (*modstring == 0)
		return;
	for (nmodules = 1, s = modstring; *s; s++)
	{
		if (isspace(*s) || *s == ',')
		{
			for (*s = 0, s++; isspace(*s) || *s == ','; s++)
				*s = ',';  /* all spaces become commas */
			if (*s)
				nmodules++;
			s--;
		}
	}

	/* Pass 2: fix pointers */
	if ((modulenames = (char **) malloc(nmodules*sizeof(char *))) == NULL)
	{
		fprintf(stderr, "cannot allocate memory");
		exit (1);
	}
	for (i = 0, s = modstring; i < nmodules; i++)
	{
		for (modulenames[i] = s++; *s; s++)
			;
		if (i == nmodules-1)
			break;
		for (; *s == 0 || *s == ','; s++)
			;
	}
}


#ifdef OMPI_REMOTE_OFFLOADING

static int Moduleid(char *modulename)
{
	int i;
	for (i = 0; allmodules[i].modname != NULL; i++)
	{
		if (strcmp(allmodules[i].modname, modulename))
			return allmodules[i].modid;
	}

	return -1;
}

struct searchquery_ {
	char **searchmods;
	int nsearchmods;
};

/* Pairs of modulenames <-> nodenames, i.e. which nodes provides which module */
SET_TYPE_DEFINE(nodemodset, int, str, DEFAULT_HASHTABLESIZE);
SET_TYPE_IMPLEMENT(nodemodset);

set(nodemodset) nodemods;

/* Builds `nodemods` set */
static void handle_modules(roff_config_module_t *mod, void *info)
{
	int i;
	struct searchquery_ *info_ = (struct searchquery_ *) info;

	for (i = 0; i < info_->nsearchmods; i++)
	{
		if (strcmp(mod->name, info_->searchmods[i]) == 0)
		{
			set_put(nodemods, Moduleid(info_->searchmods[i]))->value = Str(mod->node->name);
			return;
		}
	}
}


static void add_suffix_to_modulename(roff_config_module_t *module, void *data)
{
	char *tmp = strndup(module->name, strlen(module->name));
	size_t newsize = ROFFCONF_MAXNAME + 8;
	module->name = realloc(module->name, newsize * sizeof(char));
	if (module->name == NULL)
	{
		perror("add_suffix_to_modulename():");
		exit(EXIT_FAILURE);
	}
	snprintf(module->name, newsize - 1, 
	        "%s_node%d", tmp, module->node->id);
	free(tmp);
}


void remote_modules_employ(void)
{
	struct searchquery_ info;
	
	nodemods = set_new(nodemodset);

	roff_config_initialize(IGNORE_DISABLED_MODULES, portable_userprog);
	
	info.nsearchmods = roff_config.nuniquemodules;
	info.searchmods = roff_config.uniquemodnames;

	roff_config_iterate(handle_modules, (void *) &info);

	/* Append "_nodeX" suffix to each remote module */
	roff_config_iterate(add_suffix_to_modulename, NULL);
}


void remote_modules_finalize(void)
{
	setelem(nodemodset) e;
	for (e = nodemods->first; e; e = e->next)
	{
		str_free(e->value);
	}
	set_free(nodemods);
	roff_config_finalize();
}

#endif


char *modules_argfor_ompi()
{
	int i;
	str modstr;

#ifdef OMPI_REMOTE_OFFLOADING
	if (!nmodules && !roff_config.nuniquemodules)
#else
	if (!nmodules)
#endif
		return ("");

	modstr = Strnew();  /* String for formating module names */

	for (i = 0; i < nmodules; i++)
		str_printf(modstr, "%s-usemod=%s", i ? " -" : "-", modulenames[i]);

#ifdef OMPI_REMOTE_OFFLOADING
		/* Include remote modules, as well */
		for (i = 0; i < roff_config.nuniquemodules; i++)
		{
			if (!is_substr(str_string(modstr), roff_config.uniquemodnames[i]))
				str_printf(modstr, " --usemod=%s", roff_config.uniquemodnames[i]);
		}
#endif
	return ( str_string(modstr) );
}

#if 0
/* Compile the bundled binaries file (XYZ_bubins.c)
 * 
 * When bundling sources, this should be called instead
 * of kernel_makefiles. When bundling binaries, it should be 
 * called kernel compilation.
 */
char *kernel_bubins_file_compile(char *fname, int nkernels)
{
	int res;
	str command = Strnew();
	char *fnamecopy = strdup(fname), *ext;
	char *objfile_;
	str outfile, objfile;

	if ((ext = strrchr(fnamecopy, '.')) != NULL)
		*ext = 0; /* remove extension */
	
	outfile = Str(fnamecopy);
	objfile = Str(fnamecopy);

	/* <filename>_bubins.{o,c} */
	str_printf(objfile, "_bubins.o");
	str_printf(outfile, "_bubins.c");

	str_printf(command, "%s \"%s\" -c", COMPILER, str_string(outfile));
	if (verbose)
		fprintf(stderr, "====> Compiling kernel bundle\n  [ %s ]\n",
						str_string(command));
	res = sysexec(str_string(command), NULL);
	if (!keep)
		unlink(str_string(outfile));
	if (res)
		_exit(res);

	str_free(command);
	str_free(outfile);
	free(fnamecopy);
	
	objfile_ = strdup(str_string(objfile));

	str_free(objfile);
	return objfile_;
}
#endif

/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                           *
 *        CREATING & EXECUTING MAKEFILES FOR KERNELS         *
 *                                                           *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */


static
void km_substitute(char *var, int maxlen, str s, char *modulename, int kernid)
{
	// FIXME: is goutfile set here?
	if (strncmp("@@OMPI_OUTPUT@@", var, maxlen) == 0)
		str_printf(s, "%s", (user_outfile.head && strlen(user_outfile.head->val) != 0) ?
		                    user_outfile.head->val : "a.out");
	else if (strncmp("@@OMPI_INSTALL_DIR@@", var, maxlen) == 0)
		str_printf(s, InstallPath);
	else if (strncmp("@@OMPI_LIBDIR@@", var, maxlen) == 0)
		str_printf(s, LibDir);
	else if (strncmp("@@OMPI_COMMLIBARG@@", var, maxlen) == 0)
#ifdef OMPI_REMOTE_OFFLOADING
		str_printf(s, "-L%s/default_comm", LibDir);
#else
		str_printf(s, "", LibDir);
#endif
	else if (strncmp("@@OMPI_MODULE@@", var, maxlen) == 0)
		str_printf(s, "%s", modulename);
	else if (strncmp("@@OMPI_CC@@", var, maxlen) == 0)
		str_printf(s, "%s", COMPILER);
	else if (strncmp("@@OMPI_CFLAGS@@", var, maxlen) == 0)
		str_printf(s, "%s", CFLAGS);
	else if (strncmp("@@OMPI_LDFLAGS@@", var, maxlen) == 0)
		str_printf(s, "%s", LDFLAGS);
	else if (strncmp("@@OMPI_KERNELID@@", var, maxlen) == 0)
		str_printf(s, "%02d", kernid);
	else if (strncmp("@@OMPI_KERNELFILE@@", var, maxlen) == 0)
		str_printf(s, "%s", get_basename(current_file));
	else if (strncmp("@@CUDA_OUTPUT_EXT@@", var, maxlen) == 0)
		str_printf(s, "%s", CUDA_KERNEL_EXTENSION);
	else
		str_printf(s, "%*.*s", maxlen, maxlen, var);
}


static
void km_substitute_vars(char *line, str s, char *modulename, int kernid)
{
	char *p;

	for (; isspace(*line); line++)  /* skip spaces */
		if (*line == 0)
			return;
		else
			str_putc(s, *line);

	if (*line == '#')
	{
		str_printf(s, "%s", line);    /* comment line */
		return;
	}

	while (1)
	{
		for (; *line != '@'; line++)
			if (*line == 0)
				return;
			else
				str_putc(s, *line);

		p = line+1;
		if (*p != '@')
			str_putc(s, '@');
		else
		{
			for (p++; *p != '@'; p++)
				if (*p == 0)
				{
					str_printf(s, "%s", line);
					return;
				}
				else
					if (isspace(*p))
					{
						str_printf(s, "%*.*s", p-line, p-line, line);
						break;
					};

			if (*p == '@')
			{
				if (*(p++) == '@')
				{
					km_substitute(line, p-line+1, s, modulename, kernid);
					p++;
				}
			}
		}
		line = p;
	}
}


static
int kernel_makefile_run(char *path, char *modulename, int kernid)
{
	str cmd = Strnew(), sshprefix = Str("");
	int  res, eflag;
	char *currfile = strdup(current_file), *ext, *bname;

#if defined(OMPI_REMOTE_OFFLOADING)
	if (bundleKernels != BUNDLE_SRCS)
	{
		setelem(nodemodset) e;
	
		if ((e = set_get(nodemods, Moduleid(modulename))))
			if ((!is_module_supported(modulename)))
				str_printf(sshprefix, "ssh %s 'cd %s && ", str_string(e->value), cwd);
	}
#endif

	if (strcmp(path, ".") != 0)
	{
		if (*cwd == 0)
		{
			perror("getcwd():");
			return (1);
		}
		if (verbose)
			fprintf(stderr, "  ===> changing to directory %s\n", path);
		if (chdir(path) < 0)
		{
			perror("chdir():");
			return (1);
		}

#if defined(OMPI_REMOTE_OFFLOADING)
		if (bundleKernels != BUNDLE_SRCS)
		{
			setelem(nodemodset) e;
			/* Pass the new path to the ssh prefix */
			str_truncate(sshprefix);
			if ((e = set_get(nodemods, Moduleid(modulename))) && (!is_module_supported(modulename)))
			{
				str_printf(sshprefix, "ssh %s ", str_string(e->value));
				if (path[0] == '/')
					str_printf(sshprefix, "'cd %s && ", path);
				else
					str_printf(sshprefix, "'cd %s/%s && ", cwd, path);
			}
		}
#endif
	}

	ext = strrchr(currfile, '.');
	*ext = 0;

	bname = get_basename(currfile);

	str_printf(cmd, "%smake -f %s-makefile-%s-%02d", 
	                str_string(sshprefix), modulename, bname, kernid);

	if (str_tell(sshprefix))
		str_printf(cmd, "'");

	if (verbose)
		fprintf(stderr, "  ===> %s\n", str_string(cmd));
	res = sysexec(str_string(cmd), &eflag);

	if (strcmp(path, ".") != 0)
	{
		if (verbose)
			fprintf(stderr, "  ===> returning to directory %s\n", cwd);
		if (chdir(cwd) < 0)
		{
			perror("chdir():");
			return (1);
		}
	}

	str_free(cmd);
	
	return (eflag || res);
}


int kernel_makefile_create(char *path, char *modulename, int kernid, 
                           mapmod_t mmod)
{
	static str s = NULL;
	char   *flavor, *ext;
	FILE   *fpin, *fpout;
	char   *currfile = strdup(current_file), *bname, *ext_new;
	str    line, filepath = Strnew();
	int    res;

	if ((ext = strrchr(currfile, '.')) != NULL)
		*ext = 0; /* remove extension */
	else
		ext = currfile + strlen(currfile);

	ext_new = strdup(currfile);

	if (!usecarstats || !mmod)
		str_printf(filepath, "%s/devices/%s/MakeKernel.%s", LibDir, modulename, modulename);
	else
	{
		snprintf(ext, PATHSIZE-strlen(currfile), "_d%02d.c", kernid);

		/* Mapper arguments */
		str_printf(filepath, "%s/%s", cwd, currfile);
		flavor = mapper_select_flavor(mmod, str_string(filepath));
		str_truncate(filepath);
		
		/* Check for actual flavor and form the kernel makefile installation path */
		if (flavor == NULL || strcmp(flavor, "devpart") == 0)
			str_printf(filepath, "%s/devices/%s/MakeKernel.%s", 
			           LibDir, modulename, modulename);
		else
		{
			char *flv = flavor;    /* skip "devpart" from returned flavor name */
			if (strlen(flavor) >= 7)
				flv = (flavor[7] == 0) ? flavor+7 : flavor+8;
			str_printf(filepath, "%s/devices/%s/MakeKernel-%s.%s", 
			           LibDir, modulename, flv, modulename);
		}
		if (verbose)
			fprintf(stderr, "  kernel makefile selected (mapper):\n\t%s\n", str_string(filepath));
	}

	if ((fpin = fopen(str_string(filepath), "r")) == NULL)
	{
		fprintf(stderr, "[***] cannot find 'MakeKernel.%s' recipe for creating "
		                "%s kernels.\n", modulename, modulename);
		str_free(filepath);
		return (1);
	}

	/* filename without ext */
	bname = get_basename(ext_new);

	str_truncate(filepath);
	str_printf(filepath, "%s/%s-makefile-%s-%02d", 
			path, modulename, bname, kernid);

	if ((fpout = fopen(str_string(filepath), "w")) == NULL)
	{
		fprintf(stderr, "[***] cannot generate '%s/%s-makefile-%s-%02d' to create "
		                "kernels for device %s\n.", path, modulename, bname,
						kernid, modulename);
		fclose(fpin);
		str_free(filepath);
		return (1);
	}

	if (s)
		str_truncate(s);
	else
		s = Strnew();

	line = Strnew();
	str_reserve(line, SLEN);
	while (fgets(str_string(line), SLEN, fpin) != NULL)
	{
		km_substitute_vars(str_string(line), s, modulename, kernid);
	}
	fprintf(fpout, "%s", str_string(s));

	fclose(fpin);
	fclose(fpout);

	if (verbose)
		fprintf(stderr, "  ( %s )\n", str_string(filepath));
	res = kernel_makefile_run(path, modulename, kernid);

	if (keep < 2)    /* Remove makefile */
		unlink(str_string(filepath));
	
	str_free(filepath);
	str_free(line);
	
	return (res);
}

	
static 
void strip_ext(char *dst, const char *src)
{
	const char *s;
	char *search_from, *dot;

	snprintf(dst, PATHSIZE, "%s", src);
	s = strrchr(dst, '/');
	search_from = (char *)(s ? s + 1 : dst);

	dot = strrchr(search_from, '.');
	if (dot && dot != search_from)
		*dot = '\0';
}


static void _kernels_print(str s, const char *fname, int nkernels, char *exts[])
{
	int i, j;
	char base[PATHSIZE];

	strip_ext(base, fname);

	for (i = 0; i < nkernels; i++)
		for (j = 0; exts[j] != NULL; j++)
			str_printf(s, "%s_d%02d%s ", base, i, exts[j]); 
}


/* Removes all kernel binaries
 * 
 * TODO: Binary extensions should reside in one place.
 */
void kernel_binaries_remove(const char *fname, int nkernels)
{
	str s = Strnew();
	char *exts[] = { 
		"-proc2.out", "-cuda." CUDA_KERNEL_EXTENSION, 
		"-opencl.out", "-vulkan.spv", 
		NULL 
	};

	_kernels_print(s, fname, nkernels, exts);
	removefiles(str_string(s));

	str_free(s);
}


/* 
 * Removes kernel sources that have a specific suffix/extension 
 *
 * TODO: Sources extensions should reside in one place.
 */
void kernel_sources_remove(const char *fname, int nkernels)
{
	str s = Strnew();
	char *exts[] = { 
		".c", "-cuda.cu", "-opencl.cl", "-vulkan.comp", 
		NULL 
	};

	_kernels_print(s, fname, nkernels, exts);
	removefiles(str_string(s));

	str_free(s);
}


/* Generate makefiles for each requested module @ the given path */
int kernel_makefiles(char *fname, int nkernels)
{
	char     filepath[PATHSIZE], *s;
	mapmod_t mmod;
	int      res = 0;

	if ((s = strrchr(fname, '/')) == NULL)
		strcpy(filepath, ".");
	else  
	{ 
		strncpy((filepath), fname, (s)-fname);
		filepath[(s)-fname] = 0;  
	}

	if (!getcwd(cwd, PATHSIZE))
		*cwd = 0;
	
	/* Kernel makefile jobs
	 *
	 * For each module, we consider each kernel makefile execution
	 * a job, handled by a separate process. When the # of jobs > 1, 
	 * the overall kernel compilation overheads are reduced,
	 * as each job is executed concurrently.
	 * 
	 * The compilation load is statically distributed across processes,
	 * according to the requested number of jobs and the number of
	 * kernels to be compiled.
	 * 
	 * This feature can be leveraged with device options (see ompicc --devopt).
	 * It is particularly useful when dealing with CUDA kernels.
	 */

	sprintf(current_file, "%s\n", fname);
	if ((nmodules > 0) && (nkernels > 0))
	{
		JOB_START(kmmods, reqjobs, nmodules * nkernels)
		{
			JOB_LOOP(kmmods, c)
			{	
				int i = c / nkernels;
				int j = c % nkernels;
				mmod = usecarstats ? mapper_load_module(modulenames[i]) : NULL;

				res = kernel_makefile_create(filepath, modulenames[i], j, mmod);
				if (res != 0)
					break;
				if (mmod)
					mapper_free_module(mmod);
			}
		}
		res |= JOB_FINISH(kmmods, res);
	}


	/* Create kernel makefiles for all remote nodes, as well */
#if defined(OMPI_REMOTE_OFFLOADING)
	if ((roff_config.nuniquemodules > 0) && (nkernels > 0))
	{
		JOB_START(r_kmmods, reqjobs, roff_config.nuniquemodules * nkernels)
		{
			// for (i = 0; i < roff_config.nuniquemodules; i++)
			JOB_LOOP(r_kmmods, c)
			{
				int i = c / nkernels;
				int j = c % nkernels;

				if (is_module_supported(roff_config.uniquemodnames[i])) 
					continue;
					
				res = kernel_makefile_create(filepath, roff_config.uniquemodnames[i], j, NULL);
				if (res != 0)
					break;
			}
		}
		res |= JOB_FINISH(r_kmmods, res);
	}
#endif

	if (!keep)
		kernel_sources_remove(fname, nkernels);

	return (res);
}


/* Generate a single makefile for each requested module @ the given path */
int kernel_makefile_single(char *fname)
{
	char     filepath[PATHSIZE], *s;
	mapmod_t mmod;
	int      res = 0;

	if ((s = strrchr(fname, '/')) == NULL)
		strcpy(filepath, ".");
	else
	{
		strncpy(filepath, fname, s-fname); /* Strip filename from the path */
		filepath[s-fname] = 0;
	}

	if (!getcwd(cwd, PATHSIZE))
		*cwd = 0;

	sprintf(current_file, "%s\n", fname);

	JOB_START(kmfs, reqjobs, nmodules)
	{
		JOB_LOOP(kmfs, i)
		{
			mmod = usecarstats ? mapper_load_module(modulenames[i]) : NULL;
			res = kernel_makefile_create(filepath, modulenames[i], _kernid, NULL);
			if (res != 0)
				break;
			if (mmod)
				mapper_free_module(mmod);
		}
	}
	res |= JOB_FINISH(kmfs, res);
	return (res);
}
