/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* This is the MPI comm library for remotedev
 */

// #define DBGPRN_FORCE 
// #define DBGPRN_BLOCK
#include "config.h"

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "../sysdeps.h"
#include "../../../rt_common.h"
#include "comm.h"
#include <mpi.h>

#ifdef OMPI_REMOTE_OFFLOADING

#include "rdev_prive.h"
#include "rdev.h"
#include "rdev_config.h"

/* 
 * MPI communication layer 
 */

#ifdef USE_STATIC_MPI_PROCS

void *Comm_Init(int *argc, char ***argv)
{
	MPI_info_t *mpi_info = malloc(sizeof(MPI_info_t));
	int rank;

	MPI_Initialized(&mpi_info->initialized);
	if (mpi_info->initialized)
	{
		DBGPRN((stderr, "[comm_mpi] MPI was already initialized @ %s\n", 
			     Comm_Get_info(mpi_info)));
		return NULL;
	}

	mpi_info->argc = argc;
	mpi_info->argv = argv;

	MPI_Init(argc, argv);
	MPI_Comm_rank(MPI_COMM_WORLD, &rank);
	mpi_info->role = node_role = (rank != 0); /* 0 is primary, every other process is a worker */

	DBGPRN((stderr, "[comm_mpi] MPI initialized @ %s\n", Comm_Get_info(mpi_info)));

	return (void*) mpi_info;
}


void Comm_Finalize(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Finalize();
}


static MPI_Comm spawn_threads(MPI_info_t *mpi_info, int num_threads)
{
	int buf = WAKE_UP_MESSAGE;

	/* Wake up message */
	MPI_Bcast(&buf, 1, MPI_INT, PRIMARY_NODE, MPI_COMM_WORLD);
	
	return MPI_COMM_WORLD;
}


void Comm_Init_worker(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	int buf;
	MPI_Comm parent_comm, merged_comm;
	MPI_Comm_get_parent(&parent_comm); 

	/* All other nodes (workers) come here; they will have to block waiting
	 * for commands from the primary node instead of executing user's code.
	 */
	 
	MPI_Bcast(&buf, 1, MPI_INT, PRIMARY_NODE, MPI_COMM_WORLD);
	if (buf != WAKE_UP_MESSAGE)
	{
		fprintf(stderr, "[remotedev worker (%d)]: illegal broadcasted message.\n",
		                getpid());
		exit(EXIT_FAILURE);
	}
	mpi_info->communicator = MPI_COMM_WORLD;
}


#else /* Dynamic processes */


void *Comm_Init(int *argc, char ***argv)
{
	MPI_info_t *mpi_info = malloc(sizeof(MPI_info_t));
	MPI_Comm parent_comm;
	int rank;

	MPI_Initialized(&mpi_info->initialized);
	if (mpi_info->initialized)
	{
		DBGPRN((stderr, "[comm_mpi] MPI was already initialized @ %s\n", 
			     Comm_Get_info(mpi_info)));
		return NULL;
	}

	mpi_info->argc = argc;
	mpi_info->argv = argv;

	MPI_Init_thread(argc, argv, MPI_THREAD_MULTIPLE, &mpi_info->provided);
	if (mpi_info->provided != MPI_THREAD_MULTIPLE)
	{
		fprintf(stderr, "Your MPI library was not configured with "
						"MPI_THREAD_MULTIPLE support. Aborting...\n");
		MPI_Abort(MPI_COMM_WORLD, 1);
	}

	MPI_Comm_get_parent(&parent_comm);
	mpi_info->role = node_role = (parent_comm != MPI_COMM_NULL); /* 0 for primary, 1 for worker */

	DBGPRN((stderr, "[comm_mpi] MPI initialized @ %s\n", Comm_Get_info(mpi_info)));
	
	return (void*) mpi_info;
}


void Comm_Finalize(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;

	MPI_Comm_free(&mpi_info->communicator);
	MPI_Finalize();
}


static MPI_Comm spawn_threads(MPI_info_t *mpi_info, int num_threads)
{
	MPI_Info info;
	MPI_Comm spawned_comm, merged_comm;
	char *all_nodes_string = rdev_config_get_node_names();

	if (MPI_Info_create(&info) != MPI_SUCCESS)
	{
		fprintf(stderr, "[comm_mpi] spawn_threads: MPI_Info_create failure\n");
		return MPI_COMM_NULL;
	}
	
	if (MPI_Info_set(info, "host", all_nodes_string) != MPI_SUCCESS)
	{
		fprintf(stderr, "[comm_mpi] spawn_threads: MPI_Info_set failure\n");
		return MPI_COMM_NULL;
	}

	if (MPI_Comm_spawn(**(mpi_info->argv), *(mpi_info->argv) + 1, num_threads, info, 0,
						 MPI_COMM_WORLD, &spawned_comm, MPI_ERRCODES_IGNORE) != MPI_SUCCESS)
	{
		fprintf(stderr, "[comm_mpi] spawn_threads: MPI_Comm_spawn failure\n");
		return MPI_COMM_NULL;
	}

	MPI_Info_free(&info);

	/* Primary node sets high = 0 (second arg) so the process @ primary node gets 
	 * rank = 0 in the merged communicator. 
	 */
	MPI_Intercomm_merge(spawned_comm, 0, &merged_comm);
	
	free(all_nodes_string);

	return merged_comm;
}


void Comm_Init_worker(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	int buf;
	MPI_Comm parent_comm, merged_comm;
	MPI_Comm_get_parent(&parent_comm); 

	/* Workers set high = 1 (second arg) so the primary node gets rank = 0 in
	 * the merged communicator. 
	 */
	MPI_Intercomm_merge(parent_comm, 1, &merged_comm);
	mpi_info->communicator = merged_comm;
}

#endif /* USE_STATIC_MPI_PROCS */


void Comm_Spawn(void *info, int num_threads)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;

	mpi_info->communicator = spawn_threads(mpi_info, num_threads);
	if (mpi_info->communicator == MPI_COMM_NULL)
		MPI_Abort(MPI_COMM_WORLD, 1);
}


void Comm_Send(void *info, int dst, Comm_Datatype type, void *buf, int size, int tag, 
               Comm_Status *status)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Datatype mpitype = get_MPI_type(type);

	if (mpitype == MPI_DATATYPE_NULL)
	{
		fprintf(stderr, "[comm_mpi] Comm_Send: unknown data type; exiting.");
		exit(1);
	}

	MPI_Send(buf, size, mpitype, dst, tag, mpi_info->communicator);

	if (status != NULL)
	{
		status->info_num = 1;
		status->info_str = NULL;
	}
}


void Comm_Recv(void *info, int src, Comm_Datatype type, void *buf, int size, int tag, 
               Comm_Status *status)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	MPI_Status st;
	MPI_Datatype mpitype = get_MPI_type(type);
	
	if (mpitype == MPI_DATATYPE_NULL)
	{
		fprintf(stderr, "[comm_mpi] Comm_Recv: unknown data type; exiting.");
		exit(1);
	}

	MPI_Recv(buf, size, mpitype, src, tag, mpi_info->communicator, &st);

	if (status != NULL)
	{
		status->info_num = st.MPI_TAG;
		status->info_str = NULL;
	}
}


char *Comm_Get_info(void *info)
{
	static char name[MPI_MAX_PROCESSOR_NAME], infostring[MPI_MAX_PROCESSOR_NAME+64];
	static int name_len;
	MPI_info_t *mpi_info = (MPI_info_t*) info;

	MPI_Get_processor_name(name, &name_len);
	sprintf(infostring, "node %s (pid = %d)", name, getpid());

	return infostring;
}


int Comm_Get_id(void *info)
{
	MPI_info_t *mpi_info = (MPI_info_t*) info;
	int rank;

	MPI_Comm_rank(mpi_info->communicator, &rank);

	return rank;
}

#endif /* OMPI_REMOTE_OFFLOADING */