xref: /llvm-project/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1 //===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements C runtime wrappers around the VulkanRuntime.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <iostream>
14 #include <mutex>
15 #include <numeric>
16 #include <string>
17 #include <vector>
18 
19 #include "VulkanRuntime.h"
20 
21 // Explicitly export entry points to the vulkan-runtime-wrapper.
22 
23 #ifdef _WIN32
24 #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
25 #else
26 #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
27 #endif // _WIN32
28 
29 namespace {
30 
31 class VulkanModule;
32 
33 // Class to be a thing that can be returned from `mgpuModuleGetFunction`.
34 struct VulkanFunction {
35   VulkanModule *module;
36   std::string name;
37 
38   VulkanFunction(VulkanModule *module, const char *name)
39       : module(module), name(name) {}
40 };
41 
42 // Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage
43 // allocation of pointers returned from `mgpuModuleGetFunction`.
44 class VulkanModule {
45 public:
46   VulkanModule(const uint8_t *ptr, size_t sizeInBytes)
47       : blob(ptr, ptr + sizeInBytes) {}
48   ~VulkanModule() = default;
49 
50   VulkanFunction *getFunction(const char *name) {
51     return functions.emplace_back(std::make_unique<VulkanFunction>(this, name))
52         .get();
53   }
54 
55   uint8_t *blobData() { return blob.data(); }
56   size_t blobSizeInBytes() const { return blob.size(); }
57 
58 private:
59   std::vector<uint8_t> blob;
60   std::vector<std::unique_ptr<VulkanFunction>> functions;
61 };
62 
63 class VulkanRuntimeManager {
64 public:
65   VulkanRuntimeManager() = default;
66   VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
67   VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
68   ~VulkanRuntimeManager() = default;
69 
70   void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
71                        const VulkanHostMemoryBuffer &memBuffer) {
72     std::lock_guard<std::mutex> lock(mutex);
73     vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
74   }
75 
76   void setEntryPoint(const char *entryPoint) {
77     std::lock_guard<std::mutex> lock(mutex);
78     vulkanRuntime.setEntryPoint(entryPoint);
79   }
80 
81   void setNumWorkGroups(NumWorkGroups numWorkGroups) {
82     std::lock_guard<std::mutex> lock(mutex);
83     vulkanRuntime.setNumWorkGroups(numWorkGroups);
84   }
85 
86   void setShaderModule(uint8_t *shader, uint32_t size) {
87     std::lock_guard<std::mutex> lock(mutex);
88     vulkanRuntime.setShaderModule(shader, size);
89   }
90 
91   void runOnVulkan() {
92     std::lock_guard<std::mutex> lock(mutex);
93     if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
94         failed(vulkanRuntime.updateHostMemoryBuffers()) ||
95         failed(vulkanRuntime.destroy())) {
96       std::cerr << "runOnVulkan failed";
97     }
98   }
99 
100 private:
101   VulkanRuntime vulkanRuntime;
102   std::mutex mutex;
103 };
104 
105 } // namespace
106 
107 template <typename T, int N>
108 struct MemRefDescriptor {
109   T *allocated;
110   T *aligned;
111   int64_t offset;
112   int64_t sizes[N];
113   int64_t strides[N];
114 };
115 
116 extern "C" {
117 
118 //===----------------------------------------------------------------------===//
119 //
120 // Wrappers intended for mlir-runner. Uses of GPU dialect operations get
121 // lowered to calls to these functions by GPUToLLVMConversionPass.
122 //
123 //===----------------------------------------------------------------------===//
124 
125 VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() {
126   return new VulkanRuntimeManager();
127 }
128 
129 VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) {
130   delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
131 }
132 
133 VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) {
134   // Currently a no-op as the other operations are synchronous.
135 }
136 
137 VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data,
138                                                   size_t gpuBlobSize) {
139   // gpuBlobSize is the size of the data in bytes.
140   return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize);
141 }
142 
143 VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) {
144   delete static_cast<VulkanModule *>(vkModule);
145 }
146 
147 VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule,
148                                                          const char *name) {
149   if (!vkModule)
150     abort();
151   return static_cast<VulkanModule *>(vkModule)->getFunction(name);
152 }
153 
154 VULKAN_WRAPPER_SYMBOL_EXPORT void
155 mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
156                  size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
157                  size_t /*smem*/, void *vkRuntimeManager, void **params,
158                  void ** /*extra*/, size_t paramsCount) {
159   auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
160 
161   // GpuToLLVMConversionPass with the kernelBarePtrCallConv and
162   // kernelIntersperseSizeCallConv options will set up the params array like:
163   // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
164   const size_t paramsPerMemRef = 2;
165   if (paramsCount % paramsPerMemRef != 0) {
166     abort(); // This would indicate a serious calling convention mismatch.
167   }
168   const DescriptorSetIndex setIndex = 0;
169   BindingIndex bindIndex = 0;
170   for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
171     void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
172     size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
173     VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
174                                      static_cast<uint32_t>(memrefBufferSize)};
175     manager->setResourceData(setIndex, bindIndex, memBuffer);
176     ++bindIndex;
177   }
178 
179   manager->setNumWorkGroups(NumWorkGroups{static_cast<uint32_t>(gridX),
180                                           static_cast<uint32_t>(gridY),
181                                           static_cast<uint32_t>(gridZ)});
182 
183   auto function = static_cast<VulkanFunction *>(vkKernel);
184   // Expected size should be in bytes.
185   manager->setShaderModule(
186       function->module->blobData(),
187       static_cast<uint32_t>(function->module->blobSizeInBytes()));
188   manager->setEntryPoint(function->name.c_str());
189 
190   manager->runOnVulkan();
191 }
192 
193 //===----------------------------------------------------------------------===//
194 //
195 // Miscellaneous utility functions that can be directly used by tests.
196 //
197 //===----------------------------------------------------------------------===//
198 
199 /// Fills the given 1D float memref with the given float value.
200 VULKAN_WRAPPER_SYMBOL_EXPORT void
201 _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
202                                  float value) {
203   std::fill_n(ptr->allocated, ptr->sizes[0], value);
204 }
205 
206 /// Fills the given 2D float memref with the given float value.
207 VULKAN_WRAPPER_SYMBOL_EXPORT void
208 _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
209                                  float value) {
210   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
211 }
212 
213 /// Fills the given 3D float memref with the given float value.
214 VULKAN_WRAPPER_SYMBOL_EXPORT void
215 _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
216                                  float value) {
217   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
218               value);
219 }
220 
221 /// Fills the given 1D int memref with the given int value.
222 VULKAN_WRAPPER_SYMBOL_EXPORT void
223 _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
224                                int32_t value) {
225   std::fill_n(ptr->allocated, ptr->sizes[0], value);
226 }
227 
228 /// Fills the given 2D int memref with the given int value.
229 VULKAN_WRAPPER_SYMBOL_EXPORT void
230 _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
231                                int32_t value) {
232   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
233 }
234 
235 /// Fills the given 3D int memref with the given int value.
236 VULKAN_WRAPPER_SYMBOL_EXPORT void
237 _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
238                                int32_t value) {
239   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
240               value);
241 }
242 
243 /// Fills the given 1D int memref with the given int8 value.
244 VULKAN_WRAPPER_SYMBOL_EXPORT void
245 _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
246                                 int8_t value) {
247   std::fill_n(ptr->allocated, ptr->sizes[0], value);
248 }
249 
250 /// Fills the given 2D int memref with the given int8 value.
251 VULKAN_WRAPPER_SYMBOL_EXPORT void
252 _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
253                                 int8_t value) {
254   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
255 }
256 
257 /// Fills the given 3D int memref with the given int8 value.
258 VULKAN_WRAPPER_SYMBOL_EXPORT void
259 _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
260                                 int8_t value) {
261   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
262               value);
263 }
264 }
265