xref: /llvm-project/mlir/lib/ExecutionEngine/VulkanRuntime.cpp (revision e7e3c45bc70904e24e2b3221ac8521e67eb84668)
1*e7e3c45bSAndrea Faulds //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2*e7e3c45bSAndrea Faulds //
3*e7e3c45bSAndrea Faulds // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*e7e3c45bSAndrea Faulds // See https://llvm.org/LICENSE.txt for license information.
5*e7e3c45bSAndrea Faulds // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*e7e3c45bSAndrea Faulds //
7*e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
8*e7e3c45bSAndrea Faulds //
9*e7e3c45bSAndrea Faulds // This file provides a library for running a module on a Vulkan device.
10*e7e3c45bSAndrea Faulds // Implements a Vulkan runtime.
11*e7e3c45bSAndrea Faulds //
12*e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
13*e7e3c45bSAndrea Faulds 
14*e7e3c45bSAndrea Faulds #include "VulkanRuntime.h"
15*e7e3c45bSAndrea Faulds 
16*e7e3c45bSAndrea Faulds #include <chrono>
17*e7e3c45bSAndrea Faulds #include <cstring>
18*e7e3c45bSAndrea Faulds // TODO: It's generally bad to access stdout/stderr in a library.
19*e7e3c45bSAndrea Faulds // Figure out a better way for error reporting.
20*e7e3c45bSAndrea Faulds #include <iomanip>
21*e7e3c45bSAndrea Faulds #include <iostream>
22*e7e3c45bSAndrea Faulds 
23*e7e3c45bSAndrea Faulds inline void emitVulkanError(const char *api, VkResult error) {
24*e7e3c45bSAndrea Faulds   std::cerr << " failed with error code " << error << " when executing " << api;
25*e7e3c45bSAndrea Faulds }
26*e7e3c45bSAndrea Faulds 
27*e7e3c45bSAndrea Faulds #define RETURN_ON_VULKAN_ERROR(result, api)                                    \
28*e7e3c45bSAndrea Faulds   if ((result) != VK_SUCCESS) {                                                \
29*e7e3c45bSAndrea Faulds     emitVulkanError(api, (result));                                            \
30*e7e3c45bSAndrea Faulds     return failure();                                                          \
31*e7e3c45bSAndrea Faulds   }
32*e7e3c45bSAndrea Faulds 
33*e7e3c45bSAndrea Faulds using namespace mlir;
34*e7e3c45bSAndrea Faulds 
35*e7e3c45bSAndrea Faulds void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
36*e7e3c45bSAndrea Faulds   numWorkGroups = numberWorkGroups;
37*e7e3c45bSAndrea Faulds }
38*e7e3c45bSAndrea Faulds 
39*e7e3c45bSAndrea Faulds void VulkanRuntime::setResourceStorageClassBindingMap(
40*e7e3c45bSAndrea Faulds     const ResourceStorageClassBindingMap &stClassData) {
41*e7e3c45bSAndrea Faulds   resourceStorageClassData = stClassData;
42*e7e3c45bSAndrea Faulds }
43*e7e3c45bSAndrea Faulds 
44*e7e3c45bSAndrea Faulds void VulkanRuntime::setResourceData(
45*e7e3c45bSAndrea Faulds     const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
46*e7e3c45bSAndrea Faulds     const VulkanHostMemoryBuffer &hostMemBuffer) {
47*e7e3c45bSAndrea Faulds   resourceData[desIndex][bindIndex] = hostMemBuffer;
48*e7e3c45bSAndrea Faulds   resourceStorageClassData[desIndex][bindIndex] =
49*e7e3c45bSAndrea Faulds       SPIRVStorageClass::StorageBuffer;
50*e7e3c45bSAndrea Faulds }
51*e7e3c45bSAndrea Faulds 
52*e7e3c45bSAndrea Faulds void VulkanRuntime::setEntryPoint(const char *entryPointName) {
53*e7e3c45bSAndrea Faulds   entryPoint = entryPointName;
54*e7e3c45bSAndrea Faulds }
55*e7e3c45bSAndrea Faulds 
56*e7e3c45bSAndrea Faulds void VulkanRuntime::setResourceData(const ResourceData &resData) {
57*e7e3c45bSAndrea Faulds   resourceData = resData;
58*e7e3c45bSAndrea Faulds }
59*e7e3c45bSAndrea Faulds 
60*e7e3c45bSAndrea Faulds void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
61*e7e3c45bSAndrea Faulds   binary = shader;
62*e7e3c45bSAndrea Faulds   binarySize = size;
63*e7e3c45bSAndrea Faulds }
64*e7e3c45bSAndrea Faulds 
65*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
66*e7e3c45bSAndrea Faulds     SPIRVStorageClass storageClass, VkDescriptorType &descriptorType) {
67*e7e3c45bSAndrea Faulds   switch (storageClass) {
68*e7e3c45bSAndrea Faulds   case SPIRVStorageClass::StorageBuffer:
69*e7e3c45bSAndrea Faulds     descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
70*e7e3c45bSAndrea Faulds     break;
71*e7e3c45bSAndrea Faulds   case SPIRVStorageClass::Uniform:
72*e7e3c45bSAndrea Faulds     descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73*e7e3c45bSAndrea Faulds     break;
74*e7e3c45bSAndrea Faulds   }
75*e7e3c45bSAndrea Faulds   return success();
76*e7e3c45bSAndrea Faulds }
77*e7e3c45bSAndrea Faulds 
78*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
79*e7e3c45bSAndrea Faulds     SPIRVStorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
80*e7e3c45bSAndrea Faulds   switch (storageClass) {
81*e7e3c45bSAndrea Faulds   case SPIRVStorageClass::StorageBuffer:
82*e7e3c45bSAndrea Faulds     bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
83*e7e3c45bSAndrea Faulds     break;
84*e7e3c45bSAndrea Faulds   case SPIRVStorageClass::Uniform:
85*e7e3c45bSAndrea Faulds     bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
86*e7e3c45bSAndrea Faulds     break;
87*e7e3c45bSAndrea Faulds   }
88*e7e3c45bSAndrea Faulds   return success();
89*e7e3c45bSAndrea Faulds }
90*e7e3c45bSAndrea Faulds 
91*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::countDeviceMemorySize() {
92*e7e3c45bSAndrea Faulds   for (const auto &resourceDataMapPair : resourceData) {
93*e7e3c45bSAndrea Faulds     const auto &resourceDataMap = resourceDataMapPair.second;
94*e7e3c45bSAndrea Faulds     for (const auto &resourceDataBindingPair : resourceDataMap) {
95*e7e3c45bSAndrea Faulds       if (resourceDataBindingPair.second.size) {
96*e7e3c45bSAndrea Faulds         memorySize += resourceDataBindingPair.second.size;
97*e7e3c45bSAndrea Faulds       } else {
98*e7e3c45bSAndrea Faulds         std::cerr << "expected buffer size greater than zero for resource data";
99*e7e3c45bSAndrea Faulds         return failure();
100*e7e3c45bSAndrea Faulds       }
101*e7e3c45bSAndrea Faulds     }
102*e7e3c45bSAndrea Faulds   }
103*e7e3c45bSAndrea Faulds   return success();
104*e7e3c45bSAndrea Faulds }
105*e7e3c45bSAndrea Faulds 
106*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::initRuntime() {
107*e7e3c45bSAndrea Faulds   if (resourceData.empty()) {
108*e7e3c45bSAndrea Faulds     std::cerr << "Vulkan runtime needs at least one resource";
109*e7e3c45bSAndrea Faulds     return failure();
110*e7e3c45bSAndrea Faulds   }
111*e7e3c45bSAndrea Faulds   if (!binarySize || !binary) {
112*e7e3c45bSAndrea Faulds     std::cerr << "binary shader size must be greater than zero";
113*e7e3c45bSAndrea Faulds     return failure();
114*e7e3c45bSAndrea Faulds   }
115*e7e3c45bSAndrea Faulds   if (failed(countDeviceMemorySize())) {
116*e7e3c45bSAndrea Faulds     return failure();
117*e7e3c45bSAndrea Faulds   }
118*e7e3c45bSAndrea Faulds   return success();
119*e7e3c45bSAndrea Faulds }
120*e7e3c45bSAndrea Faulds 
121*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::destroy() {
122*e7e3c45bSAndrea Faulds   // According to Vulkan spec:
123*e7e3c45bSAndrea Faulds   // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124*e7e3c45bSAndrea Faulds   // used to gate the destruction of the device. Prior to destroying a device,
125*e7e3c45bSAndrea Faulds   // an application is responsible for destroying/freeing any Vulkan objects
126*e7e3c45bSAndrea Faulds   // that were created using that device as the first parameter of the
127*e7e3c45bSAndrea Faulds   // corresponding vkCreate* or vkAllocate* command."
128*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
129*e7e3c45bSAndrea Faulds 
130*e7e3c45bSAndrea Faulds   // Free and destroy.
131*e7e3c45bSAndrea Faulds   vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
132*e7e3c45bSAndrea Faulds                        commandBuffers.data());
133*e7e3c45bSAndrea Faulds   vkDestroyQueryPool(device, queryPool, nullptr);
134*e7e3c45bSAndrea Faulds   vkDestroyCommandPool(device, commandPool, nullptr);
135*e7e3c45bSAndrea Faulds   vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
136*e7e3c45bSAndrea Faulds                        descriptorSets.data());
137*e7e3c45bSAndrea Faulds   vkDestroyDescriptorPool(device, descriptorPool, nullptr);
138*e7e3c45bSAndrea Faulds   vkDestroyPipeline(device, pipeline, nullptr);
139*e7e3c45bSAndrea Faulds   vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
140*e7e3c45bSAndrea Faulds   for (auto &descriptorSetLayout : descriptorSetLayouts) {
141*e7e3c45bSAndrea Faulds     vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
142*e7e3c45bSAndrea Faulds   }
143*e7e3c45bSAndrea Faulds   vkDestroyShaderModule(device, shaderModule, nullptr);
144*e7e3c45bSAndrea Faulds 
145*e7e3c45bSAndrea Faulds   // For each descriptor set.
146*e7e3c45bSAndrea Faulds   for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
147*e7e3c45bSAndrea Faulds     auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
148*e7e3c45bSAndrea Faulds     // For each descriptor binding.
149*e7e3c45bSAndrea Faulds     for (auto &memoryBuffer : deviceMemoryBuffers) {
150*e7e3c45bSAndrea Faulds       vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
151*e7e3c45bSAndrea Faulds       vkFreeMemory(device, memoryBuffer.hostMemory, nullptr);
152*e7e3c45bSAndrea Faulds       vkDestroyBuffer(device, memoryBuffer.hostBuffer, nullptr);
153*e7e3c45bSAndrea Faulds       vkDestroyBuffer(device, memoryBuffer.deviceBuffer, nullptr);
154*e7e3c45bSAndrea Faulds     }
155*e7e3c45bSAndrea Faulds   }
156*e7e3c45bSAndrea Faulds 
157*e7e3c45bSAndrea Faulds   vkDestroyDevice(device, nullptr);
158*e7e3c45bSAndrea Faulds   vkDestroyInstance(instance, nullptr);
159*e7e3c45bSAndrea Faulds   return success();
160*e7e3c45bSAndrea Faulds }
161*e7e3c45bSAndrea Faulds 
162*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::run() {
163*e7e3c45bSAndrea Faulds   // Create logical device, shader module and memory buffers.
164*e7e3c45bSAndrea Faulds   if (failed(createInstance()) || failed(createDevice()) ||
165*e7e3c45bSAndrea Faulds       failed(createMemoryBuffers()) || failed(createShaderModule())) {
166*e7e3c45bSAndrea Faulds     return failure();
167*e7e3c45bSAndrea Faulds   }
168*e7e3c45bSAndrea Faulds 
169*e7e3c45bSAndrea Faulds   // Descriptor bindings divided into sets. Each descriptor binding
170*e7e3c45bSAndrea Faulds   // must have a layout binding attached into a descriptor set layout.
171*e7e3c45bSAndrea Faulds   // Each layout set must be binded into a pipeline layout.
172*e7e3c45bSAndrea Faulds   initDescriptorSetLayoutBindingMap();
173*e7e3c45bSAndrea Faulds   if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174*e7e3c45bSAndrea Faulds       // Each descriptor set must be allocated from a descriptor pool.
175*e7e3c45bSAndrea Faulds       failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176*e7e3c45bSAndrea Faulds       failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177*e7e3c45bSAndrea Faulds       // Create command buffer.
178*e7e3c45bSAndrea Faulds       failed(createCommandPool()) || failed(createQueryPool()) ||
179*e7e3c45bSAndrea Faulds       failed(createComputeCommandBuffer())) {
180*e7e3c45bSAndrea Faulds     return failure();
181*e7e3c45bSAndrea Faulds   }
182*e7e3c45bSAndrea Faulds 
183*e7e3c45bSAndrea Faulds   // Get working queue.
184*e7e3c45bSAndrea Faulds   vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
185*e7e3c45bSAndrea Faulds 
186*e7e3c45bSAndrea Faulds   if (failed(copyResource(/*deviceToHost=*/false)))
187*e7e3c45bSAndrea Faulds     return failure();
188*e7e3c45bSAndrea Faulds 
189*e7e3c45bSAndrea Faulds   auto submitStart = std::chrono::high_resolution_clock::now();
190*e7e3c45bSAndrea Faulds   // Submit command buffer into the queue.
191*e7e3c45bSAndrea Faulds   if (failed(submitCommandBuffersToQueue()))
192*e7e3c45bSAndrea Faulds     return failure();
193*e7e3c45bSAndrea Faulds   auto submitEnd = std::chrono::high_resolution_clock::now();
194*e7e3c45bSAndrea Faulds 
195*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
196*e7e3c45bSAndrea Faulds   auto execEnd = std::chrono::high_resolution_clock::now();
197*e7e3c45bSAndrea Faulds 
198*e7e3c45bSAndrea Faulds   auto submitDuration = std::chrono::duration_cast<std::chrono::microseconds>(
199*e7e3c45bSAndrea Faulds       submitEnd - submitStart);
200*e7e3c45bSAndrea Faulds   auto execDuration = std::chrono::duration_cast<std::chrono::microseconds>(
201*e7e3c45bSAndrea Faulds       execEnd - submitEnd);
202*e7e3c45bSAndrea Faulds 
203*e7e3c45bSAndrea Faulds   if (queryPool != VK_NULL_HANDLE) {
204*e7e3c45bSAndrea Faulds     uint64_t timestamps[2];
205*e7e3c45bSAndrea Faulds     RETURN_ON_VULKAN_ERROR(
206*e7e3c45bSAndrea Faulds         vkGetQueryPoolResults(
207*e7e3c45bSAndrea Faulds             device, queryPool, /*firstQuery=*/0, /*queryCount=*/2,
208*e7e3c45bSAndrea Faulds             /*dataSize=*/sizeof(timestamps),
209*e7e3c45bSAndrea Faulds             /*pData=*/reinterpret_cast<void *>(timestamps),
210*e7e3c45bSAndrea Faulds             /*stride=*/sizeof(uint64_t),
211*e7e3c45bSAndrea Faulds             VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT),
212*e7e3c45bSAndrea Faulds         "vkGetQueryPoolResults");
213*e7e3c45bSAndrea Faulds     float microsec = (timestamps[1] - timestamps[0]) * timestampPeriod / 1000;
214*e7e3c45bSAndrea Faulds     std::cout << "Compute shader execution time: " << std::setprecision(3)
215*e7e3c45bSAndrea Faulds               << microsec << "us\n";
216*e7e3c45bSAndrea Faulds   }
217*e7e3c45bSAndrea Faulds 
218*e7e3c45bSAndrea Faulds   std::cout << "Command buffer submit time: " << submitDuration.count()
219*e7e3c45bSAndrea Faulds             << "us\nWait idle time: " << execDuration.count() << "us\n";
220*e7e3c45bSAndrea Faulds 
221*e7e3c45bSAndrea Faulds   return success();
222*e7e3c45bSAndrea Faulds }
223*e7e3c45bSAndrea Faulds 
224*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createInstance() {
225*e7e3c45bSAndrea Faulds   VkApplicationInfo applicationInfo = {};
226*e7e3c45bSAndrea Faulds   applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
227*e7e3c45bSAndrea Faulds   applicationInfo.pNext = nullptr;
228*e7e3c45bSAndrea Faulds   applicationInfo.pApplicationName = "MLIR Vulkan runtime";
229*e7e3c45bSAndrea Faulds   applicationInfo.applicationVersion = 0;
230*e7e3c45bSAndrea Faulds   applicationInfo.pEngineName = "mlir";
231*e7e3c45bSAndrea Faulds   applicationInfo.engineVersion = 0;
232*e7e3c45bSAndrea Faulds   applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
233*e7e3c45bSAndrea Faulds 
234*e7e3c45bSAndrea Faulds   VkInstanceCreateInfo instanceCreateInfo = {};
235*e7e3c45bSAndrea Faulds   instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
236*e7e3c45bSAndrea Faulds   instanceCreateInfo.pNext = nullptr;
237*e7e3c45bSAndrea Faulds   instanceCreateInfo.pApplicationInfo = &applicationInfo;
238*e7e3c45bSAndrea Faulds   instanceCreateInfo.enabledLayerCount = 0;
239*e7e3c45bSAndrea Faulds   instanceCreateInfo.ppEnabledLayerNames = nullptr;
240*e7e3c45bSAndrea Faulds 
241*e7e3c45bSAndrea Faulds   std::vector<const char *> extNames;
242*e7e3c45bSAndrea Faulds #if defined(__APPLE__)
243*e7e3c45bSAndrea Faulds   // enumerate MoltenVK for Vulkan 1.0
244*e7e3c45bSAndrea Faulds   instanceCreateInfo.flags = VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
245*e7e3c45bSAndrea Faulds   // add KHR portability instance extensions
246*e7e3c45bSAndrea Faulds   extNames.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
247*e7e3c45bSAndrea Faulds   extNames.push_back(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME);
248*e7e3c45bSAndrea Faulds #else
249*e7e3c45bSAndrea Faulds   instanceCreateInfo.flags = 0;
250*e7e3c45bSAndrea Faulds #endif // __APPLE__
251*e7e3c45bSAndrea Faulds   instanceCreateInfo.enabledExtensionCount =
252*e7e3c45bSAndrea Faulds       static_cast<uint32_t>(extNames.size());
253*e7e3c45bSAndrea Faulds   instanceCreateInfo.ppEnabledExtensionNames = extNames.data();
254*e7e3c45bSAndrea Faulds 
255*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(
256*e7e3c45bSAndrea Faulds       vkCreateInstance(&instanceCreateInfo, nullptr, &instance),
257*e7e3c45bSAndrea Faulds       "vkCreateInstance");
258*e7e3c45bSAndrea Faulds   return success();
259*e7e3c45bSAndrea Faulds }
260*e7e3c45bSAndrea Faulds 
261*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createDevice() {
262*e7e3c45bSAndrea Faulds   uint32_t physicalDeviceCount = 0;
263*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(
264*e7e3c45bSAndrea Faulds       vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, nullptr),
265*e7e3c45bSAndrea Faulds       "vkEnumeratePhysicalDevices");
266*e7e3c45bSAndrea Faulds 
267*e7e3c45bSAndrea Faulds   std::vector<VkPhysicalDevice> physicalDevices(physicalDeviceCount);
268*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
269*e7e3c45bSAndrea Faulds                                                     &physicalDeviceCount,
270*e7e3c45bSAndrea Faulds                                                     physicalDevices.data()),
271*e7e3c45bSAndrea Faulds                          "vkEnumeratePhysicalDevices");
272*e7e3c45bSAndrea Faulds 
273*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
274*e7e3c45bSAndrea Faulds                          "physicalDeviceCount");
275*e7e3c45bSAndrea Faulds 
276*e7e3c45bSAndrea Faulds   // TODO: find the best device.
277*e7e3c45bSAndrea Faulds   physicalDevice = physicalDevices.front();
278*e7e3c45bSAndrea Faulds   if (failed(getBestComputeQueue()))
279*e7e3c45bSAndrea Faulds     return failure();
280*e7e3c45bSAndrea Faulds 
281*e7e3c45bSAndrea Faulds   const float queuePriority = 1.0f;
282*e7e3c45bSAndrea Faulds   VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
283*e7e3c45bSAndrea Faulds   deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
284*e7e3c45bSAndrea Faulds   deviceQueueCreateInfo.pNext = nullptr;
285*e7e3c45bSAndrea Faulds   deviceQueueCreateInfo.flags = 0;
286*e7e3c45bSAndrea Faulds   deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
287*e7e3c45bSAndrea Faulds   deviceQueueCreateInfo.queueCount = 1;
288*e7e3c45bSAndrea Faulds   deviceQueueCreateInfo.pQueuePriorities = &queuePriority;
289*e7e3c45bSAndrea Faulds 
290*e7e3c45bSAndrea Faulds   // Structure specifying parameters of a newly created device.
291*e7e3c45bSAndrea Faulds   VkDeviceCreateInfo deviceCreateInfo = {};
292*e7e3c45bSAndrea Faulds   deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
293*e7e3c45bSAndrea Faulds   deviceCreateInfo.pNext = nullptr;
294*e7e3c45bSAndrea Faulds   deviceCreateInfo.flags = 0;
295*e7e3c45bSAndrea Faulds   deviceCreateInfo.queueCreateInfoCount = 1;
296*e7e3c45bSAndrea Faulds   deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
297*e7e3c45bSAndrea Faulds   deviceCreateInfo.enabledLayerCount = 0;
298*e7e3c45bSAndrea Faulds   deviceCreateInfo.ppEnabledLayerNames = nullptr;
299*e7e3c45bSAndrea Faulds   deviceCreateInfo.enabledExtensionCount = 0;
300*e7e3c45bSAndrea Faulds   deviceCreateInfo.ppEnabledExtensionNames = nullptr;
301*e7e3c45bSAndrea Faulds   deviceCreateInfo.pEnabledFeatures = nullptr;
302*e7e3c45bSAndrea Faulds 
303*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(
304*e7e3c45bSAndrea Faulds       vkCreateDevice(physicalDevice, &deviceCreateInfo, nullptr, &device),
305*e7e3c45bSAndrea Faulds       "vkCreateDevice");
306*e7e3c45bSAndrea Faulds 
307*e7e3c45bSAndrea Faulds   VkPhysicalDeviceMemoryProperties properties = {};
308*e7e3c45bSAndrea Faulds   vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
309*e7e3c45bSAndrea Faulds 
310*e7e3c45bSAndrea Faulds   // Try to find memory type with following properties:
311*e7e3c45bSAndrea Faulds   // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
312*e7e3c45bSAndrea Faulds   // with this type can be mapped for host access using vkMapMemory;
313*e7e3c45bSAndrea Faulds   // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
314*e7e3c45bSAndrea Faulds   // management commands vkFlushMappedMemoryRanges and
315*e7e3c45bSAndrea Faulds   // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
316*e7e3c45bSAndrea Faulds   // device or make device writes visible to the host, respectively.
317*e7e3c45bSAndrea Faulds   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
318*e7e3c45bSAndrea Faulds     if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
319*e7e3c45bSAndrea Faulds          properties.memoryTypes[i].propertyFlags) &&
320*e7e3c45bSAndrea Faulds         (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
321*e7e3c45bSAndrea Faulds          properties.memoryTypes[i].propertyFlags) &&
322*e7e3c45bSAndrea Faulds         (memorySize <=
323*e7e3c45bSAndrea Faulds          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
324*e7e3c45bSAndrea Faulds       hostMemoryTypeIndex = i;
325*e7e3c45bSAndrea Faulds       break;
326*e7e3c45bSAndrea Faulds     }
327*e7e3c45bSAndrea Faulds   }
328*e7e3c45bSAndrea Faulds 
329*e7e3c45bSAndrea Faulds   // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
330*e7e3c45bSAndrea Faulds   // used on the device. This will allow better performance access for GPU with
331*e7e3c45bSAndrea Faulds   // on device memory.
332*e7e3c45bSAndrea Faulds   for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
333*e7e3c45bSAndrea Faulds     if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT &
334*e7e3c45bSAndrea Faulds          properties.memoryTypes[i].propertyFlags) &&
335*e7e3c45bSAndrea Faulds         (memorySize <=
336*e7e3c45bSAndrea Faulds          properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
337*e7e3c45bSAndrea Faulds       deviceMemoryTypeIndex = i;
338*e7e3c45bSAndrea Faulds       break;
339*e7e3c45bSAndrea Faulds     }
340*e7e3c45bSAndrea Faulds   }
341*e7e3c45bSAndrea Faulds 
342*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex == VK_MAX_MEMORY_TYPES ||
343*e7e3c45bSAndrea Faulds                           deviceMemoryTypeIndex == VK_MAX_MEMORY_TYPES)
344*e7e3c45bSAndrea Faulds                              ? VK_INCOMPLETE
345*e7e3c45bSAndrea Faulds                              : VK_SUCCESS,
346*e7e3c45bSAndrea Faulds                          "invalid memoryTypeIndex");
347*e7e3c45bSAndrea Faulds   return success();
348*e7e3c45bSAndrea Faulds }
349*e7e3c45bSAndrea Faulds 
350*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::getBestComputeQueue() {
351*e7e3c45bSAndrea Faulds   uint32_t queueFamilyPropertiesCount = 0;
352*e7e3c45bSAndrea Faulds   vkGetPhysicalDeviceQueueFamilyProperties(
353*e7e3c45bSAndrea Faulds       physicalDevice, &queueFamilyPropertiesCount, nullptr);
354*e7e3c45bSAndrea Faulds 
355*e7e3c45bSAndrea Faulds   std::vector<VkQueueFamilyProperties> familyProperties(
356*e7e3c45bSAndrea Faulds       queueFamilyPropertiesCount);
357*e7e3c45bSAndrea Faulds   vkGetPhysicalDeviceQueueFamilyProperties(
358*e7e3c45bSAndrea Faulds       physicalDevice, &queueFamilyPropertiesCount, familyProperties.data());
359*e7e3c45bSAndrea Faulds 
360*e7e3c45bSAndrea Faulds   // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
361*e7e3c45bSAndrea Faulds   // compute operations. Try to find a compute-only queue first if possible.
362*e7e3c45bSAndrea Faulds   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
363*e7e3c45bSAndrea Faulds     auto flags = familyProperties[i].queueFlags;
364*e7e3c45bSAndrea Faulds     if ((flags & VK_QUEUE_COMPUTE_BIT) && !(flags & VK_QUEUE_GRAPHICS_BIT)) {
365*e7e3c45bSAndrea Faulds       queueFamilyIndex = i;
366*e7e3c45bSAndrea Faulds       queueFamilyProperties = familyProperties[i];
367*e7e3c45bSAndrea Faulds       return success();
368*e7e3c45bSAndrea Faulds     }
369*e7e3c45bSAndrea Faulds   }
370*e7e3c45bSAndrea Faulds 
371*e7e3c45bSAndrea Faulds   // Otherwise use a queue that can also support graphics.
372*e7e3c45bSAndrea Faulds   for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
373*e7e3c45bSAndrea Faulds     auto flags = familyProperties[i].queueFlags;
374*e7e3c45bSAndrea Faulds     if ((flags & VK_QUEUE_COMPUTE_BIT)) {
375*e7e3c45bSAndrea Faulds       queueFamilyIndex = i;
376*e7e3c45bSAndrea Faulds       queueFamilyProperties = familyProperties[i];
377*e7e3c45bSAndrea Faulds       return success();
378*e7e3c45bSAndrea Faulds     }
379*e7e3c45bSAndrea Faulds   }
380*e7e3c45bSAndrea Faulds 
381*e7e3c45bSAndrea Faulds   std::cerr << "cannot find valid queue";
382*e7e3c45bSAndrea Faulds   return failure();
383*e7e3c45bSAndrea Faulds }
384*e7e3c45bSAndrea Faulds 
385*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createMemoryBuffers() {
386*e7e3c45bSAndrea Faulds   // For each descriptor set.
387*e7e3c45bSAndrea Faulds   for (const auto &resourceDataMapPair : resourceData) {
388*e7e3c45bSAndrea Faulds     std::vector<VulkanDeviceMemoryBuffer> deviceMemoryBuffers;
389*e7e3c45bSAndrea Faulds     const auto descriptorSetIndex = resourceDataMapPair.first;
390*e7e3c45bSAndrea Faulds     const auto &resourceDataMap = resourceDataMapPair.second;
391*e7e3c45bSAndrea Faulds 
392*e7e3c45bSAndrea Faulds     // For each descriptor binding.
393*e7e3c45bSAndrea Faulds     for (const auto &resourceDataBindingPair : resourceDataMap) {
394*e7e3c45bSAndrea Faulds       // Create device memory buffer.
395*e7e3c45bSAndrea Faulds       VulkanDeviceMemoryBuffer memoryBuffer;
396*e7e3c45bSAndrea Faulds       memoryBuffer.bindingIndex = resourceDataBindingPair.first;
397*e7e3c45bSAndrea Faulds       VkDescriptorType descriptorType = {};
398*e7e3c45bSAndrea Faulds       VkBufferUsageFlagBits bufferUsage = {};
399*e7e3c45bSAndrea Faulds 
400*e7e3c45bSAndrea Faulds       // Check that descriptor set has storage class map.
401*e7e3c45bSAndrea Faulds       const auto resourceStorageClassMapIt =
402*e7e3c45bSAndrea Faulds           resourceStorageClassData.find(descriptorSetIndex);
403*e7e3c45bSAndrea Faulds       if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
404*e7e3c45bSAndrea Faulds         std::cerr
405*e7e3c45bSAndrea Faulds             << "cannot find storage class for resource in descriptor set: "
406*e7e3c45bSAndrea Faulds             << descriptorSetIndex;
407*e7e3c45bSAndrea Faulds         return failure();
408*e7e3c45bSAndrea Faulds       }
409*e7e3c45bSAndrea Faulds 
410*e7e3c45bSAndrea Faulds       // Check that specific descriptor binding has storage class.
411*e7e3c45bSAndrea Faulds       const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
412*e7e3c45bSAndrea Faulds       const auto resourceStorageClassIt =
413*e7e3c45bSAndrea Faulds           resourceStorageClassMap.find(resourceDataBindingPair.first);
414*e7e3c45bSAndrea Faulds       if (resourceStorageClassIt == resourceStorageClassMap.end()) {
415*e7e3c45bSAndrea Faulds         std::cerr
416*e7e3c45bSAndrea Faulds             << "cannot find storage class for resource with descriptor index: "
417*e7e3c45bSAndrea Faulds             << resourceDataBindingPair.first;
418*e7e3c45bSAndrea Faulds         return failure();
419*e7e3c45bSAndrea Faulds       }
420*e7e3c45bSAndrea Faulds 
421*e7e3c45bSAndrea Faulds       const auto resourceStorageClassBinding = resourceStorageClassIt->second;
422*e7e3c45bSAndrea Faulds       if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
423*e7e3c45bSAndrea Faulds                                                  descriptorType)) ||
424*e7e3c45bSAndrea Faulds           failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
425*e7e3c45bSAndrea Faulds                                                   bufferUsage))) {
426*e7e3c45bSAndrea Faulds         std::cerr << "storage class for resource with descriptor binding: "
427*e7e3c45bSAndrea Faulds                   << resourceDataBindingPair.first
428*e7e3c45bSAndrea Faulds                   << " in the descriptor set: " << descriptorSetIndex
429*e7e3c45bSAndrea Faulds                   << " is not supported ";
430*e7e3c45bSAndrea Faulds         return failure();
431*e7e3c45bSAndrea Faulds       }
432*e7e3c45bSAndrea Faulds 
433*e7e3c45bSAndrea Faulds       // Set descriptor type for the specific device memory buffer.
434*e7e3c45bSAndrea Faulds       memoryBuffer.descriptorType = descriptorType;
435*e7e3c45bSAndrea Faulds       const auto bufferSize = resourceDataBindingPair.second.size;
436*e7e3c45bSAndrea Faulds       memoryBuffer.bufferSize = bufferSize;
437*e7e3c45bSAndrea Faulds       // Specify memory allocation info.
438*e7e3c45bSAndrea Faulds       VkMemoryAllocateInfo memoryAllocateInfo = {};
439*e7e3c45bSAndrea Faulds       memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
440*e7e3c45bSAndrea Faulds       memoryAllocateInfo.pNext = nullptr;
441*e7e3c45bSAndrea Faulds       memoryAllocateInfo.allocationSize = bufferSize;
442*e7e3c45bSAndrea Faulds       memoryAllocateInfo.memoryTypeIndex = hostMemoryTypeIndex;
443*e7e3c45bSAndrea Faulds 
444*e7e3c45bSAndrea Faulds       // Allocate device memory.
445*e7e3c45bSAndrea Faulds       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
446*e7e3c45bSAndrea Faulds                                               nullptr,
447*e7e3c45bSAndrea Faulds                                               &memoryBuffer.hostMemory),
448*e7e3c45bSAndrea Faulds                              "vkAllocateMemory");
449*e7e3c45bSAndrea Faulds       memoryAllocateInfo.memoryTypeIndex = deviceMemoryTypeIndex;
450*e7e3c45bSAndrea Faulds       RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo,
451*e7e3c45bSAndrea Faulds                                               nullptr,
452*e7e3c45bSAndrea Faulds                                               &memoryBuffer.deviceMemory),
453*e7e3c45bSAndrea Faulds                              "vkAllocateMemory");
454*e7e3c45bSAndrea Faulds       void *payload;
455*e7e3c45bSAndrea Faulds       RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.hostMemory, 0,
456*e7e3c45bSAndrea Faulds                                          bufferSize, 0,
457*e7e3c45bSAndrea Faulds                                          reinterpret_cast<void **>(&payload)),
458*e7e3c45bSAndrea Faulds                              "vkMapMemory");
459*e7e3c45bSAndrea Faulds 
460*e7e3c45bSAndrea Faulds       // Copy host memory into the mapped area.
461*e7e3c45bSAndrea Faulds       std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
462*e7e3c45bSAndrea Faulds       vkUnmapMemory(device, memoryBuffer.hostMemory);
463*e7e3c45bSAndrea Faulds 
464*e7e3c45bSAndrea Faulds       VkBufferCreateInfo bufferCreateInfo = {};
465*e7e3c45bSAndrea Faulds       bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
466*e7e3c45bSAndrea Faulds       bufferCreateInfo.pNext = nullptr;
467*e7e3c45bSAndrea Faulds       bufferCreateInfo.flags = 0;
468*e7e3c45bSAndrea Faulds       bufferCreateInfo.size = bufferSize;
469*e7e3c45bSAndrea Faulds       bufferCreateInfo.usage = bufferUsage | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
470*e7e3c45bSAndrea Faulds                                VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
471*e7e3c45bSAndrea Faulds       bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
472*e7e3c45bSAndrea Faulds       bufferCreateInfo.queueFamilyIndexCount = 1;
473*e7e3c45bSAndrea Faulds       bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
474*e7e3c45bSAndrea Faulds       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
475*e7e3c45bSAndrea Faulds                                             &memoryBuffer.hostBuffer),
476*e7e3c45bSAndrea Faulds                              "vkCreateBuffer");
477*e7e3c45bSAndrea Faulds       RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device, &bufferCreateInfo, nullptr,
478*e7e3c45bSAndrea Faulds                                             &memoryBuffer.deviceBuffer),
479*e7e3c45bSAndrea Faulds                              "vkCreateBuffer");
480*e7e3c45bSAndrea Faulds 
481*e7e3c45bSAndrea Faulds       // Bind buffer and device memory.
482*e7e3c45bSAndrea Faulds       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.hostBuffer,
483*e7e3c45bSAndrea Faulds                                                 memoryBuffer.hostMemory, 0),
484*e7e3c45bSAndrea Faulds                              "vkBindBufferMemory");
485*e7e3c45bSAndrea Faulds       RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device,
486*e7e3c45bSAndrea Faulds                                                 memoryBuffer.deviceBuffer,
487*e7e3c45bSAndrea Faulds                                                 memoryBuffer.deviceMemory, 0),
488*e7e3c45bSAndrea Faulds                              "vkBindBufferMemory");
489*e7e3c45bSAndrea Faulds 
490*e7e3c45bSAndrea Faulds       // Update buffer info.
491*e7e3c45bSAndrea Faulds       memoryBuffer.bufferInfo.buffer = memoryBuffer.deviceBuffer;
492*e7e3c45bSAndrea Faulds       memoryBuffer.bufferInfo.offset = 0;
493*e7e3c45bSAndrea Faulds       memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
494*e7e3c45bSAndrea Faulds       deviceMemoryBuffers.push_back(memoryBuffer);
495*e7e3c45bSAndrea Faulds     }
496*e7e3c45bSAndrea Faulds 
497*e7e3c45bSAndrea Faulds     // Associate device memory buffers with a descriptor set.
498*e7e3c45bSAndrea Faulds     deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
499*e7e3c45bSAndrea Faulds   }
500*e7e3c45bSAndrea Faulds   return success();
501*e7e3c45bSAndrea Faulds }
502*e7e3c45bSAndrea Faulds 
503*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::copyResource(bool deviceToHost) {
504*e7e3c45bSAndrea Faulds   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {
505*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
506*e7e3c45bSAndrea Faulds       nullptr,
507*e7e3c45bSAndrea Faulds       commandPool,
508*e7e3c45bSAndrea Faulds       VK_COMMAND_BUFFER_LEVEL_PRIMARY,
509*e7e3c45bSAndrea Faulds       1,
510*e7e3c45bSAndrea Faulds   };
511*e7e3c45bSAndrea Faulds   VkCommandBuffer commandBuffer;
512*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
513*e7e3c45bSAndrea Faulds                                                   &commandBufferAllocateInfo,
514*e7e3c45bSAndrea Faulds                                                   &commandBuffer),
515*e7e3c45bSAndrea Faulds                          "vkAllocateCommandBuffers");
516*e7e3c45bSAndrea Faulds 
517*e7e3c45bSAndrea Faulds   VkCommandBufferBeginInfo commandBufferBeginInfo = {
518*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
519*e7e3c45bSAndrea Faulds       nullptr,
520*e7e3c45bSAndrea Faulds       0,
521*e7e3c45bSAndrea Faulds       nullptr,
522*e7e3c45bSAndrea Faulds   };
523*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(
524*e7e3c45bSAndrea Faulds       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
525*e7e3c45bSAndrea Faulds       "vkBeginCommandBuffer");
526*e7e3c45bSAndrea Faulds 
527*e7e3c45bSAndrea Faulds   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
528*e7e3c45bSAndrea Faulds     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
529*e7e3c45bSAndrea Faulds     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
530*e7e3c45bSAndrea Faulds     for (const auto &memBuffer : deviceMemoryBuffers) {
531*e7e3c45bSAndrea Faulds       VkBufferCopy copy = {0, 0, memBuffer.bufferSize};
532*e7e3c45bSAndrea Faulds       if (deviceToHost)
533*e7e3c45bSAndrea Faulds         vkCmdCopyBuffer(commandBuffer, memBuffer.deviceBuffer,
534*e7e3c45bSAndrea Faulds                         memBuffer.hostBuffer, 1, &copy);
535*e7e3c45bSAndrea Faulds       else
536*e7e3c45bSAndrea Faulds         vkCmdCopyBuffer(commandBuffer, memBuffer.hostBuffer,
537*e7e3c45bSAndrea Faulds                         memBuffer.deviceBuffer, 1, &copy);
538*e7e3c45bSAndrea Faulds     }
539*e7e3c45bSAndrea Faulds   }
540*e7e3c45bSAndrea Faulds 
541*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
542*e7e3c45bSAndrea Faulds                          "vkEndCommandBuffer");
543*e7e3c45bSAndrea Faulds   VkSubmitInfo submitInfo = {
544*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_SUBMIT_INFO,
545*e7e3c45bSAndrea Faulds       nullptr,
546*e7e3c45bSAndrea Faulds       0,
547*e7e3c45bSAndrea Faulds       nullptr,
548*e7e3c45bSAndrea Faulds       nullptr,
549*e7e3c45bSAndrea Faulds       1,
550*e7e3c45bSAndrea Faulds       &commandBuffer,
551*e7e3c45bSAndrea Faulds       0,
552*e7e3c45bSAndrea Faulds       nullptr,
553*e7e3c45bSAndrea Faulds   };
554*e7e3c45bSAndrea Faulds   submitInfo.pCommandBuffers = &commandBuffer;
555*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, VK_NULL_HANDLE),
556*e7e3c45bSAndrea Faulds                          "vkQueueSubmit");
557*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
558*e7e3c45bSAndrea Faulds 
559*e7e3c45bSAndrea Faulds   vkFreeCommandBuffers(device, commandPool, 1, &commandBuffer);
560*e7e3c45bSAndrea Faulds   return success();
561*e7e3c45bSAndrea Faulds }
562*e7e3c45bSAndrea Faulds 
563*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createShaderModule() {
564*e7e3c45bSAndrea Faulds   VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
565*e7e3c45bSAndrea Faulds   shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
566*e7e3c45bSAndrea Faulds   shaderModuleCreateInfo.pNext = nullptr;
567*e7e3c45bSAndrea Faulds   shaderModuleCreateInfo.flags = 0;
568*e7e3c45bSAndrea Faulds   // Set size in bytes.
569*e7e3c45bSAndrea Faulds   shaderModuleCreateInfo.codeSize = binarySize;
570*e7e3c45bSAndrea Faulds   // Set pointer to the binary shader.
571*e7e3c45bSAndrea Faulds   shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
572*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device, &shaderModuleCreateInfo,
573*e7e3c45bSAndrea Faulds                                               nullptr, &shaderModule),
574*e7e3c45bSAndrea Faulds                          "vkCreateShaderModule");
575*e7e3c45bSAndrea Faulds   return success();
576*e7e3c45bSAndrea Faulds }
577*e7e3c45bSAndrea Faulds 
578*e7e3c45bSAndrea Faulds void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
579*e7e3c45bSAndrea Faulds   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
580*e7e3c45bSAndrea Faulds     std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
581*e7e3c45bSAndrea Faulds     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
582*e7e3c45bSAndrea Faulds     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
583*e7e3c45bSAndrea Faulds 
584*e7e3c45bSAndrea Faulds     // Create a layout binding for each descriptor.
585*e7e3c45bSAndrea Faulds     for (const auto &memBuffer : deviceMemoryBuffers) {
586*e7e3c45bSAndrea Faulds       VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
587*e7e3c45bSAndrea Faulds       descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
588*e7e3c45bSAndrea Faulds       descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
589*e7e3c45bSAndrea Faulds       descriptorSetLayoutBinding.descriptorCount = 1;
590*e7e3c45bSAndrea Faulds       descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
591*e7e3c45bSAndrea Faulds       descriptorSetLayoutBinding.pImmutableSamplers = nullptr;
592*e7e3c45bSAndrea Faulds       descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
593*e7e3c45bSAndrea Faulds     }
594*e7e3c45bSAndrea Faulds     descriptorSetLayoutBindingMap[descriptorSetIndex] =
595*e7e3c45bSAndrea Faulds         descriptorSetLayoutBindings;
596*e7e3c45bSAndrea Faulds   }
597*e7e3c45bSAndrea Faulds }
598*e7e3c45bSAndrea Faulds 
599*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createDescriptorSetLayout() {
600*e7e3c45bSAndrea Faulds   for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
601*e7e3c45bSAndrea Faulds     const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
602*e7e3c45bSAndrea Faulds     const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
603*e7e3c45bSAndrea Faulds     // Each descriptor in a descriptor set must be the same type.
604*e7e3c45bSAndrea Faulds     VkDescriptorType descriptorType =
605*e7e3c45bSAndrea Faulds         deviceMemoryBuffers.front().descriptorType;
606*e7e3c45bSAndrea Faulds     const uint32_t descriptorSize = deviceMemoryBuffers.size();
607*e7e3c45bSAndrea Faulds     const auto descriptorSetLayoutBindingIt =
608*e7e3c45bSAndrea Faulds         descriptorSetLayoutBindingMap.find(descriptorSetIndex);
609*e7e3c45bSAndrea Faulds 
610*e7e3c45bSAndrea Faulds     if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
611*e7e3c45bSAndrea Faulds       std::cerr << "cannot find layout bindings for the set with number: "
612*e7e3c45bSAndrea Faulds                 << descriptorSetIndex;
613*e7e3c45bSAndrea Faulds       return failure();
614*e7e3c45bSAndrea Faulds     }
615*e7e3c45bSAndrea Faulds 
616*e7e3c45bSAndrea Faulds     const auto &descriptorSetLayoutBindings =
617*e7e3c45bSAndrea Faulds         descriptorSetLayoutBindingIt->second;
618*e7e3c45bSAndrea Faulds     // Create descriptor set layout.
619*e7e3c45bSAndrea Faulds     VkDescriptorSetLayout descriptorSetLayout = {};
620*e7e3c45bSAndrea Faulds     VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
621*e7e3c45bSAndrea Faulds 
622*e7e3c45bSAndrea Faulds     descriptorSetLayoutCreateInfo.sType =
623*e7e3c45bSAndrea Faulds         VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
624*e7e3c45bSAndrea Faulds     descriptorSetLayoutCreateInfo.pNext = nullptr;
625*e7e3c45bSAndrea Faulds     descriptorSetLayoutCreateInfo.flags = 0;
626*e7e3c45bSAndrea Faulds     // Amount of descriptor bindings in a layout set.
627*e7e3c45bSAndrea Faulds     descriptorSetLayoutCreateInfo.bindingCount =
628*e7e3c45bSAndrea Faulds         descriptorSetLayoutBindings.size();
629*e7e3c45bSAndrea Faulds     descriptorSetLayoutCreateInfo.pBindings =
630*e7e3c45bSAndrea Faulds         descriptorSetLayoutBindings.data();
631*e7e3c45bSAndrea Faulds     RETURN_ON_VULKAN_ERROR(
632*e7e3c45bSAndrea Faulds         vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo,
633*e7e3c45bSAndrea Faulds                                     nullptr, &descriptorSetLayout),
634*e7e3c45bSAndrea Faulds         "vkCreateDescriptorSetLayout");
635*e7e3c45bSAndrea Faulds 
636*e7e3c45bSAndrea Faulds     descriptorSetLayouts.push_back(descriptorSetLayout);
637*e7e3c45bSAndrea Faulds     descriptorSetInfoPool.push_back(
638*e7e3c45bSAndrea Faulds         {descriptorSetIndex, descriptorSize, descriptorType});
639*e7e3c45bSAndrea Faulds   }
640*e7e3c45bSAndrea Faulds   return success();
641*e7e3c45bSAndrea Faulds }
642*e7e3c45bSAndrea Faulds 
643*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createPipelineLayout() {
644*e7e3c45bSAndrea Faulds   // Associate descriptor sets with a pipeline layout.
645*e7e3c45bSAndrea Faulds   VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
646*e7e3c45bSAndrea Faulds   pipelineLayoutCreateInfo.sType =
647*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
648*e7e3c45bSAndrea Faulds   pipelineLayoutCreateInfo.pNext = nullptr;
649*e7e3c45bSAndrea Faulds   pipelineLayoutCreateInfo.flags = 0;
650*e7e3c45bSAndrea Faulds   pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
651*e7e3c45bSAndrea Faulds   pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
652*e7e3c45bSAndrea Faulds   pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
653*e7e3c45bSAndrea Faulds   pipelineLayoutCreateInfo.pPushConstantRanges = nullptr;
654*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
655*e7e3c45bSAndrea Faulds                                                 &pipelineLayoutCreateInfo,
656*e7e3c45bSAndrea Faulds                                                 nullptr, &pipelineLayout),
657*e7e3c45bSAndrea Faulds                          "vkCreatePipelineLayout");
658*e7e3c45bSAndrea Faulds   return success();
659*e7e3c45bSAndrea Faulds }
660*e7e3c45bSAndrea Faulds 
661*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createComputePipeline() {
662*e7e3c45bSAndrea Faulds   VkPipelineShaderStageCreateInfo stageInfo = {};
663*e7e3c45bSAndrea Faulds   stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
664*e7e3c45bSAndrea Faulds   stageInfo.pNext = nullptr;
665*e7e3c45bSAndrea Faulds   stageInfo.flags = 0;
666*e7e3c45bSAndrea Faulds   stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
667*e7e3c45bSAndrea Faulds   stageInfo.module = shaderModule;
668*e7e3c45bSAndrea Faulds   // Set entry point.
669*e7e3c45bSAndrea Faulds   stageInfo.pName = entryPoint;
670*e7e3c45bSAndrea Faulds   stageInfo.pSpecializationInfo = nullptr;
671*e7e3c45bSAndrea Faulds 
672*e7e3c45bSAndrea Faulds   VkComputePipelineCreateInfo computePipelineCreateInfo = {};
673*e7e3c45bSAndrea Faulds   computePipelineCreateInfo.sType =
674*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
675*e7e3c45bSAndrea Faulds   computePipelineCreateInfo.pNext = nullptr;
676*e7e3c45bSAndrea Faulds   computePipelineCreateInfo.flags = 0;
677*e7e3c45bSAndrea Faulds   computePipelineCreateInfo.stage = stageInfo;
678*e7e3c45bSAndrea Faulds   computePipelineCreateInfo.layout = pipelineLayout;
679*e7e3c45bSAndrea Faulds   computePipelineCreateInfo.basePipelineHandle = nullptr;
680*e7e3c45bSAndrea Faulds   computePipelineCreateInfo.basePipelineIndex = 0;
681*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, nullptr, 1,
682*e7e3c45bSAndrea Faulds                                                   &computePipelineCreateInfo,
683*e7e3c45bSAndrea Faulds                                                   nullptr, &pipeline),
684*e7e3c45bSAndrea Faulds                          "vkCreateComputePipelines");
685*e7e3c45bSAndrea Faulds   return success();
686*e7e3c45bSAndrea Faulds }
687*e7e3c45bSAndrea Faulds 
688*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createDescriptorPool() {
689*e7e3c45bSAndrea Faulds   std::vector<VkDescriptorPoolSize> descriptorPoolSizes;
690*e7e3c45bSAndrea Faulds   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
691*e7e3c45bSAndrea Faulds     // For each descriptor set populate descriptor pool size.
692*e7e3c45bSAndrea Faulds     VkDescriptorPoolSize descriptorPoolSize = {};
693*e7e3c45bSAndrea Faulds     descriptorPoolSize.type = descriptorSetInfo.descriptorType;
694*e7e3c45bSAndrea Faulds     descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
695*e7e3c45bSAndrea Faulds     descriptorPoolSizes.push_back(descriptorPoolSize);
696*e7e3c45bSAndrea Faulds   }
697*e7e3c45bSAndrea Faulds 
698*e7e3c45bSAndrea Faulds   VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
699*e7e3c45bSAndrea Faulds   descriptorPoolCreateInfo.sType =
700*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
701*e7e3c45bSAndrea Faulds   descriptorPoolCreateInfo.pNext = nullptr;
702*e7e3c45bSAndrea Faulds   descriptorPoolCreateInfo.flags = 0;
703*e7e3c45bSAndrea Faulds   descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
704*e7e3c45bSAndrea Faulds   descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
705*e7e3c45bSAndrea Faulds   descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
706*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
707*e7e3c45bSAndrea Faulds                                                 &descriptorPoolCreateInfo,
708*e7e3c45bSAndrea Faulds                                                 nullptr, &descriptorPool),
709*e7e3c45bSAndrea Faulds                          "vkCreateDescriptorPool");
710*e7e3c45bSAndrea Faulds   return success();
711*e7e3c45bSAndrea Faulds }
712*e7e3c45bSAndrea Faulds 
713*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::allocateDescriptorSets() {
714*e7e3c45bSAndrea Faulds   VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
715*e7e3c45bSAndrea Faulds   // Size of descriptor sets and descriptor layout sets is the same.
716*e7e3c45bSAndrea Faulds   descriptorSets.resize(descriptorSetLayouts.size());
717*e7e3c45bSAndrea Faulds   descriptorSetAllocateInfo.sType =
718*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
719*e7e3c45bSAndrea Faulds   descriptorSetAllocateInfo.pNext = nullptr;
720*e7e3c45bSAndrea Faulds   descriptorSetAllocateInfo.descriptorPool = descriptorPool;
721*e7e3c45bSAndrea Faulds   descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
722*e7e3c45bSAndrea Faulds   descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
723*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
724*e7e3c45bSAndrea Faulds                                                   &descriptorSetAllocateInfo,
725*e7e3c45bSAndrea Faulds                                                   descriptorSets.data()),
726*e7e3c45bSAndrea Faulds                          "vkAllocateDescriptorSets");
727*e7e3c45bSAndrea Faulds   return success();
728*e7e3c45bSAndrea Faulds }
729*e7e3c45bSAndrea Faulds 
730*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::setWriteDescriptors() {
731*e7e3c45bSAndrea Faulds   if (descriptorSets.size() != descriptorSetInfoPool.size()) {
732*e7e3c45bSAndrea Faulds     std::cerr << "Each descriptor set must have descriptor set information";
733*e7e3c45bSAndrea Faulds     return failure();
734*e7e3c45bSAndrea Faulds   }
735*e7e3c45bSAndrea Faulds   // For each descriptor set.
736*e7e3c45bSAndrea Faulds   auto descriptorSetIt = descriptorSets.begin();
737*e7e3c45bSAndrea Faulds   // Each descriptor set is associated with descriptor set info.
738*e7e3c45bSAndrea Faulds   for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
739*e7e3c45bSAndrea Faulds     // For each device memory buffer in the descriptor set.
740*e7e3c45bSAndrea Faulds     const auto &deviceMemoryBuffers =
741*e7e3c45bSAndrea Faulds         deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
742*e7e3c45bSAndrea Faulds     for (const auto &memoryBuffer : deviceMemoryBuffers) {
743*e7e3c45bSAndrea Faulds       // Structure describing descriptor sets to write to.
744*e7e3c45bSAndrea Faulds       VkWriteDescriptorSet wSet = {};
745*e7e3c45bSAndrea Faulds       wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
746*e7e3c45bSAndrea Faulds       wSet.pNext = nullptr;
747*e7e3c45bSAndrea Faulds       // Descriptor set.
748*e7e3c45bSAndrea Faulds       wSet.dstSet = *descriptorSetIt;
749*e7e3c45bSAndrea Faulds       wSet.dstBinding = memoryBuffer.bindingIndex;
750*e7e3c45bSAndrea Faulds       wSet.dstArrayElement = 0;
751*e7e3c45bSAndrea Faulds       wSet.descriptorCount = 1;
752*e7e3c45bSAndrea Faulds       wSet.descriptorType = memoryBuffer.descriptorType;
753*e7e3c45bSAndrea Faulds       wSet.pImageInfo = nullptr;
754*e7e3c45bSAndrea Faulds       wSet.pBufferInfo = &memoryBuffer.bufferInfo;
755*e7e3c45bSAndrea Faulds       wSet.pTexelBufferView = nullptr;
756*e7e3c45bSAndrea Faulds       vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
757*e7e3c45bSAndrea Faulds     }
758*e7e3c45bSAndrea Faulds     // Increment descriptor set iterator.
759*e7e3c45bSAndrea Faulds     ++descriptorSetIt;
760*e7e3c45bSAndrea Faulds   }
761*e7e3c45bSAndrea Faulds   return success();
762*e7e3c45bSAndrea Faulds }
763*e7e3c45bSAndrea Faulds 
764*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createCommandPool() {
765*e7e3c45bSAndrea Faulds   VkCommandPoolCreateInfo commandPoolCreateInfo = {};
766*e7e3c45bSAndrea Faulds   commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
767*e7e3c45bSAndrea Faulds   commandPoolCreateInfo.pNext = nullptr;
768*e7e3c45bSAndrea Faulds   commandPoolCreateInfo.flags = 0;
769*e7e3c45bSAndrea Faulds   commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
770*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device, &commandPoolCreateInfo,
771*e7e3c45bSAndrea Faulds                                              /*pAllocator=*/nullptr,
772*e7e3c45bSAndrea Faulds                                              &commandPool),
773*e7e3c45bSAndrea Faulds                          "vkCreateCommandPool");
774*e7e3c45bSAndrea Faulds   return success();
775*e7e3c45bSAndrea Faulds }
776*e7e3c45bSAndrea Faulds 
777*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createQueryPool() {
778*e7e3c45bSAndrea Faulds   // Return directly if timestamp query is not supported.
779*e7e3c45bSAndrea Faulds   if (queueFamilyProperties.timestampValidBits == 0)
780*e7e3c45bSAndrea Faulds     return success();
781*e7e3c45bSAndrea Faulds 
782*e7e3c45bSAndrea Faulds   // Get timestamp period for this physical device.
783*e7e3c45bSAndrea Faulds   VkPhysicalDeviceProperties deviceProperties = {};
784*e7e3c45bSAndrea Faulds   vkGetPhysicalDeviceProperties(physicalDevice, &deviceProperties);
785*e7e3c45bSAndrea Faulds   timestampPeriod = deviceProperties.limits.timestampPeriod;
786*e7e3c45bSAndrea Faulds 
787*e7e3c45bSAndrea Faulds   // Create query pool.
788*e7e3c45bSAndrea Faulds   VkQueryPoolCreateInfo queryPoolCreateInfo = {};
789*e7e3c45bSAndrea Faulds   queryPoolCreateInfo.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO;
790*e7e3c45bSAndrea Faulds   queryPoolCreateInfo.pNext = nullptr;
791*e7e3c45bSAndrea Faulds   queryPoolCreateInfo.flags = 0;
792*e7e3c45bSAndrea Faulds   queryPoolCreateInfo.queryType = VK_QUERY_TYPE_TIMESTAMP;
793*e7e3c45bSAndrea Faulds   queryPoolCreateInfo.queryCount = 2;
794*e7e3c45bSAndrea Faulds   queryPoolCreateInfo.pipelineStatistics = 0;
795*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device, &queryPoolCreateInfo,
796*e7e3c45bSAndrea Faulds                                            /*pAllocator=*/nullptr, &queryPool),
797*e7e3c45bSAndrea Faulds                          "vkCreateQueryPool");
798*e7e3c45bSAndrea Faulds 
799*e7e3c45bSAndrea Faulds   return success();
800*e7e3c45bSAndrea Faulds }
801*e7e3c45bSAndrea Faulds 
802*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::createComputeCommandBuffer() {
803*e7e3c45bSAndrea Faulds   VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
804*e7e3c45bSAndrea Faulds   commandBufferAllocateInfo.sType =
805*e7e3c45bSAndrea Faulds       VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
806*e7e3c45bSAndrea Faulds   commandBufferAllocateInfo.pNext = nullptr;
807*e7e3c45bSAndrea Faulds   commandBufferAllocateInfo.commandPool = commandPool;
808*e7e3c45bSAndrea Faulds   commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
809*e7e3c45bSAndrea Faulds   commandBufferAllocateInfo.commandBufferCount = 1;
810*e7e3c45bSAndrea Faulds 
811*e7e3c45bSAndrea Faulds   VkCommandBuffer commandBuffer;
812*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
813*e7e3c45bSAndrea Faulds                                                   &commandBufferAllocateInfo,
814*e7e3c45bSAndrea Faulds                                                   &commandBuffer),
815*e7e3c45bSAndrea Faulds                          "vkAllocateCommandBuffers");
816*e7e3c45bSAndrea Faulds 
817*e7e3c45bSAndrea Faulds   VkCommandBufferBeginInfo commandBufferBeginInfo = {};
818*e7e3c45bSAndrea Faulds   commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
819*e7e3c45bSAndrea Faulds   commandBufferBeginInfo.pNext = nullptr;
820*e7e3c45bSAndrea Faulds   commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
821*e7e3c45bSAndrea Faulds   commandBufferBeginInfo.pInheritanceInfo = nullptr;
822*e7e3c45bSAndrea Faulds 
823*e7e3c45bSAndrea Faulds   // Commands begin.
824*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(
825*e7e3c45bSAndrea Faulds       vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
826*e7e3c45bSAndrea Faulds       "vkBeginCommandBuffer");
827*e7e3c45bSAndrea Faulds 
828*e7e3c45bSAndrea Faulds   if (queryPool != VK_NULL_HANDLE)
829*e7e3c45bSAndrea Faulds     vkCmdResetQueryPool(commandBuffer, queryPool, 0, 2);
830*e7e3c45bSAndrea Faulds 
831*e7e3c45bSAndrea Faulds   vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
832*e7e3c45bSAndrea Faulds   vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
833*e7e3c45bSAndrea Faulds                           pipelineLayout, 0, descriptorSets.size(),
834*e7e3c45bSAndrea Faulds                           descriptorSets.data(), 0, nullptr);
835*e7e3c45bSAndrea Faulds   // Get a timestamp before invoking the compute shader.
836*e7e3c45bSAndrea Faulds   if (queryPool != VK_NULL_HANDLE)
837*e7e3c45bSAndrea Faulds     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT,
838*e7e3c45bSAndrea Faulds                         queryPool, 0);
839*e7e3c45bSAndrea Faulds   vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
840*e7e3c45bSAndrea Faulds                 numWorkGroups.z);
841*e7e3c45bSAndrea Faulds   // Get another timestamp after invoking the compute shader.
842*e7e3c45bSAndrea Faulds   if (queryPool != VK_NULL_HANDLE)
843*e7e3c45bSAndrea Faulds     vkCmdWriteTimestamp(commandBuffer, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT,
844*e7e3c45bSAndrea Faulds                         queryPool, 1);
845*e7e3c45bSAndrea Faulds 
846*e7e3c45bSAndrea Faulds   // Commands end.
847*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
848*e7e3c45bSAndrea Faulds                          "vkEndCommandBuffer");
849*e7e3c45bSAndrea Faulds 
850*e7e3c45bSAndrea Faulds   commandBuffers.push_back(commandBuffer);
851*e7e3c45bSAndrea Faulds   return success();
852*e7e3c45bSAndrea Faulds }
853*e7e3c45bSAndrea Faulds 
854*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
855*e7e3c45bSAndrea Faulds   VkSubmitInfo submitInfo = {};
856*e7e3c45bSAndrea Faulds   submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
857*e7e3c45bSAndrea Faulds   submitInfo.pNext = nullptr;
858*e7e3c45bSAndrea Faulds   submitInfo.waitSemaphoreCount = 0;
859*e7e3c45bSAndrea Faulds   submitInfo.pWaitSemaphores = nullptr;
860*e7e3c45bSAndrea Faulds   submitInfo.pWaitDstStageMask = nullptr;
861*e7e3c45bSAndrea Faulds   submitInfo.commandBufferCount = commandBuffers.size();
862*e7e3c45bSAndrea Faulds   submitInfo.pCommandBuffers = commandBuffers.data();
863*e7e3c45bSAndrea Faulds   submitInfo.signalSemaphoreCount = 0;
864*e7e3c45bSAndrea Faulds   submitInfo.pSignalSemaphores = nullptr;
865*e7e3c45bSAndrea Faulds   RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, nullptr),
866*e7e3c45bSAndrea Faulds                          "vkQueueSubmit");
867*e7e3c45bSAndrea Faulds   return success();
868*e7e3c45bSAndrea Faulds }
869*e7e3c45bSAndrea Faulds 
870*e7e3c45bSAndrea Faulds LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
871*e7e3c45bSAndrea Faulds   // First copy back the data to the staging buffer.
872*e7e3c45bSAndrea Faulds   (void)copyResource(/*deviceToHost=*/true);
873*e7e3c45bSAndrea Faulds 
874*e7e3c45bSAndrea Faulds   // For each descriptor set.
875*e7e3c45bSAndrea Faulds   for (auto &resourceDataMapPair : resourceData) {
876*e7e3c45bSAndrea Faulds     auto &resourceDataMap = resourceDataMapPair.second;
877*e7e3c45bSAndrea Faulds     auto &deviceMemoryBuffers =
878*e7e3c45bSAndrea Faulds         deviceMemoryBufferMap[resourceDataMapPair.first];
879*e7e3c45bSAndrea Faulds     // For each device memory buffer in the set.
880*e7e3c45bSAndrea Faulds     for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
881*e7e3c45bSAndrea Faulds       if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
882*e7e3c45bSAndrea Faulds         void *payload;
883*e7e3c45bSAndrea Faulds         auto &hostMemoryBuffer =
884*e7e3c45bSAndrea Faulds             resourceDataMap[deviceMemoryBuffer.bindingIndex];
885*e7e3c45bSAndrea Faulds         RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
886*e7e3c45bSAndrea Faulds                                            deviceMemoryBuffer.hostMemory, 0,
887*e7e3c45bSAndrea Faulds                                            hostMemoryBuffer.size, 0,
888*e7e3c45bSAndrea Faulds                                            reinterpret_cast<void **>(&payload)),
889*e7e3c45bSAndrea Faulds                                "vkMapMemory");
890*e7e3c45bSAndrea Faulds         std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
891*e7e3c45bSAndrea Faulds         vkUnmapMemory(device, deviceMemoryBuffer.hostMemory);
892*e7e3c45bSAndrea Faulds       }
893*e7e3c45bSAndrea Faulds     }
894*e7e3c45bSAndrea Faulds   }
895*e7e3c45bSAndrea Faulds   return success();
896*e7e3c45bSAndrea Faulds }
897