/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

#include <cuda.h>
#include <stdlib.h>
#include <stdio.h>
#include <fcntl.h>
#include <inttypes.h>
#include <unistd.h>
#include <sys/stat.h>
#include "cudautils.h"
#include "ptx.h"

static
char *readdevpart(char *devpart)
{
	int fd;
	struct stat sb;
	char* file_contents;

	fd = open(devpart, O_RDONLY);
    if (fd == -1) {
        perror("open\n");
        exit(EXIT_FAILURE);
    }

    if (stat(devpart, &sb) == -1) {
        perror("stat");
        exit(EXIT_FAILURE);
    }

    file_contents = (char *) malloc(sb.st_size);
    read(fd, file_contents, sb.st_size);
    close(fd);

	return file_contents;
}


void locate_libdevpart(char *buf, char **libdevpart)
{
	char *tmp, *nl, *cpy = strdup(buf);
	
	if ((tmp = strstr(cpy, "// ompi-link-with:")) != NULL)
	{
		if ((nl = strchr(tmp, '\n')) != NULL)
			*nl = 0;
		snprintf(*libdevpart, 255, "%s/devices/cuda/%s", LibDir, tmp+18);
	}
	else
		snprintf(*libdevpart, 255, "%s/devices/cuda/libdevpart.a", LibDir);

	free(cpy);
}


char *ptxtostr(char *ptx_filename, char **libdevpart) 
{
	char *buf,  file_buf[1024];
	int file_size;
	FILE *fp = fopen(ptx_filename, "rb");

	if (fp == NULL)
		return NULL;
		
	if (!fgets(file_buf, 1023, fp))
	{
		fprintf(stderr,
		       "[%s] ptxtostr: error: fgets failed; exiting.\n",
		       modulename);
		exit(1);
	}

	locate_libdevpart(file_buf, libdevpart);
	
	/* First find the # bytes required (= file size) */
	/* and allocate them */
	fseek(fp, 0, SEEK_END);
	file_size = ftell(fp);
	buf = malloc(file_size+1);

	/* Reset the file pointer */
	fseek(fp, 0, SEEK_SET);

	/* Read the whole file and store it to the buffer */
	if (fread(buf, sizeof(char), file_size, fp) <= 0)
	{
		fprintf(stderr,
		       "[%s] ptxtostr: error: fread failed; exiting.\n",
		       modulename);
		exit(1);
	}

	fclose(fp);

	/* Null-terminate the string */
	buf[file_size] = '\0';

	return buf;
}

#define JIT_LOGSIZE 8192

#if defined(DEBUG)
	#define NOPTS 5
#endif

void ptx_compile_and_load(char *ptx, CUmodule *module, cuda_kernel_t *kernel)
{
#if !defined(LibDir)
	#error "LibDir is not defined"
#endif
	void *cubin_out;
	size_t cubin_out_size;
	
#if defined(DEBUG)
	char error_log[JIT_LOGSIZE], info_log[JIT_LOGSIZE];
	float walltime;
	CUjit_option options[NOPTS] = 
	{
		CU_JIT_WALL_TIME, CU_JIT_INFO_LOG_BUFFER,
		CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER,
		CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES
	};
	void *optionVals[NOPTS] = {
		(void *) &walltime, (void *) info_log, (void *) (long) JIT_LOGSIZE,
		(void *) error_log, (void *) (long) JIT_LOGSIZE
	};
#endif

	/* Create a pending linker invocation */
#if defined(DEBUG)
	cuda_do(cuLinkCreate(NOPTS, options, optionVals, &(kernel->linkstate)));
#else
	cuda_do(cuLinkCreate(0, 0, 0, &(kernel->linkstate)));
#endif

	/* Add PTX file as string */
	cuda_do(cuLinkAddData(kernel->linkstate, CU_JIT_INPUT_PTX, (void *) ptx,
		                  strlen(ptx) + 1, 0, 0, 0, 0));

	/* Link with devpart library */
	cuda_do(cuLinkAddFile(kernel->linkstate, CU_JIT_INPUT_LIBRARY, 
		                 (void *) kernel->libdevpart, 0, 0, 0));

	cuda_do(cuLinkComplete(kernel->linkstate, &cubin_out, &cubin_out_size));
	cuda_do(cuModuleLoadData(module, cubin_out));
	cuda_do(cuModuleGetFunction(&(kernel->function), *module, kernel->name));
	cuda_do(cuLinkDestroy(kernel->linkstate));
}
