/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdarg.h>
#include <limits.h>
#include "omp.h"
#include "str.h"
#include "config.h"
#include "assorted.h"
#include "ort_prive.h"
#ifdef OMPI_REMOTE_OFFLOADING
	#include "rdev_config.h"
	#include "remotedev/rdev.h"
#endif

#ifdef PORTABLE_BUILD
  char *InstallPath = "./";     /* dummy initialization */
#endif
static volatile ee_lock_t mod_lock; /* Lock for functions in this file */

static ort_device_t *add_device(ort_module_t *module, int id_in_module)
{
	static int prevnode = -1, node_devs = 0;
	ort_device_t *dev = &(ort->ort_devices[ort->num_devices]);

	dev->id           = ort->num_devices++;  /* advance the counter too */
	dev->id_in_module = id_in_module;
	dev->module       = module;
	dev->initialized  = false;
	dev->device_info  = NULL;
	
	if (id_in_module == 0)
		module->first_global_devid = dev->id;

#ifdef OMPI_REMOTE_OFFLOADING
	dev->is_cpudev    = 0;

	if (module)
	{
		/* Store the intranode device ID */
		if (module->remote)
		{
			if (prevnode == module->nodeid)
				node_devs++;
			else
			{
				node_devs = 0;
				prevnode = module->nodeid;
			}
			dev->id_in_node = node_devs;
		}
	}
	else
		dev->is_cpudev = 1;
	
#endif
	return (dev);
}


/* If remote offloading is enabled and I'm a worker, then I
* need to check the configuration and add the appropriate # devices
*/
static int get_num_hostdevs(void)
{
	int i, numdevs = 1;

#ifdef OMPI_REMOTE_OFFLOADING
	if (node_role == ROLE_WORKER)
	{
		rdev_config_node_t *me = &(rdev_config.nodes[rdev_man_get_my_id() - 1]);
		for (i = 0; i < me->num_modules; i++)
		{
			if (IS_CPU_MODULE(me->modules[i].name))
			{
				numdevs = me->modules[i].num_devices;
				break;
			}
		}
	}
#endif

	return numdevs;
}

/* Add and initialize the host "device" */
static void setup_host_moddev(int numdevs)
{
	ort_device_t *d;
	int i = 0, id;

	for (i = 0; i < numdevs; i++)
	{
		d = add_device(hostdev_get_module(), i);
		d->initialized = true;
#ifdef OMPI_REMOTE_OFFLOADING
		d->is_cpudev = true;
#endif
		/* Initialize device lock */
		d->lock = (volatile void *)ort_alloc(sizeof(ee_lock_t));
		ee_init_lock((ee_lock_t *) d->lock, ORT_LOCK_NORMAL);
	}
}

#ifdef OMPI_REMOTE_OFFLOADING

/* Assigns rdev_XXX functions to a remote module */
static bool load_remote_functions(ort_module_t *m)
{
	m->sharedspace      = 0; /* should be retrieved by the remote module */
	m->unified_medaddr  = 0; /* should be retrieved by the remote module */
	
	m->initialize       = rdev_initialize;
	m->finalize         = rdev_finalize;
	m->offload          = rdev_offload;
	m->dev_alloc        = rdev_alloc;
	m->dev_init_alloc_global = rdev_init_alloc_global;
	m->dev_free         = rdev_free;
	m->dev_free_global  = rdev_free_global;
	m->todev            = rdev_todev;
	m->fromdev          = rdev_fromdev;
	m->imed2umed_addr   = rdev_imed2umed_addr;
	m->umed2imed_addr   = rdev_umed2imed_addr;
	m->is_cpumodule     = IS_CPU_MODULE(m->name);
	
	return true;
}


/* This function discovers all remote devices and readjusts
 * allocated memory for the ort modules and devices array. 
 */
static void discover_remote_modules(int *nModules, int *nDevices)
{
	int i, j;
	rdev_config_node_t *node;
	rdev_config_module_t *mod;
	
#ifdef IGNORE_REMOTE_DEVICES_SNAPSHOT
	rdev_config_initialize(DONT_IGNORE_DISABLED_MODULES);
#else
	rdev_config_initialize_from_hex(ompi_remote_devices, ompi_remote_devices_size, DONT_IGNORE_DISABLED_MODULES);
#endif

	(*nDevices) += rdev_config.num_devices;
	(*nModules) += rdev_config.num_modules;
	
	ort->modules = ort_realloc(ort->modules, (*nModules) * sizeof(ort_module_t));
	ort->ort_devices = ort_realloc(ort->ort_devices, (*nDevices) * sizeof(ort_device_t));
	
	ort->num_remote_modules = rdev_config.num_modules;
	ort->num_remote_devices = rdev_config.num_devices;
}


/* This function initializes all discovered remote modules
 */
static void setup_remote_modules()
{
	rdev_config_node_t *node;
	rdev_config_module_t *mod;
	int i = 0, j;
	int index = ort->num_local_modules;
	
	for (i = 0; i < rdev_config.num_nodes; i++)
	{
		node = &(rdev_config.nodes[i]);
		for (j = 0; j < node->num_modules; j++, index++)
		{
			mod = &(node->modules[j]);
			ort->modules[index].name = strdup(mod->name);
			ort->modules[index].rdev_name = smalloc(128 * sizeof(char));
			snprintf(ort->modules[index].rdev_name, 127, "%s_node%d", mod->name, i);
			
			ort->modules[index].handle = NULL; 
			ort->modules[index].nodeid = i + 1;
			ort->modules[index].node_name = strdup(node->name);
			ort->modules[index].initialized = true;
			ort->modules[index].initialized_successful = 
				load_remote_functions(&(ort->modules[index]));
			ort->modules[index].number_of_devices = mod->num_devices;
			ort->modules[index].remote = true;
		}
	}
}

#endif /* OMPI_REMOTE_OFFLOADING */


/**
 * Get the device descriptor given a device id
 * @param devid the device id
 * @return      the device descriptor (ort_device_t)
 */
ort_device_t *ort_get_device(int devid)
{
	if (devid == AUTODEV_ID)  /* Use the default device */
		devid = omp_get_default_device();

	if (devid >= 0 && devid < ort->num_devices)
	{
		ee_set_lock((ee_lock_t *) &mod_lock);
		if (!ort->ort_devices[devid].initialized)
			ort_init_device(devid);
		ee_unset_lock((ee_lock_t *) &mod_lock);

		/* If the device has failed to initialize correctly fall back to host */
		if (ort->ort_devices[devid].device_info == NULL)
			devid = HOSTDEV_ID;
	}
 	 else
		devid = ort_illegal_device("device", devid);

	return &(ort->ort_devices[devid]);
}


/**
 * Gives a one-time warning on illegal device id and returns a fallback id.
 * Used for uniform handling of the situation across the runtime.
 * @param reason a message describing the failing device
 * @param devid  the failing device id
 * @return       the fallback device id
 */
int ort_illegal_device(char *reason, int devid)
{
	static bool warned = false;
	
	if (ort->icvs.targetoffload == OFFLOAD_MANDATORY)
		ort_error(1, "Invalid %s value (%d); mandatory exiting.\n", reason, devid);

	if (!warned)
	{
		ee_set_lock((ee_lock_t *) &mod_lock);
		if (!warned)
		{
			ort_warning("Invalid %s value (%d); falling back to device 0 (host).\n",
			            reason, devid);
			warned = true; FENCE;
		}
		ee_unset_lock((ee_lock_t *) &mod_lock);
	}
	return (HOSTDEV_ID);
}


#ifdef HAVE_DLOPEN

#include <dlfcn.h>


static void *open_module(char *name, int type)
{
	void *handle;
	str tmp = Strnew();

	/* Check current folder */
	str_printf(tmp, "./%s.so", name);
	handle = dlopen(str_string(tmp), type);
	if (handle)
	{
		str_free(tmp);
		return handle;
	}

	/* Check ompi's library folder */
	str_truncate(tmp);
#ifdef PORTABLE_BUILD
	/* Use InstallPath as passed by ompicc */
	str_printf(tmp, "%slib/ompi/devices/%s/hostpart.so", InstallPath, name);
#else
	/* Use hard-coded LibDir as provided by configure */
	str_printf(tmp, "%s/devices/%s/hostpart.so", LibDir, name);
#endif
	handle = dlopen(str_string(tmp), type);
	if (handle)
	{
		str_free(tmp);
		return handle;
	}

	/* Finally check system's library folder */
	str_truncate(tmp);
	str_printf(tmp, "%s.so", name);
	handle = dlopen(str_string(tmp), type);

	str_free(tmp);
	return handle;
}


static inline void *load_symbol(void *module, char *moduleName, char *sym, int show_warn)
{
	char *error;
	void *temp;

	temp = dlsym(module, sym);

	if (((error = dlerror()) != NULL) && show_warn)
	{
		ort_warning("module: %s, symbol: %s, %s\n", moduleName, sym, error);
		return NULL;
	}

	return temp;
}


void ort_discover_modules(int nModules, char **modnames)
{
	int  i = 0, j, modidx = 0, nhostdevs = get_num_hostdevs(),
	     nDevices = nhostdevs;  /* we also count the hostdevs in nDevices */
	int  (*get_num_devices)(void);
	
#ifdef PORTABLE_BUILD
	/* The installation path has sneaked in as 1st argument (not counted) */
	InstallPath = modnames[modidx++];
#endif

	ort->modules = ort_alloc(nModules * sizeof(ort_module_t));
	for (i = 0; i < nModules; i++)
	{
		ort->modules[i].name = modnames[modidx++];

#ifdef OMPI_REMOTE_OFFLOADING
		ort->modules[i].rdev_name = ort->modules[i].node_name = NULL; /* not a remote device */
		ort->modules[i].nodeid = -1; /* ditto */
		ort->modules[i].remote = false;
#endif
		ort->modules[i].handle = NULL;
		
#ifdef OMPI_REMOTE_OFFLOADING
		/* If it's a remote module, do not insert it to the module list.
		 * This is a temporary fix; _ompi should not include remote modules to the
		 * variadic arguments of `_ort_init', only local ones. For this reason, we should 
		 * somehow separate remote modules from local ones, e.g. by passing them through
		 * a new _ompi argument, instead of --usemod.
		 */
		if ((node_role == ROLE_PRIMARY) && (!contains_word(MODULES_CONFIG, ort->modules[i].name)))
			goto INITIALIZATION_FAIL;
#endif

		void *module = open_module(ort->modules[i].name, RTLD_LAZY);
		if (!module)
			ort_warning("Failed to open module \"%s\"\n", ort->modules[i].name);
		else
		{
			/* Clear dlerror */
			dlerror();

			get_num_devices = load_symbol(module, ort->modules[i].name,
			                                "hm_get_num_devices", 1);
			if (get_num_devices != NULL)
			{
				ort->modules[i].initialized = false;
				ort->modules[i].number_of_devices = get_num_devices();
				nDevices += ort->modules[i].number_of_devices;
				dlclose(module);
				continue;
			}
			dlclose(module);
		}

INITIALIZATION_FAIL:
		/* If we reached here we failed to get the number of devices */
		ort->modules[i].initialized = true;
		ort->modules[i].initialized_successful = false;
		ort->modules[i].number_of_devices = 0;
	}

	ort->ort_devices = ort_alloc(nDevices * sizeof(ort_device_t));

	/* The host "module" and "device" 0; call it here to get id 0 */
	setup_host_moddev(nhostdevs);
	
	ort->num_local_modules = nModules;
	ort->num_local_devices = nDevices;
	
#ifdef OMPI_REMOTE_OFFLOADING
	if (!ort->embedmode && node_role == ROLE_PRIMARY)
	{
		discover_remote_modules(&nModules, &nDevices);
		setup_remote_modules();
	}
#endif
	
	for (i = 0; i < nModules; i++)
	{
		for (j = 0; j < ort->modules[i].number_of_devices; j++)
			add_device(&(ort->modules[i]), j);
	}

	ort->num_modules = nModules;

	ee_init_lock((ee_lock_t *) &mod_lock, ORT_LOCK_NORMAL);
}


static bool load_functions(ort_module_t *m)
{
	int *x;
	void (*register_ee_calls)(void (*)(omp_lock_t *, int), void (*)(omp_lock_t *),
	                          void (*)(omp_lock_t *), int  (*hyield_in)(void));
	void (*register_str_printf)(int  (*str_printf_in)(str, char *, ...));
	void (*set_module_name)(char*);

	if ((x = load_symbol(m->handle, m->name, "hm_sharedspace", 1)) == NULL)
		return false;
	m->sharedspace = *x;
	
	if ((x = load_symbol(m->handle, m->name, "hm_unified_medaddr", 1)) == NULL)
		return false;
	m->unified_medaddr = *x;
		
	if ((m->initialize = load_symbol(m->handle, m->name,
	                                   "hm_initialize", 1)) == NULL)
		return false;
	if ((m->finalize = load_symbol(m->handle, m->name, "hm_finalize", 1)) == NULL)
		return false;
	if ((m->offload = load_symbol(m->handle, m->name, "hm_offload", 1)) == NULL)
		return false;
	if ((m->dev_alloc = load_symbol(m->handle, m->name, "hm_dev_alloc", 1)) == NULL)
		return false;
	if ((m->dev_init_alloc_global = load_symbol(m->handle, m->name, "hm_dev_init_alloc_global", 1)) == NULL)
		return false;
	if ((m->dev_free = load_symbol(m->handle, m->name, "hm_dev_free", 1)) == NULL)
		return false;
	if ((m->dev_free_global = load_symbol(m->handle, m->name, "hm_dev_free_global", 1)) == NULL)
		return false;
	if ((m->todev = load_symbol(m->handle, m->name, "hm_todev", 1)) == NULL)
		return false;
	if ((m->fromdev = load_symbol(m->handle, m->name, "hm_fromdev", 1)) == NULL)
		return false;
	if ((m->imed2umed_addr = load_symbol(m->handle, m->name,
	                                        "hm_imed2umed_addr", 1)) == NULL)
		return false;
	if ((m->umed2imed_addr = load_symbol(m->handle, m->name,
	                                        "hm_umed2imed_addr", 1)) == NULL)
		return false;

	if ((set_module_name = load_symbol(m->handle, m->name,
	                                   "hm_set_module_name", 1)) == NULL)
		return false;
	set_module_name(m->name);

	if ((register_ee_calls = load_symbol(m->handle, m->name,
	                                       "hm_register_ee_calls", 1)) == NULL)
		return false;
	register_ee_calls(ort_prepare_omp_lock, omp_set_lock, omp_unset_lock,
	                  sched_yield);
	                  
	if ((register_str_printf = load_symbol(m->handle, m->name,
	                                       "hm_register_str_printf", 1)) == NULL)
		return false;
	register_str_printf(str_printf);

	return true;
}


static void initialize_module(ort_module_t *m)
{
	m->initialized = true;
	m->initialized_successful = false;

	m->handle = open_module(m->name, RTLD_NOW);
	if (!m->handle)
	{
		ort_warning("Failed to initialize module \"%s\"\n", m->name);
		return;
	}

	/* Clear dlerror */
	dlerror();

	m->initialized_successful = load_functions(m);
	m->is_cpumodule    = false;

	if (!m->initialized_successful)
	{
		ort_warning("Failed to initialize module \"%s\" functions\n", m->name);
		dlclose(m->handle);
	}
}


void ort_init_device(int device_id)
{
	ort_device_t *d = &(ort->ort_devices[device_id]);
	int devid;
	
	if (d->initialized)
		return;

	d->initialized = true;

	/* Check if module is initialized */
	if (!d->module->initialized)
		initialize_module(d->module);

	if (!d->module->initialized_successful)
		d->device_info = NULL;
	else
	{
		devid = d->id_in_module;
		
#ifdef OMPI_REMOTE_OFFLOADING
		/* If the device belongs to a remote module, encode in-module device ID
		 * and node ID in `devid`
		 */
		if (d->module->remote)
		{
			if ((d->module->nodeid <= USHRT_MAX) && (d->id_in_node <= USHRT_MAX)
				&& (d->module->nodeid >= 0) && (d->id_in_node >= 0))
			{
				unsigned int enc_devid = _uint_encode2(d->id_in_node, d->module->nodeid);
				if (enc_devid <= INT_MAX) 
					devid = (int) enc_devid;
				else
					goto ENCDEVIDWARNING;
			}
			else
			{
				ENCDEVIDWARNING:
				fprintf(stderr, "[ORT warning] either devid or nodeid are off range;"
				                "using default devid.");
			}
		}
#endif /* OMPI_REMOTE_OFFLOADING */

		/* Call initialize function for device */
		d->device_info = MODULE_CALL(d->module, initialize, (devid, &(ort->icvs),
		                             ort->argc, ort->argv));

		/* Initialize device lock */
		d->lock = (volatile void *)ort_alloc(sizeof(ee_lock_t));
		ee_init_lock((ee_lock_t *) d->lock, ORT_LOCK_NORMAL);
		SFENCE; /* 100% initialized, before been assigned to "lock" */
	}
}


void ort_finalize_devices(void)
{
	int i;

	for (i = 0; i < ort->num_devices; i++)
		if (ort->ort_devices[i].initialized)
		{
			if (ort->ort_devices[i].device_info)
				MODULE_CALL(ort->ort_devices[i].module, 
				            finalize, (ort->ort_devices[i].device_info)
				);

			/* Deinitialize and free device lock */
			ee_destroy_lock((ee_lock_t *) ort->ort_devices[i].lock);
			free((ee_lock_t *) ort->ort_devices[i].lock);
		};

	for (i = 0; i < ort->num_modules; i++)
	{
#ifdef OMPI_REMOTE_OFFLOADING
		if (ort->modules[i].remote) 
		{
			free(ort->modules[i].name);	
			free(ort->modules[i].rdev_name);
			free(ort->modules[i].node_name);
		}
		else
#endif
		{
			if (ort->modules[i].initialized && ort->modules[i].initialized_successful)
				dlclose(ort->modules[i].handle);
		}
	}

	ort_kernfunc_cleanup();
}


/**
 * @brief Given a module, return the total number of its devices 
 * 
 * @param module_name The module name
 * @return the number of its devices
 */
int ompx_get_module_num_devices(char *module_name)
{
	int i, total_devs = 0;

	for (i = 0; i < ort->num_modules; i++)
		if (!strcmp(module_name, ort->modules[i].name))
			total_devs += ort->modules[i].number_of_devices;

	return total_devs;
}


/**
 * @brief Returns the usable ID of a device of a specified index within a module
 * 
 * @param module_name The module name
 * @param index       The index (0 .. #devices-1)
 * @return 0 if the module is not found, otherwise a usable device ID 
 */
int ompx_get_module_device(char *module_name, int index)
{
	int i, total_devices = ompx_get_module_num_devices(module_name);

	if (index < 0 || index >= total_devices)
	{
		ort_warning("Request for invalid device %d in module %s "
		            "(expected a value between 0 and %d); returning device 0 (host).\n",
					index, module_name, total_devices - 1);
		return 0;
	}

	if (total_devices > 0)
	{
		for (i = 0; i < ort->num_modules; i++)
			if (!strcmp(module_name, ort->modules[i].name))
			{
				if (index < ort->modules[i].number_of_devices)
					return ort->modules[i].first_global_devid + index;
				index -= ort->modules[i].number_of_devices;
			}
	}
	
	ort_warning("Invalid module name %s; returning device 0 (host).\n",
			    module_name);
	return 0;
}


/**
 * @brief Given a device ID, return the name of its module
 * 
 * @param device_id The device ID 
 * @return the module name of the given device ID
 */
char *ompx_get_device_module_name(int device_id)
{
	if (device_id == 0) 
		return "host";
	if (device_id < 0 || device_id >= ort->num_devices)
		return "invalid device";
	return ort->ort_devices[device_id].module->name;
}


/**
 * @brief Finds a module at a specific node and returns the number of its
 *        devices and (optionally) its first usable device ID
 * 
 * @param node_name    The hostname of the node
 * @param module_name  The module to search for in the given node
 * @param first_dev_id (ret) The first usable device ID of the module
 * @return > 0 if module is found, 0 if not and -1 if one of the arguments are NULL    
 */
int ompx_get_module_node_info(char *node_name, char *module_name, int *first_dev_id)
{
	int i, mod_name_size = strlen(module_name);

	/* Invalid arguments */
	if ((node_name == NULL) || (module_name == NULL))
		return -1;

	/* Search locally for the given module */
	if (!strcmp(node_name, "host"))
	{
		for (i = 0; i < ort->num_local_modules; i++)
			if (!strcmp(module_name, ort->modules[i].name))
			{
				if (first_dev_id != NULL)
					*first_dev_id = ort->modules[i].first_global_devid;
				return ort->modules[i].number_of_devices;
			}
	}
#ifdef OMPI_REMOTE_OFFLOADING
	/* Search in a particular node */
	else
	{
		for (i = ort->num_local_modules; i < ort->num_modules; i++)
			if ((!strcmp(module_name, ort->modules[i].name)) && 
			   (!strcmp(node_name, ort->modules[i].node_name)))
			{
				if (first_dev_id != NULL)
					*first_dev_id = ort->modules[i].first_global_devid;
				return ort->modules[i].number_of_devices;
			}
	}
#endif

	ort_warning("Device type %s not available on node %s; returning device 0 (host).\n",
			   module_name, node_name);
	return 0;
}

#ifdef OMPI_REMOTE_OFFLOADING
// These are currently available only internally
int ompx_devid_to_node_id(int device_id)
{
	return ort->ort_devices[device_id].module->nodeid;
}
int ompx_get_node_first_devid(int node_id)
{
	if (node_id < 1)
		return 0;
	return rdev_config.nodes[node_id-1].first_remote_devid+ort->num_local_devices;
}
#endif

#else   /* not HAVE_DLOPEN */


void ort_discover_modules(int nmods, char **modnames)
{
	ort->ort_devices = ort_alloc(1 * sizeof(ort_device_t));
	add_device(NULL, 0)->initialized = true;   /* Host */
	ee_init_lock((ee_lock_t *) &mod_lock, ORT_LOCK_NORMAL);
}

void ort_init_device(int device_id) {}

void ort_finalize_devices() {}


#endif  /* HAVE_DLOPEN */
