xref: /llvm-project/mlir/lib/ExecutionEngine/VulkanRuntime.h (revision e7e3c45bc70904e24e2b3221ac8521e67eb84668)
1*e7e3c45bSAndrea Faulds //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2*e7e3c45bSAndrea Faulds //
3*e7e3c45bSAndrea Faulds // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*e7e3c45bSAndrea Faulds // See https://llvm.org/LICENSE.txt for license information.
5*e7e3c45bSAndrea Faulds // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*e7e3c45bSAndrea Faulds //
7*e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
8*e7e3c45bSAndrea Faulds //
9*e7e3c45bSAndrea Faulds // This file declares Vulkan runtime API.
10*e7e3c45bSAndrea Faulds //
11*e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
12*e7e3c45bSAndrea Faulds 
13*e7e3c45bSAndrea Faulds #ifndef VULKAN_RUNTIME_H
14*e7e3c45bSAndrea Faulds #define VULKAN_RUNTIME_H
15*e7e3c45bSAndrea Faulds 
16*e7e3c45bSAndrea Faulds #include "mlir/Support/LLVM.h"
17*e7e3c45bSAndrea Faulds 
18*e7e3c45bSAndrea Faulds #include <unordered_map>
19*e7e3c45bSAndrea Faulds #include <vector>
20*e7e3c45bSAndrea Faulds #include <vulkan/vulkan.h>
21*e7e3c45bSAndrea Faulds 
22*e7e3c45bSAndrea Faulds using namespace mlir;
23*e7e3c45bSAndrea Faulds 
24*e7e3c45bSAndrea Faulds using DescriptorSetIndex = uint32_t;
25*e7e3c45bSAndrea Faulds using BindingIndex = uint32_t;
26*e7e3c45bSAndrea Faulds 
27*e7e3c45bSAndrea Faulds /// Struct containing information regarding to a device memory buffer.
28*e7e3c45bSAndrea Faulds struct VulkanDeviceMemoryBuffer {
29*e7e3c45bSAndrea Faulds   BindingIndex bindingIndex{0};
30*e7e3c45bSAndrea Faulds   VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
31*e7e3c45bSAndrea Faulds   VkDescriptorBufferInfo bufferInfo{};
32*e7e3c45bSAndrea Faulds   VkBuffer hostBuffer{VK_NULL_HANDLE};
33*e7e3c45bSAndrea Faulds   VkDeviceMemory hostMemory{VK_NULL_HANDLE};
34*e7e3c45bSAndrea Faulds   VkBuffer deviceBuffer{VK_NULL_HANDLE};
35*e7e3c45bSAndrea Faulds   VkDeviceMemory deviceMemory{VK_NULL_HANDLE};
36*e7e3c45bSAndrea Faulds   uint32_t bufferSize{0};
37*e7e3c45bSAndrea Faulds };
38*e7e3c45bSAndrea Faulds 
39*e7e3c45bSAndrea Faulds /// Struct containing information regarding to a host memory buffer.
40*e7e3c45bSAndrea Faulds struct VulkanHostMemoryBuffer {
41*e7e3c45bSAndrea Faulds   /// Pointer to a host memory.
42*e7e3c45bSAndrea Faulds   void *ptr{nullptr};
43*e7e3c45bSAndrea Faulds   /// Size of a host memory in bytes.
44*e7e3c45bSAndrea Faulds   uint32_t size{0};
45*e7e3c45bSAndrea Faulds };
46*e7e3c45bSAndrea Faulds 
47*e7e3c45bSAndrea Faulds /// Struct containing the number of local workgroups to dispatch for each
48*e7e3c45bSAndrea Faulds /// dimension.
49*e7e3c45bSAndrea Faulds struct NumWorkGroups {
50*e7e3c45bSAndrea Faulds   uint32_t x{1};
51*e7e3c45bSAndrea Faulds   uint32_t y{1};
52*e7e3c45bSAndrea Faulds   uint32_t z{1};
53*e7e3c45bSAndrea Faulds };
54*e7e3c45bSAndrea Faulds 
55*e7e3c45bSAndrea Faulds /// Struct containing information regarding a descriptor set.
56*e7e3c45bSAndrea Faulds struct DescriptorSetInfo {
57*e7e3c45bSAndrea Faulds   /// Index of a descriptor set in descriptor sets.
58*e7e3c45bSAndrea Faulds   DescriptorSetIndex descriptorSet{0};
59*e7e3c45bSAndrea Faulds   /// Number of descriptors in a set.
60*e7e3c45bSAndrea Faulds   uint32_t descriptorSize{0};
61*e7e3c45bSAndrea Faulds   /// Type of a descriptor set.
62*e7e3c45bSAndrea Faulds   VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
63*e7e3c45bSAndrea Faulds };
64*e7e3c45bSAndrea Faulds 
65*e7e3c45bSAndrea Faulds /// VulkanHostMemoryBuffer mapped into a descriptor set and a binding.
66*e7e3c45bSAndrea Faulds using ResourceData = std::unordered_map<
67*e7e3c45bSAndrea Faulds     DescriptorSetIndex,
68*e7e3c45bSAndrea Faulds     std::unordered_map<BindingIndex, VulkanHostMemoryBuffer>>;
69*e7e3c45bSAndrea Faulds 
70*e7e3c45bSAndrea Faulds /// SPIR-V storage classes.
71*e7e3c45bSAndrea Faulds /// Note that this duplicates spirv::StorageClass but it keeps the Vulkan
72*e7e3c45bSAndrea Faulds /// runtime library detached from SPIR-V dialect, so we can avoid pick up lots
73*e7e3c45bSAndrea Faulds /// of dependencies.
74*e7e3c45bSAndrea Faulds enum class SPIRVStorageClass {
75*e7e3c45bSAndrea Faulds   Uniform = 2,
76*e7e3c45bSAndrea Faulds   StorageBuffer = 12,
77*e7e3c45bSAndrea Faulds };
78*e7e3c45bSAndrea Faulds 
79*e7e3c45bSAndrea Faulds /// StorageClass mapped into a descriptor set and a binding.
80*e7e3c45bSAndrea Faulds using ResourceStorageClassBindingMap =
81*e7e3c45bSAndrea Faulds     std::unordered_map<DescriptorSetIndex,
82*e7e3c45bSAndrea Faulds                        std::unordered_map<BindingIndex, SPIRVStorageClass>>;
83*e7e3c45bSAndrea Faulds 
84*e7e3c45bSAndrea Faulds /// Vulkan runtime.
85*e7e3c45bSAndrea Faulds /// The purpose of this class is to run SPIR-V compute shader on Vulkan
86*e7e3c45bSAndrea Faulds /// device.
87*e7e3c45bSAndrea Faulds /// Before the run, user must provide and set resource data with descriptors,
88*e7e3c45bSAndrea Faulds /// SPIR-V shader, number of work groups and entry point. After the creation of
89*e7e3c45bSAndrea Faulds /// VulkanRuntime, special methods must be called in the following
90*e7e3c45bSAndrea Faulds /// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy();
91*e7e3c45bSAndrea Faulds /// each method in the sequence returns success or failure depends on the Vulkan
92*e7e3c45bSAndrea Faulds /// result code.
93*e7e3c45bSAndrea Faulds class VulkanRuntime {
94*e7e3c45bSAndrea Faulds public:
95*e7e3c45bSAndrea Faulds   explicit VulkanRuntime() = default;
96*e7e3c45bSAndrea Faulds   VulkanRuntime(const VulkanRuntime &) = delete;
97*e7e3c45bSAndrea Faulds   VulkanRuntime &operator=(const VulkanRuntime &) = delete;
98*e7e3c45bSAndrea Faulds 
99*e7e3c45bSAndrea Faulds   /// Sets needed data for Vulkan runtime.
100*e7e3c45bSAndrea Faulds   void setResourceData(const ResourceData &resData);
101*e7e3c45bSAndrea Faulds   void setResourceData(const DescriptorSetIndex desIndex,
102*e7e3c45bSAndrea Faulds                        const BindingIndex bindIndex,
103*e7e3c45bSAndrea Faulds                        const VulkanHostMemoryBuffer &hostMemBuffer);
104*e7e3c45bSAndrea Faulds   void setShaderModule(uint8_t *shader, uint32_t size);
105*e7e3c45bSAndrea Faulds   void setNumWorkGroups(const NumWorkGroups &numberWorkGroups);
106*e7e3c45bSAndrea Faulds   void setResourceStorageClassBindingMap(
107*e7e3c45bSAndrea Faulds       const ResourceStorageClassBindingMap &stClassData);
108*e7e3c45bSAndrea Faulds   void setEntryPoint(const char *entryPointName);
109*e7e3c45bSAndrea Faulds 
110*e7e3c45bSAndrea Faulds   /// Runtime initialization.
111*e7e3c45bSAndrea Faulds   LogicalResult initRuntime();
112*e7e3c45bSAndrea Faulds 
113*e7e3c45bSAndrea Faulds   /// Runs runtime.
114*e7e3c45bSAndrea Faulds   LogicalResult run();
115*e7e3c45bSAndrea Faulds 
116*e7e3c45bSAndrea Faulds   /// Updates host memory buffers.
117*e7e3c45bSAndrea Faulds   LogicalResult updateHostMemoryBuffers();
118*e7e3c45bSAndrea Faulds 
119*e7e3c45bSAndrea Faulds   /// Destroys all created vulkan objects and resources.
120*e7e3c45bSAndrea Faulds   LogicalResult destroy();
121*e7e3c45bSAndrea Faulds 
122*e7e3c45bSAndrea Faulds private:
123*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
124*e7e3c45bSAndrea Faulds   // Pipeline creation methods.
125*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
126*e7e3c45bSAndrea Faulds 
127*e7e3c45bSAndrea Faulds   LogicalResult createInstance();
128*e7e3c45bSAndrea Faulds   LogicalResult createDevice();
129*e7e3c45bSAndrea Faulds   LogicalResult getBestComputeQueue();
130*e7e3c45bSAndrea Faulds   LogicalResult createMemoryBuffers();
131*e7e3c45bSAndrea Faulds   LogicalResult createShaderModule();
132*e7e3c45bSAndrea Faulds   void initDescriptorSetLayoutBindingMap();
133*e7e3c45bSAndrea Faulds   LogicalResult createDescriptorSetLayout();
134*e7e3c45bSAndrea Faulds   LogicalResult createPipelineLayout();
135*e7e3c45bSAndrea Faulds   LogicalResult createComputePipeline();
136*e7e3c45bSAndrea Faulds   LogicalResult createDescriptorPool();
137*e7e3c45bSAndrea Faulds   LogicalResult allocateDescriptorSets();
138*e7e3c45bSAndrea Faulds   LogicalResult setWriteDescriptors();
139*e7e3c45bSAndrea Faulds   LogicalResult createCommandPool();
140*e7e3c45bSAndrea Faulds   LogicalResult createQueryPool();
141*e7e3c45bSAndrea Faulds   LogicalResult createComputeCommandBuffer();
142*e7e3c45bSAndrea Faulds   LogicalResult submitCommandBuffersToQueue();
143*e7e3c45bSAndrea Faulds   // Copy resources from host (staging buffer) to device buffer or from device
144*e7e3c45bSAndrea Faulds   // buffer to host buffer.
145*e7e3c45bSAndrea Faulds   LogicalResult copyResource(bool deviceToHost);
146*e7e3c45bSAndrea Faulds 
147*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
148*e7e3c45bSAndrea Faulds   // Helper methods.
149*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
150*e7e3c45bSAndrea Faulds 
151*e7e3c45bSAndrea Faulds   /// Maps storage class to a descriptor type.
152*e7e3c45bSAndrea Faulds   LogicalResult
153*e7e3c45bSAndrea Faulds   mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,
154*e7e3c45bSAndrea Faulds                                   VkDescriptorType &descriptorType);
155*e7e3c45bSAndrea Faulds 
156*e7e3c45bSAndrea Faulds   /// Maps storage class to buffer usage flags.
157*e7e3c45bSAndrea Faulds   LogicalResult
158*e7e3c45bSAndrea Faulds   mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,
159*e7e3c45bSAndrea Faulds                                    VkBufferUsageFlagBits &bufferUsage);
160*e7e3c45bSAndrea Faulds 
161*e7e3c45bSAndrea Faulds   LogicalResult countDeviceMemorySize();
162*e7e3c45bSAndrea Faulds 
163*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
164*e7e3c45bSAndrea Faulds   // Vulkan objects.
165*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
166*e7e3c45bSAndrea Faulds 
167*e7e3c45bSAndrea Faulds   VkInstance instance{VK_NULL_HANDLE};
168*e7e3c45bSAndrea Faulds   VkPhysicalDevice physicalDevice{VK_NULL_HANDLE};
169*e7e3c45bSAndrea Faulds   VkDevice device{VK_NULL_HANDLE};
170*e7e3c45bSAndrea Faulds   VkQueue queue{VK_NULL_HANDLE};
171*e7e3c45bSAndrea Faulds 
172*e7e3c45bSAndrea Faulds   /// Specifies VulkanDeviceMemoryBuffers divided into sets.
173*e7e3c45bSAndrea Faulds   std::unordered_map<DescriptorSetIndex, std::vector<VulkanDeviceMemoryBuffer>>
174*e7e3c45bSAndrea Faulds       deviceMemoryBufferMap;
175*e7e3c45bSAndrea Faulds 
176*e7e3c45bSAndrea Faulds   /// Specifies shader module.
177*e7e3c45bSAndrea Faulds   VkShaderModule shaderModule{VK_NULL_HANDLE};
178*e7e3c45bSAndrea Faulds 
179*e7e3c45bSAndrea Faulds   /// Specifies layout bindings.
180*e7e3c45bSAndrea Faulds   std::unordered_map<DescriptorSetIndex,
181*e7e3c45bSAndrea Faulds                      std::vector<VkDescriptorSetLayoutBinding>>
182*e7e3c45bSAndrea Faulds       descriptorSetLayoutBindingMap;
183*e7e3c45bSAndrea Faulds 
184*e7e3c45bSAndrea Faulds   /// Specifies layouts of descriptor sets.
185*e7e3c45bSAndrea Faulds   std::vector<VkDescriptorSetLayout> descriptorSetLayouts;
186*e7e3c45bSAndrea Faulds   VkPipelineLayout pipelineLayout{VK_NULL_HANDLE};
187*e7e3c45bSAndrea Faulds 
188*e7e3c45bSAndrea Faulds   /// Specifies descriptor sets.
189*e7e3c45bSAndrea Faulds   std::vector<VkDescriptorSet> descriptorSets;
190*e7e3c45bSAndrea Faulds 
191*e7e3c45bSAndrea Faulds   /// Specifies a pool of descriptor set info, each descriptor set must have
192*e7e3c45bSAndrea Faulds   /// information such as type, index and amount of bindings.
193*e7e3c45bSAndrea Faulds   std::vector<DescriptorSetInfo> descriptorSetInfoPool;
194*e7e3c45bSAndrea Faulds   VkDescriptorPool descriptorPool{VK_NULL_HANDLE};
195*e7e3c45bSAndrea Faulds 
196*e7e3c45bSAndrea Faulds   /// Timestamp query.
197*e7e3c45bSAndrea Faulds   VkQueryPool queryPool{VK_NULL_HANDLE};
198*e7e3c45bSAndrea Faulds   // Number of nonoseconds for timestamp to increase 1
199*e7e3c45bSAndrea Faulds   float timestampPeriod{0.f};
200*e7e3c45bSAndrea Faulds 
201*e7e3c45bSAndrea Faulds   /// Computation pipeline.
202*e7e3c45bSAndrea Faulds   VkPipeline pipeline{VK_NULL_HANDLE};
203*e7e3c45bSAndrea Faulds   VkCommandPool commandPool{VK_NULL_HANDLE};
204*e7e3c45bSAndrea Faulds   std::vector<VkCommandBuffer> commandBuffers;
205*e7e3c45bSAndrea Faulds 
206*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
207*e7e3c45bSAndrea Faulds   // Vulkan memory context.
208*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
209*e7e3c45bSAndrea Faulds 
210*e7e3c45bSAndrea Faulds   uint32_t queueFamilyIndex{0};
211*e7e3c45bSAndrea Faulds   VkQueueFamilyProperties queueFamilyProperties{};
212*e7e3c45bSAndrea Faulds   uint32_t hostMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
213*e7e3c45bSAndrea Faulds   uint32_t deviceMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
214*e7e3c45bSAndrea Faulds   VkDeviceSize memorySize{0};
215*e7e3c45bSAndrea Faulds 
216*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
217*e7e3c45bSAndrea Faulds   // Vulkan execution context.
218*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
219*e7e3c45bSAndrea Faulds 
220*e7e3c45bSAndrea Faulds   NumWorkGroups numWorkGroups;
221*e7e3c45bSAndrea Faulds   const char *entryPoint{nullptr};
222*e7e3c45bSAndrea Faulds   uint8_t *binary{nullptr};
223*e7e3c45bSAndrea Faulds   uint32_t binarySize{0};
224*e7e3c45bSAndrea Faulds 
225*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
226*e7e3c45bSAndrea Faulds   // Vulkan resource data and storage classes.
227*e7e3c45bSAndrea Faulds   //===--------------------------------------------------------------------===//
228*e7e3c45bSAndrea Faulds 
229*e7e3c45bSAndrea Faulds   ResourceData resourceData;
230*e7e3c45bSAndrea Faulds   ResourceStorageClassBindingMap resourceStorageClassData;
231*e7e3c45bSAndrea Faulds };
232*e7e3c45bSAndrea Faulds #endif
233