/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* OpenCL reductions */

/* There are some configuration options here:
 *
 * a) USE_WORKGROUP_COLLECTIVES: 
 *    Try to use workgroup collectives, if supported; enabled by default.
 * b) USE_32BIT_ATOMICS: 
 *    Use 32bit atomics for int/uint (only applicable when workgroup 
 *    collectives are not available); maybe someday we also use 64bit atomics,
 *    if availabled. This is disabled by default because:
 *      - AMD R9 280 max/min reductions caused problems (?) 
 *      - no measurable improvement in performance was found 
 * c) USE_LOCAL_MEMORY:
 *    Use local memory for partial results, instead of global memory (only 
 *    applicable when workgroup collectives are not available). It is not 
 *    observably faster than global memory; it however has substantially less
 *    memory requirements. This is enabled by default.
 * d) USE_ONE_FUNC_PER_DATATYPE:  -- no longer available --
 *    We have one reduction function per data type (=> 12 funcs).
 *    We used to have one reduction function per data type/operator pair,
 *    (=> 120 funcs); this should be a bit faster but gives substantially larger 
 *    binary (2-3x) -- but in any case, for a single workgroup of up to 512 
 *    work items, there is no measurable difference in performance...
 *    Devpart sizes in 3 different GPUs for the 1st and the 2nd case:
 *      AMD Radeon R9 280: 137K | 347K
 *      NVIDIA GT730:       54K | 143K
 *      Intel Iris Xe:     132K | 295K 
 */
 
//#define USE_WORKGROUP_COLLECTIVES
// #define USE_32BIT_ATOMICS
#define USE_LOCAL_MEMORY


/* Handle doubles */
#ifdef OCLC_HAS_DOUBLE
	#pragma OPENCL EXTENSION cl_khr_fp64 : enable
#else
	#define double float
#endif

#if OCLC_VERSION >= 2
	/* For unknown reasons this is not always defined, although it should 
	 * according to the specs, since it is about built-in functions.
	 */
  #ifndef __opencl_c_work_group_collective_functions
    #define __opencl_c_work_group_collective_functions 1
  #endif
#endif

#define __NWORKGROUPS get_num_groups(0)
#define __WORKGROUPID get_group_id(0)
#define __NWORKITEMS  get_local_size(0)
#define __WORKITEMID  get_local_id(0)


/* This is a pointer in global memory that points to local memory; its value 
 * is set in the kernel wrapper.
 */
extern __local void * __global _localmem;


static void _unsupported_reduction_type()
{
	/* unsupported; should printf() something if printf worked... */
}


/* Locking for reducing among workgroups. 
 * TODO: For specific operations and data types, this code should be
 *       replaced with atomics.
 * The following atomics work for int/uint on (global or local mem):
 *   atomic_add(&mem,val), atomic_sub, atomic_min, atomic_max,
 *   atomic_and, atomic_or, atomic_xor (btiwise)
 * 
 * If the cl_khr_int64_base_atomics extension is supported, they also
 * operate on long/ulong.
 */
/* Except v3, the __opencl_c_atomic_scope_device feature test macro 
 * should be also checked...
 */
#if OCLC_VERSION >= 3
	__global atomic_int _teamredlock;
	#define DOCRITICAL(code)\
		for (volatile int _stop = 0; !_stop; )\
			if (atomic_exchange_explicit(&_teamredlock, 1, memory_order_acquire,\
			                             memory_scope_device) == 0) {\
				/* Ensure any previous non-atomic reads/writes are ordered properly.\
				 * This is optional after atomic_exchange_explicit with \
				 * memory_order_acquire but include for clarity. \ 
				 */ \
				atomic_work_item_fence(CLK_GLOBAL_MEM_FENCE, memory_order_acquire, memory_scope_device);\
				code\
				atomic_work_item_fence(CLK_GLOBAL_MEM_FENCE, memory_order_release, memory_scope_device);\
	      /* Release the lock */\
	      atomic_store_explicit(&_teamredlock, 0, memory_order_release, memory_scope_device);\
				_stop = 1;\
			};
#else
	__global int _teamredlock;
	#define DOCRITICAL(code)\
		for (volatile int _stop = 0; !_stop; )\
			if (atomic_xchg(&_teamredlock,1) == 0) {\
				mem_fence(CLK_GLOBAL_MEM_FENCE); /* ensure previous writes finish */\
				code\
				mem_fence(CLK_GLOBAL_MEM_FENCE); /* ensure all my writes are visible */\
				atomic_xchg(&_teamredlock, 0);\
				_stop = 1;\
			};
#endif


#if defined(USE_WORKGROUP_COLLECTIVES) &&\
    defined(__opencl_c_work_group_collective_functions) &&\
    defined(cl_khr_work_group_uniform_arithmetic)

/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                                         *
 * CASE 1: Use workgroup collectives for all operators.                    *
 *                                                                         *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
 
#define define_reduction_function(SUFFIX, TYPE) \
void reduce##SUFFIX(int OPid, TYPE *priv_d, __global TYPE *global_d,int nelems)\
{\
	int j = 0, tid = __WORKITEMID, numteams = __NWORKGROUPS;\
	TYPE res;\
	\
	for (j=0; j < nelems; j++) /* Reduce elements one by one */ {\
		res = (OPid == 1) ? work_group_reduce_add(priv_d[j]) :\
		      (OPid == 2) ? work_group_reduce_add(priv_d[j]) :\
		      (OPid == 3) ? work_group_reduce_mul(priv_d[j]) :\
		      (OPid == 4) ? work_group_reduce_logical_and(priv_d[j]) :\
		      (OPid == 5) ? work_group_reduce_logical_or(priv_d[j]) :\
		      (OPid == 6) ? work_group_reduce_and(priv_d[j]) :\
		      (OPid == 7) ? work_group_reduce_or(priv_d[j]) :\
		      (OPid == 8) ? work_group_reduce_xor(priv_d[j]) :\
		      (OPid == 9) ? work_group_reduce_max(priv_d[j]) :\
		                    work_group_reduce_min(priv_d[j]);\
		if (tid == 0) {\
			if (numteams == 1) {\
				global_update(OPid)\
			} else {\
				DOCRITICAL( global_update(OPid) );\
			}\
		}\
	}\
}

#define global_update(OPid)\
	if (OPid < 9)\
		global_d[j] = (OPid == 1) ? global_d[j] +  res :\
		              (OPid == 2) ? global_d[j] +  res :\
		              (OPid == 3) ? global_d[j] *  res :\
		              (OPid == 4) ? global_d[j] && res :\
		              (OPid == 5) ? global_d[j] || res :\
		              (OPid == 6) ? global_d[j] &  res :\
		              (OPid == 7) ? global_d[j] |  res :\
		                            global_d[j] ^  res;\
	else\
		if ((OPid ==  9 && res > global_d[j]) ||\
		    (OPid == 10 && res < global_d[j]))\
			global_d[j] = res;

define_reduction_function(___i, int)
define_reduction_function(_u_i, unsigned int)
define_reduction_function(__si, short int)
define_reduction_function(__li, long int)
define_reduction_function(_usi, unsigned short int)
define_reduction_function(_uli, unsigned long int)
define_reduction_function(___c, char)
define_reduction_function(_u_c, unsigned char)

#undef global_update
#define global_update(OPid)\
	if (OPid < 9)\
		global_d[j] = (OPid == 1) ? global_d[j] +  res :\
		              (OPid == 2) ? global_d[j] +  res :\
		              (OPid == 3) ? global_d[j] *  res :\
		              (OPid == 4) ? global_d[j] && res :\
		                            global_d[j] || res ;\
	else\
		if ((OPid ==  9 && res > global_d[j]) ||\
		    (OPid == 10 && res < global_d[j]))\
			global_d[j] = res;

define_reduction_function(___f, float)
#ifdef OCLC_HAS_DOUBLE
  define_reduction_function(___d, double)
#else
  void reduce___d(int OPid, void *priv_d, __global void *global_d, int nelems) {
		_unsupported_reduction_type();
	} 
#endif


#else /* defined(USE_WORKGROUP_COLLECTIVES) && 
         defined(__opencl_c_work_group_collective_functions) && 
         defined(cl_khr_work_group_uniform_arithmetic) */

#if defined(USE_WORKGROUP_COLLECTIVES) &&\
    defined(__opencl_c_work_group_collective_functions)

/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                                         *
 * CASE 2: Use workgroup collectives for add, min and max.                 *
 *                                                                         *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
 
#define define_reduction_function_wgc(SUFFIX, TYPE) \
void reduce##SUFFIX##_wgc(int OPid, TYPE *priv_d, __global TYPE *global_d,int nelems)\
{\
	int j = 0, tid = __WORKITEMID, numteams = __NWORKGROUPS;\
	TYPE res;\
	\
	for (j=0; j < nelems; j++) /* Reduce elements one by one */ {\
		res = (OPid == 1) ? work_group_reduce_add(priv_d[j]) :\
		      (OPid == 2) ? work_group_reduce_add(priv_d[j]) :\
		      (OPid == 9) ? work_group_reduce_max(priv_d[j]) :\
		                    work_group_reduce_min(priv_d[j]);\
		if (tid == 0) {\
			if (numteams == 1) {\
				global_update(OPid)\
			} else {\
				DOCRITICAL( global_update(OPid) );\
			}\
		}\
	}\
}

#define global_update(OPid)\
	if (OPid == 1 || OPid == 2)\
		global_d[j] = global_d[j] + res;\
	else\
		if ((OPid == 9 && res > global_d[j]) || (OPid == 10 && res < global_d[j]))\
			global_d[j] = res;

define_reduction_function_wgc(___i, int)
define_reduction_function_wgc(_u_i, unsigned int)
define_reduction_function_wgc(__si, short int)
define_reduction_function_wgc(__li, long int)
define_reduction_function_wgc(_usi, unsigned short int)
define_reduction_function_wgc(_uli, unsigned long int)
define_reduction_function_wgc(___c, char)
define_reduction_function_wgc(_u_c, unsigned char)
define_reduction_function_wgc(___f, float)
#ifdef OCLC_HAS_DOUBLE
  define_reduction_function_wgc(___d, double)
#else
  void reduce___d_wgc(int OPid, void *priv_d, __global void *global_d, int nelems) {
		_unsupported_reduction_type();
	} 
#endif

#undef global_update

#endif /* defined() && defined(__opencl_c_work_group_collective_functions) */

/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                                         *
 * CASE 3: Our own implementation with no workgroup collectives            *
 *                                                                         *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
 
/* A union of all supported data types (size = 8 bytes).
 * [unsigned] long long int __Li is not supported by OpenCL.
 */
typedef union {
		int ___i;
		short int __si;
		long int __li;
		unsigned int _u_i;
		unsigned short int _usi;
		unsigned long int _uli;
		char ___c;
#ifdef OCLC_HAS_DOUBLE
		double ___d;
#endif
		float ___f;
		unsigned char _u_c;
	} datatype_u;


#ifndef USE_LOCAL_MEMORY

	/* IMPORTANT: The size of _localRes *must* be >= workgroup size.
	 * So keep in mind that sooner or later a kernel will come along 
	 * that will break due to the hardcoded values here.
	 * Also, because _localRes is not in local memories yet, we need to 
	 * have a _localRes array for each workgroup => the first dimension 
	 * of _localRes is the workgroup id.
	 */
	#define REDUCTION_MAX_THRPERTEAM 512
	#define REDUCTION_MAX_TEAMS 65536

	/* Size: 512*65536*8 = 256 MB
	 * If the device does not cleanup properly after a kernel, there might be 
	 * a buildup which will quickly consume VRAM
	 */
	__global datatype_u _localRes[REDUCTION_MAX_TEAMS][REDUCTION_MAX_THRPERTEAM];
	
	/* In P40, in a reduction test we achieve ~40GB/sec (0.024 sec/reduction)
	 * using the global memory table above.
	 * The version using local memories achieves ~48GB/sec (0.021 sec/reduction)
	 */

 #endif /* USE_LOCAL_MEMORY */
 
 
int get_prevPow2(int nthr)
{	
	int k = 1;
	while(nthr >>=1)
		k = k << 1;
	return k;
}

#define sync_local()  barrier(CLK_LOCAL_MEM_FENCE)

/* This is identical to the CUDA one */
#define reduction_function_body(SUFFIX, TYPE, OP)\
	for (j=0; j < nelems; j++) /* Reduce elements one by one */\
	{\
		nthr = teamsize;\
		teamRes[tid].SUFFIX = priv_d[j];\
		sync_local();\
		while (!(nthr==1 || nthr==0))\
		{\
			prev_pow2 = get_prevPow2(nthr);\
			for (i = prev_pow2 / 2; i > 0; i >>= 1)\
			{\
				if (tid < i) {\
					local_update(teamRes[tid], teamRes[tid+i], OP, SUFFIX)\
				} /* if */\
				sync_local(); \
			} /* for */\
			\
			if (tid == prev_pow2 && nthr != prev_pow2) {\
				local_update(teamRes[0], teamRes[tid], OP, SUFFIX)\
			}\
			sync_local();\
			if (tid > prev_pow2 && nthr != prev_pow2)\
				teamRes[tid-prev_pow2].SUFFIX = teamRes[tid].SUFFIX;\
			sync_local();\
			nthr = nthr - prev_pow2;\
		} /* while */\
		\
		if (tid == 0)\
		{\
			if (numteams == 1) {\
				global_update(OP, SUFFIX)\
			} else {\
				teams_reduce(OP, SUFFIX)\
			}\
		}\
	} /* for */


#ifdef USE_LOCAL_MEMORY

	#define define_reduction_function(SUFFIX, TYPE) \
	void reduce##SUFFIX(int OPid,TYPE *priv_d, __global TYPE *global_d,int nelems)\
	{\
		int i = 0, j = 0, nthr;\
		int tid = __WORKITEMID, teamid = __WORKGROUPID, teamsize = __NWORKITEMS,\
		    numteams = __NWORKGROUPS;\
		volatile int prev_pow2; /* otherwise AMD R9 280 optimizes wrongly */\
		__local datatype_u *teamRes = (__local datatype_u *) _localmem;\
		\
		reduction_function_body(SUFFIX, TYPE, OPid)\
	}

#else 

	#define define_reduction_function(SUFFIX, TYPE) \
	void reduce##SUFFIX(int OPid,TYPE *priv_d, __global TYPE *global_d,int nelems)\
	{\
		int i = 0, j = 0, nthr;\
		int tid = __WORKITEMID, teamid = __WORKGROUPID, teamsize = __NWORKITEMS,\
		    numteams = __NWORKGROUPS;\
		volatile int prev_pow2; /* otherwise AMD R9 280 optimizes wrongly */\
		datatype_u *teamRes = _localRes[teamid];\
		\
		reduction_function_body(SUFFIX, TYPE, OPid)\
	}

#endif /* USE_LOCAL_MEMORY */

/* update macros for integer types 
 */
#define local_update(to, from, OPid, SUF)\
	if (OPid < 9)\
		to.SUF = (OPid == 1) ? to.SUF +  from.SUF :\
		         (OPid == 2) ? to.SUF +  from.SUF :\
		         (OPid == 3) ? to.SUF *  from.SUF :\
		         (OPid == 4) ? to.SUF && from.SUF :\
		         (OPid == 5) ? to.SUF || from.SUF :\
		         (OPid == 6) ? to.SUF &  from.SUF :\
		         (OPid == 7) ? to.SUF |  from.SUF :\
		                       to.SUF ^  from.SUF;\
	else\
		if ((OPid ==  9 && from.SUF > to.SUF) ||\
		    (OPid == 10 && from.SUF < to.SUF))\
			to.SUF = from.SUF;

#define global_update(OPid, SUF)\
	if (OPid < 9)\
		global_d[j] = (OPid == 1) ? global_d[j] +  teamRes[0].SUF :\
		              (OPid == 2) ? global_d[j] +  teamRes[0].SUF :\
		              (OPid == 3) ? global_d[j] *  teamRes[0].SUF :\
		              (OPid == 4) ? global_d[j] && teamRes[0].SUF :\
		              (OPid == 5) ? global_d[j] || teamRes[0].SUF :\
		              (OPid == 6) ? global_d[j] &  teamRes[0].SUF :\
		              (OPid == 7) ? global_d[j] |  teamRes[0].SUF :\
		                            global_d[j] ^  teamRes[0].SUF;\
	else\
		if ((OPid ==  9 && teamRes[0].SUF > global_d[j]) ||\
		    (OPid == 10 && teamRes[0].SUF < global_d[j]))\
			global_d[j] = teamRes[0].SUF;


/* Special functions for int/unit, employing atomic operations 
 */
#ifdef USE_32BIT_ATOMICS

	/* For types with atomics (int/unit and maybe long/ulong) */
	#define teams_global_update_(OPid, SUF)\
		global_d[j] = (OPid == 3) ? global_d[j] *  teamRes[0].SUF :\
		              (OPid == 4) ? global_d[j] && teamRes[0].SUF :\
		                            global_d[j] || teamRes[0].SUF;
	
	#define teams_reduce(OPid, SUF)\
		switch (OPid) {\
			case 1: case 2:\
				atomic_add(&global_d[j], teamRes[0].SUF); break;\
			case 6:\
				atomic_and(&global_d[j], teamRes[0].SUF); break;\
			case 7:\
				atomic_or(&global_d[j], teamRes[0].SUF); break;\
			case 8:\
				atomic_xor(&global_d[j], teamRes[0].SUF); break;\
			case 9:\
				atomic_max(&global_d[j], teamRes[0].SUF); break;\
			case 10:\
				atomic_min(&global_d[j], teamRes[0].SUF); break;\
			default:\
				DOCRITICAL( teams_global_update_(OPid, SUF) )\
		}
	
	define_reduction_function(___i, int)
	define_reduction_function(_u_i, unsigned int)
	
	#undef teams_reduce

#endif /* USE_32BIT_ATOMICS */

#define teams_reduce(OPid, SUF) DOCRITICAL( global_update(OPid, SUF) )

#ifndef USE_32BIT_ATOMICS
	define_reduction_function(___i, int)
	define_reduction_function(_u_i, unsigned int)
#endif
define_reduction_function(__si, short int)
define_reduction_function(__li, long int)
define_reduction_function(_usi, unsigned short int)
define_reduction_function(_uli, unsigned long int)
define_reduction_function(___c, char)
define_reduction_function(_u_c, unsigned char)

#undef local_update
#undef global_update

/* update macros for real types 
 */
#define local_update(to, from, OPid, SUF)\
	if (OPid < 9)\
		to.SUF = (OPid == 1) ? to.SUF +  from.SUF :\
		         (OPid == 2) ? to.SUF +  from.SUF :\
		         (OPid == 3) ? to.SUF *  from.SUF :\
		         (OPid == 4) ? to.SUF && from.SUF :\
		                       to.SUF || from.SUF;\
	else\
		if ((OPid ==  9 && from.SUF > to.SUF) ||\
		    (OPid == 10 && from.SUF < to.SUF))\
			to.SUF = from.SUF;

#define global_update(OPid, SUF)\
	if (OPid < 9)\
		global_d[j] = (OPid == 1) ? global_d[j] +  teamRes[0].SUF :\
		              (OPid == 2) ? global_d[j] +  teamRes[0].SUF :\
		              (OPid == 3) ? global_d[j] *  teamRes[0].SUF :\
		              (OPid == 4) ? global_d[j] && teamRes[0].SUF :\
		                            global_d[j] || teamRes[0].SUF;\
	else\
		if ((OPid ==  9 && teamRes[0].SUF > global_d[j]) ||\
		    (OPid == 10 && teamRes[0].SUF < global_d[j]))\
			global_d[j] = teamRes[0].SUF;

define_reduction_function(___f, float)
#ifdef OCLC_HAS_DOUBLE
  define_reduction_function(___d, double)
#else
  void reduce___d(int OPid, void *priv_d, __global void *global_d, int nelems) {
		_unsupported_reduction_type();
	} 
#endif

#endif /* defined(USE_WORKGROUP_COLLECTIVES) && 
          defined(__opencl_c_work_group_collective_functions) && 
          defined(cl_khr_work_group_uniform_arithmetic) */


#define EXPORTED_REDUCTION_FUNCTION(op, opid, SUF) \
void _ort_reduce_##op(int type, void *priv_d, __global void *global_d, int nelems)\
{\
	switch (type) {\
		case 0:\
			reduce___i##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 1:\
			reduce__si##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 2:\
			reduce__li##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 3:\
			/* printf("[opencl reduction]: long long int unsupported\n"); */\
			break;\
		case 4:\
			reduce_u_i##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 5:\
			reduce_usi##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 6:\
			reduce_uli##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 7:\
			/* printf("[opencl reduction]: unsigned long long int unsupported\n"); */\
			break;\
		case 8:\
			reduce___c##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 9:\
			reduce___d##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 10:\
			reduce___f##SUF(opid, priv_d, global_d, nelems);\
			break;\
		case 11:\
			/* printf("[opencl reduction]: long double unsupported\n"); */\
			break;\
		case 12:\
			reduce_u_c##SUF(opid, priv_d, global_d, nelems);\
			break;\
	}\
}

#if defined(USE_WORKGROUP_COLLECTIVES) && \
    defined(__opencl_c_work_group_collective_functions) &&\
   !defined(cl_khr_work_group_uniform_arithmetic)
	EXPORTED_REDUCTION_FUNCTION(add, 1, _wgc)
	EXPORTED_REDUCTION_FUNCTION(subtract, 2, _wgc)
	EXPORTED_REDUCTION_FUNCTION(max, 9, _wgc)
	EXPORTED_REDUCTION_FUNCTION(min, 10, _wgc)
#else
	EXPORTED_REDUCTION_FUNCTION(add, 1, )
	EXPORTED_REDUCTION_FUNCTION(subtract, 2, )
	EXPORTED_REDUCTION_FUNCTION(max, 9, )
	EXPORTED_REDUCTION_FUNCTION(min, 10, )
#endif /* ... && ... && ... */ 

EXPORTED_REDUCTION_FUNCTION(multiply, 3, )
EXPORTED_REDUCTION_FUNCTION(and, 4, )
EXPORTED_REDUCTION_FUNCTION(or, 5, )
EXPORTED_REDUCTION_FUNCTION(bitand, 6, )
EXPORTED_REDUCTION_FUNCTION(bitor, 7, )
EXPORTED_REDUCTION_FUNCTION(bitxor, 8, )
