#include "globals.h"
#include <cuda.h>
#include <stdarg.h>
#include <stdlib.h>
#include <stdio.h>
#if __CUDA_ARCH__ >= 600
#include <cooperative_groups.h>
#endif

union DataTypes {
	int ___i;
	short int __si;
	long int __li;
	long long int __Li;
	unsigned int _u_i;
	unsigned short int _usi;
	unsigned long int _uli;
	unsigned long long int _uLi;
	char ___c;
	double ___d;
	float ___f;
	unsigned char _u_c;
};

#if __CUDA_ARCH__ >= 600
namespace cg = cooperative_groups;
#endif

__SHAREDQLFR union DataTypes partialRes[1024];
__DEVQLFR union DataTypes blockRes[30];

__DEVQLFR int get_prevPow2(int nthr)
{	
	int k = 1;
	while(nthr >>=1)
		k = k << 1;
	return k;
}

#if __CUDA_ARCH__ >= 600
	#define DECLARE_GRID_GROUP() cg::grid_group grid = cg::this_grid();
	#define GRIDSYNC(g) cg::sync(g);
#else
	#define DECLARE_GRID_GROUP()
	#define GRIDSYNC(g) 
#endif

#define reduction_function_prologue(COPERATOR, SUFFIX, TYPE)\
__DEVQLFR void reduce##COPERATOR##SUFFIX(TYPE *local_d, TYPE *global_d, int nelems)\
{\
	int i = 0, j = 0;\
	int tid = __THRID,\
	    block_id = __BLOCKID,\
	    nthr = __NTHR,\
	    nblocks = __NBLOCKS,\
	    prev_pow2;\
	\
	DECLARE_GRID_GROUP();\
	\
	for(j=0; j<nelems; j++)\
	{\
		partialRes[tid].SUFFIX = local_d[j];\
		__syncthreads();\
		while(!(nthr==1 || nthr==0))\
		{\
			prev_pow2 = get_prevPow2(nthr);\
			for(i = prev_pow2 / 2; i > 0; i >>= 1)\
			{\
				if(tid < i)\
				{

#define reduction_function_after_block(COPERATOR, SUFFIX, TYPE)\
			if((tid > prev_pow2) && (nthr - prev_pow2 != 0))\
			{\
				partialRes[tid-prev_pow2].SUFFIX = partialRes[tid].SUFFIX;\
			}\
			__syncthreads();\
			nthr = nthr - prev_pow2;\
		}\
		\
		__syncthreads();\
		if(tid == 0)\
		{\
			if(nblocks == 1)\
			{\
				global_d[j] = partialRes[0].SUFFIX;\
			}\
			else\
			{\
				blockRes[block_id].SUFFIX = partialRes[0].SUFFIX;\
			}\
		}\
		\
		GRIDSYNC(grid);\
		\
		if((nblocks > 1) && (block_id == 0))\
		{\
			for(i = 0; i < nblocks; i++)\
				partialRes[i].SUFFIX = blockRes[i].SUFFIX;\
			__syncthreads();\
			nthr = nblocks;\
			while(!(nthr==1 || nthr==0))\
			{\
				prev_pow2 = get_prevPow2(nthr);\
				for(i = prev_pow2 / 2; i > 0; i >>= 1)\
				{\
					if(tid < i)\
					{
			

#define reduction_function_epilogue(COPERATOR, SUFFIX, TYPE)\
					if((tid > prev_pow2) && (nthr - prev_pow2 != 0))\
					{\
						partialRes[tid-prev_pow2].SUFFIX = partialRes[tid].SUFFIX;\
					}\
					__syncthreads();\
					nthr = nthr - prev_pow2;\
				}\
			global_d[j] = partialRes[0].SUFFIX;\
		}\
	}\
	\
	return;\
}



#define define_reduction_function_anyop(COPERATOR, SUFFIX, TYPE, OPERATOR)\
	reduction_function_prologue(COPERATOR, SUFFIX, TYPE)\
					partialRes[tid].SUFFIX = partialRes[tid].SUFFIX OPERATOR partialRes[tid+i].SUFFIX;\
				}\
				__syncthreads();\
                	}\
			\
			if(tid == prev_pow2 && nthr-prev_pow2!=0)\
			{\
				partialRes[0].SUFFIX = partialRes[0].SUFFIX OPERATOR partialRes[tid].SUFFIX;\
			}\
			__syncthreads();\
	reduction_function_after_block(COPERATOR, SUFFIX, TYPE)\
							partialRes[tid].SUFFIX = partialRes[tid].SUFFIX OPERATOR partialRes[tid+i].SUFFIX;\
					}\
					__syncthreads();\
                		}\
				\
				if(tid == prev_pow2 && nthr-prev_pow2!=0)\
				{\
					partialRes[0].SUFFIX = partialRes[0].SUFFIX OPERATOR partialRes[tid].SUFFIX;\
				}\
				__syncthreads();\
	reduction_function_epilogue(COPERATOR, SUFFIX, TYPE)

#define define_reduction_function_minmax(COPERATOR, SUFFIX, TYPE, OPERATOR)\
	reduction_function_prologue(COPERATOR, SUFFIX, TYPE)\
					if(partialRes[tid+i].SUFFIX OPERATOR partialRes[tid].SUFFIX)\
						partialRes[tid+i].SUFFIX = partialRes[tid].SUFFIX;\
				}\
				__syncthreads();\
			}\
			\
			if((block_id == prev_pow2) && (nthr-prev_pow2 != 0))\
			{\
				if(partialRes[tid].SUFFIX OPERATOR partialRes[0].SUFFIX)\
					partialRes[0].SUFFIX = partialRes[tid].SUFFIX;\
			}\
	reduction_function_after_block(COPERATOR, SUFFIX, TYPE)\
							if(partialRes[tid+i].SUFFIX OPERATOR partialRes[tid].SUFFIX)\
								partialRes[tid+i].SUFFIX = partialRes[tid].SUFFIX;\
						}\
						__syncthreads();\
					}\
					\
					if((block_id == prev_pow2) && (nthr-prev_pow2!=0))\
					{\
						if(partialRes[tid].SUFFIX OPERATOR partialRes[0].SUFFIX)\
							partialRes[0].SUFFIX = partialRes[tid].SUFFIX;\
					}\
	reduction_function_epilogue(COPERATOR, SUFFIX, TYPE)

// add (+)			 
define_reduction_function_anyop(_add, ___i, int, + )
define_reduction_function_anyop(_add, __si, short int, + )
define_reduction_function_anyop(_add, __li, long int, + )
define_reduction_function_anyop(_add, __Li, long long int, + )
define_reduction_function_anyop(_add, _u_i, unsigned int, + )
define_reduction_function_anyop(_add, _usi, unsigned short int, + )
define_reduction_function_anyop(_add, _uli, unsigned long int, + )
define_reduction_function_anyop(_add, _uLi, unsigned long long int, + )
define_reduction_function_anyop(_add, ___c, char, + )
define_reduction_function_anyop(_add, ___d, double, +  )
define_reduction_function_anyop(_add, ___f, float, +  )
define_reduction_function_anyop(_add, _u_c, unsigned char, + )

//subtract (-)
define_reduction_function_anyop(_subtract, ___i, int, + )
define_reduction_function_anyop(_subtract, __si, short int, + )
define_reduction_function_anyop(_subtract, __li, long int, + )
define_reduction_function_anyop(_subtract, __Li, long long int, + )
define_reduction_function_anyop(_subtract, _u_i, unsigned int, + )
define_reduction_function_anyop(_subtract, _usi, unsigned short int, + )
define_reduction_function_anyop(_subtract, _uli, unsigned long int, + )
define_reduction_function_anyop(_subtract, _uLi, unsigned long long int, + )
define_reduction_function_anyop(_subtract, ___c, char, + )
define_reduction_function_anyop(_subtract, ___d, double, + )
define_reduction_function_anyop(_subtract, ___f, float, + )
define_reduction_function_anyop(_subtract, _u_c, unsigned char, + )

//multiply (*)
define_reduction_function_anyop(_multiply, ___i, int, * )
define_reduction_function_anyop(_multiply, __si, short int, * )
define_reduction_function_anyop(_multiply, __li, long int, * )
define_reduction_function_anyop(_multiply, __Li, long long int, * )
define_reduction_function_anyop(_multiply, _u_i, unsigned int, * )
define_reduction_function_anyop(_multiply, _usi, unsigned short int, * )
define_reduction_function_anyop(_multiply, _uli, unsigned long int, * )
define_reduction_function_anyop(_multiply, _uLi, unsigned long long int, * )
define_reduction_function_anyop(_multiply, ___c, char, * )
define_reduction_function_anyop(_multiply, ___d, double, * )
define_reduction_function_anyop(_multiply, ___f, float, * )
define_reduction_function_anyop(_multiply, _u_c, unsigned char, * )

//bitwise AND (&)
define_reduction_function_anyop(_bitand, ___i, int, & )
define_reduction_function_anyop(_bitand, __si, short int, & )
define_reduction_function_anyop(_bitand, __li, long int, & )
define_reduction_function_anyop(_bitand, __Li, long long int, & )
define_reduction_function_anyop(_bitand, _u_i, unsigned int, & )
define_reduction_function_anyop(_bitand, _usi, unsigned short int, & )
define_reduction_function_anyop(_bitand, _uli, unsigned long int, & )
define_reduction_function_anyop(_bitand, _uLi, unsigned long long int, & )
define_reduction_function_anyop(_bitand, ___c, char, & )
define_reduction_function_anyop(_bitand, _u_c, unsigned char, & )

//bitwise OR (|) 
define_reduction_function_anyop(_bitor, ___i, int, | )
define_reduction_function_anyop(_bitor, __si, short int, | )
define_reduction_function_anyop(_bitor, __li, long int, | )
define_reduction_function_anyop(_bitor, __Li, long long int, | )
define_reduction_function_anyop(_bitor, _u_i, unsigned int, | )
define_reduction_function_anyop(_bitor, _usi, unsigned short int, | )
define_reduction_function_anyop(_bitor, _uli, unsigned long int, | )
define_reduction_function_anyop(_bitor, _uLi, unsigned long long int, | )
define_reduction_function_anyop(_bitor, ___c, char, | )
define_reduction_function_anyop(_bitor, _u_c, unsigned char, | )

//bitwise XOR (^)
define_reduction_function_anyop(_bitxor, ___i, int, ^ )
define_reduction_function_anyop(_bitxor, __si, short int, ^ )
define_reduction_function_anyop(_bitxor, __li, long int, ^ )
define_reduction_function_anyop(_bitxor, __Li, long long int, ^ )
define_reduction_function_anyop(_bitxor, _u_i, unsigned int, ^ )
define_reduction_function_anyop(_bitxor, _usi, unsigned short int, ^ )
define_reduction_function_anyop(_bitxor, _uli, unsigned long int, ^ )
define_reduction_function_anyop(_bitxor, _uLi, unsigned long long int, ^ )
define_reduction_function_anyop(_bitxor, ___c, char, ^ )
define_reduction_function_anyop(_bitxor, _u_c, unsigned char, ^ )

//logical AND (&&)
define_reduction_function_anyop(_and, ___i, int, && )
define_reduction_function_anyop(_and, __si, short int, && )
define_reduction_function_anyop(_and, __li, long int, && )
define_reduction_function_anyop(_and, __Li, long long int, && )
define_reduction_function_anyop(_and, _u_i, unsigned int, && )
define_reduction_function_anyop(_and, _usi, unsigned short int, && )
define_reduction_function_anyop(_and, _uli, unsigned long int, && )
define_reduction_function_anyop(_and, _uLi, unsigned long long int, && )
define_reduction_function_anyop(_and, ___c, char, && )
define_reduction_function_anyop(_and, ___d, double, && )
define_reduction_function_anyop(_and, ___f, float, && )
define_reduction_function_anyop(_and, _u_c, unsigned char, && )

//logical OR (||)
define_reduction_function_anyop(_or, ___i, int, || )
define_reduction_function_anyop(_or, __si, short int, || )
define_reduction_function_anyop(_or, __li, long int, || )
define_reduction_function_anyop(_or, __Li, long long int, || )
define_reduction_function_anyop(_or, _u_i, unsigned int, || )
define_reduction_function_anyop(_or, _usi, unsigned short int, || )
define_reduction_function_anyop(_or, _uli, unsigned long int, || )
define_reduction_function_anyop(_or, _uLi, unsigned long long int, || )
define_reduction_function_anyop(_or, ___c, char, || )
define_reduction_function_anyop(_or, ___d, double, || )
define_reduction_function_anyop(_or, ___f, float, || )
define_reduction_function_anyop(_or, _u_c, unsigned char, || )

// max 
define_reduction_function_minmax(_max, ___i, int, > )
define_reduction_function_minmax(_max, __si, short int, > )
define_reduction_function_minmax(_max, __li, long int, > )
define_reduction_function_minmax(_max, __Li, long long int, > )
define_reduction_function_minmax(_max, _u_i, unsigned int, > )
define_reduction_function_minmax(_max, _usi, unsigned short int, > )
define_reduction_function_minmax(_max, _uli, unsigned long int, > )
define_reduction_function_minmax(_max, _uLi, unsigned long long int, > )
define_reduction_function_minmax(_max, ___c, char, > )
define_reduction_function_minmax(_max, ___d, double, > )
define_reduction_function_minmax(_max, ___f, float, > )
define_reduction_function_minmax(_max, _u_c, unsigned char, > )

// min 
define_reduction_function_minmax(_min, ___i, int, < )
define_reduction_function_minmax(_min, __si, short int, < )
define_reduction_function_minmax(_min, __li, long int, < )
define_reduction_function_minmax(_min, __Li, long long int, < )
define_reduction_function_minmax(_min, _u_i, unsigned int, < )
define_reduction_function_minmax(_min, _usi, unsigned short int, < )
define_reduction_function_minmax(_min, _uli, unsigned long int, < )
define_reduction_function_minmax(_min, _uLi, unsigned long long int, < )
define_reduction_function_minmax(_min, ___c, char, < )
define_reduction_function_minmax(_min, ___d, double, < )
define_reduction_function_minmax(_min, ___f, float, < )
define_reduction_function_minmax(_min, _u_c, unsigned char, < )

//JUMP TABLES

typedef void (*redfunc_t)(void *, void *, int);

#define REDFUNC_FULL_JUMP_TABLE(op) \
__device__ static redfunc_t op ## _jump[] = { \
	(redfunc_t) reduce_##op##___i, \
	(redfunc_t) reduce_##op##__si, \
	(redfunc_t) reduce_##op##__li, \
	(redfunc_t) reduce_##op##__Li, \
	(redfunc_t) reduce_##op##_u_i, \
	(redfunc_t) reduce_##op##_usi, \
	(redfunc_t) reduce_##op##_uli, \
	(redfunc_t) reduce_##op##_uLi, \
	(redfunc_t) reduce_##op##___c, \
	(redfunc_t) reduce_##op##___d, \
	(redfunc_t) reduce_##op##___f, \
	(redfunc_t) reduce_##op##_u_c  \
};

REDFUNC_FULL_JUMP_TABLE(add);
REDFUNC_FULL_JUMP_TABLE(subtract);
REDFUNC_FULL_JUMP_TABLE(multiply);
REDFUNC_FULL_JUMP_TABLE(and);
REDFUNC_FULL_JUMP_TABLE(or);
REDFUNC_FULL_JUMP_TABLE(max);
REDFUNC_FULL_JUMP_TABLE(min);

#define REDFUNC_INT_JUMP_TABLE(op)\
__device__ static redfunc_t op ## _jump[] = { \
	(redfunc_t) reduce_##op##___i, \
	(redfunc_t) reduce_##op##__si, \
	(redfunc_t) reduce_##op##__li, \
	(redfunc_t) reduce_##op##__Li, \
	(redfunc_t) reduce_##op##_u_i, \
	(redfunc_t) reduce_##op##_usi, \
	(redfunc_t) reduce_##op##_uli, \
	(redfunc_t) reduce_##op##_uLi, \
	(redfunc_t) reduce_##op##___c, \
	(redfunc_t) NULL, \
	(redfunc_t) NULL, \
	(redfunc_t) NULL, \
	(redfunc_t) reduce_##op##_u_c \
}; 

REDFUNC_INT_JUMP_TABLE(bitand);
REDFUNC_INT_JUMP_TABLE(bitor);
REDFUNC_INT_JUMP_TABLE(bitxor);

// THE INTERFACE

#define EXPORTED_REDUCTION_FUNCTION(op) \
__DEVQLFR void _ort_reduce_##op(int type, void *local, void *global, int nelems)\
{\
	(* op##_jump[type])(local, global, nelems);\
}

EXPORTED_REDUCTION_FUNCTION(add)
EXPORTED_REDUCTION_FUNCTION(subtract)
EXPORTED_REDUCTION_FUNCTION(multiply)
EXPORTED_REDUCTION_FUNCTION(and)
EXPORTED_REDUCTION_FUNCTION(or)
EXPORTED_REDUCTION_FUNCTION(max)
EXPORTED_REDUCTION_FUNCTION(min)
EXPORTED_REDUCTION_FUNCTION(bitand)
EXPORTED_REDUCTION_FUNCTION(bitor)
EXPORTED_REDUCTION_FUNCTION(bitxor)

