/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

// #define DBGPRN_FORCE 
// #define DBGPRN_BLOCK
#define DBGPRN_FILTER DBG_ROFF_PRIMARY

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <ctype.h>
#include <string.h>
#include <unistd.h>
#include <assert.h>
#include "config.h"
#include "ort.h"
#include "ort_prive.h"
#include "assorted.h"

#ifdef OMPI_REMOTE_OFFLOADING
#include "roff_config.h"
#include "remote/workercmds.h"
#include "remote/memory.h"
#include "remote/node_manager.h"
#include "remote/roff.h"
#include "remote/roff_prive.h"

#define DUMMY_ARG  -1
#define LAST_ARG   -2
#define NO_TAG     0

#define ROFF_LOCK \
	int _cur_level = __MYCB->level; \
	if (_cur_level != 0) \
		ee_set_lock((ee_lock_t*) &(rofflocks[dev->global_device_id - ort->num_local_devices]))\

#define ROFF_UNLOCK \
	if (_cur_level != 0) \
		ee_unset_lock((ee_lock_t*) &(rofflocks[dev->global_device_id - ort->num_local_devices]))


roff_devinfo_t *devinfotab;           /* Array containing (node_id, device_id, devinfo) entries */
int devinfotab_size = 0;
static volatile ee_lock_t *rofflocks; /* MPI calls must be serialized per device */
static char *nodes_array;         /* Running nodes array */

static
void exit_error(int exitcode, char *format, ...)
{
	va_list ap;
	va_start(ap, format);
	fprintf(stderr, "[remote] ");
	vfprintf(stderr, format, ap);
	va_end(ap);
	exit(exitcode);
}


/* We use this function to send a bunch of ints with one MPI Send call and
 * improve performance.
 */
void send_to_node(int node_id, int tag, ...)
{
    int i;
    roff_datatype_t value, array[CMD_MAX_NUM_ARGS] = { DUMMY_ARG };

    va_list args;
    va_start(args, tag);

    value = va_arg(args, roff_datatype_t);
	for (i = 0; i < CMD_MAX_NUM_ARGS && value != LAST_ARG; i++)
	{
        array[i] = value;
        value = va_arg(args, roff_datatype_t);
    }

    va_end(args);

	Comm_Send(comminfo, node_id, COMM_UNSIGNED_LONG_LONG, array, CMD_MAX_NUM_ARGS, 
	          tag, NULL);
}


static
void send_offload_arguments(int node_id, int device_id, char *kernel_filename_prefix, int num_teams,
                            int num_threads, int thread_limit, 
                            roff_datatype_t teamdims, roff_datatype_t thrdims, 
                            int *num_args, void **args)
{
	int kernlen = strlen(kernel_filename_prefix) + 1; /* filename length */
	int total_nargs = num_args[0] + num_args[1] + num_args[2];
	
	/* Sizes */
	send_to_node(node_id, device_id, num_teams, num_threads, thread_limit, teamdims, thrdims, 
	             LAST_ARG);
	
	/* Array holding the number of arguments */
	Comm_Send(comminfo, node_id, COMM_INT, num_args, NUM_ARGS_SIZE, device_id, NULL);
	
	/* Kernel filename (length + str) */
	_Comm_Send_1int(&kernlen, node_id, device_id);
	Comm_Send(comminfo, node_id, COMM_CHAR, kernel_filename_prefix, kernlen, device_id, NULL);
	
	/* Kernel arguments */
	if (total_nargs > 0)
		_Comm_Send_bytes(args, total_nargs * sizeof(void *), node_id, device_id);
}


static
bool is_cpu_device(roff_devinfo_t *dev)
{
	return IS_CPU_MODULE(dev->modulename);
}


static
roff_devinfo_t *devinfotab_get(int node_id, int device_id_in_node)
{
	int i;
	for (i = 0; i < devinfotab_size; i++)
		if ((devinfotab[i].device_id_in_node == device_id_in_node) 
		 && (devinfotab[i].node_id == node_id))
			return &(devinfotab[i]);
	return (roff_devinfo_t*) NULL;
}


roff_devinfo_t *devinfotab_get_from_gldevid(int global_device_id)
{
	int i;
	for (i = 0; i < devinfotab_size; i++)
		if (devinfotab[i].global_device_id == global_device_id)
			return &(devinfotab[i]);
	return (roff_devinfo_t*) NULL;
}


/** 
 * Initializes the module
 * 
 * @param modname the name of the module
 * @param global_id_of_first_device the global ID of the first device
 * @param init_lock_in pointer to the function used for initializing a lock.
 *                     It's parameters are the address of a "void *" variable
 *                     and one of the "ORT_LOCK_*" defines denoting the type of
 *                     the lock
 * @param lock_in      pointer to the function used for acquiring a lock
 * @param unlock_in    pointer to the function used for releasing a lock
 * @param hyield_in    pointer to the function used for thread yield
 * 
 * @return             the number of available devices on success, 0 on failure
 */
int roff_initialize(char *modname, int global_id_of_first_device,
                            void (*init_lock_in_ignore)(void **lock, int type),
                            void (*lock_in_ignore)(void **lock),
                            void (*unlock_in_ignore)(void **lock),
                            int  (*hyield_in_ignore)(void), int *argc, char ***argv)
{
	static int primary_node_initialized = 0;
	int i, num_nodes;
	
	if (primary_node_initialized) return 1;
		
	primary_node_initialized = 1;

	DBGPRN((stderr, "[remote primary] >>> roff_initialize\n"));
	
	/* (1) Initialize devinfo table */
	devinfotab_size = ort->num_remote_devices;
	devinfotab = (roff_devinfo_t*) smalloc(devinfotab_size * sizeof(roff_devinfo_t));

	/* (2) Spawn/wake up node processes */
	if (roff_man_create_workers() == 0)
	{
		fprintf(stderr, "[ORT warning]: Failed to initialize remote offloading.\n");
		return 0;
	}
	
	num_nodes = roff_man_get_num_nodes() + 1; /* + myself */
	
	/* (3) Initialize locks */
	rofflocks = smalloc(ort->num_remote_devices * sizeof(ee_lock_t));
	for (i = 0; i < ort->num_remote_devices; ++i)
		ee_init_lock((ee_lock_t*) &rofflocks[i], ORT_LOCK_NORMAL);
	
	/* (4) Initialize node flags array */
	nodes_array = smalloc(num_nodes);
	memset(nodes_array, 1, num_nodes);
	
	/* (5) Initialize global vars and send an initialization command to all nodes */
	for (i = ort->num_local_devices; i < ort->num_devices; i++)
		roff_alloctab_init_global_vars(i);

#ifndef ROFF_WORKER_MULTITHREADING	
	/* Node initialization is serialized; no mutex need here */

	#ifndef ROFF_MULTIPLE_WORKERS
		for (i = 0; i < num_nodes - 1; i++)
			send_to_node(i + 1, NO_TAG, CMD(initialize), LAST_ARG);
	#else
		for (i = 0; i < num_nodes - 1; i++)
		{
			int w, numworkers = (ROFF_NUM_NODE_WORKER_PROCS == -1) ? 
			                    roff_config.nodes[i].total_num_devices 
								: ROFF_NUM_NODE_WORKER_PROCS;

			for (w = 0; w < numworkers; w++)
				send_to_node((i * numworkers) + w + 1, NO_TAG, CMD(initialize), LAST_ARG);
		}
	#endif
#endif

	return 1;
}


/** 
 * Finalizes the module
 */
void roff_finalize(void)
{
	int i;
	static int roff_finalized = 0;

	DBGPRN((stderr, "[remote primary] >>> roff_finalize\n"));

	if (roff_finalized) return;

	for (i = 0; i < devinfotab_size; i++)
		roff_dev_end(&(devinfotab[i]));

	/* Shutdown all remaining active nodes */
#ifndef ROFF_MULTIPLE_WORKERS
	for (i = 0; i < roff_man_get_num_nodes(); ++i)
	{
		if (nodes_array[i+1])
		{
			send_to_node(i+1, NO_TAG, CMD(shutdown), LAST_ARG);
			nodes_array[i+1] = 0;
		}
	}
#else
	for (i = 0; i < roff_man_get_num_nodes(); ++i)
	{
		if (nodes_array[i+1])
		{
			int w, numworkers = (ROFF_NUM_NODE_WORKER_PROCS == -1) ? 
								roff_config.nodes[i].total_num_devices 
								: ROFF_NUM_NODE_WORKER_PROCS;

			for (w = 0; w < numworkers; w++)
				send_to_node((i * numworkers) + w + 1, NO_TAG, CMD(shutdown), LAST_ARG);

			nodes_array[i+1] = 0;
		}
	}
#endif

	for (i = 0; i < ort->num_remote_devices; i++)
		ee_destroy_lock((ee_lock_t*) &(rofflocks[i]));
	free((void*) rofflocks);

	roff_alloctab_free_all();
	roff_man_finalize();
	Comm_Finalize(comminfo);
	free(devinfotab);
	free(nodes_array);
	roff_finalized = 1;
}


/**
 * Initializes a remote device
 *
 * @param dev_num     the (local) id of the device to initialize
 *                    (0 <= dev_num < hm_get_num_devices())
 * @param ort_icv     Pointer to struct with
 *                    initial values for the device ICVs.
 * @param sharedspace (ret) set to true if the device address space 
 *                    is identical to host (default: false)
 * @param argc        Pointer to main function's argc.
 * @param argv        Pointer to main function's argv.
 *
 * @return device_info: arbitrary pointer that will be passed back in
 *         following calls (see below).
 *         Return NULL only if it failed to initialize.
 */
void *roff_dev_init(int dev_num, ort_icvs_t *ort_icv, int *sharedspace)
{
	unsigned short device_id_in_node, node_id;
	roff_devinfo_t *dev;
	
	if (sharedspace)
		*sharedspace = 0;
	_uint_decode2(dev_num, &device_id_in_node, &node_id);

	DBGPRN((stderr, "[remote primary] >>> roff_dev_init (node_id: %d, device_id: %d)\n", 
	                 node_id, device_id_in_node));

	/* At this point we pass a pointer to the table entry, and not the device info 
	 * itself, as the actual device info does not live in our address space.
	 */
	dev = devinfotab_get((int) node_id, (int) device_id_in_node);
	if (dev == NULL)
	{
		fprintf(stderr, "[remote] failed to retrieve device info for device_id=%d.\n", 
		                (int) device_id_in_node);
		return NULL;
	}

	dev->sharedspace = 0;
	dev->status = DEVICE_INITIALIZED;

	return (void*) dev;
}


/**
 * Finalizes a remote device.
 *
 * @param device_info the device to finalize
 */
void roff_dev_end(void *device_info)
{
	int finalize_result;
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;

	if (dev->status != DEVICE_INITIALIZED) return;

	DBGPRN((stderr, "[remote primary] >>> roff_dev_end (node_id: %d, device_id: %d)\n", 
	                 dev->node_id, dev->device_id_in_node));
	
	/* Finalize the device */
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(finalize), LAST_ARG);

	/* Wait for finalization to finish */            
	_Comm_Recv_1int(&finalize_result, dev->node_id, dev->device_id_in_node);

	if (finalize_result != FINALIZE_OK)
		exit_error(FINALIZE_OK, "finalization failed for device %d at node %d\n", 
		                        dev->device_id_in_node, dev->node_id);
	
	dev->status = DEVICE_UNINITIALIZED;

#if 0
	nodes_array[dev->node_id] = 0;
#endif
}


/**
 * Offloads and executes a kernel file.
 *
 * @param device_info         the device
 * @param host_func pointer   to offload function on host address space
 * @param devdata pointer     to a struct containing kernel variables
 * @param decldata pointer    to a struct containing globally declared variables
 * @param kernel_filename_prefix filename of the kernel (without the suffix)
 * @param num_teams           num_teams clause from "teams" construct
 * @param num_threads         form clause in combined parallel constructs
 * @param thread_limit        thread_limit clause from "teams" construct
 * @param teamdims            an unsigned long long that contains the
 *                            dimensions of the launched league, encoded as follows:
 *                            x: bits 0-20, y: bits 21-41, z: bits 42-62 
 * @param thrdims             an unsigned long long that contains the
 *                            dimensions of each thread team, encoded as follows:
 *                            x: bits 0-20, y: bits 21-41, z: bits 42-62 
 * @param num_args            an array that contains the number of declare variables, 
 *                            firstprivates and mapped variables
 * @param args                the addresses of all target data and target
 *                            declare variables
 *
 * NOTE: `teamdims' and `thrdims' can be decoded using the _ull_decode3 function.
 *
 * NOTES: 
 * 1. `teamdims' and `thrdims' can be decoded using the _ull_decode3 function.
 * 2. In MPI module, decldata variables are dealt with in dev_alloc
 * function, not here.
 */
int roff_offload(void *device_info, void *(*host_func)(void *), void *devdata,
                  void *decldata, char *kernel_filename_prefix, int num_teams,
                  int num_threads, int thread_limit, 
                  unsigned long long teamdims, unsigned long long thrdims, 
                  int *num_args, void **args)
{
	int kernel_id, exec_result;
	roff_datatype_t devdata_maddr = DUMMY_ARG, decldata_maddr = DUMMY_ARG;
	size_t devdata_len = 0, decldata_len = 0;
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;

	DBGPRN((stderr, "[remote primary] >>> roff_offload (node_id: %d, device_id: %d)\n", 
					dev->node_id, dev->device_id_in_node));
	
	kernel_id = ort_kernfunc_findbyname(kernel_filename_prefix);
	if (kernel_id < 0)
	{
		fprintf(stderr, "could not locate kernel function %s !?\n", kernel_filename_prefix);
		return 1;
	}

	ROFF_LOCK;

	/* devdata handling */
	if (devdata)
	{
		devdata_len = ort_mapped_get_size(devdata);
		devdata_maddr = roff_alloctab_register(dev->alloc_table, MAPPED_DEVDATA);
	}

	/* #declare data handling */
	if (decldata)
	{
		decldata_len = ort_mapped_get_size(decldata);
		decldata_maddr = roff_alloctab_register(dev->alloc_table, MAPPED_DECLDATA);
	}
	
	/* Offload */
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(offload), kernel_id,
	             devdata_maddr, devdata_len, decldata_maddr, decldata_len, 
	             LAST_ARG);
		             
	/* Send the devdata struct */
	if (devdata)
		_Comm_Send_bytes(devdata, devdata_len, dev->node_id, dev->device_id_in_node);
		
	/* Non-CPU devices do not call the kernel function directly,
	 * thus we need to transfer all offloading arguments; some of them
	 * can be sent using send_to_node call, the others are sent manually.
	 */
	if (!is_cpu_device(dev))
	{
		/* Then send the decldata struct itself */
		if (decldata)
			_Comm_Send_bytes(decldata, decldata_len, dev->node_id, dev->device_id_in_node);
			
		/* Send all offload arguments */
		send_offload_arguments(dev->node_id, dev->device_id_in_node, kernel_filename_prefix, num_teams, 
		                       num_threads, thread_limit, teamdims, 
		                       thrdims, num_args, args);
	}
		
	/* Wait till the device finishes */
	_Comm_Recv_1int(&exec_result, dev->node_id, dev->device_id_in_node);
	ROFF_UNLOCK;

	if (exec_result != OFFLOAD_OK)
	{
		fprintf(stderr, "kernel execution failed at the end %d\n", kernel_id);
		return 1;
	}
	
	if (devdata)
	{
		ROFF_LOCK;
		roff_alloctab_unregister(dev->alloc_table, devdata_maddr, MAPPED_DEVDATA);
		ROFF_UNLOCK;
	}
	
	if (decldata)
	{
		ROFF_LOCK;
		roff_alloctab_unregister(dev->alloc_table, decldata_maddr, MAPPED_DECLDATA);
		ROFF_UNLOCK;
	}
	return 0;  /* All OK */
}


/**
 * Allocates memory on the device
 *
 * @param device_info the device
 * @param size        the number of bytes to allocate
 * @param map_memory  used in OpenCL, when set to 1 additionaly to the memory
 *                    allocation in shared virtual address space, the memory
 *                    is mapped with read/write permissions so the host cpu
 *                    can utilize it.
 * @param hostaddr    used in MPI to allocate #declare target link variables;
 *                    when we encounter such variables instead of
 *                    allocating new space, we should return the mediary
 *                    address of the original address.
 * @param map_type    the mapping type that triggered this allocation (to/from/tofrom/alloc)
 * @return hostaddr a pointer to the allocated space (mediary address)
 */
void *roff_dev_alloc(void *device_info, size_t size, int map_memory, void *hostaddr, int map_type)
{
	roff_datatype_t maddr;
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;

	DBGPRN((stderr, "[remote primary] >>> roff_dev_alloc (node_id: %d, device_id: %d)\n", 
					dev->node_id, dev->device_id_in_node));
	
	if (map_memory) /* Only needed for devdata/decldata structs */
		return ort_mapped_alloc(size);
	
	ROFF_LOCK;
	maddr = roff_alloctab_register(dev->alloc_table, NOT_MAPPED);
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(dev_alloc), maddr, size, 
	             map_memory, LAST_ARG);
	ROFF_UNLOCK;
	
	return (void *) maddr;
}


/**
 * Allocates memory on the device for a global variable
 *
 * @param device_info the device
 * @param global_id   the ID of the global variable
 * @param size        the number of bytes to allocate
 * @param hostaddr    used in MPI to allocate #declare target link variables;
 *                    when we encounter such variables instead of
 *                    allocating new space, we should return the mediary
 *                    address of the original address.
 * @return hostaddr a pointer to the allocated space (mediary address)
 */
void *roff_dev_init_alloc_global(void *device_info, void *initfrom, size_t size, int global_id,
                                 void *hostaddr)
{
	roff_datatype_t maddr = (roff_datatype_t) global_id;
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;

	DBGPRN((stderr, "[remote primary] >>> roff_dev_init_alloc_global (node_id: %d, device_id: %d)\n", 
					dev->node_id, dev->device_id_in_node));

	ROFF_LOCK;
	roff_alloctab_add_global(dev->alloc_table, global_id, hostaddr);
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(dev_init_alloc_global), maddr, 
	            (roff_datatype_t) initfrom, size, global_id, LAST_ARG);
	ROFF_UNLOCK;

	return (void *) maddr;
}


/**
 * Frees data allocated with roff_alloc
 *
 * @param device_info  the device
 * @param maddr        pointer to the memory that will be released
 * @param unmap_memory used in OpenCL, when set to 1 prior to the memory
 *                     deallocation, the memory is unmapped.
 */
void roff_dev_free(void *device_info, void *maddr, int unmap_memory)
{
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;

	DBGPRN((stderr, "[remote primary] >>> roff_dev_free (node_id: %d, device_id: %d)\n", 
	                dev->node_id, dev->device_id_in_node));
	if (unmap_memory)
	{
		ort_mapped_free(maddr);
		return;
	}
		
	ROFF_LOCK;
	roff_alloctab_unregister(dev->alloc_table, (roff_datatype_t)maddr, NOT_MAPPED);
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(dev_free), (roff_datatype_t)maddr, 
	             unmap_memory, LAST_ARG);
	ROFF_UNLOCK;
}


/**
 * Frees data allocated with roff_dev_init_alloc_global
 *
 * @param device_info  the device
 * @param maddr        pointer to the memory that will be released
 * @param global_id    the ID of the global variable that will be released
 */
void roff_dev_free_global(void *device_info, void *maddr, int global_id)
{
	roff_dev_free(device_info, maddr, NOT_MAPPED);
}


/**
 * Transfers data from a device to the host
 *
 * @param device_info the source device
 * @param hostaddr    the target memory
 * @param hostoffset  offset from hostaddr
 * @param maddr       the source memory mediary address
 * @param size        the size of the memory block
 */
void roff_fromdev(void *device_info, void *hostaddr, size_t hostoffset,
                  void *maddr, size_t devoffset, size_t size)
{
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;

	ROFF_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(fromdev), (roff_datatype_t) maddr, 
	             devoffset, size, LAST_ARG);
	_Comm_Recv_bytes(hostaddr + hostoffset, size, dev->node_id, dev->device_id_in_node);
	ROFF_UNLOCK;
}


/**
 * Transfers data from the host to a device
 *
 * @param device_info the device
 * @param hostaddr    the source memory
 * @param hostoffset  offset from hostaddr
 * @param maddr       the target memory mediary address
 * @param size        the size of the memory block
 */
void roff_todev(void *device_info, void *hostaddr, size_t hostoffset,
                void *maddr, size_t devoffset, size_t size)
{
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;

	DBGPRN((stderr, "[remote primary] >>> roff_todev (node_id: %d, device_id: %d)\n", 
					dev->node_id, dev->device_id_in_node));

	ROFF_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(todev), (roff_datatype_t) maddr, 
	             devoffset, size, LAST_ARG);
	_Comm_Send_bytes(hostaddr + hostoffset, size, dev->node_id, dev->device_id_in_node);
	ROFF_UNLOCK;
}


/**
 * Given an internal mediary address, it returns a usable mediary address
 *
 * @param device_info the device
 * @param imedaddr    allocated memory from roff_alloc
 *
 * @return usable mediary address to pass to a kernel
 */
void *roff_imed2umed_addr(void *device_info, void *imedaddr)
{
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;
	roff_datatype_t umedaddr;
	
	DBGPRN((stderr, "[remote primary] >>> roff_imed2umed_addr (node_id: %d, device_id: %d)\n", 
					dev->node_id, dev->device_id_in_node));

	ROFF_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(imed2umed_addr), (roff_datatype_t) imedaddr, 
	             LAST_ARG);
	_Comm_Recv_1ull(&umedaddr, dev->node_id, dev->device_id_in_node);
	ROFF_UNLOCK;
	
	return (void*) umedaddr; 
}


/**
 * Given a usable mediary address, it returns the internal mediary address
 *
 * @param device_info the device
 * @param umedaddr    allocated memory from roff_alloc
 *
 * @return internal mediary address to be used by ORT
 */
void *roff_umed2imed_addr(void *device_info, void *umedaddr)
{
	roff_devinfo_t *dev = (roff_devinfo_t*) device_info;
	roff_datatype_t imedaddr;

	DBGPRN((stderr, "[remote primary] >>> roff_umed2imed_addr (node_id: %d, device_id: %d)\n", 
				dev->node_id, dev->device_id_in_node));
				
	ROFF_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(umed2imed_addr), (roff_datatype_t) umedaddr, LAST_ARG);
	_Comm_Recv_1ull(&imedaddr, dev->node_id, dev->device_id_in_node);
	ROFF_UNLOCK;
	
	return (void*) imedaddr; 
}


/* new-hostpart-func.sh:rofffuncdef */

#endif /* OMPI_REMOTE_OFFLOADING */
