/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* reduction.cu
 * This file implements reduction
 */


#include "globals.h"
#include <cuda.h>

/* On a P40 GPU, and a particular test with 64 blocks, launching the kernel 
 * with cooperative groups and the grid sync implementation below, 
 * it takes 375usec per reduction.
 * If the kernel is launched normally, and blocks synchronize through critical
 * regions below, we need 425usec per reduction.
 * However, the first approach has an upper limit on the number of blocks 
 * (128 blocks would no work).
 * We keep the second approach.
 */
#if 0 /* __CUDA_ARCH__ >= 600  && && CUDART_VERSION >= 9000 */
	#include <cooperative_groups.h>
	namespace cg = cooperative_groups;
	#define DECLARE_GRID_GROUP() cg::grid_group grid = cg::this_grid();
	#define GRIDSYNC(g) cg::sync(g);
#else
	#define DECLARE_GRID_GROUP()
	__device__ int _teamredlock = 0;
	#define DOCRITICAL(code) \
	{\
		int __cuda_spin = 1;\
		while (__cuda_spin)\
			if (!atomicCAS(&_teamredlock, 0, 1)) {\
				__threadfence();\
				code\
				__threadfence();\
				atomicExch(&_teamredlock, 0);\
				__cuda_spin = 0;\
			}\
	}
#endif


union DataTypes {
	int ___i;
	short int __si;
	long int __li;
	long long int __Li;
	unsigned int _u_i;
	unsigned short int _usi;
	unsigned long int _uli;
	unsigned long long int _uLi;
	char ___c;
	double ___d;
	float ___f;
	unsigned char _u_c;
};

#define REDUCTION_MAX_THRPERTEAM 1024
__shared__ union DataTypes teamRes[REDUCTION_MAX_THRPERTEAM];


__device__ int get_prevPow2(int nthr)
{	
	int k = 1;
	while(nthr >>=1)
		k = k << 1;
	return k;
}


#define sync_local()  __syncthreads()

#if 0 /* __CUDA_ARCH__ >= 600  && && CUDART_VERSION >= 9000 */

#define sync_global() GRIDSYNC(grid)

#define REDUCTION_MAX_TEAMS 65536
__device__ union DataTypes blockRes[REDUCTION_MAX_TEAMS];

/* This is identical to the OpenCL one */
#define reduction_function_body(SUFFIX, TYPE, OP)\
	for (j=0; j < nelems; j++) /* Reduce elements one by one */\
	{\
		nthr = teamsize;\
		teamRes[tid].SUFFIX = priv_d[j];\
		sync_local();\
		while (!(nthr==1 || nthr==0))\
		{\
			prev_pow2 = get_prevPow2(nthr);\
			for (i = prev_pow2 / 2; i > 0; i >>= 1)\
			{\
				if (tid < i) {\
					local_update(teamRes[tid], teamRes[tid+i], OP, SUFFIX)\
				} /* if */\
				sync_local(); \
			} /* for */\
			\
			if (tid == prev_pow2 && nthr != prev_pow2) {\
				local_update(teamRes[0], teamRes[tid], OP, SUFFIX)\
			}\
			sync_local();\
			if (tid > prev_pow2 && nthr != prev_pow2)\
				teamRes[tid-prev_pow2].SUFFIX = teamRes[tid].SUFFIX;\
			sync_local();\
			nthr = nthr - prev_pow2;\
		} /* while */\
		\
		sync_local();\
		if (tid == 0)\
		{\
			if (numteams == 1) {\
				global_update(OP, SUFFIX)\
			} else\
				blockRes[teamid].SUFFIX = teamRes[0].SUFFIX;/* store workgroup result */\
		}\
		\
		sync_global(); /* Sync all teams */\
		if ((numteams > 1) && (teamid == 0))\
		{\
			/* Now transfer blockRes to teamRes[0] and do reductions within team 0 */\
			/* If #thr < #blocks, then combine excess elements before final red */\
			if (tid < numteams)\
				teamRes[tid].SUFFIX = blockRes[tid].SUFFIX;\
			if (numteams > nthr) /* combine excess elements */\
				for (i = nthr; tid + i < numteams; i += nthr)\
					local_update(teamRes[tid], blockRes[tid+i], OP, SUFFIX)\
			sync_local();\
			nthr = numteams;\
			while (!(nthr==1 || nthr==0))\
			{\
				prev_pow2 = get_prevPow2(nthr);\
				for (i = prev_pow2 / 2; i > 0; i >>= 1)\
				{\
					if (tid < i) {\
						local_update(teamRes[tid], teamRes[tid+i], OP, SUFFIX)\
					} /* if */\
					sync_local();\
				} /* for */\
				\
				if (tid == prev_pow2 && nthr != prev_pow2) {\
					local_update(teamRes[0], teamRes[tid], OP, SUFFIX)\
				}\
				sync_local();\
				if (tid > prev_pow2 && nthr != prev_pow2)\
					teamRes[tid-prev_pow2].SUFFIX = teamRes[tid].SUFFIX;\
				sync_local();\
				nthr = nthr - prev_pow2;\
			} /* while */\
			global_update(OP, SUFFIX)\
		} /* if */\
	} /* for */

#else  /* __CUDA_ARCH__ < 600*/

#define reduction_function_body(SUFFIX, TYPE, OP)\
	for (j=0; j < nelems; j++) /* Reduce elements one by one */\
	{\
		nthr = teamsize;\
		teamRes[tid].SUFFIX = priv_d[j];\
		sync_local();\
		while (!(nthr==1 || nthr==0))\
		{\
			prev_pow2 = get_prevPow2(nthr);\
			for (i = prev_pow2 / 2; i > 0; i >>= 1)\
			{\
				if (tid < i) {\
					local_update(teamRes[tid], teamRes[tid+i], OP, SUFFIX)\
				} /* if */\
				sync_local(); \
			} /* for */\
			\
			if (tid == prev_pow2 && nthr != prev_pow2) {\
				local_update(teamRes[0], teamRes[tid], OP, SUFFIX)\
			}\
			sync_local();\
			if (tid > prev_pow2 && nthr != prev_pow2)\
				teamRes[tid-prev_pow2].SUFFIX = teamRes[tid].SUFFIX;\
			sync_local();\
			nthr = nthr - prev_pow2;\
		} /* while */\
		\
		if (tid == 0)\
		{\
			if (numteams == 1) {\
				global_update(OP, SUFFIX) \
			} else\
				DOCRITICAL( global_update(OP, SUFFIX) )\
		}\
	} /* for */

#endif


#define define_reduction_function(OPNAME, SUFFIX, TYPE, COP)\
__device__ void reduce##OPNAME##SUFFIX(TYPE *priv_d, TYPE *global_d, int nelems)\
{\
	int i = 0, j = 0, nthr;\
	int tid = __THRID, teamid = __BLOCKID, teamsize =__NTHR, numteams = __NBLOCKS;\
	volatile int prev_pow2; /* to avoid optimizing out */\
	DECLARE_GRID_GROUP();\
	\
	reduction_function_body(SUFFIX, TYPE, COP)\
}


/* updaters for any op except min/max 
 */
#define local_update(to, from, OP, SUF)\
	to.SUF = to.SUF OP from.SUF;

#define global_update(OP, SUFFIX)\
			global_d[j] = global_d[j] OP teamRes[0].SUFFIX;


// add (+)			 
define_reduction_function(_add, ___i, int, + )
define_reduction_function(_add, __si, short int, + )
define_reduction_function(_add, __li, long int, + )
define_reduction_function(_add, __Li, long long int, + )
define_reduction_function(_add, _u_i, unsigned int, + )
define_reduction_function(_add, _usi, unsigned short int, + )
define_reduction_function(_add, _uli, unsigned long int, + )
define_reduction_function(_add, _uLi, unsigned long long int, + )
define_reduction_function(_add, ___c, char, + )
define_reduction_function(_add, ___d, double, +  )
define_reduction_function(_add, ___f, float, +  )
define_reduction_function(_add, _u_c, unsigned char, + )

//subtract (-)
define_reduction_function(_subtract, ___i, int, + )
define_reduction_function(_subtract, __si, short int, + )
define_reduction_function(_subtract, __li, long int, + )
define_reduction_function(_subtract, __Li, long long int, + )
define_reduction_function(_subtract, _u_i, unsigned int, + )
define_reduction_function(_subtract, _usi, unsigned short int, + )
define_reduction_function(_subtract, _uli, unsigned long int, + )
define_reduction_function(_subtract, _uLi, unsigned long long int, + )
define_reduction_function(_subtract, ___c, char, + )
define_reduction_function(_subtract, ___d, double, + )
define_reduction_function(_subtract, ___f, float, + )
define_reduction_function(_subtract, _u_c, unsigned char, + )

//multiply (*)
define_reduction_function(_multiply, ___i, int, * )
define_reduction_function(_multiply, __si, short int, * )
define_reduction_function(_multiply, __li, long int, * )
define_reduction_function(_multiply, __Li, long long int, * )
define_reduction_function(_multiply, _u_i, unsigned int, * )
define_reduction_function(_multiply, _usi, unsigned short int, * )
define_reduction_function(_multiply, _uli, unsigned long int, * )
define_reduction_function(_multiply, _uLi, unsigned long long int, * )
define_reduction_function(_multiply, ___c, char, * )
define_reduction_function(_multiply, ___d, double, * )
define_reduction_function(_multiply, ___f, float, * )
define_reduction_function(_multiply, _u_c, unsigned char, * )

//bitwise AND (&)
define_reduction_function(_bitand, ___i, int, & )
define_reduction_function(_bitand, __si, short int, & )
define_reduction_function(_bitand, __li, long int, & )
define_reduction_function(_bitand, __Li, long long int, & )
define_reduction_function(_bitand, _u_i, unsigned int, & )
define_reduction_function(_bitand, _usi, unsigned short int, & )
define_reduction_function(_bitand, _uli, unsigned long int, & )
define_reduction_function(_bitand, _uLi, unsigned long long int, & )
define_reduction_function(_bitand, ___c, char, & )
define_reduction_function(_bitand, _u_c, unsigned char, & )

//bitwise OR (|) 
define_reduction_function(_bitor, ___i, int, | )
define_reduction_function(_bitor, __si, short int, | )
define_reduction_function(_bitor, __li, long int, | )
define_reduction_function(_bitor, __Li, long long int, | )
define_reduction_function(_bitor, _u_i, unsigned int, | )
define_reduction_function(_bitor, _usi, unsigned short int, | )
define_reduction_function(_bitor, _uli, unsigned long int, | )
define_reduction_function(_bitor, _uLi, unsigned long long int, | )
define_reduction_function(_bitor, ___c, char, | )
define_reduction_function(_bitor, _u_c, unsigned char, | )

//bitwise XOR (^)
define_reduction_function(_bitxor, ___i, int, ^ )
define_reduction_function(_bitxor, __si, short int, ^ )
define_reduction_function(_bitxor, __li, long int, ^ )
define_reduction_function(_bitxor, __Li, long long int, ^ )
define_reduction_function(_bitxor, _u_i, unsigned int, ^ )
define_reduction_function(_bitxor, _usi, unsigned short int, ^ )
define_reduction_function(_bitxor, _uli, unsigned long int, ^ )
define_reduction_function(_bitxor, _uLi, unsigned long long int, ^ )
define_reduction_function(_bitxor, ___c, char, ^ )
define_reduction_function(_bitxor, _u_c, unsigned char, ^ )

//logical AND (&&)
define_reduction_function(_and, ___i, int, && )
define_reduction_function(_and, __si, short int, && )
define_reduction_function(_and, __li, long int, && )
define_reduction_function(_and, __Li, long long int, && )
define_reduction_function(_and, _u_i, unsigned int, && )
define_reduction_function(_and, _usi, unsigned short int, && )
define_reduction_function(_and, _uli, unsigned long int, && )
define_reduction_function(_and, _uLi, unsigned long long int, && )
define_reduction_function(_and, ___c, char, && )
define_reduction_function(_and, ___d, double, && )
define_reduction_function(_and, ___f, float, && )
define_reduction_function(_and, _u_c, unsigned char, && )

//logical OR (||)
define_reduction_function(_or, ___i, int, || )
define_reduction_function(_or, __si, short int, || )
define_reduction_function(_or, __li, long int, || )
define_reduction_function(_or, __Li, long long int, || )
define_reduction_function(_or, _u_i, unsigned int, || )
define_reduction_function(_or, _usi, unsigned short int, || )
define_reduction_function(_or, _uli, unsigned long int, || )
define_reduction_function(_or, _uLi, unsigned long long int, || )
define_reduction_function(_or, ___c, char, || )
define_reduction_function(_or, ___d, double, || )
define_reduction_function(_or, ___f, float, || )
define_reduction_function(_or, _u_c, unsigned char, || )


#undef local_update
#undef global_update

/* updaters for min/max
 */
#define local_update(to, from, OP, SUF)\
	if (from.SUF OP to.SUF)\
		to.SUF = from.SUF;

#define global_update(OP, SUFFIX)\
			if (teamRes[0].SUFFIX OP global_d[j])\
				global_d[j] = teamRes[0].SUFFIX;

// max 
define_reduction_function(_max, ___i, int, > )
define_reduction_function(_max, __si, short int, > )
define_reduction_function(_max, __li, long int, > )
define_reduction_function(_max, __Li, long long int, > )
define_reduction_function(_max, _u_i, unsigned int, > )
define_reduction_function(_max, _usi, unsigned short int, > )
define_reduction_function(_max, _uli, unsigned long int, > )
define_reduction_function(_max, _uLi, unsigned long long int, > )
define_reduction_function(_max, ___c, char, > )
define_reduction_function(_max, ___d, double, > )
define_reduction_function(_max, ___f, float, > )
define_reduction_function(_max, _u_c, unsigned char, > )

// min 
define_reduction_function(_min, ___i, int, < )
define_reduction_function(_min, __si, short int, < )
define_reduction_function(_min, __li, long int, < )
define_reduction_function(_min, __Li, long long int, < )
define_reduction_function(_min, _u_i, unsigned int, < )
define_reduction_function(_min, _usi, unsigned short int, < )
define_reduction_function(_min, _uli, unsigned long int, < )
define_reduction_function(_min, _uLi, unsigned long long int, < )
define_reduction_function(_min, ___c, char, < )
define_reduction_function(_min, ___d, double, < )
define_reduction_function(_min, ___f, float, < )
define_reduction_function(_min, _u_c, unsigned char, < )

//JUMP TABLES

typedef void (*redfunc_t)(void *, void *, int);

/* The NULL below is for long double which is unsupported */
#define REDFUNC_FULL_JUMP_TABLE(op) \
__device__ static redfunc_t op ## _jump[] = { \
	(redfunc_t) reduce_##op##___i, \
	(redfunc_t) reduce_##op##__si, \
	(redfunc_t) reduce_##op##__li, \
	(redfunc_t) reduce_##op##__Li, \
	(redfunc_t) reduce_##op##_u_i, \
	(redfunc_t) reduce_##op##_usi, \
	(redfunc_t) reduce_##op##_uli, \
	(redfunc_t) reduce_##op##_uLi, \
	(redfunc_t) reduce_##op##___c, \
	(redfunc_t) reduce_##op##___d, \
	(redfunc_t) reduce_##op##___f, \
	(redfunc_t) NULL, \
	(redfunc_t) reduce_##op##_u_c  \
};

REDFUNC_FULL_JUMP_TABLE(add);
REDFUNC_FULL_JUMP_TABLE(subtract);
REDFUNC_FULL_JUMP_TABLE(multiply);
REDFUNC_FULL_JUMP_TABLE(and);
REDFUNC_FULL_JUMP_TABLE(or);
REDFUNC_FULL_JUMP_TABLE(max);
REDFUNC_FULL_JUMP_TABLE(min);

#define REDFUNC_INT_JUMP_TABLE(op)\
__device__ static redfunc_t op ## _jump[] = { \
	(redfunc_t) reduce_##op##___i, \
	(redfunc_t) reduce_##op##__si, \
	(redfunc_t) reduce_##op##__li, \
	(redfunc_t) reduce_##op##__Li, \
	(redfunc_t) reduce_##op##_u_i, \
	(redfunc_t) reduce_##op##_usi, \
	(redfunc_t) reduce_##op##_uli, \
	(redfunc_t) reduce_##op##_uLi, \
	(redfunc_t) reduce_##op##___c, \
	(redfunc_t) NULL, \
	(redfunc_t) NULL, \
	(redfunc_t) NULL, \
	(redfunc_t) reduce_##op##_u_c \
}; 

REDFUNC_INT_JUMP_TABLE(bitand);
REDFUNC_INT_JUMP_TABLE(bitor);
REDFUNC_INT_JUMP_TABLE(bitxor);

// THE INTERFACE

#define EXPORTED_REDUCTION_FUNCTION(op) \
__device__ void _ort_reduce_##op(int type, void *local, void *global, int nelems)\
{\
	(* op##_jump[type])(local, global, nelems);\
}

EXPORTED_REDUCTION_FUNCTION(add)
EXPORTED_REDUCTION_FUNCTION(subtract)
EXPORTED_REDUCTION_FUNCTION(multiply)
EXPORTED_REDUCTION_FUNCTION(and)
EXPORTED_REDUCTION_FUNCTION(or)
EXPORTED_REDUCTION_FUNCTION(max)
EXPORTED_REDUCTION_FUNCTION(min)
EXPORTED_REDUCTION_FUNCTION(bitand)
EXPORTED_REDUCTION_FUNCTION(bitor)
EXPORTED_REDUCTION_FUNCTION(bitxor)
