/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* remote/workercmds.c
 * 
 * All worker (remote; with rank != 0) MPI processes execute
 * these functions. Their purpose is to respond to the
 * primary node requests. Note that the primary and the worker nodes do
 * **NOT** run in a shared memory environment. Communication is
 * achieved with MPI.
 */

// #define DBGPRN_FORCE 
// #define DBGPRN_BLOCK
#define DBGPRN_FILTER DBG_ROFF_WORKERS

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <unistd.h>
#include "ort_prive.h"
#include "rt_common.h"
#include "assorted.h"

#ifdef OMPI_REMOTE_OFFLOADING
#include "remote/workercmds.h"
#include "remote/workers.h"
#include "remote/memory.h"
#include "remote/node_manager.h"
#include "remote/roff_prive.h"
#include "roff_config.h"
#include "targdenv.h"

/* Declare all worker commands prototypes */
WORKER_PROTO_LIST

/* Add worker commands to an array */
void (*worker_commands[])(int, roff_datatype_t*) = {
	WORKER_COMMAND_ELEMS
	NULL
};

/* Primary node performs INITIALIZE */
static void initialize(int tag, roff_datatype_t *data)
{
	int i;
	DBGPRN((stderr, "[remote worker %d] >>> INITIALIZE cmd\n", 
	                getpid()));

	roff_alloctab_init(ort->num_devices);
	for (i = 0; i < ort->num_devices; i++)
		roff_alloctab_init_global_vars(i);
}


/* Primary node performs SHUTDOWN */
static void shutdown(int tag, roff_datatype_t *data)
{
	DBGPRN((stderr, "[remote worker %d] >>> SHUTDOWN cmd; exiting...\n",
	                getpid()));
	
	roff_alloctab_free_all();
	ort_kernfunc_cleanup();

#ifndef ROFF_WORKER_MULTITHREADING
	Comm_Finalize(comminfo);
#endif

	exit(0);
}


/* Primary node performs FINALIZE */
static void finalize(int device_id, roff_datatype_t *data)
{
	ort_device_t    *dev = __DEVICE(device_id);
	int              finalize_result;
	DBGPRN((stderr, "[remote worker %d] >>> Device %d: FINALIZE cmd\n",
	                getpid(), device_id));

	/* (1) Finalize the device (only non-CPU ones) */
	if (!dev->is_cpudev)
		DEVICE_FINALIZE(dev);

	/* (2) Notify the primary node anyway */
	finalize_result = FINALIZE_OK;
	_Comm_Send_1int(&finalize_result, PRIMARY_NODE, device_id);

#ifdef ROFF_WORKER_MULTITHREADING
	exit(0);
#endif
}


/* Primary node performs FROMDEV (medaddr, offset, size) */
static void fromdev(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  medaddr = data[1];
	int              offset = data[2], nbytes = data[3];
	void            *devaddr, *buf;
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *dev = __DEVICE(device_id);
	
	buf = smalloc(nbytes);
	
	DBGPRN((stderr, "[remote worker %d] >>> Device %d: FROMDEV cmd (medaddr: %p, offset: %d, size: %d)\n", 
	                getpid(), device_id, (void*)medaddr, offset, nbytes));
	                 
	/* CPU device; send the contents of devaddr directly */
	devaddr = roff_alloctab_get(table, medaddr);
	if (dev->is_cpudev)
		_Comm_Send_bytes(devaddr + offset, nbytes, PRIMARY_NODE, device_id);
	else
	{
		/* Not a CPU device; store the device data to a buffer
	 	and send it to the primary node. */
		DEVICE_READ(dev, buf, 0, devaddr, offset, nbytes);
		_Comm_Send_bytes(buf, nbytes, PRIMARY_NODE, device_id);
		free(buf);
	}
}


/* Primary node performs TODEV (medaddr, offset, size) */
static void todev(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  medaddr = data[1];
	int              offset = data[2], nbytes = data[3];
	void            *devaddr, *buf;
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *dev = __DEVICE(device_id);
	
	buf = smalloc(nbytes);

	/* CPU device; write to devaddr directly */
	devaddr = roff_alloctab_get(table, medaddr);

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: TODEV cmd"
	                "(medaddr: %p, devaddr: %p, offset: %d, size: %d)\n", 
	                getpid(), device_id, medaddr, devaddr, offset, nbytes));
	
	if (dev->is_cpudev)
		_Comm_Recv_bytes(devaddr + offset, nbytes, PRIMARY_NODE, device_id);
	else
	{
		/* Not a CPU device; store the received data to a buffer 
		and then pass it to the device */
		_Comm_Recv_bytes(buf, nbytes, PRIMARY_NODE, device_id);
		DEVICE_WRITE(dev, buf, 0, devaddr, offset, nbytes);
		free(buf);
	}
}


/* Primary node performs DEV_ALLOC (medaddr, size, map_memory) */
static void dev_alloc(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  medaddr = data[1];
	int              nbytes = data[2], 
	                 map_memory = data[3];

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: DEV_ALLOC cmd (medaddr: %p, size: %d)\n", 
	                 getpid(), device_id, medaddr, nbytes));	

	roff_alloctab_t *table = __ALLOCTABLE(device_id);

	roff_alloctab_add(table, medaddr, nbytes, map_memory);
}


/* Primary node performs DEV_INIT_ALLOC_GLOBAL (medaddr, initfrom, size, glid) */
static void dev_init_alloc_global(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  medaddr = data[1], 
	                 initfrom = data[2];
	int              nbytes = data[3], 
	                 glid  = data[4];
	tditem_t         item;
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *dev = __DEVICE(device_id);
	void            *devaddr;

	item = tdenv_global_get_by_id(device_id, glid); 
	
	if (dev->is_cpudev)
		roff_alloctab_add_global(table, medaddr, item->hostaddr);
	else
	{
		devaddr = DEVICE_INIT_ALLOC_GLOBAL(dev, (void*)initfrom, nbytes, glid, item->hostaddr);
		DBGPRN((stderr, "[remote worker %d] >>> Device %d: DEV_INIT_ALLOC_GLOBAL cmd"
		                "(medaddr: %p, initfrom: %p, size: %d, glid: %d)\n", 
	                    getpid(), device_id, item->hostaddr, initfrom, nbytes, glid));
		roff_alloctab_add_global(table, medaddr, devaddr);
	}
}


/* Primary node performs DEV_FREE (medaddr, unmap_memory) */
static void dev_free(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  medaddr = data[1];
	int              unmap_memory = data[2];
	roff_alloctab_t *table = __ALLOCTABLE(device_id);

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: DEV_FREE cmd (medaddr: %ld, addr: %p)\n", 
	                getpid(), device_id, medaddr, roff_alloctab_get(table, medaddr)));
	roff_alloctab_remove(table, medaddr, unmap_memory);
}


/* Primary node performs DEV_FREE_GLOBAL (medaddr) */
static void dev_free_global(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  medaddr = data[1];
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *dev = __DEVICE(device_id);

	if (dev->is_cpudev)
		return;

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: DEV_FREE_GLOBAL cmd (medaddr: %p)\n", 
	                getpid(), device_id, medaddr));
	roff_alloctab_remove(table, medaddr, NOT_MAPPED);
}


/* Receives decldata struct from the primary node and stores it */
static 
void *recv_datastruct(int device_id, roff_datatype_t devdecldata_medaddr, size_t devdecldata_size,
                      int map_type)
{
	void *devdecldata = NULL;
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	
	if (devdecldata_size)
	{
		devdecldata = roff_alloctab_add(table, devdecldata_medaddr, devdecldata_size, 
										map_type);

		/* Memory copy */
		_Comm_Recv_bytes(devdecldata, devdecldata_size, PRIMARY_NODE, device_id); 

		DBGPRN((stderr, "[remote worker %d] >>> Device %d: OFFLOAD cmd: alloced devdata/decldata "
		                "(i: %d, u: %p, size: %d)\n", 
						getpid(), device_id, devdecldata_medaddr, 
						devdecldata, devdecldata_size));
	}

	return devdecldata;
}


static 
void generic_offload(ort_device_t *dev, roff_datatype_t *offload_args, ttkfunc_t kernel_function, 
                     void *devdata, void *decldata)
{
	int    kernel_filename_prefix_len, num_args[3], args_size,
	       num_teams = offload_args[0], 
	       num_threads = offload_args[1], 
	       thread_limit = offload_args[2];
	u_long teamdims = offload_args[3], 
	       thrdims = offload_args[4];
	char   kernel_filename_prefix[256];
	void **args;
	void  *args2[] = { NULL };
		
	/* (1) Receive remaining offload arguments */
	/* (1.1) Array holding the number of args (#decl, #fip, #mapped) */
	Comm_Recv(comminfo, PRIMARY_NODE, COMM_INT, num_args, NUM_ARGS_SIZE, dev->idx, NULL);

	/* (1.2) Length of kernel filename prefix */
	_Comm_Recv_1int(&kernel_filename_prefix_len, PRIMARY_NODE, dev->idx);

	/* (1.3) The kernel filename prefix */
	Comm_Recv(comminfo, PRIMARY_NODE, COMM_CHAR, kernel_filename_prefix, kernel_filename_prefix_len, dev->idx, NULL);

	args_size = (num_args[ARGS_NUMDECL] + num_args[ARGS_NUMFIP] + num_args[ARGS_NUMMAPPED]) * sizeof(void*);

	/* (2) Offload */
	if (args_size > 0)
	{
		args = smalloc(args_size);
		_Comm_Recv_bytes(args, args_size, PRIMARY_NODE, dev->idx);
		DEVICE_OFFLOAD(dev, kernel_function, devdata, decldata, kernel_filename_prefix, num_teams, 
		               num_threads, thread_limit, teamdims, thrdims, num_args, args); 
		free(args);
		args = NULL;
	}
	else /* no arguments */
	{
		DEVICE_OFFLOAD(dev, kernel_function, devdata, decldata, kernel_filename_prefix, num_teams, 
		               num_threads, thread_limit, teamdims, thrdims, num_args, args2); 
	}
}


/* Primary node performs OFFLOAD */
static void offload(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  offload_args[CMD_MAX_NUM_ARGS],
	                 devdata_medaddr = data[2],
				     decldata_medaddr = data[4];
	int              kernel_id = data[1], 
	                 exec_result;
	size_t           devdata_size = data[3], 
	                 decldata_size = data[5];
	ort_device_t    *dev = __DEVICE(device_id);
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	void            *devdata = NULL, *decldata = NULL;

	/* (1) Retrieve kernel function from ttk table */
	ttkfunc_t kernel_function = ort_kernfunc_getptr(kernel_id);

	/* (2) Receive devdata from the primary node */
	devdata = recv_datastruct(device_id, devdata_medaddr, devdata_size, MAPPED_DEVDATA);

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: OFFLOAD cmd: Start kernel "
	                "(id: %d, devdata: %p, devdata_size: %d, kernel_func: %p, is_cpudev: %d)\n",
			         getpid(), device_id, kernel_id, devdata, devdata_size, *kernel_function, dev->is_cpudev));
	
	current_execdev = dev; /* set current device */

	/* (3) Receive the offload arguments from the primary node and perform an offload */
	if (dev->is_cpudev) 
		(*kernel_function)(devdata); 
	else
	{
		/* Receive decldata from the primary node */
		decldata = recv_datastruct(device_id, decldata_medaddr, decldata_size, MAPPED_DECLDATA);

		/* #teams, #threads, thread limit, teamdims, thrdims */
		Comm_Recv(comminfo, PRIMARY_NODE, COMM_UNSIGNED_LONG_LONG, offload_args, CMD_MAX_NUM_ARGS, 
		          device_id, NULL);

		/* remaining args + offloading */
		generic_offload(dev, offload_args, kernel_function, devdata, decldata);
	}

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: OFFLOAD cmd: Finish kernel (id: %d)\n",
			        getpid(), device_id, kernel_id));
	
	/* (4) Remove devdata from memory */
	if (devdata)
		roff_alloctab_remove(table, devdata_medaddr, MAPPED_DEVDATA);

	/* (5) Remove decldata from memory */
	if (decldata)
		roff_alloctab_remove(table, decldata_medaddr, MAPPED_DECLDATA);

	/* (6) Notify the primary node */
	exec_result = OFFLOAD_OK; /* all OK */
	_Comm_Send_1int(&exec_result, PRIMARY_NODE, device_id);
}


/* Primary node performs IMED2UMED_ADDR */
static void imed2umed_addr(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  imedaddr = data[1], umedaddr;
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *dev = __DEVICE(device_id);

	void *medaddr = roff_alloctab_get(table, imedaddr);

	/* Convert and send to primary node */
	umedaddr = (roff_datatype_t) DEVICE_IMED2UMED_ADDR(dev, medaddr);

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: IMED2UMED_ADDR (medaddr: %p, imedaddr: %p) -> umedaddr: %p\n",
				    getpid(), device_id, medaddr, imedaddr, umedaddr));
	_Comm_Send_1ull(&umedaddr, PRIMARY_NODE, device_id);
}


/* Primary node performs UMED2IMED_ADDR */
static void umed2imed_addr(int device_id, roff_datatype_t *data)
{
	roff_datatype_t  umedaddr = data[1], imedaddr;
	void            *medaddr;
	roff_alloctab_t *table = __ALLOCTABLE(device_id);
	ort_device_t    *dev = __DEVICE(device_id);
	
	/* Convert and send to primary node */
	medaddr = DEVICE_UMED2IMED_ADDR(dev, (void*) umedaddr);
	imedaddr = roff_alloctab_get_item(table, medaddr);

	DBGPRN((stderr, "[remote worker %d] >>> Device %d: UMED2IMED_ADDR (medaddr: %p, umedaddr: %p) -> imedaddr: %p\n",
				    getpid(), device_id, medaddr, umedaddr, imedaddr));
	_Comm_Send_1ull(&imedaddr, PRIMARY_NODE, device_id);
}

/* new-hostpart-func.sh:workerfunc */

#endif /* OMPI_REMOTE_OFFLOADING */
