xref: /llvm-project/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1e7e3c45bSAndrea Faulds //===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===//
2e7e3c45bSAndrea Faulds //
3e7e3c45bSAndrea Faulds // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4e7e3c45bSAndrea Faulds // See https://llvm.org/LICENSE.txt for license information.
5e7e3c45bSAndrea Faulds // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6e7e3c45bSAndrea Faulds //
7e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
8e7e3c45bSAndrea Faulds //
9e7e3c45bSAndrea Faulds // Implements C runtime wrappers around the VulkanRuntime.
10e7e3c45bSAndrea Faulds //
11e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
12e7e3c45bSAndrea Faulds 
13e7e3c45bSAndrea Faulds #include <iostream>
14e7e3c45bSAndrea Faulds #include <mutex>
15e7e3c45bSAndrea Faulds #include <numeric>
16e7e3c45bSAndrea Faulds #include <string>
17e7e3c45bSAndrea Faulds #include <vector>
18e7e3c45bSAndrea Faulds 
19e7e3c45bSAndrea Faulds #include "VulkanRuntime.h"
20e7e3c45bSAndrea Faulds 
21e7e3c45bSAndrea Faulds // Explicitly export entry points to the vulkan-runtime-wrapper.
22e7e3c45bSAndrea Faulds 
23e7e3c45bSAndrea Faulds #ifdef _WIN32
24e7e3c45bSAndrea Faulds #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
25e7e3c45bSAndrea Faulds #else
26e7e3c45bSAndrea Faulds #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
27e7e3c45bSAndrea Faulds #endif // _WIN32
28e7e3c45bSAndrea Faulds 
29e7e3c45bSAndrea Faulds namespace {
30e7e3c45bSAndrea Faulds 
31e7e3c45bSAndrea Faulds class VulkanModule;
32e7e3c45bSAndrea Faulds 
33e7e3c45bSAndrea Faulds // Class to be a thing that can be returned from `mgpuModuleGetFunction`.
34e7e3c45bSAndrea Faulds struct VulkanFunction {
35e7e3c45bSAndrea Faulds   VulkanModule *module;
36e7e3c45bSAndrea Faulds   std::string name;
37e7e3c45bSAndrea Faulds 
38e7e3c45bSAndrea Faulds   VulkanFunction(VulkanModule *module, const char *name)
39e7e3c45bSAndrea Faulds       : module(module), name(name) {}
40e7e3c45bSAndrea Faulds };
41e7e3c45bSAndrea Faulds 
42e7e3c45bSAndrea Faulds // Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage
43e7e3c45bSAndrea Faulds // allocation of pointers returned from `mgpuModuleGetFunction`.
44e7e3c45bSAndrea Faulds class VulkanModule {
45e7e3c45bSAndrea Faulds public:
46e7e3c45bSAndrea Faulds   VulkanModule(const uint8_t *ptr, size_t sizeInBytes)
47e7e3c45bSAndrea Faulds       : blob(ptr, ptr + sizeInBytes) {}
48e7e3c45bSAndrea Faulds   ~VulkanModule() = default;
49e7e3c45bSAndrea Faulds 
50e7e3c45bSAndrea Faulds   VulkanFunction *getFunction(const char *name) {
51e7e3c45bSAndrea Faulds     return functions.emplace_back(std::make_unique<VulkanFunction>(this, name))
52e7e3c45bSAndrea Faulds         .get();
53e7e3c45bSAndrea Faulds   }
54e7e3c45bSAndrea Faulds 
55e7e3c45bSAndrea Faulds   uint8_t *blobData() { return blob.data(); }
56e7e3c45bSAndrea Faulds   size_t blobSizeInBytes() const { return blob.size(); }
57e7e3c45bSAndrea Faulds 
58e7e3c45bSAndrea Faulds private:
59e7e3c45bSAndrea Faulds   std::vector<uint8_t> blob;
60e7e3c45bSAndrea Faulds   std::vector<std::unique_ptr<VulkanFunction>> functions;
61e7e3c45bSAndrea Faulds };
62e7e3c45bSAndrea Faulds 
63e7e3c45bSAndrea Faulds class VulkanRuntimeManager {
64e7e3c45bSAndrea Faulds public:
65e7e3c45bSAndrea Faulds   VulkanRuntimeManager() = default;
66e7e3c45bSAndrea Faulds   VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
67e7e3c45bSAndrea Faulds   VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
68e7e3c45bSAndrea Faulds   ~VulkanRuntimeManager() = default;
69e7e3c45bSAndrea Faulds 
70e7e3c45bSAndrea Faulds   void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
71e7e3c45bSAndrea Faulds                        const VulkanHostMemoryBuffer &memBuffer) {
72e7e3c45bSAndrea Faulds     std::lock_guard<std::mutex> lock(mutex);
73e7e3c45bSAndrea Faulds     vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
74e7e3c45bSAndrea Faulds   }
75e7e3c45bSAndrea Faulds 
76e7e3c45bSAndrea Faulds   void setEntryPoint(const char *entryPoint) {
77e7e3c45bSAndrea Faulds     std::lock_guard<std::mutex> lock(mutex);
78e7e3c45bSAndrea Faulds     vulkanRuntime.setEntryPoint(entryPoint);
79e7e3c45bSAndrea Faulds   }
80e7e3c45bSAndrea Faulds 
81e7e3c45bSAndrea Faulds   void setNumWorkGroups(NumWorkGroups numWorkGroups) {
82e7e3c45bSAndrea Faulds     std::lock_guard<std::mutex> lock(mutex);
83e7e3c45bSAndrea Faulds     vulkanRuntime.setNumWorkGroups(numWorkGroups);
84e7e3c45bSAndrea Faulds   }
85e7e3c45bSAndrea Faulds 
86e7e3c45bSAndrea Faulds   void setShaderModule(uint8_t *shader, uint32_t size) {
87e7e3c45bSAndrea Faulds     std::lock_guard<std::mutex> lock(mutex);
88e7e3c45bSAndrea Faulds     vulkanRuntime.setShaderModule(shader, size);
89e7e3c45bSAndrea Faulds   }
90e7e3c45bSAndrea Faulds 
91e7e3c45bSAndrea Faulds   void runOnVulkan() {
92e7e3c45bSAndrea Faulds     std::lock_guard<std::mutex> lock(mutex);
93e7e3c45bSAndrea Faulds     if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
94e7e3c45bSAndrea Faulds         failed(vulkanRuntime.updateHostMemoryBuffers()) ||
95e7e3c45bSAndrea Faulds         failed(vulkanRuntime.destroy())) {
96e7e3c45bSAndrea Faulds       std::cerr << "runOnVulkan failed";
97e7e3c45bSAndrea Faulds     }
98e7e3c45bSAndrea Faulds   }
99e7e3c45bSAndrea Faulds 
100e7e3c45bSAndrea Faulds private:
101e7e3c45bSAndrea Faulds   VulkanRuntime vulkanRuntime;
102e7e3c45bSAndrea Faulds   std::mutex mutex;
103e7e3c45bSAndrea Faulds };
104e7e3c45bSAndrea Faulds 
105e7e3c45bSAndrea Faulds } // namespace
106e7e3c45bSAndrea Faulds 
107e7e3c45bSAndrea Faulds template <typename T, int N>
108e7e3c45bSAndrea Faulds struct MemRefDescriptor {
109e7e3c45bSAndrea Faulds   T *allocated;
110e7e3c45bSAndrea Faulds   T *aligned;
111e7e3c45bSAndrea Faulds   int64_t offset;
112e7e3c45bSAndrea Faulds   int64_t sizes[N];
113e7e3c45bSAndrea Faulds   int64_t strides[N];
114e7e3c45bSAndrea Faulds };
115e7e3c45bSAndrea Faulds 
116e7e3c45bSAndrea Faulds extern "C" {
117e7e3c45bSAndrea Faulds 
118e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
119e7e3c45bSAndrea Faulds //
120*eb206e9eSAndrea Faulds // Wrappers intended for mlir-runner. Uses of GPU dialect operations get
121e7e3c45bSAndrea Faulds // lowered to calls to these functions by GPUToLLVMConversionPass.
122e7e3c45bSAndrea Faulds //
123e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
124e7e3c45bSAndrea Faulds 
125e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() {
126e7e3c45bSAndrea Faulds   return new VulkanRuntimeManager();
127e7e3c45bSAndrea Faulds }
128e7e3c45bSAndrea Faulds 
129e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) {
130e7e3c45bSAndrea Faulds   delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
131e7e3c45bSAndrea Faulds }
132e7e3c45bSAndrea Faulds 
133e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) {
134e7e3c45bSAndrea Faulds   // Currently a no-op as the other operations are synchronous.
135e7e3c45bSAndrea Faulds }
136e7e3c45bSAndrea Faulds 
137e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data,
138e7e3c45bSAndrea Faulds                                                   size_t gpuBlobSize) {
139e7e3c45bSAndrea Faulds   // gpuBlobSize is the size of the data in bytes.
140e7e3c45bSAndrea Faulds   return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize);
141e7e3c45bSAndrea Faulds }
142e7e3c45bSAndrea Faulds 
143e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) {
144e7e3c45bSAndrea Faulds   delete static_cast<VulkanModule *>(vkModule);
145e7e3c45bSAndrea Faulds }
146e7e3c45bSAndrea Faulds 
147e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule,
148e7e3c45bSAndrea Faulds                                                          const char *name) {
149e7e3c45bSAndrea Faulds   if (!vkModule)
150e7e3c45bSAndrea Faulds     abort();
151e7e3c45bSAndrea Faulds   return static_cast<VulkanModule *>(vkModule)->getFunction(name);
152e7e3c45bSAndrea Faulds }
153e7e3c45bSAndrea Faulds 
154e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
155e7e3c45bSAndrea Faulds mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
156e7e3c45bSAndrea Faulds                  size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
157e7e3c45bSAndrea Faulds                  size_t /*smem*/, void *vkRuntimeManager, void **params,
158e7e3c45bSAndrea Faulds                  void ** /*extra*/, size_t paramsCount) {
159e7e3c45bSAndrea Faulds   auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
160e7e3c45bSAndrea Faulds 
161e7e3c45bSAndrea Faulds   // GpuToLLVMConversionPass with the kernelBarePtrCallConv and
162e7e3c45bSAndrea Faulds   // kernelIntersperseSizeCallConv options will set up the params array like:
163e7e3c45bSAndrea Faulds   // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
164e7e3c45bSAndrea Faulds   const size_t paramsPerMemRef = 2;
165e7e3c45bSAndrea Faulds   if (paramsCount % paramsPerMemRef != 0) {
166e7e3c45bSAndrea Faulds     abort(); // This would indicate a serious calling convention mismatch.
167e7e3c45bSAndrea Faulds   }
168e7e3c45bSAndrea Faulds   const DescriptorSetIndex setIndex = 0;
169e7e3c45bSAndrea Faulds   BindingIndex bindIndex = 0;
170e7e3c45bSAndrea Faulds   for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
171e7e3c45bSAndrea Faulds     void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
172e7e3c45bSAndrea Faulds     size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
173e7e3c45bSAndrea Faulds     VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
174e7e3c45bSAndrea Faulds                                      static_cast<uint32_t>(memrefBufferSize)};
175e7e3c45bSAndrea Faulds     manager->setResourceData(setIndex, bindIndex, memBuffer);
176e7e3c45bSAndrea Faulds     ++bindIndex;
177e7e3c45bSAndrea Faulds   }
178e7e3c45bSAndrea Faulds 
179e7e3c45bSAndrea Faulds   manager->setNumWorkGroups(NumWorkGroups{static_cast<uint32_t>(gridX),
180e7e3c45bSAndrea Faulds                                           static_cast<uint32_t>(gridY),
181e7e3c45bSAndrea Faulds                                           static_cast<uint32_t>(gridZ)});
182e7e3c45bSAndrea Faulds 
183e7e3c45bSAndrea Faulds   auto function = static_cast<VulkanFunction *>(vkKernel);
184e7e3c45bSAndrea Faulds   // Expected size should be in bytes.
185e7e3c45bSAndrea Faulds   manager->setShaderModule(
186e7e3c45bSAndrea Faulds       function->module->blobData(),
187e7e3c45bSAndrea Faulds       static_cast<uint32_t>(function->module->blobSizeInBytes()));
188e7e3c45bSAndrea Faulds   manager->setEntryPoint(function->name.c_str());
189e7e3c45bSAndrea Faulds 
190e7e3c45bSAndrea Faulds   manager->runOnVulkan();
191e7e3c45bSAndrea Faulds }
192e7e3c45bSAndrea Faulds 
193e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
194e7e3c45bSAndrea Faulds //
195e7e3c45bSAndrea Faulds // Miscellaneous utility functions that can be directly used by tests.
196e7e3c45bSAndrea Faulds //
197e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===//
198e7e3c45bSAndrea Faulds 
199e7e3c45bSAndrea Faulds /// Fills the given 1D float memref with the given float value.
200e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
201e7e3c45bSAndrea Faulds _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
202e7e3c45bSAndrea Faulds                                  float value) {
203e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0], value);
204e7e3c45bSAndrea Faulds }
205e7e3c45bSAndrea Faulds 
206e7e3c45bSAndrea Faulds /// Fills the given 2D float memref with the given float value.
207e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
208e7e3c45bSAndrea Faulds _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
209e7e3c45bSAndrea Faulds                                  float value) {
210e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
211e7e3c45bSAndrea Faulds }
212e7e3c45bSAndrea Faulds 
213e7e3c45bSAndrea Faulds /// Fills the given 3D float memref with the given float value.
214e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
215e7e3c45bSAndrea Faulds _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
216e7e3c45bSAndrea Faulds                                  float value) {
217e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
218e7e3c45bSAndrea Faulds               value);
219e7e3c45bSAndrea Faulds }
220e7e3c45bSAndrea Faulds 
221e7e3c45bSAndrea Faulds /// Fills the given 1D int memref with the given int value.
222e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
223e7e3c45bSAndrea Faulds _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
224e7e3c45bSAndrea Faulds                                int32_t value) {
225e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0], value);
226e7e3c45bSAndrea Faulds }
227e7e3c45bSAndrea Faulds 
228e7e3c45bSAndrea Faulds /// Fills the given 2D int memref with the given int value.
229e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
230e7e3c45bSAndrea Faulds _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
231e7e3c45bSAndrea Faulds                                int32_t value) {
232e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
233e7e3c45bSAndrea Faulds }
234e7e3c45bSAndrea Faulds 
235e7e3c45bSAndrea Faulds /// Fills the given 3D int memref with the given int value.
236e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
237e7e3c45bSAndrea Faulds _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
238e7e3c45bSAndrea Faulds                                int32_t value) {
239e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
240e7e3c45bSAndrea Faulds               value);
241e7e3c45bSAndrea Faulds }
242e7e3c45bSAndrea Faulds 
243e7e3c45bSAndrea Faulds /// Fills the given 1D int memref with the given int8 value.
244e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
245e7e3c45bSAndrea Faulds _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
246e7e3c45bSAndrea Faulds                                 int8_t value) {
247e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0], value);
248e7e3c45bSAndrea Faulds }
249e7e3c45bSAndrea Faulds 
250e7e3c45bSAndrea Faulds /// Fills the given 2D int memref with the given int8 value.
251e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
252e7e3c45bSAndrea Faulds _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
253e7e3c45bSAndrea Faulds                                 int8_t value) {
254e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
255e7e3c45bSAndrea Faulds }
256e7e3c45bSAndrea Faulds 
257e7e3c45bSAndrea Faulds /// Fills the given 3D int memref with the given int8 value.
258e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void
259e7e3c45bSAndrea Faulds _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
260e7e3c45bSAndrea Faulds                                 int8_t value) {
261e7e3c45bSAndrea Faulds   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
262e7e3c45bSAndrea Faulds               value);
263e7e3c45bSAndrea Faulds }
264e7e3c45bSAndrea Faulds }
265