xref: /llvm-project/mlir/lib/ExecutionEngine/VulkanRuntime.cpp (revision e7e3c45bc70904e24e2b3221ac8521e67eb84668)
1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VulkanRuntime.h"
15 
16 #include <chrono>
17 #include <cstring>
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
20 #include <iomanip>
21 #include <iostream>
22 
23 inline void emitVulkanError(const char *api, VkResult error) {
24   std::cerr << " failed with error code " << error << " when executing " << api;
25 }
26 
27 #define RETURN_ON_VULKAN_ERROR(result, api)                                    \
28   if ((result) != VK_SUCCESS) {                                                \
29     emitVulkanError(api, (result));                                            \
30     return failure();                                                          \
31   }
32 
33 using namespace mlir;
34 
35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36   numWorkGroups = numberWorkGroups;
37 }
38 
39 void VulkanRuntime::setResourceStorageClassBindingMap(
40     const ResourceStorageClassBindingMap &stClassData) {
41   resourceStorageClassData = stClassData;
42 }
43 
44 void VulkanRuntime::setResourceData(
45     const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46     const VulkanHostMemoryBuffer &hostMemBuffer) {
47   resourceData[desIndex][bindIndex] = hostMemBuffer;
48   resourceStorageClassData[desIndex][bindIndex] =
49       SPIRVStorageClass::StorageBuffer;
50 }
51 
52 void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53   entryPoint = entryPointName;
54 }
55 
56 void VulkanRuntime::setResourceData(const ResourceData &resData) {
57   resourceData = resData;
58 }
59 
60 void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61   binary = shader;
62   binarySize = size;
63 }
64 
65 LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66     SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67   switch (storageClass) {
68   case SPIRVStorageClass::StorageBuffer:
69     descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70     break;
71   case SPIRVStorageClass::Uniform:
72     descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73     break;
74   }
75   return success();
76 }
77 
78 LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79     SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80   switch (storageClass) {
81   case SPIRVStorageClass::StorageBuffer:
82     bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83     break;
84   case SPIRVStorageClass::Uniform:
85     bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86     break;
87   }
88   return success();
89 }
90 
91 LogicalResult VulkanRuntime::countDeviceMemorySize() {
92   for (const auto &resourceDataMapPair : resourceData) {
93     const auto &resourceDataMap = resourceDataMapPair.second;
94     for (const auto &resourceDataBindingPair : resourceDataMap) {
95       if (resourceDataBindingPair.second.size) {
96         memorySize += resourceDataBindingPair.second.size;
97       } else {
98         std::cerr << "expected buffer size greater than zero for resource data";
99         return failure();
100       }
101     }
102   }
103   return success();
104 }
105 
106 LogicalResult VulkanRuntime::initRuntime() {
107   if (resourceData.empty()) {
108     std::cerr << "Vulkan runtime needs at least one resource";
109     return failure();
110   }
111   if (!binarySize || !binary) {
112     std::cerr << "binary shader size must be greater than zero";
113     return failure();
114   }
115   if (failed(countDeviceMemorySize())) {
116     return failure();
117   }
118   return success();
119 }
120 
121 LogicalResult VulkanRuntime::destroy() {
122   // According to Vulkan spec:
123   // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124   // used to gate the destruction of the device. Prior to destroying a device,
125   // an application is responsible for destroying/freeing any Vulkan objects
126   // that were created using that device as the first parameter of the
127   // corresponding vkCreate* or vkAllocate* command."
128   RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129 
130   // Free and destroy.
131   vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132                        commandBuffers.data());
133   vkDestroyQueryPool(device, queryPool, nullptr);
134   vkDestroyCommandPool(device, commandPool, nullptr);
135   vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136                        descriptorSets.data());
137   vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138   vkDestroyPipeline(device, pipeline, nullptr);
139   vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140   for (auto &descriptorSetLayout : descriptorSetLayouts) {
141     vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142   }
143   vkDestroyShaderModule(device, shaderModule, nullptr);
144 
145   // For each descriptor set.
146   for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147     auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148     // For each descriptor binding.
149     for (auto &memoryBuffer : deviceMemoryBuffers) {
150       vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151       vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152       vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153       vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154     }
155   }
156 
157   vkDestroyDevice(device, nullptr);
158   vkDestroyInstance(instance, nullptr);
159   return success();
160 }
161 
162 LogicalResult VulkanRuntime::run() {
163   // Create logical device, shader module and memory buffers.
164   if (failed(createInstance()) || failed(createDevice()) ||
165       failed(createMemoryBuffers()) || failed(createShaderModule())) {
166     return failure();
167   }
168 
169   // Descriptor bindings divided into sets. Each descriptor binding
170   // must have a layout binding attached into a descriptor set layout.
171   // Each layout set must be binded into a pipeline layout.
172   initDescriptorSetLayoutBindingMap();
173   if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174       // Each descriptor set must be allocated from a descriptor pool.
175       failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176       failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177       // Create command buffer.
178       failed(createCommandPool()) || failed(createQueryPool()) ||
179       failed(createComputeCommandBuffer())) {
180     return failure();
181   }
182 
183   // Get working queue.
184   vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185 
186   if (failed(copyResource(/*deviceToHost=*/false)))
187     return failure();
188 
189   auto submitStart = std::chrono::high_resolution_clock::now();
190   // Submit command buffer into the queue.
191   if (failed(submitCommandBuffersToQueue()))
192     return failure();
193   auto submitEnd = std::chrono::high_resolution_clock::now();
194 
195   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196   auto execEnd = std::chrono::high_resolution_clock::now();
197 
198   auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199       submitEnd - submitStart);
200   auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201       execEnd - submitEnd);
202 
203   if (queryPool != VK_NULL_HANDLE) {
204     uint64_t timestamps[2];
205     RETURN_ON_VULKAN_ERROR(
206         vkGetQueryPoolResults(
207             device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208             /*dataSize=*/sizeof(timestamps),
209             /*pData=*/reinterpret_cast<void *>(timestamps),
210             /*stride=*/sizeof(uint64_t),
211             VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212         "vkGetQueryPoolResults");
213     float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214     std::cout << "Compute shader execution time: " << std::setprecision(3)
215               << microsec << "us\n";
216   }
217 
218   std::cout << "Command buffer submit time: " << submitDuration.count()
219             << "us\nWait idle time: " << execDuration.count() << "us\n";
220 
221   return success();
222 }
223 
224 LogicalResult VulkanRuntime::createInstance() {
225   VkApplicationInfo applicationInfo = {};
226   applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227   applicationInfo.pNext = nullptr;
228   applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229   applicationInfo.applicationVersion = 0;
230   applicationInfo.pEngineName = "mlir";
231   applicationInfo.engineVersion = 0;
232   applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233 
234   VkInstanceCreateInfo instanceCreateInfo = {};
235   instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236   instanceCreateInfo.pNext = nullptr;
237   instanceCreateInfo.pApplicationInfo = &applicationInfo;
238   instanceCreateInfo.enabledLayerCount = 0;
239   instanceCreateInfo.ppEnabledLayerNames = nullptr;
240 
241   std::vector<const char *> extNames;
242 #if defined(__APPLE__)
243   // enumerate MoltenVK for Vulkan 1.0
244   instanceCreateInfo.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
245   // add KHR portability instance extensions
246   extNames.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
247   extNames.push_back(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME);
248 #else
249   instanceCreateInfo.flags = 0;
250 #endif // __APPLE__
251   instanceCreateInfo.enabledExtensionCount =
252       static_cast<uint32_t>(extNames.size());
253   instanceCreateInfo.ppEnabledExtensionNames = extNames.data();
254 
255   RETURN_ON_VULKAN_ERROR(
256       vkCreateInstance(&instanceCreateInfo, nullptr, &instance),
257       "vkCreateInstance");
258   return success();
259 }
260 
261 LogicalResult VulkanRuntime::createDevice() {
262   uint32_t physicalDeviceCount = 0;
263   RETURN_ON_VULKAN_ERROR(
264       vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, nullptr),
265       "vkEnumeratePhysicalDevices");
266 
267   std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
268   RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
269                                                     &physicalDeviceCount,
270                                                     physicalDevices.data()),
271                          "vkEnumeratePhysicalDevices");
272 
273   RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
274                          "physicalDeviceCount");
275 
276   // TODO: find the best device.
277   physicalDevice = physicalDevices.front();
278   if (failed(getBestComputeQueue()))
279     return failure();
280 
281   const float queuePriority = 1.0f;
282   VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
283   deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
284   deviceQueueCreateInfo.pNext = nullptr;
285   deviceQueueCreateInfo.flags = 0;
286   deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
287   deviceQueueCreateInfo.queueCount = 1;
288   deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
289 
290   // Structure specifying parameters of a newly created device.
291   VkDeviceCreateInfo deviceCreateInfo = {};
292   deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
293   deviceCreateInfo.pNext = nullptr;
294   deviceCreateInfo.flags = 0;
295   deviceCreateInfo.queueCreateInfoCount = 1;
296   deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
297   deviceCreateInfo.enabledLayerCount = 0;
298   deviceCreateInfo.ppEnabledLayerNames = nullptr;
299   deviceCreateInfo.enabledExtensionCount = 0;
300   deviceCreateInfo.ppEnabledExtensionNames = nullptr;
301   deviceCreateInfo.pEnabledFeatures = nullptr;
302 
303   RETURN_ON_VULKAN_ERROR(
304       vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device),
305       "vkCreateDevice");
306 
307   VkPhysicalDeviceMemoryProperties properties = {};
308   vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
309 
310   // Try to find memory type with following properties:
311   // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
312   // with this type can be mapped for host access using vkMapMemory;
313   // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
314   // management commands vkFlushMappedMemoryRanges and
315   // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
316   // device or make device writes visible to the host, respectively.
317   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
318     if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
319          properties.memoryTypes[i].propertyFlags) &&
320         (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
321          properties.memoryTypes[i].propertyFlags) &&
322         (memorySize <=
323          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
324       hostMemoryTypeIndex = i;
325       break;
326     }
327   }
328 
329   // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
330   // used on the device. This will allow better performance access for GPU with
331   // on device memory.
332   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
333     if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
334          properties.memoryTypes[i].propertyFlags) &&
335         (memorySize <=
336          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
337       deviceMemoryTypeIndex = i;
338       break;
339     }
340   }
341 
342   RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
343                           deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
344                              ? VK_INCOMPLETE
345                              : VK_SUCCESS,
346                          "invalid memoryTypeIndex");
347   return success();
348 }
349 
350 LogicalResult VulkanRuntime::getBestComputeQueue() {
351   uint32_t queueFamilyPropertiesCount = 0;
352   vkGetPhysicalDeviceQueueFamilyProperties(
353       physicalDevice, &queueFamilyPropertiesCount, nullptr);
354 
355   std::vector<VkQueueFamilyProperties> familyProperties(
356       queueFamilyPropertiesCount);
357   vkGetPhysicalDeviceQueueFamilyProperties(
358       physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
359 
360   // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
361   // compute operations. Try to find a compute-only queue first if possible.
362   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
363     auto flags = familyProperties[i].queueFlags;
364     if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
365       queueFamilyIndex = i;
366       queueFamilyProperties = familyProperties[i];
367       return success();
368     }
369   }
370 
371   // Otherwise use a queue that can also support graphics.
372   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
373     auto flags = familyProperties[i].queueFlags;
374     if ((flags & VK_QUEUE_COMPUTE_BIT)) {
375       queueFamilyIndex = i;
376       queueFamilyProperties = familyProperties[i];
377       return success();
378     }
379   }
380 
381   std::cerr << "cannot find valid queue";
382   return failure();
383 }
384 
385 LogicalResult VulkanRuntime::createMemoryBuffers() {
386   // For each descriptor set.
387   for (const auto &resourceDataMapPair : resourceData) {
388     std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
389     const auto descriptorSetIndex = resourceDataMapPair.first;
390     const auto &resourceDataMap = resourceDataMapPair.second;
391 
392     // For each descriptor binding.
393     for (const auto &resourceDataBindingPair : resourceDataMap) {
394       // Create device memory buffer.
395       VulkanDeviceMemoryBuffer memoryBuffer;
396       memoryBuffer.bindingIndex = resourceDataBindingPair.first;
397       VkDescriptorType descriptorType = {};
398       VkBufferUsageFlagBits bufferUsage = {};
399 
400       // Check that descriptor set has storage class map.
401       const auto resourceStorageClassMapIt =
402           resourceStorageClassData.find(descriptorSetIndex);
403       if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
404         std::cerr
405             << "cannot find storage class for resource in descriptor set: "
406             << descriptorSetIndex;
407         return failure();
408       }
409 
410       // Check that specific descriptor binding has storage class.
411       const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
412       const auto resourceStorageClassIt =
413           resourceStorageClassMap.find(resourceDataBindingPair.first);
414       if (resourceStorageClassIt == resourceStorageClassMap.end()) {
415         std::cerr
416             << "cannot find storage class for resource with descriptor index: "
417             << resourceDataBindingPair.first;
418         return failure();
419       }
420 
421       const auto resourceStorageClassBinding = resourceStorageClassIt->second;
422       if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
423                                                  descriptorType)) ||
424           failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
425                                                   bufferUsage))) {
426         std::cerr << "storage class for resource with descriptor binding: "
427                   << resourceDataBindingPair.first
428                   << " in the descriptor set: " << descriptorSetIndex
429                   << " is not supported ";
430         return failure();
431       }
432 
433       // Set descriptor type for the specific device memory buffer.
434       memoryBuffer.descriptorType = descriptorType;
435       const auto bufferSize = resourceDataBindingPair.second.size;
436       memoryBuffer.bufferSize = bufferSize;
437       // Specify memory allocation info.
438       VkMemoryAllocateInfo memoryAllocateInfo = {};
439       memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
440       memoryAllocateInfo.pNext = nullptr;
441       memoryAllocateInfo.allocationSize = bufferSize;
442       memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
443 
444       // Allocate device memory.
445       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
446                                               nullptr,
447                                               &memoryBuffer.hostMemory),
448                              "vkAllocateMemory");
449       memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
450       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
451                                               nullptr,
452                                               &memoryBuffer.deviceMemory),
453                              "vkAllocateMemory");
454       void *payload;
455       RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
456                                          bufferSize, 0,
457                                          reinterpret_cast<void **>(&payload)),
458                              "vkMapMemory");
459 
460       // Copy host memory into the mapped area.
461       std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
462       vkUnmapMemory(device, memoryBuffer.hostMemory);
463 
464       VkBufferCreateInfo bufferCreateInfo = {};
465       bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
466       bufferCreateInfo.pNext = nullptr;
467       bufferCreateInfo.flags = 0;
468       bufferCreateInfo.size = bufferSize;
469       bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
470                                VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
471       bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
472       bufferCreateInfo.queueFamilyIndexCount = 1;
473       bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
474       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
475                                             &memoryBuffer.hostBuffer),
476                              "vkCreateBuffer");
477       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
478                                             &memoryBuffer.deviceBuffer),
479                              "vkCreateBuffer");
480 
481       // Bind buffer and device memory.
482       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
483                                                 memoryBuffer.hostMemory, 0),
484                              "vkBindBufferMemory");
485       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
486                                                 memoryBuffer.deviceBuffer,
487                                                 memoryBuffer.deviceMemory, 0),
488                              "vkBindBufferMemory");
489 
490       // Update buffer info.
491       memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
492       memoryBuffer.bufferInfo.offset = 0;
493       memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
494       deviceMemoryBuffers.push_back(memoryBuffer);
495     }
496 
497     // Associate device memory buffers with a descriptor set.
498     deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
499   }
500   return success();
501 }
502 
503 LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
504   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
505       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
506       nullptr,
507       commandPool,
508       VK_COMMAND_BUFFER_LEVEL_PRIMARY,
509       1,
510   };
511   VkCommandBuffer commandBuffer;
512   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
513                                                   &commandBufferAllocateInfo,
514                                                   &commandBuffer),
515                          "vkAllocateCommandBuffers");
516 
517   VkCommandBufferBeginInfo commandBufferBeginInfo = {
518       VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
519       nullptr,
520       0,
521       nullptr,
522   };
523   RETURN_ON_VULKAN_ERROR(
524       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
525       "vkBeginCommandBuffer");
526 
527   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
528     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
529     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
530     for (const auto &memBuffer : deviceMemoryBuffers) {
531       VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
532       if (deviceToHost)
533         vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
534                         memBuffer.hostBuffer, 1, &copy);
535       else
536         vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
537                         memBuffer.deviceBuffer, 1, &copy);
538     }
539   }
540 
541   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
542                          "vkEndCommandBuffer");
543   VkSubmitInfo submitInfo = {
544       VK_STRUCTURE_TYPE_SUBMIT_INFO,
545       nullptr,
546       0,
547       nullptr,
548       nullptr,
549       1,
550       &commandBuffer,
551       0,
552       nullptr,
553   };
554   submitInfo.pCommandBuffers = &commandBuffer;
555   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
556                          "vkQueueSubmit");
557   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
558 
559   vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
560   return success();
561 }
562 
563 LogicalResult VulkanRuntime::createShaderModule() {
564   VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
565   shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
566   shaderModuleCreateInfo.pNext = nullptr;
567   shaderModuleCreateInfo.flags = 0;
568   // Set size in bytes.
569   shaderModuleCreateInfo.codeSize = binarySize;
570   // Set pointer to the binary shader.
571   shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
572   RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device, &shaderModuleCreateInfo,
573                                               nullptr, &shaderModule),
574                          "vkCreateShaderModule");
575   return success();
576 }
577 
578 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
579   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
580     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
581     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
582     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
583 
584     // Create a layout binding for each descriptor.
585     for (const auto &memBuffer : deviceMemoryBuffers) {
586       VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
587       descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
588       descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
589       descriptorSetLayoutBinding.descriptorCount = 1;
590       descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
591       descriptorSetLayoutBinding.pImmutableSamplers = nullptr;
592       descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
593     }
594     descriptorSetLayoutBindingMap[descriptorSetIndex] =
595         descriptorSetLayoutBindings;
596   }
597 }
598 
599 LogicalResult VulkanRuntime::createDescriptorSetLayout() {
600   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
601     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
602     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
603     // Each descriptor in a descriptor set must be the same type.
604     VkDescriptorType descriptorType =
605         deviceMemoryBuffers.front().descriptorType;
606     const uint32_t descriptorSize = deviceMemoryBuffers.size();
607     const auto descriptorSetLayoutBindingIt =
608         descriptorSetLayoutBindingMap.find(descriptorSetIndex);
609 
610     if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
611       std::cerr << "cannot find layout bindings for the set with number: "
612                 << descriptorSetIndex;
613       return failure();
614     }
615 
616     const auto &descriptorSetLayoutBindings =
617         descriptorSetLayoutBindingIt->second;
618     // Create descriptor set layout.
619     VkDescriptorSetLayout descriptorSetLayout = {};
620     VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
621 
622     descriptorSetLayoutCreateInfo.sType =
623         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
624     descriptorSetLayoutCreateInfo.pNext = nullptr;
625     descriptorSetLayoutCreateInfo.flags = 0;
626     // Amount of descriptor bindings in a layout set.
627     descriptorSetLayoutCreateInfo.bindingCount =
628         descriptorSetLayoutBindings.size();
629     descriptorSetLayoutCreateInfo.pBindings =
630         descriptorSetLayoutBindings.data();
631     RETURN_ON_VULKAN_ERROR(
632         vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo,
633                                     nullptr, &descriptorSetLayout),
634         "vkCreateDescriptorSetLayout");
635 
636     descriptorSetLayouts.push_back(descriptorSetLayout);
637     descriptorSetInfoPool.push_back(
638         {descriptorSetIndex, descriptorSize, descriptorType});
639   }
640   return success();
641 }
642 
643 LogicalResult VulkanRuntime::createPipelineLayout() {
644   // Associate descriptor sets with a pipeline layout.
645   VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
646   pipelineLayoutCreateInfo.sType =
647       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
648   pipelineLayoutCreateInfo.pNext = nullptr;
649   pipelineLayoutCreateInfo.flags = 0;
650   pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
651   pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
652   pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
653   pipelineLayoutCreateInfo.pPushConstantRanges = nullptr;
654   RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
655                                                 &pipelineLayoutCreateInfo,
656                                                 nullptr, &pipelineLayout),
657                          "vkCreatePipelineLayout");
658   return success();
659 }
660 
661 LogicalResult VulkanRuntime::createComputePipeline() {
662   VkPipelineShaderStageCreateInfo stageInfo = {};
663   stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
664   stageInfo.pNext = nullptr;
665   stageInfo.flags = 0;
666   stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
667   stageInfo.module = shaderModule;
668   // Set entry point.
669   stageInfo.pName = entryPoint;
670   stageInfo.pSpecializationInfo = nullptr;
671 
672   VkComputePipelineCreateInfo computePipelineCreateInfo = {};
673   computePipelineCreateInfo.sType =
674       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
675   computePipelineCreateInfo.pNext = nullptr;
676   computePipelineCreateInfo.flags = 0;
677   computePipelineCreateInfo.stage = stageInfo;
678   computePipelineCreateInfo.layout = pipelineLayout;
679   computePipelineCreateInfo.basePipelineHandle = nullptr;
680   computePipelineCreateInfo.basePipelineIndex = 0;
681   RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, nullptr, 1,
682                                                   &computePipelineCreateInfo,
683                                                   nullptr, &pipeline),
684                          "vkCreateComputePipelines");
685   return success();
686 }
687 
688 LogicalResult VulkanRuntime::createDescriptorPool() {
689   std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
690   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
691     // For each descriptor set populate descriptor pool size.
692     VkDescriptorPoolSize descriptorPoolSize = {};
693     descriptorPoolSize.type = descriptorSetInfo.descriptorType;
694     descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
695     descriptorPoolSizes.push_back(descriptorPoolSize);
696   }
697 
698   VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
699   descriptorPoolCreateInfo.sType =
700       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
701   descriptorPoolCreateInfo.pNext = nullptr;
702   descriptorPoolCreateInfo.flags = 0;
703   descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
704   descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
705   descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
706   RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
707                                                 &descriptorPoolCreateInfo,
708                                                 nullptr, &descriptorPool),
709                          "vkCreateDescriptorPool");
710   return success();
711 }
712 
713 LogicalResult VulkanRuntime::allocateDescriptorSets() {
714   VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
715   // Size of descriptor sets and descriptor layout sets is the same.
716   descriptorSets.resize(descriptorSetLayouts.size());
717   descriptorSetAllocateInfo.sType =
718       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
719   descriptorSetAllocateInfo.pNext = nullptr;
720   descriptorSetAllocateInfo.descriptorPool = descriptorPool;
721   descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
722   descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
723   RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
724                                                   &descriptorSetAllocateInfo,
725                                                   descriptorSets.data()),
726                          "vkAllocateDescriptorSets");
727   return success();
728 }
729 
730 LogicalResult VulkanRuntime::setWriteDescriptors() {
731   if (descriptorSets.size() != descriptorSetInfoPool.size()) {
732     std::cerr << "Each descriptor set must have descriptor set information";
733     return failure();
734   }
735   // For each descriptor set.
736   auto descriptorSetIt = descriptorSets.begin();
737   // Each descriptor set is associated with descriptor set info.
738   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
739     // For each device memory buffer in the descriptor set.
740     const auto &deviceMemoryBuffers =
741         deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
742     for (const auto &memoryBuffer : deviceMemoryBuffers) {
743       // Structure describing descriptor sets to write to.
744       VkWriteDescriptorSet wSet = {};
745       wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
746       wSet.pNext = nullptr;
747       // Descriptor set.
748       wSet.dstSet = *descriptorSetIt;
749       wSet.dstBinding = memoryBuffer.bindingIndex;
750       wSet.dstArrayElement = 0;
751       wSet.descriptorCount = 1;
752       wSet.descriptorType = memoryBuffer.descriptorType;
753       wSet.pImageInfo = nullptr;
754       wSet.pBufferInfo = &memoryBuffer.bufferInfo;
755       wSet.pTexelBufferView = nullptr;
756       vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
757     }
758     // Increment descriptor set iterator.
759     ++descriptorSetIt;
760   }
761   return success();
762 }
763 
764 LogicalResult VulkanRuntime::createCommandPool() {
765   VkCommandPoolCreateInfo commandPoolCreateInfo = {};
766   commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
767   commandPoolCreateInfo.pNext = nullptr;
768   commandPoolCreateInfo.flags = 0;
769   commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
770   RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
771                                              /*pAllocator=*/nullptr,
772                                              &commandPool),
773                          "vkCreateCommandPool");
774   return success();
775 }
776 
777 LogicalResult VulkanRuntime::createQueryPool() {
778   // Return directly if timestamp query is not supported.
779   if (queueFamilyProperties.timestampValidBits == 0)
780     return success();
781 
782   // Get timestamp period for this physical device.
783   VkPhysicalDeviceProperties deviceProperties = {};
784   vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
785   timestampPeriod = deviceProperties.limits.timestampPeriod;
786 
787   // Create query pool.
788   VkQueryPoolCreateInfo queryPoolCreateInfo = {};
789   queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
790   queryPoolCreateInfo.pNext = nullptr;
791   queryPoolCreateInfo.flags = 0;
792   queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
793   queryPoolCreateInfo.queryCount = 2;
794   queryPoolCreateInfo.pipelineStatistics = 0;
795   RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
796                                            /*pAllocator=*/nullptr, &queryPool),
797                          "vkCreateQueryPool");
798 
799   return success();
800 }
801 
802 LogicalResult VulkanRuntime::createComputeCommandBuffer() {
803   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
804   commandBufferAllocateInfo.sType =
805       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
806   commandBufferAllocateInfo.pNext = nullptr;
807   commandBufferAllocateInfo.commandPool = commandPool;
808   commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
809   commandBufferAllocateInfo.commandBufferCount = 1;
810 
811   VkCommandBuffer commandBuffer;
812   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
813                                                   &commandBufferAllocateInfo,
814                                                   &commandBuffer),
815                          "vkAllocateCommandBuffers");
816 
817   VkCommandBufferBeginInfo commandBufferBeginInfo = {};
818   commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
819   commandBufferBeginInfo.pNext = nullptr;
820   commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
821   commandBufferBeginInfo.pInheritanceInfo = nullptr;
822 
823   // Commands begin.
824   RETURN_ON_VULKAN_ERROR(
825       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
826       "vkBeginCommandBuffer");
827 
828   if (queryPool != VK_NULL_HANDLE)
829     vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
830 
831   vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
832   vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
833                           pipelineLayout, 0, descriptorSets.size(),
834                           descriptorSets.data(), 0, nullptr);
835   // Get a timestamp before invoking the compute shader.
836   if (queryPool != VK_NULL_HANDLE)
837     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
838                         queryPool, 0);
839   vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
840                 numWorkGroups.z);
841   // Get another timestamp after invoking the compute shader.
842   if (queryPool != VK_NULL_HANDLE)
843     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
844                         queryPool, 1);
845 
846   // Commands end.
847   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
848                          "vkEndCommandBuffer");
849 
850   commandBuffers.push_back(commandBuffer);
851   return success();
852 }
853 
854 LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
855   VkSubmitInfo submitInfo = {};
856   submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
857   submitInfo.pNext = nullptr;
858   submitInfo.waitSemaphoreCount = 0;
859   submitInfo.pWaitSemaphores = nullptr;
860   submitInfo.pWaitDstStageMask = nullptr;
861   submitInfo.commandBufferCount = commandBuffers.size();
862   submitInfo.pCommandBuffers = commandBuffers.data();
863   submitInfo.signalSemaphoreCount = 0;
864   submitInfo.pSignalSemaphores = nullptr;
865   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, nullptr),
866                          "vkQueueSubmit");
867   return success();
868 }
869 
870 LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
871   // First copy back the data to the staging buffer.
872   (void)copyResource(/*deviceToHost=*/true);
873 
874   // For each descriptor set.
875   for (auto &resourceDataMapPair : resourceData) {
876     auto &resourceDataMap = resourceDataMapPair.second;
877     auto &deviceMemoryBuffers =
878         deviceMemoryBufferMap[resourceDataMapPair.first];
879     // For each device memory buffer in the set.
880     for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
881       if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
882         void *payload;
883         auto &hostMemoryBuffer =
884             resourceDataMap[deviceMemoryBuffer.bindingIndex];
885         RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
886                                            deviceMemoryBuffer.hostMemory, 0,
887                                            hostMemoryBuffer.size, 0,
888                                            reinterpret_cast<void **>(&payload)),
889                                "vkMapMemory");
890         std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
891         vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
892       }
893     }
894   }
895   return success();
896 }
897