1 //===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements C runtime wrappers around the VulkanRuntime. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <iostream> 14 #include <mutex> 15 #include <numeric> 16 #include <string> 17 #include <vector> 18 19 #include "VulkanRuntime.h" 20 21 // Explicitly export entry points to the vulkan-runtime-wrapper. 22 23 #ifdef _WIN32 24 #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport) 25 #else 26 #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default"))) 27 #endif // _WIN32 28 29 namespace { 30 31 class VulkanModule; 32 33 // Class to be a thing that can be returned from `mgpuModuleGetFunction`. 34 struct VulkanFunction { 35 VulkanModule *module; 36 std::string name; 37 38 VulkanFunction(VulkanModule *module, const char *name) 39 : module(module), name(name) {} 40 }; 41 42 // Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage 43 // allocation of pointers returned from `mgpuModuleGetFunction`. 44 class VulkanModule { 45 public: 46 VulkanModule(const uint8_t *ptr, size_t sizeInBytes) 47 : blob(ptr, ptr + sizeInBytes) {} 48 ~VulkanModule() = default; 49 50 VulkanFunction *getFunction(const char *name) { 51 return functions.emplace_back(std::make_unique<VulkanFunction>(this, name)) 52 .get(); 53 } 54 55 uint8_t *blobData() { return blob.data(); } 56 size_t blobSizeInBytes() const { return blob.size(); } 57 58 private: 59 std::vector<uint8_t> blob; 60 std::vector<std::unique_ptr<VulkanFunction>> functions; 61 }; 62 63 class VulkanRuntimeManager { 64 public: 65 VulkanRuntimeManager() = default; 66 VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; 67 VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; 68 ~VulkanRuntimeManager() = default; 69 70 void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, 71 const VulkanHostMemoryBuffer &memBuffer) { 72 std::lock_guard<std::mutex> lock(mutex); 73 vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer); 74 } 75 76 void setEntryPoint(const char *entryPoint) { 77 std::lock_guard<std::mutex> lock(mutex); 78 vulkanRuntime.setEntryPoint(entryPoint); 79 } 80 81 void setNumWorkGroups(NumWorkGroups numWorkGroups) { 82 std::lock_guard<std::mutex> lock(mutex); 83 vulkanRuntime.setNumWorkGroups(numWorkGroups); 84 } 85 86 void setShaderModule(uint8_t *shader, uint32_t size) { 87 std::lock_guard<std::mutex> lock(mutex); 88 vulkanRuntime.setShaderModule(shader, size); 89 } 90 91 void runOnVulkan() { 92 std::lock_guard<std::mutex> lock(mutex); 93 if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) || 94 failed(vulkanRuntime.updateHostMemoryBuffers()) || 95 failed(vulkanRuntime.destroy())) { 96 std::cerr << "runOnVulkan failed"; 97 } 98 } 99 100 private: 101 VulkanRuntime vulkanRuntime; 102 std::mutex mutex; 103 }; 104 105 } // namespace 106 107 template <typename T, int N> 108 struct MemRefDescriptor { 109 T *allocated; 110 T *aligned; 111 int64_t offset; 112 int64_t sizes[N]; 113 int64_t strides[N]; 114 }; 115 116 extern "C" { 117 118 //===----------------------------------------------------------------------===// 119 // 120 // Wrappers intended for mlir-runner. Uses of GPU dialect operations get 121 // lowered to calls to these functions by GPUToLLVMConversionPass. 122 // 123 //===----------------------------------------------------------------------===// 124 125 VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() { 126 return new VulkanRuntimeManager(); 127 } 128 129 VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) { 130 delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager); 131 } 132 133 VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) { 134 // Currently a no-op as the other operations are synchronous. 135 } 136 137 VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data, 138 size_t gpuBlobSize) { 139 // gpuBlobSize is the size of the data in bytes. 140 return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize); 141 } 142 143 VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) { 144 delete static_cast<VulkanModule *>(vkModule); 145 } 146 147 VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule, 148 const char *name) { 149 if (!vkModule) 150 abort(); 151 return static_cast<VulkanModule *>(vkModule)->getFunction(name); 152 } 153 154 VULKAN_WRAPPER_SYMBOL_EXPORT void 155 mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, 156 size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/, 157 size_t /*smem*/, void *vkRuntimeManager, void **params, 158 void ** /*extra*/, size_t paramsCount) { 159 auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); 160 161 // GpuToLLVMConversionPass with the kernelBarePtrCallConv and 162 // kernelIntersperseSizeCallConv options will set up the params array like: 163 // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... } 164 const size_t paramsPerMemRef = 2; 165 if (paramsCount % paramsPerMemRef != 0) { 166 abort(); // This would indicate a serious calling convention mismatch. 167 } 168 const DescriptorSetIndex setIndex = 0; 169 BindingIndex bindIndex = 0; 170 for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) { 171 void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]); 172 size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]); 173 VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr, 174 static_cast<uint32_t>(memrefBufferSize)}; 175 manager->setResourceData(setIndex, bindIndex, memBuffer); 176 ++bindIndex; 177 } 178 179 manager->setNumWorkGroups(NumWorkGroups{static_cast<uint32_t>(gridX), 180 static_cast<uint32_t>(gridY), 181 static_cast<uint32_t>(gridZ)}); 182 183 auto function = static_cast<VulkanFunction *>(vkKernel); 184 // Expected size should be in bytes. 185 manager->setShaderModule( 186 function->module->blobData(), 187 static_cast<uint32_t>(function->module->blobSizeInBytes())); 188 manager->setEntryPoint(function->name.c_str()); 189 190 manager->runOnVulkan(); 191 } 192 193 //===----------------------------------------------------------------------===// 194 // 195 // Miscellaneous utility functions that can be directly used by tests. 196 // 197 //===----------------------------------------------------------------------===// 198 199 /// Fills the given 1D float memref with the given float value. 200 VULKAN_WRAPPER_SYMBOL_EXPORT void 201 _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT 202 float value) { 203 std::fill_n(ptr->allocated, ptr->sizes[0], value); 204 } 205 206 /// Fills the given 2D float memref with the given float value. 207 VULKAN_WRAPPER_SYMBOL_EXPORT void 208 _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT 209 float value) { 210 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 211 } 212 213 /// Fills the given 3D float memref with the given float value. 214 VULKAN_WRAPPER_SYMBOL_EXPORT void 215 _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT 216 float value) { 217 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 218 value); 219 } 220 221 /// Fills the given 1D int memref with the given int value. 222 VULKAN_WRAPPER_SYMBOL_EXPORT void 223 _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT 224 int32_t value) { 225 std::fill_n(ptr->allocated, ptr->sizes[0], value); 226 } 227 228 /// Fills the given 2D int memref with the given int value. 229 VULKAN_WRAPPER_SYMBOL_EXPORT void 230 _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT 231 int32_t value) { 232 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 233 } 234 235 /// Fills the given 3D int memref with the given int value. 236 VULKAN_WRAPPER_SYMBOL_EXPORT void 237 _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT 238 int32_t value) { 239 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 240 value); 241 } 242 243 /// Fills the given 1D int memref with the given int8 value. 244 VULKAN_WRAPPER_SYMBOL_EXPORT void 245 _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT 246 int8_t value) { 247 std::fill_n(ptr->allocated, ptr->sizes[0], value); 248 } 249 250 /// Fills the given 2D int memref with the given int8 value. 251 VULKAN_WRAPPER_SYMBOL_EXPORT void 252 _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT 253 int8_t value) { 254 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 255 } 256 257 /// Fills the given 3D int memref with the given int8 value. 258 VULKAN_WRAPPER_SYMBOL_EXPORT void 259 _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT 260 int8_t value) { 261 std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 262 value); 263 } 264 } 265