/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

// #define DBGPRN_FORCE 
// #define DBGPRN_BLOCK
#define DBGPRN_FILTER DBG_RDEV_PRIMARY

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <ctype.h>
#include <string.h>
#include <unistd.h>
#include "config.h"
#include "ort.h"
#include "ort_prive.h"

#ifdef OMPI_REMOTE_OFFLOADING
#include "rdev_config.h"
#include "remotedev/workercmds.h"
#include "remotedev/memory.h"
#include "remotedev/node_manager.h"
#include "remotedev/rdev.h"
#include "remotedev/rdev_prive.h"

static int *argc;
static char ***argv;

#define COMM_DUMMY_ARG -1
#define COMM_LAST_ARG  -2
#define COMM_NO_TAG     0

#define RDEV_LOCK \
	int _cur_level = __MYCB->level; \
	if (_cur_level != 0) \
		ee_set_lock((ee_lock_t*) &(mpi_lock[dev->global_device_id - ort->num_local_devices]))\

#define RDEV_UNLOCK \
	if (_cur_level != 0) \
		ee_unset_lock((ee_lock_t*) &(mpi_lock[dev->global_device_id - ort->num_local_devices]))


rdevinfo_t *devinfotab;           /* Array containing (node_id, device_id, devinfo) entries */
int devinfotab_size = 0;
static volatile ee_lock_t *mpi_lock; /* MPI calls must be serialized per device */
static int running_devices;       /* Number of running devices */
static char *nodes_array;         /* Running nodes array */

static
void exit_error(int exitcode, char *format, ...)
{
	va_list ap;
	va_start(ap, format);
	fprintf(stderr, "[remotedev] ");
	vfprintf(stderr, format, ap);
	va_end(ap);
	exit(exitcode);
}


/* We use this function to send a bunch of ints with one MPI Send call and
 * improve performance.
 */
void send_to_node(int node_id, int tag, ...)
{
    int i;
    rdev_datatype_t value, array[COMM_MAX_NUM_ARGS] = { COMM_DUMMY_ARG };

    va_list args;
    va_start(args, tag);

    value = va_arg(args, rdev_datatype_t);
	for (i = 0; i < COMM_MAX_NUM_ARGS && value != COMM_LAST_ARG; i++)
	{
        array[i] = value;
        value = va_arg(args, rdev_datatype_t);
    }

    va_end(args);

	Comm_Send(comminfo, node_id, COMM_UNSIGNED_LONG_LONG, array, COMM_MAX_NUM_ARGS, 
	          tag, NULL);
}


static
void send_offload_arguments(int node_id, int device_id, char *kernel_fname, int num_teams,
                            int num_threads, int thread_limit, 
                            rdev_datatype_t teamdims, rdev_datatype_t thrdims, 
                            int *num_args, void **args)
{
	int kernlen = strlen(kernel_fname) + 1; /* filename length */
	int total_nargs = num_args[0] + num_args[1] + num_args[2];
	
	/* Sizes */
	send_to_node(node_id, device_id, num_teams, num_threads, thread_limit, teamdims, thrdims, 
	             COMM_LAST_ARG);
	
	/* Array holding the number of arguments */
	Comm_Send(comminfo, node_id, COMM_INT, num_args, NUM_ARGS_SIZE, device_id, NULL);
	
	/* Kernel filename */
	_Comm_Send_1int(&kernlen, node_id, device_id);
	Comm_Send(comminfo, node_id, COMM_CHAR, kernel_fname, kernlen, device_id, NULL);
	
	/* Kernel arguments */
	if (total_nargs > 0)
		_Comm_Send_bytes(args, total_nargs * sizeof(void *), node_id, device_id);
}


static
bool is_cpu_device(rdevinfo_t *dev)
{
	return IS_CPU_MODULE(dev->modulename);
}


static
rdevinfo_t *devinfotab_get(int node_id, int device_id_in_node)
{
	int i;
	for (i = 0; i < devinfotab_size; i++)
		if ((devinfotab[i].device_id_in_node == device_id_in_node) 
		 && (devinfotab[i].node_id == node_id))
			return &(devinfotab[i]);
	return (rdevinfo_t*) NULL;
}

rdevinfo_t *devinfotab_get_from_gldevid(int global_device_id)
{
	int i;
	for (i = 0; i < devinfotab_size; i++)
		if (devinfotab[i].global_device_id == global_device_id)
			return &(devinfotab[i]);
	return (rdevinfo_t*) NULL;
}


static
int primary_node_initialize(int *argc, char ***argv)
{
	static int primary_node_initialized = 0;
	int i, j, num_nodes;
	
	if (primary_node_initialized) return -1;
		
	primary_node_initialized = 1;
	
	/* (1) Initialize devinfo table */
	devinfotab_size = ort->num_remote_devices;
	devinfotab = (rdevinfo_t*) smalloc(devinfotab_size * sizeof(rdevinfo_t));

	/* (2) Spawn/wake up node processes */
	if (rdev_man_create_workers() == 0)
	{
		fprintf(stderr, "[ORT warning]: Failed to initialize remotedev.\n");
		return 0;
	}
	
	num_nodes = rdev_man_get_num_nodes() + 1; /* + myself */
	
	/* (3) Initialize locks */
	mpi_lock = smalloc(ort->num_remote_devices * sizeof(ee_lock_t));
	for (i = 0; i < ort->num_remote_devices; ++i)
		ee_init_lock((ee_lock_t*) &mpi_lock[i], ORT_LOCK_NORMAL);
	
	/* (4) Initialize node flags array */
	nodes_array = smalloc(num_nodes);
	memset(nodes_array, 1, num_nodes);
	
	/* (5) Initialize global vars and send an initialization command to all nodes */
	for (i = ort->num_local_devices; i < ort->num_devices; i++)
		rdev_alloctab_init_global_vars(i, 1);
	
	/* Node initialization is serialized; no mutex need here */
	for (i = 0; i < num_nodes - 1; i++)
		send_to_node(i + 1, COMM_NO_TAG, CMD(INIT), COMM_LAST_ARG);

	return 1;
}


/**
 * Initializes a device
 *
 * @param device_id   the id of the device to initialize
 *                (0 <= dev_num < rdev_get_num_devices())
 *                Only for remotedev, this parameter contains the in-module
 *                device ID and the ID of the node the device belongs to,
 *                encoded as follows:
 *                <node_id> (bits 17-32) <device_id> (bits 1-16) 
 * @param ort_icv Pointer to struct with
 *                initial values for the device ICVs.
 * @param argc    Pointer to main function's argc.
 * @param argv    Pointer to main function's argv.
 * @return        rdevinfo_t pointer that will be used in further calls.
 *                Returns null only if it failed to initialize.
 */
void *rdev_initialize(int device_id, ort_icvs_t *ort_icv, int *argc, char ***argv)
{
	int i, global_device_id;
	unsigned short device_id_in_node, node_id;
	rdevinfo_t *info;
	
	_uint_decode2(device_id, &device_id_in_node, &node_id);

	/* This function is called only once and initializes the primary node 
	 */
	if (primary_node_initialize(argc, argv) == 0)
		return NULL;
	
	/* At this point we pass a pointer to the table entry, and not the device info 
	 * itself, as the actual device info does not live in our address space.
	 */
	info = devinfotab_get((int) node_id, (int) device_id_in_node);
	if (info == NULL)
	{
		fprintf(stderr, "[remotedev] failed to retrieve device info for device_id=%d.\n", 
		                (int) device_id_in_node);
		return NULL;
	}

	++running_devices;
	return (void*) info;
}


/**
 * Finalizes a device
 *
 * @param device_info the device to finalize
 */
void rdev_finalize(void *device_info)
{
	int i, j;
	rdevinfo_t *dev = (rdevinfo_t*) device_info;
	
	/* Shutdown the node the device belongs to */
	send_to_node(dev->node_id, COMM_NO_TAG, CMD(SHUTDOWN), COMM_LAST_ARG);
				            
	nodes_array[dev->node_id] = 0;
	
	/* This is the last device to be finalized */
	if (--running_devices == 0)
	{
		/* Shutdown all remaining active nodes */
		for (i = 0; i < rdev_man_get_num_nodes(); ++i)
		{
			if (nodes_array[i+1])
			{
				send_to_node(i+1, COMM_NO_TAG, CMD(SHUTDOWN), COMM_LAST_ARG);
				nodes_array[i+1] = 0;
			}
		}

		for (i = 0; i < ort->num_remote_devices; i++)
			ee_destroy_lock((ee_lock_t*) &(mpi_lock[i]));

		free((void*) mpi_lock);

		
		free(nodes_array);
		rdev_alloctab_free_all();
		free(devinfotab);
		rdev_man_finalize();
		Comm_Finalize(comminfo);
	}
}


/**
 * Offloads and executes a kernel file.
 *
 * @param device_info         the device
 * @param host_func pointer   to offload function on host address space
 * @param devdata pointer     to a struct containing kernel variables
 * @param decldata pointer    to a struct containing globally declared variables
 * @param kernel_filename_prefix filename of the kernel (without the suffix)
 * @param num_teams           num_teams clause from "teams" construct
 * @param num_threads         form clause in combined parallel constructs
 * @param thread_limit        thread_limit clause from "teams" construct
 * @param teamdims            an unsigned long long that contains the
 *                            dimensions of the launched league, encoded as follows:
 *                            x: bits 0-20, y: bits 21-41, z: bits 42-62 
 * @param thrdims             an unsigned long long that contains the
 *                            dimensions of each thread team, encoded as follows:
 *                            x: bits 0-20, y: bits 21-41, z: bits 42-62 
 * @param num_args            an array that contains the number of declare variables, 
 *                            firstprivates and mapped variables
 * @param args                the addresses of all target data and target
 *                            declare variables
 *
 * NOTE: `teamdims' and `thrdims' can be decoded using the _ull_decode3 function.
 *
 * NOTES: 
 * 1. `teamdims' and `thrdims' can be decoded using the _ull_decode3 function.
 * 2. In MPI module, decldata variables are dealt with in dev_alloc
 * function, not here.
 */
void rdev_offload(void *device_info, void *(*host_func)(void *), void *devdata,
                  void *decldata, char *kernel_fname, int num_teams,
                  int num_threads, int thread_limit, 
                  unsigned long long teamdims, unsigned long long thrdims, 
                  int *num_args, void **args)
{
	int kernel_id, exec_result;
	rdev_datatype_t devdata_maddr = COMM_DUMMY_ARG, decldata_maddr = COMM_DUMMY_ARG;
	size_t devdata_len = 0, decldata_len = 0;
	rdevinfo_t *dev = (rdevinfo_t*) device_info;
	
	kernel_id = ort_kernfunc_findbyname(kernel_fname);
	if (kernel_id < 0)
		exit_error(1, "could not locate kernel function %s !?\n", kernel_fname);

	RDEV_LOCK;

	/* devdata handling */
	if (devdata)
	{
		devdata_len = ort_mapped_get_size(devdata);
		devdata_maddr = rdev_alloctab_register(dev->alloc_table, MAPPED_DEVDATA);
	}

	/* #declare data handling */
	if (decldata)
	{
		decldata_len = ort_mapped_get_size(decldata);
		decldata_maddr = rdev_alloctab_register(dev->alloc_table, MAPPED_DECLDATA);
	}
	
	/* Offload */
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(OFFLOAD), kernel_id, devdata_maddr, devdata_len,
	             decldata_maddr, decldata_len, COMM_LAST_ARG);
		             
	/* Send the devdata struct */
	if (devdata)
		_Comm_Send_bytes(devdata, devdata_len, dev->node_id, dev->device_id_in_node);
		
	/* Non-CPU devices do not call the kernel function directly,
	 * thus we need to transfer all offloading arguments; some of them
	 * can be sent using send_to_node call, the others are sent manually.
	 */
	if (!is_cpu_device(dev))
	{
		/* Then send the decldata struct itself */
		if (decldata)
			_Comm_Send_bytes(decldata, decldata_len, dev->node_id, dev->device_id_in_node);
			
		/* Send all offload arguments */
		send_offload_arguments(dev->node_id, dev->device_id_in_node, kernel_fname, num_teams, 
		                       num_threads, thread_limit, teamdims, 
		                       thrdims, num_args, args);
	}
		
	/* Wait till the device finishes */
	_Comm_Recv_1int(&exec_result, dev->node_id, dev->device_id_in_node);
	RDEV_UNLOCK;

	if (exec_result != OFFLOAD_OK)
		exit_error(OFFLOAD_OK, "kernel execution failed at the end %d\n", kernel_id);
	
	if (devdata)
	{
		RDEV_LOCK;
		rdev_alloctab_unregister(dev->alloc_table, devdata_maddr, MAPPED_DEVDATA);
		RDEV_UNLOCK;
	}
	
	if (decldata)
	{
		RDEV_LOCK;
		rdev_alloctab_unregister(dev->alloc_table, decldata_maddr, MAPPED_DECLDATA);
		RDEV_UNLOCK;
	}
}


/**
 * Allocates memory on the device
 *
 * @param device_info the device
 * @param size        the number of bytes to allocate
 * @param map_memory  used in OpenCL, when set to 1 additionaly to the memory
 *                    allocation in shared virtual address space, the memory
 *                    is mapped with read/write permissions so the host cpu
 *                    can utilize it.
 * @param hostaddr    used in MPI to allocate #declare target link variables;
 *                    when we encounter such variables instead of
 *                    allocating new space, we should return the mediary
 *                    address of the original address.
 * @return hostaddr a pointer to the allocated space (mediary address)
 */
void *rdev_alloc(void *device_info, size_t size, int map_memory, void *hostaddr)
{
	rdev_datatype_t maddr;
	rdevinfo_t *dev = (rdevinfo_t*) device_info;
	
	if (map_memory) /* Only needed for devdata/decldata structs */
		return ort_mapped_alloc(size);
	
	RDEV_LOCK;
	maddr = rdev_alloctab_register(dev->alloc_table, NOT_MAPPED);
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(DEVALLOC), maddr, size, 
	             map_memory, COMM_LAST_ARG);
	RDEV_UNLOCK;
	
	return (void *) maddr;
}


/**
 * Allocates memory on the device for a global variable
 *
 * @param device_info the device
 * @param global_id   the ID of the global variable
 * @param size        the number of bytes to allocate
 * @param hostaddr    used in MPI to allocate #declare target link variables;
 *                    when we encounter such variables instead of
 *                    allocating new space, we should return the mediary
 *                    address of the original address.
 * @return hostaddr a pointer to the allocated space (mediary address)
 */
void *rdev_init_alloc_global(void *device_info, void *initfrom, size_t size, int global_id,
                             void *hostaddr)
{
	rdev_datatype_t maddr = (rdev_datatype_t) global_id;
	rdevinfo_t *dev = (rdevinfo_t*) device_info;

	RDEV_LOCK;
	rdev_alloctab_add_global(dev->alloc_table, global_id, hostaddr);
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(DEVINIT_ALLOC_GLOBAL), maddr, 
	            (rdev_datatype_t)initfrom, size, global_id, COMM_LAST_ARG);
	RDEV_UNLOCK;

	return (void *) maddr;
}


/**
 * Frees data allocated with rdev_alloc
 *
 * @param device_info  the device
 * @param maddr        pointer to the memory that will be released
 * @param unmap_memory used in OpenCL, when set to 1 prior to the memory
 *                     deallocation, the memory is unmapped.
 */
void rdev_free(void *device_info, void *maddr, int unmap_memory)
{
	rdevinfo_t *dev = (rdevinfo_t*) device_info;

	if (unmap_memory)
	{
		ort_mapped_free(maddr);
		return;
	}
		
	RDEV_LOCK;
	rdev_alloctab_unregister(dev->alloc_table, (rdev_datatype_t)maddr, NOT_MAPPED);
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(DEVFREE), (rdev_datatype_t)maddr, unmap_memory, 
	             COMM_LAST_ARG);
	RDEV_UNLOCK;
}


/**
 * Frees data allocated with rdev_init_alloc_global
 *
 * @param device_info  the device
 * @param maddr        pointer to the memory that will be released
 * @param global_id    the ID of the global variable that will be released
 */
void rdev_free_global(void *device_info, void *maddr, int global_id)
{
	rdev_free(device_info, maddr, NOT_MAPPED);
}


/**
 * Transfers data from a device to the host
 *
 * @param device_info the source device
 * @param hostaddr    the target memory
 * @param hostoffset  offset from hostaddr
 * @param maddr       the source memory mediary address
 * @param size        the size of the memory block
 */
void rdev_fromdev(void *device_info, void *hostaddr, size_t hostoffset,
                  void *maddr, size_t devoffset, size_t size)
{
	rdevinfo_t *dev = (rdevinfo_t*) device_info;

	RDEV_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(FROMDEV), (rdev_datatype_t)maddr, devoffset, size, 
	                    COMM_LAST_ARG);
	_Comm_Recv_bytes(hostaddr + hostoffset, size, dev->node_id, dev->device_id_in_node);
	RDEV_UNLOCK;
}


/**
 * Transfers data from the host to a device
 *
 * @param device_info the device
 * @param hostaddr    the source memory
 * @param hostoffset  offset from hostaddr
 * @param maddr       the target memory mediary address
 * @param size        the size of the memory block
 */
void rdev_todev(void *device_info, void *hostaddr, size_t hostoffset,
                void *maddr, size_t devoffset, size_t size)
{
	rdevinfo_t *dev = (rdevinfo_t*) device_info;

	RDEV_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(TODEV), (rdev_datatype_t) maddr, devoffset, size,
	                    COMM_LAST_ARG);
	_Comm_Send_bytes(hostaddr + hostoffset, size, dev->node_id, dev->device_id_in_node);
	RDEV_UNLOCK;
}


/**
 * Given an internal mediary address, it returns a usable mediary address
 *
 * @param device_info the device
 * @param imedaddr    allocated memory from rdev_alloc
 *
 * @return usable mediary address to pass to a kernel
 */
void *rdev_imed2umed_addr(void *device_info, void *imedaddr)
{
	rdevinfo_t *dev = (rdevinfo_t*) device_info;
	rdev_datatype_t umedaddr;
	
	RDEV_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(I2UMEDADDR), (rdev_datatype_t) imedaddr, COMM_LAST_ARG);
	_Comm_Recv_1ull(&umedaddr, dev->node_id, dev->device_id_in_node);
	RDEV_UNLOCK;
	
	return (void*) umedaddr; 
}


/**
 * Given a usable mediary address, it returns the internal mediary address
 *
 * @param device_info the device
 * @param umedaddr    allocated memory from rdev_alloc
 *
 * @return internal mediary address to be used by ORT
 */
void *rdev_umed2imed_addr(void *device_info, void *umedaddr)
{
	rdevinfo_t *dev = (rdevinfo_t*) device_info;
	rdev_datatype_t imedaddr;
	
	RDEV_LOCK;
	send_to_node(dev->node_id, dev->device_id_in_node, CMD(U2IMEDADDR), (rdev_datatype_t) umedaddr, COMM_LAST_ARG);
	_Comm_Recv_1ull(&imedaddr, dev->node_id, dev->device_id_in_node);
	RDEV_UNLOCK;
	
	return (void*) imedaddr; 
}

#undef DEBUG_PRIMARY_NODE
#endif /* OMPI_REMOTE_OFFLOADING */
