/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* tac.c - Three-address code and related stuff */

#include <stdlib.h>
#include <stdarg.h>
#include <stack.h>
#include "ast_traverse.h"
#include "ompi.h"
#include "tac.h"


SET_TYPE_IMPLEMENT(accvar)


set(accvar) acclist2vars(acclist_t *l)
{
	int i;
	set(accvar) s;
	setelem(accvar) e;
	
	if (l->nelems <= 0) return NULL;
	s = set_new(accvar);
	
	for (i = 0; i < l->nelems; i++)
	{
		e = set_get(s, l->list[i].var);
		if (!e)
			set_put(s, l->list[i].var)->value = l->list[i].way;
		else
		{
			switch (l->list[i].way)
			{
				case ACC_REFER:
					e->value |= xREF;
					break;
				case ACC_DEREF:
					e->value |= xDRF;
					break;
				case ACC_NAMED:
					e->value |= xNAM;
					break;
				case ACC_CANT:   /* should not happen */
					fprintf(stderr, "[?? cannot determine access ??]\n");
					break;
				case ACC_READ:
					if (ACCRW(e->value) == xNA)
						e->value += xR;
					else
						if (ACCRW(e->value) == xW)
							e->value += (xWR - xW);
					break;
				case ACC_WRITE:
					if (ACCRW(e->value) == xNA)
						e->value += xW;
					else
						if (ACCRW(e->value) == xR)
							e->value += (xRW - xR);
					break;
				default:
					break;
			}
		}
	}
	return s;
}


void accvar_show(set(accvar) s)
{
	setelem(accvar) e;
	
	for (e = s->first; e; e = e->next)
		fprintf(stderr, "%s (%s%s%s%s%s)\n", 
		        e->key->name,
		        (ACCRW(e->value) == xNA ? "--" :
		         ACCRW(e->value) == xR  ? "Rd" :
		         ACCRW(e->value) == xW  ? "Wr" :
		         ACCRW(e->value) == xRW ? "RW" : "WR"),
		        (ACCRW(e->value) != e->value ? " + " : ""),
		        (ACCREF(e->value) ? "&" : ""),
		        (ACCDEREF(e->value) ? "*" : ""),
						(ACCNAMED(e->value) ? "N" : "")
		       );
}


void acclist_show(acclist_t *l)
{
	int i;
	
	for (i = 0; i < l->nelems; i++)
		fprintf(stderr, "%5.5s: %s\n", 
		  (l->list[i].way == ACC_READ  ? "READ" :
		   l->list[i].way == ACC_WRITE ? "WRITE" :
		   l->list[i].way == ACC_REFER ? "REFED" :
		   l->list[i].way == ACC_DEREF ? "DEREF" :
		   l->list[i].way == ACC_NAMED  ? "NAME" :
		   l->list[i].way == ACC_CANT  ? "CANT" : "NONE"),
		  l->list[i].var->name);
		  
	fprintf(stderr, "\nSummary:\n");
	accvar_show(acclist2vars(l));
}


/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 *                                                                   *
 * ACCESS LIST FROM AST EXPRESSION                                   *
 *                                                                   *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */


typedef enum { 
	RVAL=1, LVAL=2, ADDR=4, STAR=8, INDX=16, SZOF=32,
} accstate_e;

/* To walk the expression AST we need a stack */
STACK_TYPE_DEFINE(accmode, accstate_e)
STACK_TYPE_IMPLEMENT(accmode)

typedef struct {
		acclist_t  *accl;      /* The access list */
		travopts_t *trops;     /* The options of the traversal */
		stack(accmode) stack;  /* The access mode stack */
	} trstate_t;



static void acclist_add(accway_e way, symbol sym, acclist_t *l)
{
	int n = l->nelems;
	
	if ((n & 7) == 0)    /* multiple of 8 */
	{
		acclistelem_t *a = (n == 0) ?
		                      malloc(8 * sizeof(acclistelem_t)) :
		                      realloc(l->list, (n + 8) * sizeof(acclistelem_t));
		if (a == NULL)
			exit_error(1, "%s() cannot allocate memory.\n", __func__);
		l->list = a;
	}
	l->list[n].var = sym;
	l->list[n].way = way;
	(l->nelems)++;
}


static void ident_c(astexpr e, void *state, int vistime)
{
	accstate_e mode;
	trstate_t *st = (trstate_t *) state;
	acclist_t *l = st->accl;
	
	if (vistime != PREVISIT) return;
	mode = stack_peep( st->stack )->data;
	
	if (mode & SZOF)
	{
		acclist_add(ACC_NAMED, e->u.sym, l);
		return;
	}
	
	if (mode & RVAL)
	{
		if ((mode & ADDR) && (mode & STAR))  /* cancel each other */
			acclist_add(ACC_READ, e->u.sym, l);
		else
			if (mode & ADDR)
				acclist_add(ACC_REFER, e->u.sym, l);
			else
				if (mode & STAR)
					acclist_add(ACC_DEREF, e->u.sym, l);
				else
					acclist_add(ACC_READ, e->u.sym, l);
		return;
	}
	
	if (mode & LVAL)
	{
		if ((mode & ADDR) && (mode & STAR))  /* cancel each other */
			acclist_add(ACC_WRITE, e->u.sym, l);
		else
		  if (mode & ADDR)
		    acclist_add(ACC_REFER, e->u.sym, l);
		  else
				if (mode & STAR)
				{
					if (mode & INDX)    /* *(x[...]) */
						acclist_add(ACC_READ, e->u.sym, l);
					else
					{
						if (1 /* TODO: is not array */)
							acclist_add(ACC_DEREF, e->u.sym, l);
						else
							acclist_add(ACC_WRITE, e->u.sym, l);
					}
				}
				else
				{
					if ((mode & INDX) && (1 /* TODO: is ptr */))
						acclist_add(ACC_READ, e->u.sym, l);
					else
						acclist_add(ACC_WRITE, e->u.sym, l);
				}
		return;
	}
	
	if (mode & ADDR)
		acclist_add(ACC_REFER, e->u.sym, l);
}


/* right-to-left assumed */
static void ass_c(astexpr e, void *state, int vistime)
{
	trstate_t *st = (trstate_t *) state;
	
	if (vistime == PREVISIT)
	{
		if (e->opid != ASS_eq)        /* force an RVAL visit of lhs */
		{
			stack_push(st->stack, RVAL);
			ast_expr_traverse(e->left, st->trops);
			stack_pop(st->stack);
		}
		stack_push(st->stack, RVAL);  /* start RVAL (for rhs) */
	}
	if (vistime == MIDVISIT)
	{
		stack_pop(st->stack);         /* stop RVAL (for rhs) */
		stack_push(st->stack, LVAL);  /* start LVAL */
	}
	if (vistime == POSTVISIT)
		stack_pop(st->stack);         /* end LVAL */
}


/* right-to-left assumed */
static void arrayidx_c(astexpr e, void *state, int vistime)
{
	trstate_t *st = (trstate_t *) state;
	accstate_e mode;
	
	if (vistime == PREVISIT)
		stack_push(st->stack, RVAL);
	if (vistime == MIDVISIT)
	{
		stack_pop(st->stack);
		mode = stack_peep(st->stack)->data;
		stack_push(st->stack, mode | INDX);
	}
	if (vistime == POSTVISIT)
		stack_pop(st->stack);
}


static void prepostop_c(astexpr e, void *state, int vistime)
{
	trstate_t *st = (trstate_t *) state;
	
	if (vistime == PREVISIT)
	{
		stack_push(st->stack, RVAL);  /* Force an extra visit in RVAL mode */
		ast_expr_traverse(e->left, ((trstate_t *) state)->trops);
		stack_pop(st->stack);
		stack_push(st->stack, LVAL);  /* Assign (LVAL visit) */
	}
	if (vistime == POSTVISIT)
		stack_pop(st->stack);         /* end LVAL */
}


/* The only callback when recursively traversing a sizeof operator */
static void identnaming_c(astexpr e, void *state, int vistime)
{
	if (vistime == PREVISIT)
		acclist_add(ACC_NAMED, e->u.sym, ((trstate_t *) state)->accl);
}


static void uop_c(astexpr e, void *state, int vistime)
{
	trstate_t *st = (trstate_t *) state;
	accstate_e mode;
	
	switch (e->opid)
	{
		case UOP_star:
			if (vistime == PREVISIT)
			{
				mode = stack_peep(st->stack)->data;
				stack_push(st->stack, mode | STAR);
			}
			if (vistime == POSTVISIT)
				stack_pop(st->stack);
			break;
			
		case UOP_addr:    /* we only read the address (not even the value)... */
			if (vistime == PREVISIT)
			{
				mode = stack_peep(st->stack)->data;
				stack_push(st->stack, mode | ADDR);
			}
			if (vistime == POSTVISIT)
				stack_pop(st->stack);
			break;
			
		case UOP_sizeof:
		{
			travopts_t *opts = st->trops;
			
			/* Change traversal options temporarily so that all identifiers
			 * are visited in naming mode 
			 */
			if (vistime == PREVISIT)
			{
				opts->exprc.ident_c = identnaming_c;
				stack_push(st->stack, SZOF);
			}
			if (vistime == POSTVISIT)
			{
				stack_pop(st->stack);
				opts->exprc.ident_c = ident_c;
			}
			break;
		}
		
		case UOP_paren:
			break;
			
		default:
			if (vistime == PREVISIT)
				stack_push(st->stack, RVAL);
			if (vistime == POSTVISIT)
				stack_pop(st->stack);
			break;
	}
}


static void restexpr_c(astexpr e, void *state, int vistime)
{
	trstate_t *st = (trstate_t *) state;
	
	if (vistime == PREVISIT)
		stack_push(st->stack, RVAL);
	if (vistime == POSTVISIT)
		stack_pop(st->stack);
}


/**
 * @brief Fills a variable access list out of an expression
 * @param e the expression
 * @param l the access list to fill
 */
void expr2acclist(astexpr e, acclist_t *l)
{
	travopts_t accopts;
	trstate_t st = { l, &accopts };
	
	if (!e) return;
	
	st.stack = stack_new(accmode);
	
	travopts_init_batch(&accopts, NULL, restexpr_c, NULL, NULL, NULL, NULL, 
	                              NULL, NULL, NULL, NULL);
	accopts.when = PREPOSTVISIT | MIDVISIT;
	accopts.lrorder = RLORDER;  /* Visit right-hand side first */
	accopts.starg = &st;

	accopts.exprc.ident_c = ident_c;
	accopts.exprc.funccall_c = NULL;    /* TODO */
	accopts.exprc.arrayidx_c = arrayidx_c;
	accopts.exprc.uop_c = uop_c;
	accopts.exprc.preop_c = prepostop_c;
	accopts.exprc.postop_c = prepostop_c;
	accopts.exprc.ass_c = ass_c;
	accopts.exprc.constval_c =
	accopts.exprc.string_c =
	accopts.exprc.dotfield_c = NULL;    /* Should do nothing */

	ast_expr_traverse(e, &accopts);
	
	if (!stack_isempty(st.stack)) /* sanity */
		warning("[%s() bug] stack not empty!!!\n", __func__);
	stack_free(st.stack);
}


#if 0

/**
 * Returns the base variable of an lvalue expression, if one exists.
 * @param e the lvalue expression
 * @return the symbol, or NULL if none (e.g. comes from a complex expression)
 */
symbol lvalue_var(astexpr e)
{
	if (e->type == IDENT)
		return e->u.sym;
	if (e->type == DOTFIELD)
		return ( lvalue_var(e->left, isptr) );
	if (e->type == PTRFIELD)
		return NULL;     /* Could search for the base pointer ... */
	if (e->type==UOP && e->opid==UOP_star)
		return NULL;     /* Could search for the base pointer ... */
	if (e->type==BOP && e->opid==ARRAYIDX)
		return (e->left->type == IDENT) ? e->left->u.sym : NULL;
	if (e->type==UOP && e->opid==UOP_paren)
		return ( lvalue_var(e->left, isptr) );
	exit_error(1, "[%s bug!?]: not an lvalue\n", __func__);
}


acclist_assign_to_lvalue(astexpr e, acclist_t *l)
{
	symbol tmp, var = lvalue_var(e);
	
	switch (e->type)
	{
		case IDENT:
			acclist_add(ACC_WRITE, var, l);
			break;
		case DOTFIELD:
			expr2acclist(e->left, l);
			if (var) /* base variable; approximate by an assignment to it */
				acclist_add(ACC_WRITE, var, l);
			break;
		case UOP:
			if (e->opid == UOP_paren)
			{
				acclist_assign_to_lvalue(e->left, l);
				break;
			}
			if (e->opid != UOP_star)
				goto ERROR;
		case PTRFIELD:
			expr2acclist(e->left, l);
			/* nothing to do */
			break;
		case ARRAYIDX:
		{
			/* This is wrong because the var will be found to be READ also */
			expr2acclist(e->right, l);
			if (var) /* base variable; approximate by an assignment to it */
				acclist_add(ACC_WRITE, var, l);
			else
				expr2acclist(e->left, l);  /* nothing else to do... */
			break;
		}
		default:
		ERROR:
			exit_error(1, "[bug!?]: not an lvalue\n");
	}
}

symbol ast2tac(astexpr tree)
{
	if (!trop->doexpr)
		return;
	
	switch (tree->type)
	{
		case IDENT:
			tmp1 = tac_tmp_new();
			tac_emit(TAC_SET, tmp1, tree->u.sym, NULL, TAC_NONE);
			return tmp1;
		case CONSTVAL:
			visit_expr(trop, constval_c, tree);
			break;
		case STRING:
			visit_expr(trop, string_c, tree);
			break;
		case FUNCCALL:
			pre_visit_expr(trop, funccall_c, tree);
			ast_expr_traverse(tree->left, trop);
			if (tree->right)
				ast_expr_traverse(tree->right, trop);
			post_visit_expr(trop, funccall_c, tree);
			break;
		case ARRAYIDX:
			pre_visit_expr(trop, arrayidx_c, tree);
			ast_expr_traverse(tree->left, trop);
			ast_expr_traverse(tree->right, trop);
			post_visit_expr(trop, arrayidx_c, tree);
			break;
		case DOTFIELD:
			pre_visit_expr(trop, dotfield_c, tree);
			ast_expr_traverse(tree->left, trop);
			post_visit_expr(trop, dotfield_c, tree);
			break;
		case PTRFIELD:
			pre_visit_expr(trop, ptrfield_c, tree);
			ast_expr_traverse(tree->left, trop);
			post_visit_expr(trop, ptrfield_c, tree);
			break;
		case BRACEDINIT:
			pre_visit_expr(trop, bracedinit_c, tree);
			ast_expr_traverse(tree->left, trop);
			post_visit_expr(trop, bracedinit_c, tree);
			break;
		case CASTEXPR:
			pre_visit_expr(trop, castexpr_c, tree);
			ast_decl_traverse(tree->u.dtype, trop);
			ast_expr_traverse(tree->left, trop);
			post_visit_expr(trop, castexpr_c, tree);
			break;
		case CONDEXPR:
			pre_visit_expr(trop, condexpr_c, tree);
			ast_expr_traverse(tree->u.cond, trop);
			ast_expr_traverse(tree->left, trop);
			ast_expr_traverse(tree->right, trop);
			post_visit_expr(trop, condexpr_c, tree);
			break;
		case UOP:
			pre_visit_expr(trop, uop_c, tree);
			if (tree->opid == UOP_sizeoftype || tree->opid == UOP_typetrick)
				ast_decl_traverse(tree->u.dtype, trop);
			else
				ast_expr_traverse(tree->left, trop);
			post_visit_expr(trop, uop_c, tree);
			break;
		case BOP:
			tmp1 = tac_tmp_new();
			/* Well here we should differentiate the logical operators so as to
			 * evaluate the 2nd operand only if the first one is not conclusive
			 */
			tac_emit(TAC_SET,tmp1,ast2tac(tree->left),ast2tac(tree->right),TAC_BOP);
			return tmp1;
		case PREOP:
		case POSTOP:
			tmp1 = ast2tac(tree->left);
			tmp2 = tac_tmp_new();
			if (tree->opid == UOP_inc)
				tac_emit(TAC_SET, tmp2, tmp1, tac_num(1), TAC_BOP);
			else
				tac_emit(TAC_SET, tmp2, tmp1, tac_num(1), TAC_BOP);
			tac_assign_to_lvalue(tree->left, tmp2);
			return (tree->type == PREOP ? tmp2 : tmp1);
		case ASS:
			tmp1 = ast2tac(tree->right);
			tac_assign_to_lvalue(tree->left, tmp1);
			return tmp1;
			
		case DESIGNATED:
			pre_visit_expr(trop, designated_c, tree);
			ast_expr_traverse(tree->left, trop);
			ast_expr_traverse(tree->right, trop);
			post_visit_expr(trop, designated_c, tree);
			break;
		case IDXDES:
			pre_visit_expr(trop, idxdes_c, tree);
			ast_expr_traverse(tree->left, trop);
			post_visit_expr(trop, idxdes_c, tree);
			break;
		case DOTDES:
			visit_expr(trop, dotdes_c, tree);
			break;
		case COMMALIST:
		case SPACELIST:
			pre_visit_expr(trop, list_c, tree);
			ast_expr_traverse(tree->left, trop);
			ast_expr_traverse(tree->right, trop);
			post_visit_expr(trop, list_c, tree);
			break;
		default:
			fprintf(stderr, "[ast_expr_traverse]: b u g !!\n");
	}
}


/* Returns true if e is a possible l-value expression */
bool expr_is_lvalue(astexpr e)
{
	if (e->type == IDENT || e->type == DOTFIELD || e->type == PTRFIELD ||
	    (e->type==UOP && e->opid==UOP_star) || 
	    (e->type==BOP && e->opid==ARRAYIDX))
		return true;   /*FIXME: we should check if the last 4 yield arrays */ 
	return (e->type==UOP && e->opid==UOP_paren) ? expr_is_lvalue(e->left) : false;
}

#endif
