/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* remote/workers.c
 * 
 * All worker (remote; with rank != 0) MPI processes execute
 * these functions. Their purpose is to respond to the
 * primary node's requests. Note that the primary and the worker nodes do
 * **NOT** run in a shared memory environment. Communication is
 * achieved with MPI.
 */

// #define DBGPRN_FORCE 
// #define DBGPRN_BLOCK
#define DBGPRN_FILTER DBG_ROFF_WORKERS

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <unistd.h>
#include "ort_prive.h"
#include "rt_common.h"
#include "assorted.h"
#include <pthread.h>
#include <assert.h>


#ifdef OMPI_REMOTE_OFFLOADING
#include "remote/workers.h"
#include "remote/workercmds.h"
#include "remote/memory.h"
#include "remote/node_manager.h"
#include "remote/roff_prive.h"
#include "roff_config.h"

/* args[0] contains the command ID */
#define _workercmd_exec(message) \
	(worker_commands[message.args[0]])(message.tag, message.args)
#define _workercmd_is_valid(message) \
	( (message.args[0] >= CMD(initialize)) && (message.args[0] < CMD(last)) )

static volatile ee_lock_t roff_lock;
static void **local_devinfos; /* Used by workers to store their local devinfos */

static bool recv_command(command_t *command);

#ifdef TLS_KEYWORD
	__thread ort_device_t *current_execdev = NULL;
#else
	ort_device_t *current_execdev = NULL;
#endif

#ifdef ROFF_WORKER_MULTITHREADING
	typedef struct 
	{
		pthread_t threads[256];
		int nthreads_used;
		int thrid;
	} worker_threads_t;

	typedef struct threadinfo_
	{
		int thread_id;
		int first_device_id;
		int last_device_id;
	} threadinfo_t;

	/* Worker threads */
	worker_threads_t worker_threads;

	pthread_key_t     key;
	pthread_mutex_t   mutex;
	pthread_barrier_t barrier;

	static int num_worker_threads = 1;
#endif

#ifdef ROFF_MULTIPLE_WORKERS
	static int first_device_id;
	static int last_device_id;
	static int parent_worker_node_id;
	static int num_siblings;
#endif

/* Made weak so that if found at load time they get overridden.
 * We need (weak) definitions to cover the case where the module is only 
 * loaded for queries e.g. by ompicc --devinfo; in such cases ORT is not 
 * present and those symbols would otherwise be undefined.
 */
#pragma weak actual__dev_med2dev_addr
char *(*actual__dev_med2dev_addr)(void *, unsigned long);
#pragma weak actual_omp_is_initial_device
int (*actual_omp_is_initial_device)(void);
#pragma weak actual_omp_get_device_num
int (*actual_omp_get_device_num)(void);


char *roff__dev_med2dev_addr(void *medaddr, unsigned long size)
{
	DBGPRN((stderr, "[remote worker %d] _dev_med2dev_addr %p\n", 
	             getpid(), medaddr));
	return medaddr;
}
void override__dev_med2dev_addr(void)
{
	extern char *(*actual__dev_med2dev_addr)(void *, unsigned long);
	actual__dev_med2dev_addr = roff__dev_med2dev_addr;
}


int roff_omp_get_device_num(void)
{
	assert(current_execdev != NULL);
	return num_primary_devs - 1 + current_execdev->idx;
}
void override_omp_get_device_num(void)
{
	extern int (*actual_omp_get_device_num)(void);
	actual_omp_get_device_num = roff_omp_get_device_num;
}


int roff_omp_is_initial_device(void)
{
	return 0;
}
void override_omp_is_initial_device(void)
{
	extern int (*actual_omp_is_initial_device)(void);
	actual_omp_is_initial_device = roff_omp_is_initial_device;
}


#ifdef ROFF_MULTIPLE_WORKERS

/* Initialize all local devices and construct an array
 * containing all devinfos.
 */
static void _worker_init_local_devices(void)
{
	int did, i = 0;
	ee_init_lock((ee_lock_t *) &roff_lock, ORT_LOCK_NORMAL);

	local_devinfos = (void **) smalloc((last_device_id - first_device_id) * sizeof(void *));

	/* Learn how many devices the primary node has */
	_Comm_Recv_1int(&num_primary_devs, PRIMARY_NODE, COMM_ANY_TAG);
	
	/* Initialize all non-remote devices (excluding hostdev which is already initialized) */
	for (did = 0; did < ort->num_devices; did++)
	{
		if (ort->ort_devices[did].module->remote)
			continue;

		DBGPRN((stderr, "[remote worker %s] >>> Initializing device %d of module \"%s\"\n",
		                Comm_Get_info(comminfo), did, ort->ort_devices[did].module->name));
						
		ee_set_lock((ee_lock_t *) &roff_lock);
		if (!ort->ort_devices[did].initialized)
			ort_init_device(did);
		ee_unset_lock((ee_lock_t *) &roff_lock);
		
		if (did >= first_device_id && did < last_device_id)
			local_devinfos[i++] = (ort->ort_devices[did].device_info == NULL) ? 
				(void *) 0x999 : ort->ort_devices[did].device_info;
	}

	/* Inform the primary node (waiting @ _primary_receive_all_devinfos) */
	Comm_Send(comminfo, PRIMARY_NODE, COMM_BYTE, local_devinfos, 
			 (last_device_id - first_device_id) * sizeof(void*), 0, NULL);
	

	free(local_devinfos); /* not needed anymore */
	local_devinfos = NULL;
}

#else

/* Initialize all local devices and construct an array
 * containing all devinfos.
 */
static void _worker_init_local_devices(void)
{
	int did, i = 0;
	ee_init_lock((ee_lock_t *) &roff_lock, ORT_LOCK_NORMAL);

	local_devinfos = (void**) smalloc(ort->num_local_devices * sizeof(void *));

	/* Learn how many devices the primary node has */
	_Comm_Recv_1int(&num_primary_devs, PRIMARY_NODE, COMM_ANY_TAG);

	/* Initialize all non-remote devices (excluding hostdev which is already initialized) */
	for (did = 0; did < ort->num_local_devices; did++)
	{
		if (ort->ort_devices[did].module->remote)
			continue;

		DBGPRN((stderr, "[remote worker %s] >>> Initializing device %d of module \"%s\"\n",
		                Comm_Get_info(comminfo), did, ort->ort_devices[did].module->name));
						
		ee_set_lock((ee_lock_t *) &roff_lock);
		if (!ort->ort_devices[did].initialized)
			ort_init_device(did);
		ee_unset_lock((ee_lock_t *) &roff_lock);
		
		local_devinfos[i++] = ort->ort_devices[did].device_info;
	}

	/* Inform the primary node (waiting @ _primary_receive_all_devinfos) */
	Comm_Send(comminfo, PRIMARY_NODE, COMM_BYTE, local_devinfos, 
	          ort->num_local_devices * sizeof(void*), 0, NULL);
	
	free(local_devinfos); /* not needed anymore */
	local_devinfos = NULL;
}

#endif


/* Retrieves node entry in remote offloading config and returns installed modules */
char **roff_worker_get_modules(int *nModules)
{
	int index;

	/* Initialize remote configuration from the embedded hex */
#ifdef ROFF_IGNORE_CONFIGURATION_SNAPSHOT
	roff_config_initialize(DONT_IGNORE_DISABLED_MODULES, is_userprog_portable);
#else
	roff_config_initialize_from_hex(ompi_remote_devices, ompi_remote_devices_size, DONT_IGNORE_DISABLED_MODULES,
	                                is_userprog_portable);
#endif /* ROFF_IGNORE_CONFIGURATION_SNAPSHOT */
		
#ifdef ROFF_MULTIPLE_WORKERS
	index = parent_worker_node_id;

	num_siblings = (ROFF_NUM_NODE_WORKER_PROCS == -1)
	               ? roff_config.nodes[index].total_num_devices 
	               : ROFF_NUM_NODE_WORKER_PROCS;
	
	roff_man_get_device_chunk(roff_config.nodes[index].total_num_devices, 
	                          num_siblings, index, &first_device_id, &last_device_id);
#else
	/* Learn my ID */
	index = roff_man_get_my_id() - 1; // 1, 2, 3 -> 0, 1, 2
#endif /* ROFF_MULTIPLE_WORKERS */

	/* TODO check if this is correct */
	*nModules = roff_config.nodes[index].num_modules;
	if (roff_config.nodes[index].has_cpu_module)
		(*nModules)--;

	return roff_config.nodes[index].module_names;
}


/* This is called by the host and the other nodes upon first
 * initialization.
 */
void roff_worker_init(void)
{
	Comm_Init_worker(comminfo);
	DBGPRN((stderr, "[remote worker %s] >>> roff_worker_init \n", Comm_Get_info(comminfo)));

	override__dev_med2dev_addr();
	override_omp_is_initial_device();
	override_omp_get_device_num();

	DBGPRN((stderr, "[remote worker %s] >>> Blocking...\n", 
	                Comm_Get_info(comminfo)));
	
#ifdef ROFF_MULTIPLE_WORKERS
	parent_worker_node_id = Comm_Get_node_id(comminfo);
#endif
}


static void _worker_loop(void)
{
	int res;
	command_t cmd;

	cmd.args = smalloc(CMD_MAX_NUM_ARGS * sizeof(roff_datatype_t)); /* recycled */

	/* Loop waiting for commands */
	for (res = recv_command(&cmd); _workercmd_is_valid(cmd); res = recv_command(&cmd))
	{
		if (!res) continue;
		_workercmd_exec(cmd);
	}
		
	DBGPRN((stderr, "[remote worker %d]: >>> Illegal command (tag=%d, cmd=%llu) received\n", 
	                getpid(), cmd.tag, cmd.args[0]));

	free(cmd.args);
	exit(2);
}


#ifdef ROFF_WORKER_MULTITHREADING

static bool recv_command(command_t *command)
{
	Comm_Status st;
	threadinfo_t *my_thread_info = (threadinfo_t *) pthread_getspecific(key);

	if (command->args == NULL)
		return false;
	
	Comm_Probe(comminfo, PRIMARY_NODE, COMM_ANY_TAG, &st);

	/* If the targeted device is within my range, receive the message */
	if ((st.Int >= my_thread_info->first_device_id) && (st.Int < my_thread_info->last_device_id))
	{
		/* Type of command to execute + args */
		Comm_Recv(comminfo, PRIMARY_NODE, COMM_UNSIGNED_LONG_LONG, 
		          command->args, CMD_MAX_NUM_ARGS, st.Int, &st);
		command->tag = st.Int;
		return true;	
	}

	return false;
}

static void _free_thread_info(void *thrinfo)
{
	free(thrinfo);
}


static void *_worker_thread_exec(void *arg)
{
	threadinfo_t *my_thread_info = (threadinfo_t *) pthread_getspecific(key);

	if (my_thread_info == NULL)
	{
		my_thread_info = smalloc(sizeof(threadinfo_t));
		my_thread_info->thread_id = ((threadinfo_t *) arg)->thread_id;
		my_thread_info->first_device_id = ((threadinfo_t *) arg)->first_device_id;
		my_thread_info->last_device_id = ((threadinfo_t *) arg)->last_device_id;
		pthread_setspecific(key, my_thread_info);
	}

	/* Wait until all workers reach this point */
	pthread_barrier_wait(&barrier);

	/* Enter the receive-loop */
	_worker_loop();

	return NULL;
}


static void create_worker_threads(void)
{
	int i;
	threadinfo_t thrargs[256];
	void (*initialize)() = worker_commands[CMD(initialize)];

	/* (1) Set # of threads = # of worker devices */
	num_worker_threads = worker_threads.nthreads_used = 
	                     (ROFF_NUM_WORKER_THREADS == -1) ? ort->num_local_devices 
						 : ROFF_NUM_WORKER_THREADS;

	/* (2) Initialize the barrier, mutex and key */
	pthread_barrier_init(&barrier, NULL, worker_threads.nthreads_used);
	pthread_mutex_init(&mutex, NULL);
	pthread_key_create(&key, _free_thread_info);

	/* (3) Call the worker initialization function */
	initialize(0, NULL);
	
	/* (4) Create the threads */
	for (i = 0; i < worker_threads.nthreads_used; i++)
	{
		thrargs[i].thread_id = i;
		roff_man_get_device_chunk(ort->num_local_devices, worker_threads.nthreads_used, i, 
		                          &(thrargs[i].first_device_id), &(thrargs[i].last_device_id));
		pthread_create(&(worker_threads.threads[i]), NULL, _worker_thread_exec, (void*) &thrargs[i]);
	}

	/* (5) Join threads (threads exit when a SHUTDOWN command is sent) */
	for (i = 0; i < worker_threads.nthreads_used; i++)
		pthread_join(worker_threads.threads[i], NULL);

	/* (6) Destroy everything */	
	pthread_barrier_destroy(&barrier);
	pthread_mutex_destroy(&mutex);
	pthread_key_delete(key);

	Comm_Finalize(comminfo);
}

#else

static bool recv_command(command_t *command)
{
	Comm_Status st;

	if (command->args == NULL)
		return false;

	/* Type of command to execute + args */
	Comm_Recv(comminfo, PRIMARY_NODE, COMM_UNSIGNED_LONG_LONG, 
	          command->args, CMD_MAX_NUM_ARGS, COMM_ANY_TAG, &st);
	command->tag = st.Int;

	return true;
}

#endif /* ROFF_WORKER_MULTITHREADING */

void roff_worker_loop(void)
{
	/* Initialize my local devices and send my local devinfos to the primary node
	 */
	_worker_init_local_devices();
	
#ifdef ROFF_WORKER_MULTITHREADING
	/* Create all worker threads, one per device */
	create_worker_threads();
#else
	_worker_loop(); /* Enter the receive-loop directly */
#endif
} 

#endif /* OMPI_REMOTE_OFFLOADING */
