1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file provides a library for running a module on a Vulkan device. 10 // Implements a Vulkan runtime. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VulkanRuntime.h" 15 16 #include <chrono> 17 #include <cstring> 18 // TODO: It's generally bad to access stdout/stderr in a library. 19 // Figure out a better way for error reporting. 20 #include <iomanip> 21 #include <iostream> 22 23 inline void emitVulkanError(const char *api, VkResult error) { 24 std::cerr << " failed with error code " << error << " when executing " << api; 25 } 26 27 #define RETURN_ON_VULKAN_ERROR(result, api) \ 28 if ((result) != VK_SUCCESS) { \ 29 emitVulkanError(api, (result)); \ 30 return failure(); \ 31 } 32 33 using namespace mlir; 34 35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) { 36 numWorkGroups = numberWorkGroups; 37 } 38 39 void VulkanRuntime::setResourceStorageClassBindingMap( 40 const ResourceStorageClassBindingMap &stClassData) { 41 resourceStorageClassData = stClassData; 42 } 43 44 void VulkanRuntime::setResourceData( 45 const DescriptorSetIndex desIndex, const BindingIndex bindIndex, 46 const VulkanHostMemoryBuffer &hostMemBuffer) { 47 resourceData[desIndex][bindIndex] = hostMemBuffer; 48 resourceStorageClassData[desIndex][bindIndex] = 49 SPIRVStorageClass::StorageBuffer; 50 } 51 52 void VulkanRuntime::setEntryPoint(const char *entryPointName) { 53 entryPoint = entryPointName; 54 } 55 56 void VulkanRuntime::setResourceData(const ResourceData &resData) { 57 resourceData = resData; 58 } 59 60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) { 61 binary = shader; 62 binarySize = size; 63 } 64 65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType( 66 SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) { 67 switch (storageClass) { 68 case SPIRVStorageClass::StorageBuffer: 69 descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; 70 break; 71 case SPIRVStorageClass::Uniform: 72 descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; 73 break; 74 } 75 return success(); 76 } 77 78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag( 79 SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) { 80 switch (storageClass) { 81 case SPIRVStorageClass::StorageBuffer: 82 bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; 83 break; 84 case SPIRVStorageClass::Uniform: 85 bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; 86 break; 87 } 88 return success(); 89 } 90 91 LogicalResult VulkanRuntime::countDeviceMemorySize() { 92 for (const auto &resourceDataMapPair : resourceData) { 93 const auto &resourceDataMap = resourceDataMapPair.second; 94 for (const auto &resourceDataBindingPair : resourceDataMap) { 95 if (resourceDataBindingPair.second.size) { 96 memorySize += resourceDataBindingPair.second.size; 97 } else { 98 std::cerr << "expected buffer size greater than zero for resource data"; 99 return failure(); 100 } 101 } 102 } 103 return success(); 104 } 105 106 LogicalResult VulkanRuntime::initRuntime() { 107 if (resourceData.empty()) { 108 std::cerr << "Vulkan runtime needs at least one resource"; 109 return failure(); 110 } 111 if (!binarySize || !binary) { 112 std::cerr << "binary shader size must be greater than zero"; 113 return failure(); 114 } 115 if (failed(countDeviceMemorySize())) { 116 return failure(); 117 } 118 return success(); 119 } 120 121 LogicalResult VulkanRuntime::destroy() { 122 // According to Vulkan spec: 123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be 124 // used to gate the destruction of the device. Prior to destroying a device, 125 // an application is responsible for destroying/freeing any Vulkan objects 126 // that were created using that device as the first parameter of the 127 // corresponding vkCreate* or vkAllocate* command." 128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle"); 129 130 // Free and destroy. 131 vkFreeCommandBuffers(device, commandPool, commandBuffers.size(), 132 commandBuffers.data()); 133 vkDestroyQueryPool(device, queryPool, nullptr); 134 vkDestroyCommandPool(device, commandPool, nullptr); 135 vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(), 136 descriptorSets.data()); 137 vkDestroyDescriptorPool(device, descriptorPool, nullptr); 138 vkDestroyPipeline(device, pipeline, nullptr); 139 vkDestroyPipelineLayout(device, pipelineLayout, nullptr); 140 for (auto &descriptorSetLayout : descriptorSetLayouts) { 141 vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); 142 } 143 vkDestroyShaderModule(device, shaderModule, nullptr); 144 145 // For each descriptor set. 146 for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 147 auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 148 // For each descriptor binding. 149 for (auto &memoryBuffer : deviceMemoryBuffers) { 150 vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr); 151 vkFreeMemory(device, memoryBuffer.hostMemory, nullptr); 152 vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr); 153 vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr); 154 } 155 } 156 157 vkDestroyDevice(device, nullptr); 158 vkDestroyInstance(instance, nullptr); 159 return success(); 160 } 161 162 LogicalResult VulkanRuntime::run() { 163 // Create logical device, shader module and memory buffers. 164 if (failed(createInstance()) || failed(createDevice()) || 165 failed(createMemoryBuffers()) || failed(createShaderModule())) { 166 return failure(); 167 } 168 169 // Descriptor bindings divided into sets. Each descriptor binding 170 // must have a layout binding attached into a descriptor set layout. 171 // Each layout set must be binded into a pipeline layout. 172 initDescriptorSetLayoutBindingMap(); 173 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) || 174 // Each descriptor set must be allocated from a descriptor pool. 175 failed(createComputePipeline()) || failed(createDescriptorPool()) || 176 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) || 177 // Create command buffer. 178 failed(createCommandPool()) || failed(createQueryPool()) || 179 failed(createComputeCommandBuffer())) { 180 return failure(); 181 } 182 183 // Get working queue. 184 vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue); 185 186 if (failed(copyResource(/*deviceToHost=*/false))) 187 return failure(); 188 189 auto submitStart = std::chrono::high_resolution_clock::now(); 190 // Submit command buffer into the queue. 191 if (failed(submitCommandBuffersToQueue())) 192 return failure(); 193 auto submitEnd = std::chrono::high_resolution_clock::now(); 194 195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 196 auto execEnd = std::chrono::high_resolution_clock::now(); 197 198 auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>( 199 submitEnd - submitStart); 200 auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>( 201 execEnd - submitEnd); 202 203 if (queryPool != VK_NULL_HANDLE) { 204 uint64_t timestamps[2]; 205 RETURN_ON_VULKAN_ERROR( 206 vkGetQueryPoolResults( 207 device, queryPool, /*firstQuery=*/0, /*queryCount=*/2, 208 /*dataSize=*/sizeof(timestamps), 209 /*pData=*/reinterpret_cast<void *>(timestamps), 210 /*stride=*/sizeof(uint64_t), 211 VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT), 212 "vkGetQueryPoolResults"); 213 float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000; 214 std::cout << "Compute shader execution time: " << std::setprecision(3) 215 << microsec << "us\n"; 216 } 217 218 std::cout << "Command buffer submit time: " << submitDuration.count() 219 << "us\nWait idle time: " << execDuration.count() << "us\n"; 220 221 return success(); 222 } 223 224 LogicalResult VulkanRuntime::createInstance() { 225 VkApplicationInfo applicationInfo = {}; 226 applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; 227 applicationInfo.pNext = nullptr; 228 applicationInfo.pApplicationName = "MLIR Vulkan runtime"; 229 applicationInfo.applicationVersion = 0; 230 applicationInfo.pEngineName = "mlir"; 231 applicationInfo.engineVersion = 0; 232 applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0); 233 234 VkInstanceCreateInfo instanceCreateInfo = {}; 235 instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; 236 instanceCreateInfo.pNext = nullptr; 237 instanceCreateInfo.pApplicationInfo = &applicationInfo; 238 instanceCreateInfo.enabledLayerCount = 0; 239 instanceCreateInfo.ppEnabledLayerNames = nullptr; 240 241 std::vector<const char *> extNames; 242 #if defined(__APPLE__) 243 // enumerate MoltenVK for Vulkan 1.0 244 instanceCreateInfo.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR; 245 // add KHR portability instance extensions 246 extNames.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); 247 extNames.push_back(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME); 248 #else 249 instanceCreateInfo.flags = 0; 250 #endif // __APPLE__ 251 instanceCreateInfo.enabledExtensionCount = 252 static_cast<uint32_t>(extNames.size()); 253 instanceCreateInfo.ppEnabledExtensionNames = extNames.data(); 254 255 RETURN_ON_VULKAN_ERROR( 256 vkCreateInstance(&instanceCreateInfo, nullptr, &instance), 257 "vkCreateInstance"); 258 return success(); 259 } 260 261 LogicalResult VulkanRuntime::createDevice() { 262 uint32_t physicalDeviceCount = 0; 263 RETURN_ON_VULKAN_ERROR( 264 vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, nullptr), 265 "vkEnumeratePhysicalDevices"); 266 267 std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount); 268 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance, 269 &physicalDeviceCount, 270 physicalDevices.data()), 271 "vkEnumeratePhysicalDevices"); 272 273 RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE, 274 "physicalDeviceCount"); 275 276 // TODO: find the best device. 277 physicalDevice = physicalDevices.front(); 278 if (failed(getBestComputeQueue())) 279 return failure(); 280 281 const float queuePriority = 1.0f; 282 VkDeviceQueueCreateInfo deviceQueueCreateInfo = {}; 283 deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; 284 deviceQueueCreateInfo.pNext = nullptr; 285 deviceQueueCreateInfo.flags = 0; 286 deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex; 287 deviceQueueCreateInfo.queueCount = 1; 288 deviceQueueCreateInfo.pQueuePriorities = &queuePriority; 289 290 // Structure specifying parameters of a newly created device. 291 VkDeviceCreateInfo deviceCreateInfo = {}; 292 deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; 293 deviceCreateInfo.pNext = nullptr; 294 deviceCreateInfo.flags = 0; 295 deviceCreateInfo.queueCreateInfoCount = 1; 296 deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo; 297 deviceCreateInfo.enabledLayerCount = 0; 298 deviceCreateInfo.ppEnabledLayerNames = nullptr; 299 deviceCreateInfo.enabledExtensionCount = 0; 300 deviceCreateInfo.ppEnabledExtensionNames = nullptr; 301 deviceCreateInfo.pEnabledFeatures = nullptr; 302 303 RETURN_ON_VULKAN_ERROR( 304 vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device), 305 "vkCreateDevice"); 306 307 VkPhysicalDeviceMemoryProperties properties = {}; 308 vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties); 309 310 // Try to find memory type with following properties: 311 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated 312 // with this type can be mapped for host access using vkMapMemory; 313 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache 314 // management commands vkFlushMappedMemoryRanges and 315 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the 316 // device or make device writes visible to the host, respectively. 317 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 318 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT & 319 properties.memoryTypes[i].propertyFlags) && 320 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT & 321 properties.memoryTypes[i].propertyFlags) && 322 (memorySize <= 323 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 324 hostMemoryTypeIndex = i; 325 break; 326 } 327 } 328 329 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be 330 // used on the device. This will allow better performance access for GPU with 331 // on device memory. 332 for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { 333 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT & 334 properties.memoryTypes[i].propertyFlags) && 335 (memorySize <= 336 properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { 337 deviceMemoryTypeIndex = i; 338 break; 339 } 340 } 341 342 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES || 343 deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES) 344 ? VK_INCOMPLETE 345 : VK_SUCCESS, 346 "invalid memoryTypeIndex"); 347 return success(); 348 } 349 350 LogicalResult VulkanRuntime::getBestComputeQueue() { 351 uint32_t queueFamilyPropertiesCount = 0; 352 vkGetPhysicalDeviceQueueFamilyProperties( 353 physicalDevice, &queueFamilyPropertiesCount, nullptr); 354 355 std::vector<VkQueueFamilyProperties> familyProperties( 356 queueFamilyPropertiesCount); 357 vkGetPhysicalDeviceQueueFamilyProperties( 358 physicalDevice, &queueFamilyPropertiesCount, familyProperties.data()); 359 360 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support 361 // compute operations. Try to find a compute-only queue first if possible. 362 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 363 auto flags = familyProperties[i].queueFlags; 364 if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) { 365 queueFamilyIndex = i; 366 queueFamilyProperties = familyProperties[i]; 367 return success(); 368 } 369 } 370 371 // Otherwise use a queue that can also support graphics. 372 for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { 373 auto flags = familyProperties[i].queueFlags; 374 if ((flags & VK_QUEUE_COMPUTE_BIT)) { 375 queueFamilyIndex = i; 376 queueFamilyProperties = familyProperties[i]; 377 return success(); 378 } 379 } 380 381 std::cerr << "cannot find valid queue"; 382 return failure(); 383 } 384 385 LogicalResult VulkanRuntime::createMemoryBuffers() { 386 // For each descriptor set. 387 for (const auto &resourceDataMapPair : resourceData) { 388 std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers; 389 const auto descriptorSetIndex = resourceDataMapPair.first; 390 const auto &resourceDataMap = resourceDataMapPair.second; 391 392 // For each descriptor binding. 393 for (const auto &resourceDataBindingPair : resourceDataMap) { 394 // Create device memory buffer. 395 VulkanDeviceMemoryBuffer memoryBuffer; 396 memoryBuffer.bindingIndex = resourceDataBindingPair.first; 397 VkDescriptorType descriptorType = {}; 398 VkBufferUsageFlagBits bufferUsage = {}; 399 400 // Check that descriptor set has storage class map. 401 const auto resourceStorageClassMapIt = 402 resourceStorageClassData.find(descriptorSetIndex); 403 if (resourceStorageClassMapIt == resourceStorageClassData.end()) { 404 std::cerr 405 << "cannot find storage class for resource in descriptor set: " 406 << descriptorSetIndex; 407 return failure(); 408 } 409 410 // Check that specific descriptor binding has storage class. 411 const auto &resourceStorageClassMap = resourceStorageClassMapIt->second; 412 const auto resourceStorageClassIt = 413 resourceStorageClassMap.find(resourceDataBindingPair.first); 414 if (resourceStorageClassIt == resourceStorageClassMap.end()) { 415 std::cerr 416 << "cannot find storage class for resource with descriptor index: " 417 << resourceDataBindingPair.first; 418 return failure(); 419 } 420 421 const auto resourceStorageClassBinding = resourceStorageClassIt->second; 422 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding, 423 descriptorType)) || 424 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding, 425 bufferUsage))) { 426 std::cerr << "storage class for resource with descriptor binding: " 427 << resourceDataBindingPair.first 428 << " in the descriptor set: " << descriptorSetIndex 429 << " is not supported "; 430 return failure(); 431 } 432 433 // Set descriptor type for the specific device memory buffer. 434 memoryBuffer.descriptorType = descriptorType; 435 const auto bufferSize = resourceDataBindingPair.second.size; 436 memoryBuffer.bufferSize = bufferSize; 437 // Specify memory allocation info. 438 VkMemoryAllocateInfo memoryAllocateInfo = {}; 439 memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; 440 memoryAllocateInfo.pNext = nullptr; 441 memoryAllocateInfo.allocationSize = bufferSize; 442 memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex; 443 444 // Allocate device memory. 445 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 446 nullptr, 447 &memoryBuffer.hostMemory), 448 "vkAllocateMemory"); 449 memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex; 450 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 451 nullptr, 452 &memoryBuffer.deviceMemory), 453 "vkAllocateMemory"); 454 void *payload; 455 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0, 456 bufferSize, 0, 457 reinterpret_cast<void **>(&payload)), 458 "vkMapMemory"); 459 460 // Copy host memory into the mapped area. 461 std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize); 462 vkUnmapMemory(device, memoryBuffer.hostMemory); 463 464 VkBufferCreateInfo bufferCreateInfo = {}; 465 bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; 466 bufferCreateInfo.pNext = nullptr; 467 bufferCreateInfo.flags = 0; 468 bufferCreateInfo.size = bufferSize; 469 bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT | 470 VK_BUFFER_USAGE_TRANSFER_SRC_BIT; 471 bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; 472 bufferCreateInfo.queueFamilyIndexCount = 1; 473 bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex; 474 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr, 475 &memoryBuffer.hostBuffer), 476 "vkCreateBuffer"); 477 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr, 478 &memoryBuffer.deviceBuffer), 479 "vkCreateBuffer"); 480 481 // Bind buffer and device memory. 482 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer, 483 memoryBuffer.hostMemory, 0), 484 "vkBindBufferMemory"); 485 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, 486 memoryBuffer.deviceBuffer, 487 memoryBuffer.deviceMemory, 0), 488 "vkBindBufferMemory"); 489 490 // Update buffer info. 491 memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer; 492 memoryBuffer.bufferInfo.offset = 0; 493 memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE; 494 deviceMemoryBuffers.push_back(memoryBuffer); 495 } 496 497 // Associate device memory buffers with a descriptor set. 498 deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers; 499 } 500 return success(); 501 } 502 503 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) { 504 VkCommandBufferAllocateInfo commandBufferAllocateInfo = { 505 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, 506 nullptr, 507 commandPool, 508 VK_COMMAND_BUFFER_LEVEL_PRIMARY, 509 1, 510 }; 511 VkCommandBuffer commandBuffer; 512 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 513 &commandBufferAllocateInfo, 514 &commandBuffer), 515 "vkAllocateCommandBuffers"); 516 517 VkCommandBufferBeginInfo commandBufferBeginInfo = { 518 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, 519 nullptr, 520 0, 521 nullptr, 522 }; 523 RETURN_ON_VULKAN_ERROR( 524 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 525 "vkBeginCommandBuffer"); 526 527 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 528 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings; 529 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 530 for (const auto &memBuffer : deviceMemoryBuffers) { 531 VkBufferCopy copy = {0, 0, memBuffer.bufferSize}; 532 if (deviceToHost) 533 vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer, 534 memBuffer.hostBuffer, 1, ©); 535 else 536 vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer, 537 memBuffer.deviceBuffer, 1, ©); 538 } 539 } 540 541 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 542 "vkEndCommandBuffer"); 543 VkSubmitInfo submitInfo = { 544 VK_STRUCTURE_TYPE_SUBMIT_INFO, 545 nullptr, 546 0, 547 nullptr, 548 nullptr, 549 1, 550 &commandBuffer, 551 0, 552 nullptr, 553 }; 554 submitInfo.pCommandBuffers = &commandBuffer; 555 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE), 556 "vkQueueSubmit"); 557 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); 558 559 vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer); 560 return success(); 561 } 562 563 LogicalResult VulkanRuntime::createShaderModule() { 564 VkShaderModuleCreateInfo shaderModuleCreateInfo = {}; 565 shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; 566 shaderModuleCreateInfo.pNext = nullptr; 567 shaderModuleCreateInfo.flags = 0; 568 // Set size in bytes. 569 shaderModuleCreateInfo.codeSize = binarySize; 570 // Set pointer to the binary shader. 571 shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary); 572 RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device, &shaderModuleCreateInfo, 573 nullptr, &shaderModule), 574 "vkCreateShaderModule"); 575 return success(); 576 } 577 578 void VulkanRuntime::initDescriptorSetLayoutBindingMap() { 579 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 580 std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings; 581 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 582 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 583 584 // Create a layout binding for each descriptor. 585 for (const auto &memBuffer : deviceMemoryBuffers) { 586 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {}; 587 descriptorSetLayoutBinding.binding = memBuffer.bindingIndex; 588 descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType; 589 descriptorSetLayoutBinding.descriptorCount = 1; 590 descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; 591 descriptorSetLayoutBinding.pImmutableSamplers = nullptr; 592 descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding); 593 } 594 descriptorSetLayoutBindingMap[descriptorSetIndex] = 595 descriptorSetLayoutBindings; 596 } 597 } 598 599 LogicalResult VulkanRuntime::createDescriptorSetLayout() { 600 for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { 601 const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; 602 const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; 603 // Each descriptor in a descriptor set must be the same type. 604 VkDescriptorType descriptorType = 605 deviceMemoryBuffers.front().descriptorType; 606 const uint32_t descriptorSize = deviceMemoryBuffers.size(); 607 const auto descriptorSetLayoutBindingIt = 608 descriptorSetLayoutBindingMap.find(descriptorSetIndex); 609 610 if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) { 611 std::cerr << "cannot find layout bindings for the set with number: " 612 << descriptorSetIndex; 613 return failure(); 614 } 615 616 const auto &descriptorSetLayoutBindings = 617 descriptorSetLayoutBindingIt->second; 618 // Create descriptor set layout. 619 VkDescriptorSetLayout descriptorSetLayout = {}; 620 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {}; 621 622 descriptorSetLayoutCreateInfo.sType = 623 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; 624 descriptorSetLayoutCreateInfo.pNext = nullptr; 625 descriptorSetLayoutCreateInfo.flags = 0; 626 // Amount of descriptor bindings in a layout set. 627 descriptorSetLayoutCreateInfo.bindingCount = 628 descriptorSetLayoutBindings.size(); 629 descriptorSetLayoutCreateInfo.pBindings = 630 descriptorSetLayoutBindings.data(); 631 RETURN_ON_VULKAN_ERROR( 632 vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 633 nullptr, &descriptorSetLayout), 634 "vkCreateDescriptorSetLayout"); 635 636 descriptorSetLayouts.push_back(descriptorSetLayout); 637 descriptorSetInfoPool.push_back( 638 {descriptorSetIndex, descriptorSize, descriptorType}); 639 } 640 return success(); 641 } 642 643 LogicalResult VulkanRuntime::createPipelineLayout() { 644 // Associate descriptor sets with a pipeline layout. 645 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {}; 646 pipelineLayoutCreateInfo.sType = 647 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; 648 pipelineLayoutCreateInfo.pNext = nullptr; 649 pipelineLayoutCreateInfo.flags = 0; 650 pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size(); 651 pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data(); 652 pipelineLayoutCreateInfo.pushConstantRangeCount = 0; 653 pipelineLayoutCreateInfo.pPushConstantRanges = nullptr; 654 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device, 655 &pipelineLayoutCreateInfo, 656 nullptr, &pipelineLayout), 657 "vkCreatePipelineLayout"); 658 return success(); 659 } 660 661 LogicalResult VulkanRuntime::createComputePipeline() { 662 VkPipelineShaderStageCreateInfo stageInfo = {}; 663 stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; 664 stageInfo.pNext = nullptr; 665 stageInfo.flags = 0; 666 stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT; 667 stageInfo.module = shaderModule; 668 // Set entry point. 669 stageInfo.pName = entryPoint; 670 stageInfo.pSpecializationInfo = nullptr; 671 672 VkComputePipelineCreateInfo computePipelineCreateInfo = {}; 673 computePipelineCreateInfo.sType = 674 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; 675 computePipelineCreateInfo.pNext = nullptr; 676 computePipelineCreateInfo.flags = 0; 677 computePipelineCreateInfo.stage = stageInfo; 678 computePipelineCreateInfo.layout = pipelineLayout; 679 computePipelineCreateInfo.basePipelineHandle = nullptr; 680 computePipelineCreateInfo.basePipelineIndex = 0; 681 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, nullptr, 1, 682 &computePipelineCreateInfo, 683 nullptr, &pipeline), 684 "vkCreateComputePipelines"); 685 return success(); 686 } 687 688 LogicalResult VulkanRuntime::createDescriptorPool() { 689 std::vector<VkDescriptorPoolSize> descriptorPoolSizes; 690 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 691 // For each descriptor set populate descriptor pool size. 692 VkDescriptorPoolSize descriptorPoolSize = {}; 693 descriptorPoolSize.type = descriptorSetInfo.descriptorType; 694 descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize; 695 descriptorPoolSizes.push_back(descriptorPoolSize); 696 } 697 698 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {}; 699 descriptorPoolCreateInfo.sType = 700 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; 701 descriptorPoolCreateInfo.pNext = nullptr; 702 descriptorPoolCreateInfo.flags = 0; 703 descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size(); 704 descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size(); 705 descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data(); 706 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device, 707 &descriptorPoolCreateInfo, 708 nullptr, &descriptorPool), 709 "vkCreateDescriptorPool"); 710 return success(); 711 } 712 713 LogicalResult VulkanRuntime::allocateDescriptorSets() { 714 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {}; 715 // Size of descriptor sets and descriptor layout sets is the same. 716 descriptorSets.resize(descriptorSetLayouts.size()); 717 descriptorSetAllocateInfo.sType = 718 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; 719 descriptorSetAllocateInfo.pNext = nullptr; 720 descriptorSetAllocateInfo.descriptorPool = descriptorPool; 721 descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size(); 722 descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data(); 723 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device, 724 &descriptorSetAllocateInfo, 725 descriptorSets.data()), 726 "vkAllocateDescriptorSets"); 727 return success(); 728 } 729 730 LogicalResult VulkanRuntime::setWriteDescriptors() { 731 if (descriptorSets.size() != descriptorSetInfoPool.size()) { 732 std::cerr << "Each descriptor set must have descriptor set information"; 733 return failure(); 734 } 735 // For each descriptor set. 736 auto descriptorSetIt = descriptorSets.begin(); 737 // Each descriptor set is associated with descriptor set info. 738 for (const auto &descriptorSetInfo : descriptorSetInfoPool) { 739 // For each device memory buffer in the descriptor set. 740 const auto &deviceMemoryBuffers = 741 deviceMemoryBufferMap[descriptorSetInfo.descriptorSet]; 742 for (const auto &memoryBuffer : deviceMemoryBuffers) { 743 // Structure describing descriptor sets to write to. 744 VkWriteDescriptorSet wSet = {}; 745 wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; 746 wSet.pNext = nullptr; 747 // Descriptor set. 748 wSet.dstSet = *descriptorSetIt; 749 wSet.dstBinding = memoryBuffer.bindingIndex; 750 wSet.dstArrayElement = 0; 751 wSet.descriptorCount = 1; 752 wSet.descriptorType = memoryBuffer.descriptorType; 753 wSet.pImageInfo = nullptr; 754 wSet.pBufferInfo = &memoryBuffer.bufferInfo; 755 wSet.pTexelBufferView = nullptr; 756 vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr); 757 } 758 // Increment descriptor set iterator. 759 ++descriptorSetIt; 760 } 761 return success(); 762 } 763 764 LogicalResult VulkanRuntime::createCommandPool() { 765 VkCommandPoolCreateInfo commandPoolCreateInfo = {}; 766 commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; 767 commandPoolCreateInfo.pNext = nullptr; 768 commandPoolCreateInfo.flags = 0; 769 commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex; 770 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo, 771 /*pAllocator=*/nullptr, 772 &commandPool), 773 "vkCreateCommandPool"); 774 return success(); 775 } 776 777 LogicalResult VulkanRuntime::createQueryPool() { 778 // Return directly if timestamp query is not supported. 779 if (queueFamilyProperties.timestampValidBits == 0) 780 return success(); 781 782 // Get timestamp period for this physical device. 783 VkPhysicalDeviceProperties deviceProperties = {}; 784 vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties); 785 timestampPeriod = deviceProperties.limits.timestampPeriod; 786 787 // Create query pool. 788 VkQueryPoolCreateInfo queryPoolCreateInfo = {}; 789 queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; 790 queryPoolCreateInfo.pNext = nullptr; 791 queryPoolCreateInfo.flags = 0; 792 queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP; 793 queryPoolCreateInfo.queryCount = 2; 794 queryPoolCreateInfo.pipelineStatistics = 0; 795 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo, 796 /*pAllocator=*/nullptr, &queryPool), 797 "vkCreateQueryPool"); 798 799 return success(); 800 } 801 802 LogicalResult VulkanRuntime::createComputeCommandBuffer() { 803 VkCommandBufferAllocateInfo commandBufferAllocateInfo = {}; 804 commandBufferAllocateInfo.sType = 805 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; 806 commandBufferAllocateInfo.pNext = nullptr; 807 commandBufferAllocateInfo.commandPool = commandPool; 808 commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; 809 commandBufferAllocateInfo.commandBufferCount = 1; 810 811 VkCommandBuffer commandBuffer; 812 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, 813 &commandBufferAllocateInfo, 814 &commandBuffer), 815 "vkAllocateCommandBuffers"); 816 817 VkCommandBufferBeginInfo commandBufferBeginInfo = {}; 818 commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; 819 commandBufferBeginInfo.pNext = nullptr; 820 commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; 821 commandBufferBeginInfo.pInheritanceInfo = nullptr; 822 823 // Commands begin. 824 RETURN_ON_VULKAN_ERROR( 825 vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), 826 "vkBeginCommandBuffer"); 827 828 if (queryPool != VK_NULL_HANDLE) 829 vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2); 830 831 vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); 832 vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, 833 pipelineLayout, 0, descriptorSets.size(), 834 descriptorSets.data(), 0, nullptr); 835 // Get a timestamp before invoking the compute shader. 836 if (queryPool != VK_NULL_HANDLE) 837 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, 838 queryPool, 0); 839 vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y, 840 numWorkGroups.z); 841 // Get another timestamp after invoking the compute shader. 842 if (queryPool != VK_NULL_HANDLE) 843 vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, 844 queryPool, 1); 845 846 // Commands end. 847 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), 848 "vkEndCommandBuffer"); 849 850 commandBuffers.push_back(commandBuffer); 851 return success(); 852 } 853 854 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() { 855 VkSubmitInfo submitInfo = {}; 856 submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; 857 submitInfo.pNext = nullptr; 858 submitInfo.waitSemaphoreCount = 0; 859 submitInfo.pWaitSemaphores = nullptr; 860 submitInfo.pWaitDstStageMask = nullptr; 861 submitInfo.commandBufferCount = commandBuffers.size(); 862 submitInfo.pCommandBuffers = commandBuffers.data(); 863 submitInfo.signalSemaphoreCount = 0; 864 submitInfo.pSignalSemaphores = nullptr; 865 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, nullptr), 866 "vkQueueSubmit"); 867 return success(); 868 } 869 870 LogicalResult VulkanRuntime::updateHostMemoryBuffers() { 871 // First copy back the data to the staging buffer. 872 (void)copyResource(/*deviceToHost=*/true); 873 874 // For each descriptor set. 875 for (auto &resourceDataMapPair : resourceData) { 876 auto &resourceDataMap = resourceDataMapPair.second; 877 auto &deviceMemoryBuffers = 878 deviceMemoryBufferMap[resourceDataMapPair.first]; 879 // For each device memory buffer in the set. 880 for (auto &deviceMemoryBuffer : deviceMemoryBuffers) { 881 if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) { 882 void *payload; 883 auto &hostMemoryBuffer = 884 resourceDataMap[deviceMemoryBuffer.bindingIndex]; 885 RETURN_ON_VULKAN_ERROR(vkMapMemory(device, 886 deviceMemoryBuffer.hostMemory, 0, 887 hostMemoryBuffer.size, 0, 888 reinterpret_cast<void **>(&payload)), 889 "vkMapMemory"); 890 std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size); 891 vkUnmapMemory(device, deviceMemoryBuffer.hostMemory); 892 } 893 } 894 } 895 return success(); 896 } 897