/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* x_assorted_cu.c -- transform assorted CUDA constructs */

#include "x_assorted_cu.h"
#include "ast_xform.h"
#include "ast_free.h"
#include "ast_copy.h"
#include "ast_assorted.h"
#include "ast.h"
#include "ompi.h"
#include "builder.h"
#include "ast_assorted.h"
#include "ast_arith.h"
#include "cuda.h"
#include <string.h>

void xform_atomic_cuda(aststmt *t)
{
	aststmt s, parent, v, decl, inittree;
	astexpr ex;
	stentry e;
	bool    stlist;   
	void ( *sharedadjust)(aststmt) = codetarg_get_adjuster(CODETARGID(cuda), ADJ_SHARED_STRUCT);
	char   *lockname = "_cuda_atomic_lock";

	/* First transform the body */
	ast_stmt_xform(&((*t)->u.omp->body));

	s = (*t)->u.omp->body;
	parent = (*t)->parent;
	ex = s->u.expr;

	if ((s->type != EXPRESSION) ||
		(ex->type != POSTOP && ex->type != PREOP && ex->type != ASS))
		exit_error(1, "(%s, line %d) openmp error:\n\t"
				   "non-compliant ATOMIC expression.\n",
				   (*t)->u.omp->directive->file->name, (*t)->u.omp->directive->l);

	v = ompdir_commented((*t)->u.omp->directive); /* Put directive in comments */
	stlist = ((*t)->parent->type == STATEMENTLIST ||
	          (*t)->parent->type == COMPOUND);

	(*t)->u.omp->body = NULL;     /* Make it NULL so as to free t easily */
	ast_free(*t);                 /* Get rid of the OmpStmt */

	decl = Declaration(
	         Declspec(SPEC_int),
	         Declarator(NULL, IdentifierDecl(Symbol(lockname)))
	       );
	
	if (sharedadjust)
		sharedadjust(decl); /* Add __shared__ qualifier */

	inittree = BlockList(If(
	             BinaryOperator(BOP_eqeq, 
	               FunctionCall(IdentName("omp_get_thread_num"), NULL), 
	               numConstant(0)
	             ),
	             Expression(
	               Assignment(
	                 IdentName(lockname), ASS_eq, numConstant(0)
	               )
	             ),
	             NULL
	           ), FuncCallStmt("__syncthreads", NULL));

	codetargs_kernel_add_global(decl, inittree, CODETARGID(cuda));
	ast_stmt_free(decl);         /* They keep a copy */

	/* Remove __shared__ qualifier from every declaration (except for the global) */
	e = symtab_get(stab, Symbol(lockname), IDNAME);
	e->spec = Declspec(SPEC_int);

	if (ex->type == ASS &&
		(ex->right->type != IDENT && !xar_expr_is_constant(ex->right)))
	{
		aststmt tmp;

		tmp = Declaration(
				(ex->left->type != IDENT ?
				 Declspec(SPEC_long) :
				 ast_spec_copy_nosc(
				   symtab_get(stab, ex->left->u.sym, IDNAME)->spec)
				),
				InitDecl(
				  Declarator(NULL, IdentifierDecl(Symbol("__tmp"))),
				  ex->right
				)
			  );
		ex->right = IdentName("__tmp");
		if (xformingTarget)
			*t = Block4(
					v, tmp,
					Block5(
						Call0_stmt("_ort_atomic_begin"),
						Expression(
							FunctionCall(IdentName("_cuda_dev_set_lock"), 
							UnaryOperator(UOP_addr, IdentName(lockname)))
						),
						s,
						Expression(
							FunctionCall(IdentName("_cuda_dev_unset_lock"), 
							UnaryOperator(UOP_addr, IdentName(lockname)))
						),
						Call0_stmt("_ort_atomic_end")
					),
					linepragma(s->l + 1 - (!stlist), s->file)
				);
	}
	else
	{
		if (xformingTarget)
			*t = Block3(
					v,
					Block5(
						Call0_stmt("_ort_atomic_begin"),
						Expression(
							FunctionCall(IdentName("_cuda_dev_set_lock"), 
							UnaryOperator(UOP_addr, IdentName(lockname)))
						),
						s,
						Expression(
							FunctionCall(IdentName("_cuda_dev_unset_lock"), 
							UnaryOperator(UOP_addr, IdentName(lockname)))
						),
						Call0_stmt("_ort_atomic_end")
					),
					linepragma(s->l + 1 - (!stlist), s->file)
				);
		if (!stlist)
			*t = Compound(*t);
	}
	(*t)->parent = parent;
}


void xform_critical_cuda(aststmt *t)
{
	aststmt s, parent, v, decl, inittree;
	stentry e;
	char    shlock[128]; //lock[128], 
	bool    stlist;
	void (*sharedadjust)(aststmt) = codetarg_get_adjuster(CODETARGID(cuda), ADJ_SHARED_STRUCT);

	/* First transform the body */
	ast_stmt_xform(&((*t)->u.omp->body));
	
	s = (*t)->u.omp->body;
	parent = (*t)->parent;
	v = ompdir_commented((*t)->u.omp->directive); /* Put directive in comments */
	stlist = ((*t)->parent->type == STATEMENTLIST ||
	          (*t)->parent->type == COMPOUND);

	/* A lock named after the region name */
	if ((*t)->u.omp->directive->u.region)
		snprintf(shlock,127,"_cuda_ompi_crity_%s", (*t)->u.omp->directive->u.region->name);
	else
		strcpy(shlock, "_cuda_ompi_crity");

	/* Add declaration to globals (includes check for duplicates) */
	decl = Declaration(
	         Declspec(SPEC_int),
	         Declarator(NULL, IdentifierDecl(Symbol(shlock)))
	       );
	
	if (sharedadjust)
		sharedadjust(decl); /* Add __shared__ qualifier */
	
	inittree = Block2(If(
	             BinaryOperator(BOP_eqeq, 
	               FunctionCall(IdentName("omp_get_thread_num"), NULL), 
	               numConstant(0)
	             ),
	             Expression(
	               Assignment(
	                 IdentName(shlock), ASS_eq, numConstant(0)
	               )
	             ),
	             NULL
	           ), FuncCallStmt("__syncthreads", NULL));

	codetargs_kernel_add_global(decl, inittree, CODETARGID(cuda));
	ast_stmt_free(decl);         /* They keep a copy */


	/* Remove __shared__ qualifier from every declaration (except for the global) */
	e = symtab_get(stab, Symbol(shlock), IDNAME);
	e->spec = Declspec(SPEC_int);

	(*t)->u.omp->body = NULL;    /* Make it NULL so as to free it easily */
	ast_free(*t);                /* Get rid of the OmpStmt */

	/* Produced code:
	 *   _cuda_dev_set_lock(&<intlock>);
	 *   <body> // critical section
	 *   _cuda_dev_unset_lock(&<intlock>);
	 */		 
	*t = Block3(
	       v,
	       Block3(
	         Expression(
	           FunctionCall(IdentName("_cuda_dev_set_lock"), 
	           UnaryOperator(UOP_addr, IdentName(shlock)))
	         ),
	         s,
	         Expression(
	           FunctionCall(IdentName("_cuda_dev_unset_lock"), 
	           UnaryOperator(UOP_addr, IdentName(shlock)))
	         )
	       ),
	       linepragma(s->l + 1 - (!stlist), s->file)
	     );

	if (!stlist)
		*t = Compound(*t);
	(*t)->parent = parent;
}
