/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* remotedev/workercmds.c
 * 
 * All worker (remote; with rank != 0) MPI processes execute
 * these functions. Their purpose is to respond to the
 * primary node requests. Note that the primary and the worker nodes do
 * **NOT** run in a shared memory environment. Communication is
 * achieved with MPI.
 */

// #define DBGPRN_FORCE 
// #define DBGPRN_BLOCK
#define DBGPRN_FILTER DBG_RDEV_WORKERS

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <unistd.h>
#include "ort_prive.h"
#include "../../rt_common.h"

#ifdef OMPI_REMOTE_OFFLOADING
#include "remotedev/workercmds.h"
#include "remotedev/workers.h"
#include "remotedev/memory.h"
#include "remotedev/node_manager.h"
#include "remotedev/rdev_prive.h"
#include "rdev_config.h"
#include "targdenv.h"

DEFINE_WORKER_COMMANDS(
	init,    
	shutdown, 
	fromdev, 
	todev,    
	devalloc, 
	devinit_alloc_global,
	devfree, 
	devfree_global,
	offload, 
	i2umedaddr, 
	u2imedaddr
);

/* Primary node performs INIT (node_id) */
static void init(int tag, rdev_datatype_t *data)
{
	int i;

	rdev_alloctab_init(ort->num_devices);
	for (i = 0; i < ort->num_devices; i++)
		rdev_alloctab_init_global_vars(i, 0);
}


/* Primary node performs SHUTDOWN */
static void shutdown(int tag, rdev_datatype_t *data)
{
	DBGPRN((stderr, "[remotedev worker %d] >>> SHUTDOWN cmd; exiting...\n",
	             getpid()));
	
	rdev_alloctab_free_all();
	ort_kernfunc_cleanup();

	Comm_Finalize(comminfo);
	exit(0);
}


/* Primary node performs FROMDEV (medaddress, offset, size) */
static void fromdev(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  maddr = data[1];
	int              offset = data[2], nbytes = data[3];
	void            *devaddr, *buf;
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);
	
	buf = smalloc(nbytes);
	
	DBGPRN((stderr, "[remotedev worker %d] Device %d: GET cmd from %p (offset:%d, size:%d)\n", 
	             getpid(), device_id, (void*)maddr, offset, nbytes));
	                 
	/* CPU device; send the contents of devaddr directly */
	devaddr = rdev_alloctab_get(table, maddr);
	if (d->is_cpudev)
		_Comm_Send_bytes(devaddr + offset, nbytes, PRIMARY_NODE, device_id);
	else
	{
		/* Not a CPU device; store the device data to a buffer
	 	 * and send it to the primary node. */
		DEVICE_READ(d, buf, 0, devaddr, offset, nbytes);
		_Comm_Send_bytes(buf, nbytes, PRIMARY_NODE, device_id);
		free(buf);
	}
	
}


/* Primary node performs TODEV (medaddress, offset, size) */
static void todev(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  maddr = data[1];
	int              offset = data[2], nbytes = data[3];
	void            *devaddr, *buf;
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);
	
	buf = smalloc(nbytes);

	/* CPU device; write to devaddr directly */
	devaddr = rdev_alloctab_get(table, maddr);

	DBGPRN((stderr, "[remotedev worker %d] Device %d: PUT cmd to %p, devaddr=%p (offset:%d, size:%d)\n", 
	        getpid(), device_id, maddr, devaddr, offset, nbytes));
	
	if (d->is_cpudev)
		_Comm_Recv_bytes(devaddr + offset, nbytes, PRIMARY_NODE, device_id);
	else
	{
		/* Not a CPU device; store the received data to a buffer 
		 * and then pass it to the device */
		_Comm_Recv_bytes(buf, nbytes, PRIMARY_NODE, device_id);
		DEVICE_WRITE(d, buf, 0, devaddr, offset, nbytes);
		free(buf);
	}
}


/* Primary node performs DEVALLOC (size, medaddress) */
static void devalloc(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  maddr = data[1], devaddr;
	int              nbytes = data[2], 
	                 map_memory = data[3];
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);
	
	rdev_alloctab_add(table, maddr, nbytes, map_memory);
	DBGPRN((stderr, "[remotedev worker %d] Device %d: ALLOC cmd for %d bytes --> maddr: %p\n", 
	                 getpid(), device_id, nbytes, maddr));
}


/* Primary node performs DEVINIT_ALLOC_GLOBAL (size, medaddress) */
static void devinit_alloc_global(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  maddr = data[1], 
	                 initfrom = data[2];
	int              nbytes = data[3], 
	                 glid  = data[4];
	tditem_t         item;
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);
	void            *devaddr;

	item = tdenv_global_get_by_id(device_id, glid); 
	
	if (d->is_cpudev)
		rdev_alloctab_add_global(table, maddr, item->hostaddr);
	else
	{
		devaddr = DEVICE_INIT_ALLOC_GLOBAL(d, (void*)initfrom, nbytes, glid, item->hostaddr);
		DBGPRN((stderr, "[remotedev worker %d] Device %d: ALLOC GLOBAL cmd for %d bytes, --> devaddr: %p, initfrom %p (glid = %d)\n", 
	             getpid(), device_id, nbytes, item->hostaddr, initfrom, glid));
		rdev_alloctab_add_global(table, maddr, devaddr);
	}
}


/* Primary node performs DEVFREE (medaddress) */
static void devfree(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  maddr = data[1];
	int              unmap_memory = data[2];
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);

	DBGPRN((stderr, "[remotedev worker %d] Device %d: FREE cmd for %ld at %p\n", getpid(), 
	                device_id, maddr, rdev_alloctab_get(table, maddr)));
	rdev_alloctab_remove(table, maddr, unmap_memory);
}


/* Primary node performs DEVFREE_GLOBAL (medaddress) */
static void devfree_global(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  maddr = data[1];
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);

	if (d->is_cpudev)
		return;

	DBGPRN((stderr, "[remotedev worker %d] Device %d: FREE GLOBAL cmd for %p\n", 
	                getpid(), device_id, maddr));
	rdev_alloctab_remove(table, maddr, NOT_MAPPED);
}


/* Receives decldata struct from the primary node and stores it */
static 
void *recv_datastruct(int device_id, rdev_datatype_t devdecldata_maddr, size_t devdecldata_size,
                      int map_type)
{
	void *devdecldata = NULL;
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	
	if (devdecldata_size)
	{
		devdecldata = rdev_alloctab_add(table, devdecldata_maddr, devdecldata_size, 
										map_type);

		/* Memory copy */
		_Comm_Recv_bytes(devdecldata, devdecldata_size, PRIMARY_NODE, device_id); 

		DBGPRN((stderr, "[remotedev worker %d] Device %d: OFFLOAD cmd: alloced devdata/decldata "
		                "(i:%d, u:%p, size:%d)\n", 
						getpid(), device_id, devdecldata_maddr, 
						devdecldata, devdecldata_size));
	}

	return devdecldata;
}


static 
void generic_offload(ort_device_t *d, rdev_datatype_t *offload_args, ttkfunc_t kernel_function, 
                     void *devdata, void *decldata)
{
	int    kernel_fname_len, num_args[3], args_size,
	       num_teams = offload_args[0], 
	       num_threads = offload_args[1], 
	       thread_limit = offload_args[2];
	u_long teamdims = offload_args[3], 
	       thrdims = offload_args[4];
	char   kernel_fname[256];
	void **args;
	void  *args2[] = { NULL };
		
	/* (1) Receive remaining offload arguments */
	Comm_Recv(comminfo, PRIMARY_NODE, COMM_INT, num_args, NUM_ARGS_SIZE, d->id, NULL);
	_Comm_Recv_1int(&kernel_fname_len, PRIMARY_NODE, d->id);
	Comm_Recv(comminfo, PRIMARY_NODE, COMM_CHAR, kernel_fname, kernel_fname_len, d->id, NULL);
	         
	args_size = (num_args[0] + num_args[1] + num_args[2]) * sizeof(void*);

	if (args_size > 0)
	{
		args = smalloc(args_size);
		_Comm_Recv_bytes(args, args_size, PRIMARY_NODE, d->id);
		/* (2) Offload */
		DEVICE_OFFLOAD(d, kernel_function, devdata, decldata, kernel_fname, num_teams, 
		               num_threads, thread_limit, teamdims, thrdims, num_args, args); 
		free(args);
		args = NULL;
	}
	else
	{
		DEVICE_OFFLOAD(d, kernel_function, devdata, decldata, kernel_fname, num_teams, 
		               num_threads, thread_limit, teamdims, thrdims, num_args, args2); 
	}
}


/* Primary node performs OFFLOAD */
static void offload(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  offload_args[COMM_MAX_NUM_ARGS],
	                 devdata_maddr = data[2],
				     decldata_maddr = data[4];
	int              kernel_id = data[1], 
	                 exec_result;
	size_t           devdata_size = data[3], 
	                 decldata_size = data[5];
	ort_device_t    *d = __DEVICE(device_id);
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	void            *devdata = NULL, *decldata = NULL;

	/* (1) Retrieve kernel function from ttk table */
	ttkfunc_t     kernel_function = ort_kernfunc_getptr(kernel_id);

	/* (2) Receive devdata from the primary node */
	devdata = recv_datastruct(device_id, devdata_maddr, devdata_size, MAPPED_DEVDATA);

	DBGPRN((stderr, "[remotedev worker %d] Device %d: OFFLOAD cmd: start kernel (id=%d) %p %d\n",
			         getpid(), device_id, kernel_id, devdata, devdata_size));
	
	/* (3) Receive the offload arguments from the primary node and perform an offload */
	if (d->is_cpudev) 
		(*kernel_function)(devdata); 
	else
	{
		/* Receive decldata from the primary node */
		decldata = recv_datastruct(device_id, decldata_maddr, decldata_size, MAPPED_DECLDATA);

		/* #teams, #threads, thread limit, teamdims, thrdims */
		Comm_Recv(comminfo, PRIMARY_NODE, COMM_UNSIGNED_LONG_LONG, offload_args, COMM_MAX_NUM_ARGS, 
		          device_id, NULL);

		/* remaining args + offloading */
		generic_offload(d, offload_args, kernel_function, devdata, decldata);
	}

	DBGPRN((stderr, "[remotedev worker %d] Device %d: OFFLOAD cmd: finish kernel (id=%d)\n",
			        getpid(), device_id, kernel_id));
	
	/* (4) Remove devdata from memory */
	if (devdata)
		rdev_alloctab_remove(table, devdata_maddr, MAPPED_DEVDATA);

	/* (5) Remove decldata from memory */
	if (decldata)
		rdev_alloctab_remove(table, decldata_maddr, MAPPED_DECLDATA);

	/* (6) Notify the primary node */
	exec_result = OFFLOAD_OK; /* all OK */
	_Comm_Send_1int(&exec_result, PRIMARY_NODE, device_id);
}


/* Primary node performs I2UMEDADDR */
static void i2umedaddr(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  imedaddr = data[1], umedaddr;
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);

	void *medaddr = rdev_alloctab_get(table, imedaddr);

	/* Convert and send to primary node */
	umedaddr = (rdev_datatype_t) DEVICE_I2UMEDADDR(d, medaddr);

	DBGPRN((stderr, "[remotedev worker %d] Device %d: >>> I2UMEDADDR maddr=%p, imedaddr = %p <-> umedaddr = %p\n",
				    getpid(), device_id, medaddr, imedaddr, umedaddr));
	_Comm_Send_1ull(&umedaddr, PRIMARY_NODE, device_id);
}


/* Primary node performs U2IMEDADDR */
static void u2imedaddr(int device_id, rdev_datatype_t *data)
{
	rdev_datatype_t  umedaddr = data[1], imedaddr;
	void            *medaddr;
	rdev_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *d = __DEVICE(device_id);
	
	/* Convert and send to primary node */
	medaddr = DEVICE_U2IMEDADDR(d, (void*) umedaddr);
	imedaddr = rdev_alloctab_get_item(table, medaddr);

	DBGPRN((stderr, "[remotedev worker %d] Device %d: >>> U2IMEDADDR maddr=%p, umedaddr = %p <-> imedaddr = %p\n",
				    getpid(), device_id, medaddr, umedaddr, imedaddr));

	_Comm_Send_1ull(&imedaddr, PRIMARY_NODE, device_id);
}

#endif /* OMPI_REMOTE_OFFLOADING */
