/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* This is the MPI comm library for remote offloading
 */

// #define DBGPRN_FORCE 
// #define DBGPRN_BLOCK
#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <mpi.h>
#include "comm.h"
#include "../sysdeps.h"
#include "rt_common.h"
#include "assorted.h"

#ifdef OMPI_REMOTE_OFFLOADING

#include "roff_prive.h"
#include "roff.h"
#include "roff_config.h"

/* 
 * MPI communication layer 
 */

#ifdef ROFF_USE_STATIC_MPI_PROCS

void *Comm_Init(int *argc, char ***argv)
{
	MPI_info_t *mpi_info = smalloc(sizeof(MPI_info_t));
	int rank;

	MPI_Initialized(&mpi_info->initialized);
	if (mpi_info->initialized)
	{
		DBGPRN((stderr, "[comm_mpi] MPI was already initialized at process %s\n", 
		                 Comm_Get_info(mpi_info)));
		return NULL;
	}

	mpi_info->argc = argc;
	mpi_info->argv = argv;

	MPI_Init(argc, argv);
	MPI_Initialized(&(mpi_info->initialized));

	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
	mpi_info->role = node_role = (rank != 0); /* 0 is primary, every other process is a worker */

	DBGPRN((stderr, "[comm_mpi] MPI initialized at process %s\n", Comm_Get_info(mpi_info)));

	return (void*) mpi_info;
}


void Comm_Finalize(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Finalize();
}


static MPI_Comm spawn_1proc_per_node(MPI_info_t *mpi_info, int num_nodes)
{
	int buf = WAKE_UP_MESSAGE;

	/* Wake up message */
	MPI_Bcast(&buf, 1, MPI_INT, PRIMARY_NODE, MPI_COMM_WORLD);
	
	return MPI_COMM_WORLD;
}


void Comm_Init_worker(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	int buf;
	MPI_Comm parent_comm, merged_comm;
	MPI_Comm_get_parent(&parent_comm); 

	/* All other nodes (workers) come here; they will have to block waiting
	 * for commands from the primary node instead of executing user's code.
	 */
	 
	MPI_Bcast(&buf, 1, MPI_INT, PRIMARY_NODE, MPI_COMM_WORLD);
	if (buf != WAKE_UP_MESSAGE)
	{
		fprintf(stderr, "[remote worker (%d)]: illegal broadcasted message.\n",
		                getpid());
		exit(EXIT_FAILURE);
	}
	mpi_info->communicator = MPI_COMM_WORLD;
}


#else /* Dynamic processes */


void *Comm_Init(int *argc, char ***argv)
{
	MPI_info_t *mpi_info = smalloc(sizeof(MPI_info_t));
	MPI_Comm parent_comm;

	MPI_Initialized(&(mpi_info->initialized));
	if (mpi_info->initialized)
	{
		DBGPRN((stderr, "[comm_mpi] MPI was already initialized at process %s\n", 
		                 Comm_Get_info(mpi_info)));
		return NULL;
	}

	mpi_info->argc = argc;
	mpi_info->argv = argv;

	MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, &mpi_info->provided);
	MPI_Initialized(&(mpi_info->initialized));

	if (mpi_info->provided != MPI_THREAD_MULTIPLE)
	{
		fprintf(stderr, "Your MPI library was not configured with "
						"MPI_THREAD_MULTIPLE support. Aborting...\n");
		MPI_Abort(MPI_COMM_WORLD, 1);
	}

	MPI_Comm_get_parent(&parent_comm);
	mpi_info->role = node_role = (parent_comm != MPI_COMM_NULL); /* 0 for primary, 1 for worker */

	DBGPRN((stderr, "[comm_mpi] MPI initialized at process %s\n", Comm_Get_info(mpi_info)));
	
	return (void*) mpi_info;
}


void Comm_Finalize(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;

	MPI_Comm_free(&mpi_info->communicator);
	MPI_Finalize();
}

#ifdef ROFF_MULTIPLE_WORKERS

static MPI_Comm spawn_nprocs_per_node(MPI_info_t *mpi_info, int num_nodes, int num_procs)
{
	MPI_Comm merged_comm;
	int i, total_spawned_procs = 0, *nodetab, count = 0;

	if ((num_nodes == 0) || (num_procs == 0))
		return MPI_COMM_NULL;

	total_spawned_procs = (num_procs == -1) ? roff_config.num_devices : num_nodes * num_procs;
	nodetab = (int*) malloc(total_spawned_procs * sizeof(int));

	for (i = 0; i < num_nodes; i++)
	{
		int num_spawned_procs, j, _nodeid = i;
		num_spawned_procs = (num_procs == -1) ? roff_config.nodes[i].total_num_devices : num_procs;

		MPI_Info info;
		MPI_Comm spawned_comm;

		if (MPI_Info_create(&info) != MPI_SUCCESS)
		{
			fprintf(stderr, "[comm_mpi] spawn_nprocs_per_node: MPI_Info_create failure\n");
			return MPI_COMM_NULL;
		}
		
		if (MPI_Info_set(info, "host", roff_config.nodes[i].name) != MPI_SUCCESS)
		{
			fprintf(stderr, "[comm_mpi] spawn_nprocs_per_node: MPI_Info_set failure\n");
			return MPI_COMM_NULL;
		}

		if (MPI_Comm_spawn(**(mpi_info->argv), *(mpi_info->argv) + 1, num_spawned_procs, info, 0,
		                   MPI_COMM_WORLD, &spawned_comm, MPI_ERRCODES_IGNORE) != MPI_SUCCESS)
		{
			fprintf(stderr, "[comm_mpi] spawn_nprocs_per_node: MPI_Comm_spawn failure\n");
			return MPI_COMM_NULL;
		}

		MPI_Info_free(&info);

		/* Primary node sets high = 0 (second arg) so the process @ primary node gets 
		* rank = 0 in the merged communicator. 
		*/
		MPI_Intercomm_merge(spawned_comm, 0, &merged_comm);
		for (j = count; j < count + num_spawned_procs; j++)
			MPI_Send(&_nodeid, 1, MPI_INT, j + 1, 0, merged_comm);

		count += num_spawned_procs;
	}

	free(nodetab);

	return merged_comm;
}

#else

static MPI_Comm spawn_1proc_per_node(MPI_info_t *mpi_info, int num_nodes)
{
	MPI_Info info;
	MPI_Comm spawned_comm, merged_comm;
	char *all_nodes_string = roff_config_get_node_names();

	if (num_nodes == 0)
		return MPI_COMM_NULL;

	if (MPI_Info_create(&info) != MPI_SUCCESS)
	{
		fprintf(stderr, "[comm_mpi] spawn_1proc_per_node: MPI_Info_create failure\n");
		return MPI_COMM_NULL;
	}
	
	if (MPI_Info_set(info, "host", all_nodes_string) != MPI_SUCCESS)
	{
		fprintf(stderr, "[comm_mpi] spawn_1proc_per_node: MPI_Info_set failure\n");
		return MPI_COMM_NULL;
	}

	if (MPI_Comm_spawn(**(mpi_info->argv), *(mpi_info->argv) + 1, num_nodes, info, 0,
	                  MPI_COMM_WORLD, &spawned_comm, MPI_ERRCODES_IGNORE) != MPI_SUCCESS)
	{
		fprintf(stderr, "[comm_mpi] spawn_1proc_per_node: MPI_Comm_spawn failure\n");
		return MPI_COMM_NULL;
	}

	MPI_Info_free(&info);

	/* Primary node sets high = 0 (second arg) so the process @ primary node gets 
	 * rank = 0 in the merged communicator. 
	 */
	MPI_Intercomm_merge(spawned_comm, 0, &merged_comm);
	if (merged_comm == MPI_COMM_NULL)
	{
		fprintf(stderr, "[comm_mpi] spawn_1proc_per_node: MPI_Intercomm_merge failure\n");
		return MPI_COMM_NULL;
	}

	DBGPRN((stderr, "[comm_mpi] spawn_1proc_per_node: spawned %d processes on %d nodes\n", 
	                num_nodes, num_nodes));

	return merged_comm;
}

#endif

void Comm_Init_worker(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Comm parent_comm, merged_comm;
	MPI_Comm_get_parent(&parent_comm); 

	/* Workers set high = 1 (second arg) so the primary node gets rank = 0 in
	 * the merged communicator. 
	 */
	MPI_Intercomm_merge(parent_comm, 1, &merged_comm);
	if (merged_comm == MPI_COMM_NULL)
	{
		fprintf(stderr, "[comm_mpi] Comm_Init_worker: MPI_Intercomm_merge failure\n");
		exit(EXIT_FAILURE);
	}

	mpi_info->communicator = merged_comm;

#ifdef ROFF_MULTIPLE_WORKERS
	MPI_Recv(&(mpi_info->parent_worker_node_id), 1, MPI_INT, 0, MPI_ANY_TAG, merged_comm, MPI_STATUS_IGNORE);
#endif

	DBGPRN((stderr, "[comm_mpi] Comm_Init_worker: worker %d initialized\n", 
	                mpi_info->parent_worker_node_id));
}

#endif /* ROFF_USE_STATIC_MPI_PROCS */


void Comm_Spawn(void *info, int num_nodes, int num_procs)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;

#if defined(ROFF_MULTIPLE_WORKERS) && !defined(ROFF_USE_STATIC_MPI_PROCS)
	mpi_info->communicator = spawn_nprocs_per_node(mpi_info, num_nodes, num_procs);
#else
	mpi_info->communicator = spawn_1proc_per_node(mpi_info, num_nodes);
#endif
	if (mpi_info->communicator == MPI_COMM_NULL)
		MPI_Abort(MPI_COMM_WORLD, 1);

	DBGPRN((stderr, "[comm_mpi] Comm_Spawn: spawned %d processes on %d nodes\n", 
	                num_procs, num_nodes));
}


void Comm_Send(void *info, int dst, Comm_Datatype type, void *buf, int size, int tag, 
               Comm_Status *status)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Datatype mpitype = get_MPI_type(type);

	if (mpitype == MPI_DATATYPE_NULL)
	{
		fprintf(stderr, "[comm_mpi] Comm_Send: unknown data type; exiting.");
		exit(1);
	}

	MPI_Send(buf, size, mpitype, dst, tag, mpi_info->communicator);

	if (status != NULL)
	{
		status->Int = 1;
		status->Str = NULL;
	}

	DBGPRN((stderr, "[comm_mpi] Comm_Send: sent %d bytes to %d\n", size, dst));
}


void Comm_Recv(void *info, int src, Comm_Datatype type, void *buf, int size, int tag, 
               Comm_Status *status)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Status st;
	MPI_Datatype mpitype = get_MPI_type(type);

	if (tag == COMM_ANY_TAG)
		tag = MPI_ANY_TAG;
	
	if (mpitype == MPI_DATATYPE_NULL)
	{
		fprintf(stderr, "[comm_mpi] Comm_Recv: unknown data type; exiting.");
		exit(1);
	}

	if (status == NULL)
		MPI_Recv(buf, size, mpitype, src, tag, mpi_info->communicator, MPI_STATUS_IGNORE);
	else
	{
		MPI_Recv(buf, size, mpitype, src, tag, mpi_info->communicator, &st);
		status->Int = st.MPI_TAG;
		status->Str = NULL;
	}

	DBGPRN((stderr, "[comm_mpi] Comm_Recv: received %d bytes from %d\n", size, src));
}


int Comm_Probe(void *info, int src, int tag, Comm_Status *status)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Status st;
	int res;

	if (tag == COMM_ANY_TAG)
		tag = MPI_ANY_TAG;

	if (status == NULL)
		res = MPI_Probe(src, tag, mpi_info->communicator, MPI_STATUS_IGNORE);
	else
	{
		res = MPI_Probe(src, tag, mpi_info->communicator, &st);
		status->Int = st.MPI_TAG;
		status->Str = NULL;
	}

	return res;
}


int Comm_Get_node_id(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	return mpi_info->parent_worker_node_id;
}


char *Comm_Get_info(void *info)
{
	static char name[MPI_MAX_PROCESSOR_NAME], infostring[MPI_MAX_PROCESSOR_NAME+64];
	static int name_len;

	MPI_Get_processor_name(name, &name_len);
	sprintf(infostring, "%d @ %s", getpid(), name);

	return infostring;
}


int Comm_Get_id(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	int rank;

	MPI_Comm_rank(mpi_info->communicator, &rank);
	return rank;
}

void Comm_Barrier(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Barrier(mpi_info->communicator);
}

#endif /* OMPI_REMOTE_OFFLOADING */