/*
  OMPi OpenMP Compiler
  == Copyright since 2001 the OMPi Team
  == Dept. of Computer Science & Engineering, University of Ioannina

  This file is part of OMPi.

  OMPi is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  OMPi is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with OMPi; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/

/* This is the host-side part of the module; it should be compiled
 * to a shared library. It is dynamically linked to the host runtime at runtime.
 */


// #define DBGPRN_FORCE
// #define DBGPRN_BLOCK
#define DBGPRN_FILTER DBG_DEVICES

#include "rt_common.h"
#include "assorted.h"
#include "context.h"
#include "vkgpu.h"
#include <vulkan/vulkan.h>
#include <pthread.h>
#include <assert.h>

#define VK_MAX_NUM_GPUS 128


void *vklock;

int available_gpus;         /* Number of capable Vulkan GPUs */
vk_gpu_t *vk_gpus;       /* Global GPU bookkeeping array */
VkInstance instance;
bool _hm_init_called = false;

static uint32_t total_num_devices; /* Total number of Vulkan GPUs */
static int physical_dev_indexes[VK_MAX_NUM_GPUS];
static vk_gpu_physical_dev_t *alldevs;
static VkPhysicalDevice *allpdevs;


/* 
 * Dumps a shader file to a string and returns its size 
 * (used only in no kernel bundling is active)
 */
static char *read_shader(char *filename, size_t *size) 
{
	FILE* file = fopen(filename, "rb");
	char *buffer;

	if (!file) return NULL;

	fseek(file, 0, SEEK_END);
	*size = ftell(file);
	fseek(file, 0, SEEK_SET);

	buffer = (char*) smalloc(*size);
	if (fread(buffer, 1, *size, file) <= 0)
	{
		fprintf(stderr, "[vulkan] read_shader: error: fread failed; exiting.\n");
		fclose(file);
		exit(1);
	}

	fclose(file);

	return buffer;
}


/* 
 * Creates a Vulkan instance 
 */
void vkgpu_create_instance(char *app_name) 
{
	VkApplicationInfo    app_info = {};
	VkInstanceCreateInfo create_info = {};

	app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
	app_info.pApplicationName = app_name;
	app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
	app_info.pEngineName = "No Engine";
	app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0);
	app_info.apiVersion = VK_API_VERSION_1_0;

	create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
	create_info.pApplicationInfo = &app_info;

	if (vkCreateInstance(&create_info, NULL, &instance) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to create instance!\n");
		exit(EXIT_FAILURE);
	}
}


/* 
 * Returns the number of available Vulkan GPUs
 */
int vkgpu_get_num_gpus(vk_gpu_physical_dev_t **devices, VkPhysicalDevice **pdevices, 
                       uint32_t *total_dev_count)
{
	VkResult result;
	int ngpus = 0, i;
	uint32_t device_count = 0;
	vk_gpu_physical_dev_t *devs;
	VkPhysicalDevice *physical_devs;
	char *devname;

	vkgpu_create_instance("vkhostpart");
	DBGPRN((stderr, "[vulkan] vkgpu: %s create instance OK\n", __FUNCTION__));

	result = vkEnumeratePhysicalDevices(instance, &device_count, NULL);
	if (result != VK_SUCCESS || device_count == 0) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to find GPUs with Vulkan support\n");
		return 0;
	}

	devs = (vk_gpu_physical_dev_t *) smalloc(device_count * sizeof(vk_gpu_physical_dev_t));
	physical_devs = (VkPhysicalDevice *) smalloc(device_count * sizeof(VkPhysicalDevice));

	vkEnumeratePhysicalDevices(instance, &device_count, physical_devs);

	/* Count only suitable devices (those supporting geometry shaders) */
	for (i = 0; i < device_count; i++)
	{
		if (ngpus == VK_MAX_NUM_GPUS)
			break;

		devs[i].physical_device = physical_devs[i];
		vkGetPhysicalDeviceProperties(devs[i].physical_device, &(devs[i].properties));
		vkGetPhysicalDeviceFeatures(devs[i].physical_device, &(devs[i].features));
		vkGetPhysicalDeviceMemoryProperties(devs[i].physical_device, &(devs[i].memory_properties));
		devname = devs[i].properties.deviceName;

		if (strstr(devname, "llvmpipe") || strstr(devname, "lavapipe"))
			continue;

		if (devs[i].features.geometryShader)
			physical_dev_indexes[ngpus++] = i;
	}

	if (devices != NULL)
		*devices = devs;

	if (pdevices != NULL)
		*pdevices = physical_devs;
		
	if (total_dev_count != NULL)
		*total_dev_count = device_count;

	DBGPRN((stderr, "[vulkan] vkgpu: %s device count %d\n", __FUNCTION__, ngpus));
	return ngpus;
}


static void _prepare_gpu(vk_gpu_t *gpu)
{
	int i;
	gpu->nshaders = 0;
	gpu->status = DEVICE_UNINITIALIZED;
#ifndef VK_ENABLE_CONCURRENCY
	gpu->context = NULL;
#endif
	for (i = 0; i < VK_CACHE_SIZE; i++)
	{
		gpu->shader_cache[i].name = NULL;
		gpu->shader_cache[i].nargs = 0;
		gpu->shader_cache[i].lock_init = false;
	}

#ifdef VK_ENABLE_CONCURRENCY
	pthread_key_create(&(gpu->context_key), NULL);
#endif
	vkgpu_ctx_test_init(gpu);
	
	/* Workgroup sizes (x, y, z) */
	for (i = 0; i < 3; i++)
	{
		gpu->dimensions[i].constantID = i;
		gpu->dimensions[i].offset = i * sizeof(uint32_t);
		gpu->dimensions[i].size = sizeof(uint32_t);
	}
}


/**
 * @brief Enumerate and initialize all Vulkan GPUs
 * 
 */
void vkgpus_init_all(bool prepare_gpus)
{
	int i;

	if (prepare_gpus)
		init_lock(&vklock, ORT_LOCK_SPIN);

	if (vk_gpus == NULL)
	{
		if (available_gpus == 0)
			available_gpus = vkgpu_get_num_gpus(&alldevs, &allpdevs, &total_num_devices);
		vk_gpus = (vk_gpu_t *) smalloc(available_gpus * sizeof(vk_gpu_t));
	}

	for (i = 0; i < available_gpus; i++)
	{
		vk_gpus[i].pdev = alldevs[physical_dev_indexes[i]];
		if (prepare_gpus)
			_prepare_gpu(&(vk_gpus[i]));
	}
}

/**
 * @brief Finalize all Vulkan GPUs
 * 
 */
void vkgpus_finalize(void)
{
	if (_hm_init_called)
		vkDestroyInstance(instance, NULL);

	if (vk_gpus)
	{
		free(vk_gpus);
		vk_gpus = NULL;
	}

	if (alldevs)
	{
		free(alldevs);
		alldevs = NULL;
	}

	if (allpdevs)
	{
		free(allpdevs);
		allpdevs = NULL;
	}
		
	total_num_devices = 0;
	available_gpus = 0;
}


static uint32_t _find_memory_type(vk_gpu_t *gpu, uint32_t typeFilter, VkMemoryPropertyFlags properties, 
                                  VkMemoryPropertyFlags fallback) 
{
	uint32_t i;
	for (i = 0; i < gpu->pdev.memory_properties.memoryTypeCount; i++)
		if ((typeFilter & (1 << i)) && 
		   (gpu->pdev.memory_properties.memoryTypes[i].propertyFlags & properties) == properties)
			return i;

	for (i = 0; i < gpu->pdev.memory_properties.memoryTypeCount; i++)
		if ((typeFilter & (1 << i)) && 
		    (gpu->pdev.memory_properties.memoryTypes[i].propertyFlags & fallback) == fallback)
			return i;

	fprintf(stderr, "[vulkan] vkgpu: %s: error: Failed to find suitable memory type!\n",
	                 __FUNCTION__);
	exit(EXIT_FAILURE);
}

/* Sets the desired memory type according to the mapping type 
 */
static
void _set_mem_property_flags(int map_type, VkMemoryPropertyFlags *property_flags, 
                             VkMemoryPropertyFlags *fallback_property_flags)
{
	switch (map_type)
	{
		case MAP_TYPE_TO:
			*property_flags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT 
			                | VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT;
			break;
		case MAP_TYPE_FROM:
			*property_flags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT 
			                | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT 
			                | VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
			break;
		case MAP_TYPE_TOFROM:
			*property_flags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT 
			                | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT 
			                | VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
			break;
		case MAP_TYPE_ALLOC:
			*property_flags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT
			                | VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT;
			break;
		default:
			*property_flags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT 
			                | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT 
			                | VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
			break;
	}

	*fallback_property_flags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT 
	                         | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
}

/**
 * @brief Allocates memory on a specific Vulkan GPU
 * 
 * @param gpu      the Vulkan GPU
 * @param size     the allocated memory size
 * @param map_type the mapping type that triggered this allocation (to/from/tofrom/alloc)
 * @return     vk_devptr_t containing the VkBuffer and VkDeviceMemory items
 */
void *vkgpu_alloc(vk_gpu_t *gpu, VkDeviceSize size, int map_type) 
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	vk_devptr_t          *res;
	VkBuffer              buffer;
	VkDeviceMemory        devmem;
	VkBufferCreateInfo    buffer_info = {};
	VkMemoryRequirements  mem_requirements;
	VkMemoryAllocateInfo  alloc_info = {};
	VkMemoryPropertyFlags property_flags, fallback_property_flags;

	/* (1) Set parameters for the new buffer */
	buffer_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
	buffer_info.size = size;
	buffer_info.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; /* Could be a parameter */
	buffer_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;

	/* (2) Create the buffer */
	if (vkCreateBuffer(ctx->device, &buffer_info, NULL, &buffer) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to create buffer!\n");
		exit(EXIT_FAILURE);
	}

	vkGetBufferMemoryRequirements(ctx->device, buffer, &mem_requirements);

	/* (3) Set parameters for the memory allocation */
	_set_mem_property_flags(map_type, &property_flags, &fallback_property_flags);
	alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
	alloc_info.allocationSize = mem_requirements.size;
	alloc_info.memoryTypeIndex = _find_memory_type(gpu, mem_requirements.memoryTypeBits, 
	                                               property_flags, fallback_property_flags);

	/* (4) Allocate memory */
	if (vkAllocateMemory(ctx->device, &alloc_info, NULL, &devmem) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to allocate buffer memory!\n");
		exit(EXIT_FAILURE);
	}

	/* (5) Allocate memory on host and bind it to the created buffer */
	res = (vk_devptr_t *) smalloc(sizeof(vk_devptr_t));
	vkBindBufferMemory(ctx->device, buffer, devmem, 0);

	/* (6) Store the pair of VkBuffer and VkDeviceMemory */
	res->buffer = buffer;
	res->mem = devmem; 

	DBGPRN((stderr, "[vulkan] vkgpu: %s res:%p buffer:%p mem:%p\n", __FUNCTION__, res, res->buffer, res->mem));

	return res;
}

/**
 * @brief Deallocates memory on a Vulkan GPU
 * 
 * @param gpu  the Vulkan GPU
 * @param addr the memory address
 */
void vkgpu_free(vk_gpu_t *gpu, void *addr)
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	vkDestroyBuffer(ctx->device, ((vk_devptr_t *) addr)->buffer, NULL);
	vkFreeMemory(ctx->device, ((vk_devptr_t *) addr)->mem, NULL);	
}

static unsigned long offsets[256];

/* Iterates the args array and copies the arguments to a new array */
static void handle_shader_args(vk_gpu_t *gpu, vk_shader_t *shader, void *devdata, 
                               size_t devdata_size, int *num_args, void **args)
{
	void *new_shader_args[256];  /* Array used for storing kernel args temporarily */
	void *ptr_var;
	unsigned long offset_var = 0UL;
	int ndeclargs, nfirstprivate, i;
	int off = 0;
	int offset_idx = 0;

	for (i = 0; i < 256; i++)
		offsets[i] = 0;

	/* (1) Handle target declare arguments */
	ndeclargs = num_args[ARGS_NUMDECL];
	for (i = 0; i < ndeclargs; i++, shader->nargs++)
	{
		ptr_var = (void*) args[off++];
		DBGPRN((stderr, "[vulkan] vkgpu: %s DECL %p %d\n", __FUNCTION__, ptr_var, i));
		new_shader_args[i] = (void*) ptr_var;
	}
	
	if (devdata != NULL)
	{
		/* (2) Handle firstprivate arguments */
		nfirstprivate = num_args[ARGS_NUMFIP];
		
		for (i = ndeclargs; i < ndeclargs + nfirstprivate; i++, shader->nargs++)
		{
			ptr_var = (void *) args[off++];
			DBGPRN((stderr, "[vulkan] vkgpu: %s FIP %p\n", __FUNCTION__, ptr_var));
			new_shader_args[i] = (void*) ptr_var;
		}

		/* (3) Handle mapped arguments */
		ptr_var = (void *) args[off++];
		DBGPRN((stderr, "[vulkan] vkgpu: %s MAP %p %d\n", __FUNCTION__, ptr_var, shader->nargs));
		offset_idx = shader->nargs;
		for (i = 0; i < num_args[ARGS_NUMMAPPED]; i++, off++)
		{
			if ((i % 2) == 1)
			{
				offsets[offset_idx++] = offset_var;
				ptr_var = (void*) args[off];
				DBGPRN((stderr, "[vulkan] vkgpu: %s MAP %p %s\n", __FUNCTION__, 
				                ptr_var, (ptr_var == NULL) ? "end" : ""));
			}
			else
			{
				new_shader_args[shader->nargs++] = (void*) ptr_var;
				offset_var = (unsigned long) args[off];
			}
		}
	}

	/* (4) Copy new arguments */
	shader->args = (void**) smalloc(shader->nargs*sizeof(void*));
	for (i = 0; i < shader->nargs; i++)
	{
		shader->args[i] = new_shader_args[i];
		new_shader_args[i] = NULL;
	}

	DBGPRN((stderr, "[vulkan] vkgpu: %s shader arguments copied, %d in total\n", __FUNCTION__, shader->nargs));
}


static void create_descriptor_set_layout(vk_gpu_t *gpu, vk_shader_t *shader, int size) 
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	int i;
	VkDescriptorSetLayoutBinding   *layout_bindings;
	VkDescriptorSetLayoutCreateInfo layout_info = {};

	/* (1) Create all bindings for the new DSL (size = #args) */
	layout_bindings = (VkDescriptorSetLayoutBinding *) smalloc(size * sizeof(VkDescriptorSetLayoutBinding));

	for (i = 0; i < size; i++)
	{
		layout_bindings[i].binding = i;
		layout_bindings[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
		layout_bindings[i].descriptorCount = 1;
		layout_bindings[i].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
	}

	/* (2) Set parameters for the new DSL */
	layout_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
	layout_info.bindingCount = size;
	layout_info.pBindings = layout_bindings;

	/* (3) Create the DSL */
	if (vkCreateDescriptorSetLayout(ctx->device, &layout_info, NULL, &(shader->descriptor_set_layout)) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to create descriptor set layout!\n");
		exit(EXIT_FAILURE);
	}

	DBGPRN((stderr, "[vulkan] vkgpu: %s OK\n", __FUNCTION__));
}


static void create_compute_pipeline(vk_gpu_t *gpu, vk_shader_t *shader,
                                    int lx, int ly, int lz) 
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	uint32_t local_sizes[3] = {lx, ly, lz};
	VkSpecializationInfo            specialization_info = {};
	VkShaderModuleCreateInfo        create_info = {};
	VkPipelineShaderStageCreateInfo shader_stage_info = {};
	VkPipelineLayoutCreateInfo      pipeline_layout_info = {};
	VkComputePipelineCreateInfo     pipeline_info = {};
	VkPushConstantRange             push_constant_range = {};

	if (shader->code == NULL)
	{
		switch (ort_bundling_type())
		{
			case BUNDLE_BINS:
			{
				bubin_t *entry = ort_bubins_search(shader->filename);
				shader->code = (char *) entry->data;
				shader->code_size = entry->size;
				break;
			}
			case BUNDLE_SRCS:
				ort_bubins_unbundle_and_compile(shader->sources_filename);
				/* fall through */
			default:   /* no bundling */
				shader->code = read_shader(shader->filename, &(shader->code_size));
				if (shader->code == NULL)
				{
					fprintf(stderr, "[vulkan] vkgpu: error: Failed to open shader: %s\n", 
					                shader->filename);
					exit(EXIT_FAILURE);
				}
		}
	}
	DBGPRN((stderr, "[vulkan] vkgpu: %s shader read %s\n", __FUNCTION__, shader->filename));

	/* (1) Set the specialization info (pertains to the local sizes) */
	specialization_info.mapEntryCount = 3; /* 3 dimensions */
	specialization_info.pMapEntries = gpu->dimensions;
	specialization_info.dataSize = sizeof(local_sizes);
	specialization_info.pData = local_sizes;

	/* (2) Set the parameters for the new shader module */
	create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
	create_info.codeSize = shader->code_size;
	create_info.pCode = (const uint32_t*) shader->code;

	/* (3) Create the shader module */
	if (vkCreateShaderModule(ctx->device, &create_info, NULL, &(shader->compute_shader_module)) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to create shader module!\n");
		exit(EXIT_FAILURE);
	}

	/* (4) Set shader stage info */
	shader_stage_info.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
	shader_stage_info.stage = VK_SHADER_STAGE_COMPUTE_BIT;
	shader_stage_info.module = shader->compute_shader_module;
	shader_stage_info.pName = shader->name;
	shader_stage_info.pSpecializationInfo = &specialization_info;

	/* (5) Set push constant range and pipeline layout info */
	push_constant_range.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
	push_constant_range.offset = 0;
	push_constant_range.size = sizeof(int);
	pipeline_layout_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
	pipeline_layout_info.setLayoutCount = 1;
	pipeline_layout_info.pSetLayouts = &(shader->descriptor_set_layout);
	pipeline_layout_info.pushConstantRangeCount = 1;
	pipeline_layout_info.pPushConstantRanges = &push_constant_range;

	/* (6) Create the pipeline layout */
	if (vkCreatePipelineLayout(ctx->device, &pipeline_layout_info, NULL, 
	                           &(shader->pipeline_layout)) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to create pipeline layout!\n");
		exit(EXIT_FAILURE);
	}

	/* (6) Create the compute pipeline */
	pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
	pipeline_info.stage = shader_stage_info;
	pipeline_info.layout = shader->pipeline_layout;

	if (vkCreateComputePipelines(ctx->device, VK_NULL_HANDLE, 1, &pipeline_info, NULL,
	                             &(shader->compute_pipeline)) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to create compute pipeline!\n");
		exit(EXIT_FAILURE);
	}

	DBGPRN((stderr, "[vulkan] vkgpu: %s OK\n", __FUNCTION__));
}


static void create_descriptor_pool(vk_gpu_t *gpu, vk_shader_t *shader) 
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	int i;
	VkDescriptorPoolSize      *pool_sizes;
	VkDescriptorPoolCreateInfo pool_info = {};
	
	/* (1) Set descriptor pool sizes for each argument */
	pool_sizes = (VkDescriptorPoolSize *) smalloc(shader->nargs * sizeof(VkDescriptorPoolSize));
	for (i = 0; i < shader->nargs; i++)
	{
		pool_sizes[i].type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
		pool_sizes[i].descriptorCount = 1;
	}

	/* (2) Set the parameters for the new descriptor pool */
	pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
	pool_info.poolSizeCount = shader->nargs;
	pool_info.pPoolSizes = pool_sizes;
	pool_info.maxSets = 1;

	/* (3) Create the descriptor pool */
	if (vkCreateDescriptorPool(ctx->device, &pool_info, NULL, &(shader->descriptor_pool)) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to create descriptor pool!\n");
		exit(EXIT_FAILURE);
	}

	DBGPRN((stderr, "[vulkan] vkgpu: %s OK\n", __FUNCTION__));
}


static void create_descriptor_set(vk_gpu_t *gpu, vk_shader_t *shader) 
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	int i;
	VkDescriptorSetAllocateInfo alloc_info = {};

	assert((shader->buffer_infos == NULL) && (shader->desc_sets == NULL));

	shader->buffer_infos = (VkDescriptorBufferInfo*) calloc(shader->nargs, sizeof(VkDescriptorBufferInfo));
	shader->desc_sets = (VkWriteDescriptorSet*) calloc(shader->nargs, sizeof(VkWriteDescriptorSet));

	/* (1) Set the parameters for the new descriptor set */
	alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
	alloc_info.descriptorPool = shader->descriptor_pool;
	alloc_info.descriptorSetCount = 1;
	alloc_info.pSetLayouts = &(shader->descriptor_set_layout);

	DBGPRN((stderr, "[vulkan] vkgpu: %s create descriptor set in\n", __FUNCTION__));

	if (vkAllocateDescriptorSets(ctx->device, &alloc_info, &(shader->descriptor_set)) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: error: Failed to allocate descriptor set!\n");
		exit(EXIT_FAILURE);
	}

	/* (2) For each argument, create a descriptor set */
	for (i = 0; i < shader->nargs; i++)
	{
		shader->buffer_infos[i].buffer = ((vk_devptr_t *) shader->args[i])->buffer;
		DBGPRN((stderr, "[vulkan] vkgpu: %s buffer %p, offset %lu\n", __FUNCTION__, 
		                ((vk_devptr_t *) shader->args[i])->buffer, offsets[i]));
		shader->buffer_infos[i].offset = 0;
		shader->buffer_infos[i].range = VK_WHOLE_SIZE;

		shader->desc_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
		shader->desc_sets[i].dstSet = shader->descriptor_set;
		shader->desc_sets[i].dstBinding = i;
		shader->desc_sets[i].dstArrayElement = 0;
		shader->desc_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
		shader->desc_sets[i].descriptorCount = 1;
		shader->desc_sets[i].pBufferInfo = &(shader->buffer_infos[i]);
		shader->desc_sets[i].pImageInfo = NULL;
		shader->desc_sets[i].pTexelBufferView = NULL;
	}

	DBGPRN((stderr, "[vulkan] vkgpu: %s create descriptor set items OK\n", __FUNCTION__));

	/* (3) Set the descriptor sets */
	vkUpdateDescriptorSets(ctx->device, shader->nargs, shader->desc_sets, 0, NULL);

	DBGPRN((stderr, "[vulkan] vkgpu: %s OK\n", __FUNCTION__));
}


static void record_command_buffer(vk_gpu_t *gpu, vk_shader_t *shader, int thread_limit,
                                 int workgroupX, int workgroupY, int workgroupZ) 
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);

	VkCommandBufferBeginInfo begin_info = {};
	struct push_constants_ {
		int global_id;
		int thread_limit;
	} push_constants;

	/* (2) Parameters about the begin operation of the command buffer */
	begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
	begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;

	push_constants.global_id = gpu->global_id;
	push_constants.thread_limit = thread_limit;

	/* (3) Begin the command buffer */
	vkResetCommandBuffer(ctx->command_buffer, 0);
	vkBeginCommandBuffer(ctx->command_buffer, &begin_info);
	vkCmdPushConstants(ctx->command_buffer, shader->pipeline_layout,
	                   VK_SHADER_STAGE_COMPUTE_BIT, 0, 2*sizeof(int), &(push_constants));

	vkCmdBindPipeline(ctx->command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, shader->compute_pipeline);
	vkCmdBindDescriptorSets(ctx->command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, shader->pipeline_layout, 
	                        0, 1, &(shader->descriptor_set), 0, NULL);
	vkCmdDispatch(ctx->command_buffer, workgroupX, workgroupY, workgroupZ);

	vkEndCommandBuffer(ctx->command_buffer);
	DBGPRN((stderr, "[vulkan] vkgpu: %s OK\n", __FUNCTION__));
}


static void run_compute_shader(vk_gpu_t *gpu, vk_shader_t *shader) 
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	VkSubmitInfo submit_info = {};

	/* (1) Specify the queue submit operation */
	submit_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
	submit_info.commandBufferCount = 1;
	submit_info.pCommandBuffers = &(ctx->command_buffer);

	/* (2) Submit the queue */
	if (vkQueueSubmit(ctx->compute_queue, 1, &submit_info, VK_NULL_HANDLE) != VK_SUCCESS) 
	{
		fprintf(stderr, "[vulkan] vkgpu: %s: error: Failed to submit compute shader!\n",
		                __FUNCTION__);
		exit(EXIT_FAILURE);
	}

	vkQueueWaitIdle(ctx->compute_queue);
	DBGPRN((stderr, "[vulkan] vkgpu: %s OK\n", __FUNCTION__));
}


/**
 * @brief Creates a new shader for execution on a Vulkan GPU.
 * 
 * @param gpu             the Vulkan GPU
 * @param shader_name     the name of the shader
 * @param shader_filename the filename of the shader
 * @param host_func       the host function
 * @return the new shader id 
 */
int vkgpu_new_shader(vk_gpu_t *gpu, char *shader_name, char *shader_filename, 
                     char *sources_filename, void *host_func)
{
	vk_shader_t *shader;
	int i, shader_id = 0;
	bool cached = false;

	/* We suppose that the shader is not cached, thus given a new id */
	shader_id = gpu->nshaders % VK_CACHE_SIZE;

	if (gpu->nshaders <= VK_CACHE_SIZE)
	{
		pthread_mutex_init(&(gpu->shader_cache[gpu->nshaders].lock), NULL);
		for (i = 0; i < gpu->nshaders; i++)
		{
			if (gpu->shader_cache[i].host_func == host_func)
			{
				shader_id = i;
				cached = true;
			}
		}
	}
	else
	{
		for (i = 1; i <= VK_CACHE_SIZE; i++)
		{
			if (gpu->shader_cache[MODSUB(gpu->nshaders, i)].host_func == host_func)
			{
				shader_id = MODSUB(gpu->nshaders, i);
				cached = true;
			}
		}
	}
		
	shader = (vk_shader_t *) &(gpu->shader_cache[shader_id]);

	if (!cached)
	{
		shader->nargs = 0;
		shader->args = NULL;
		shader->name = shader_name;
		shader->filename = shader_filename;
		shader->sources_filename = sources_filename;
		shader->host_func = host_func;
		shader->buffer_infos = NULL;
		shader->desc_sets = NULL;
		shader->descriptor_set = NULL;
		shader->descriptor_set_layout = NULL;
		shader->compute_pipeline = NULL;
		shader->pipeline_layout = NULL;
		shader->compute_shader_module = NULL;
		shader->code = NULL;
		shader->code_size = 0;
		gpu->nshaders++;
		DBGPRN((stderr, "[vulkan] %s shader created with id %d\n", __FUNCTION__, shader_id));
	}
	else
	{
		DBGPRN((stderr, "[vulkan] %s shader cached (id = %d)\n", __FUNCTION__, shader_id));
	}

	return shader_id;
}


/**
 * @brief Cleans up a Vulkan shader belonging to a specific GPU
 * 
 * @param gpu       the Vulkan GPU owning the shader
 * @param shader_id the shader id
 */
void vkgpu_cleanup_shader(vk_gpu_t *gpu, int shader_id)
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	vk_shader_t *shader = &(gpu->shader_cache[shader_id]);

	if (shader->args)
	{
		free(shader->args);
		shader->args = NULL;
	}
	shader->nargs = 0;

	if (shader->buffer_infos)
	{
		free(shader->buffer_infos);
		shader->buffer_infos = NULL;
	}

	if (shader->desc_sets)
	{
		free(shader->desc_sets);
		shader->desc_sets = NULL;
	}

	vkDestroyDescriptorPool(ctx->device, shader->descriptor_pool, NULL);
	vkDestroyDescriptorSetLayout(ctx->device, shader->descriptor_set_layout, NULL);
	vkDestroyPipeline(ctx->device, shader->compute_pipeline, NULL);
	vkDestroyPipelineLayout(ctx->device, shader->pipeline_layout, NULL);
	vkDestroyShaderModule(ctx->device, shader->compute_shader_module, NULL);
}


/* Decodes the given ULL and sets the workgroup dimensions (X, Y, Z) */
static void set_launch_dimensions(vk_dimensions_t *workgroup_dims, int num_1d, 
                                  unsigned long long dimensions_3)
{
	unsigned long x, y, z;

	_ull_decode3(dimensions_3, &x, &y, &z);
	workgroup_dims->X = x ? ((!y && !z) ? num_1d : x) : num_1d;
	workgroup_dims->Y = y ? y : 1;
	workgroup_dims->Z = z ? z : 1;

	DBGPRN((stderr, "[vulkan] vkgpu: %s %dx%dx%d\n", __FUNCTION__, 
	                workgroup_dims->X, workgroup_dims->Y, workgroup_dims->Z));
}


/** 
 * @brief Copies data from the host to the device memory of a Vulkan GPU
 * 
 * @param gpu         the Vulkan GPU
 * @param devmem      the device memory
 * @param devoffset   the device memory offset
 * @param hostaddr    the host address
 * @param hostoffset  the host address offset
 * @param size        the size of the data to be copied
 */
void vkgpu_host2dev(vk_gpu_t *gpu, VkDeviceMemory devmem, VkDeviceSize devoffset, 
                    void *hostaddr, size_t hostoffset, size_t size)
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	void *data;
	DBGPRN((stderr, "[vulkan] vkgpu: %s %p %p %zu\n", 
	                __FUNCTION__, hostaddr, devmem, size));

	/* (1) Map the device memory */
	vkMapMemory(ctx->device, (VkDeviceMemory) devmem, (VkDeviceSize) devoffset, 
	            (VkDeviceSize) size, 0, (void**)&data);

	/* (2) Copy the data from the host to the device */
	memcpy(data, hostaddr + hostoffset, size);

	/* (3) Unmap the device memory */
	vkUnmapMemory(ctx->device, (VkDeviceMemory) devmem);
}


/** 
 * @brief Copies data from the device memory of a Vulkan GPU to the host
 * 
 * @param gpu         the Vulkan GPU
 * @param devmem      the device memory
 * @param devoffset   the device memory offset
 * @param hostaddr    the host address
 * @param hostoffset  the host address offset
 * @param size        the size of the data to be copied
 */
void vkgpu_dev2host(vk_gpu_t *gpu, VkDeviceMemory devmem, VkDeviceSize devoffset, 
                    void *hostaddr, size_t hostoffset, size_t size)
{
	vk_gpu_ctx_t *ctx = vkgpu_ctx_get(gpu);
	void *data;
	DBGPRN((stderr, "[vulkan] vkgpu: %s %p %p %zu\n", 
	                __FUNCTION__, hostaddr, devmem, size));

	/* (1) Map the device memory */
	vkMapMemory(ctx->device, (VkDeviceMemory) devmem, (VkDeviceSize) devoffset, 
	           (VkDeviceSize) size, 0, (void**)&data);

	/* (2) Copy the data from the device to the host */
	memcpy(hostaddr + hostoffset, data, size);

	/* (3) Unmap the device memory */
	vkUnmapMemory(ctx->device, (VkDeviceMemory) devmem);
}


/**
 * @brief Launches a Vulkan shader on a specific Vulkan GPU
 * 
 * @param gpu          the Vulkan GPU
 * @param shader_id    the shader id
 * @param devdata      the devdata structure
 * @param devdata_size the size of the devdata structure
 * @param num_args     array holding number of arguments (decl, fip, mapped)
 * @param args         the arguments array
 * @param num_teams    number of teams to be launched (fallback)
 * @param num_threads  number of threads to be launched (fallback)
 * @param teamdims     number of teams to be launched for 3 dimensions (encoded)
 * @param thrdims      number of threads to be launched for 3 dimensions (encoded)
 * @param thread_limit an upper limit for the # of threads
 */
int vkgpu_launch_shader(vk_gpu_t *gpu, int shader_id, void *devdata, size_t devdata_size, 
                        int *num_args, void **args, uint32_t num_teams, int num_threads, 
                        unsigned long long teamdims, unsigned long long thrdims, 
                        int thread_limit)
{
	vk_shader_t *shader;
	vk_dimensions_t *workgroup_sizes, *invocation_sizes;
	DBGPRN((stderr, "[vulkan] vkgpu: %s\n", __FUNCTION__));

	if (shader_id >= gpu->nshaders)
	{
		fprintf(stderr, "%s: error: invalid shader ID", __FUNCTION__);
		return 1;
	}

	shader = (vk_shader_t *) &(gpu->shader_cache[shader_id]);

	workgroup_sizes = &(shader->workgroups);
	invocation_sizes = &(shader->invocations);
	
	/* (1) Set the workgroup and invocation dimension sizes */
	set_launch_dimensions(workgroup_sizes, num_teams, teamdims);
	set_launch_dimensions(invocation_sizes, num_threads, thrdims);

	/* (2) Copy all arguments to a new array */
	pthread_mutex_lock(&(shader->lock));
	handle_shader_args(gpu, shader, devdata, devdata_size, num_args, args);
	pthread_mutex_unlock(&(shader->lock));

	/* (4) Create the descriptor set layout & compute pipeline */
	create_descriptor_set_layout(gpu, shader, shader->nargs);
	create_compute_pipeline(gpu, shader, invocation_sizes->X, invocation_sizes->Y, invocation_sizes->Z);

	/* (5) Create the descriptor pool & set */
	create_descriptor_pool(gpu, shader);
	create_descriptor_set(gpu, shader);

	/* (6) Create the command buffer based on the given workgroup sizes */
	record_command_buffer(gpu, shader, thread_limit, workgroup_sizes->X, 
	                      workgroup_sizes->Y, workgroup_sizes->Z);

	/* (7) Run the shader */
	run_compute_shader(gpu, shader);

	gpu->num_launched_shaders++;
	return 0;
}


/**
 * @brief Finalizes a Vulkan GPU
 * 
 * @param gpu the Vulkan GPU
 */
void vkgpu_finalize_device(vk_gpu_t *gpu) 
{
	vkgpu_ctx_test_destroy(gpu);
#ifdef VK_ENABLE_CONCURRENCY
	pthread_key_delete(gpu->context_key);
#endif
}
