1e7e3c45bSAndrea Faulds //===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===// 2e7e3c45bSAndrea Faulds // 3e7e3c45bSAndrea Faulds // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4e7e3c45bSAndrea Faulds // See https://llvm.org/LICENSE.txt for license information. 5e7e3c45bSAndrea Faulds // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6e7e3c45bSAndrea Faulds // 7e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===// 8e7e3c45bSAndrea Faulds // 9e7e3c45bSAndrea Faulds // Implements C runtime wrappers around the VulkanRuntime. 10e7e3c45bSAndrea Faulds // 11e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===// 12e7e3c45bSAndrea Faulds 13e7e3c45bSAndrea Faulds #include <iostream> 14e7e3c45bSAndrea Faulds #include <mutex> 15e7e3c45bSAndrea Faulds #include <numeric> 16e7e3c45bSAndrea Faulds #include <string> 17e7e3c45bSAndrea Faulds #include <vector> 18e7e3c45bSAndrea Faulds 19e7e3c45bSAndrea Faulds #include "VulkanRuntime.h" 20e7e3c45bSAndrea Faulds 21e7e3c45bSAndrea Faulds // Explicitly export entry points to the vulkan-runtime-wrapper. 22e7e3c45bSAndrea Faulds 23e7e3c45bSAndrea Faulds #ifdef _WIN32 24e7e3c45bSAndrea Faulds #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport) 25e7e3c45bSAndrea Faulds #else 26e7e3c45bSAndrea Faulds #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default"))) 27e7e3c45bSAndrea Faulds #endif // _WIN32 28e7e3c45bSAndrea Faulds 29e7e3c45bSAndrea Faulds namespace { 30e7e3c45bSAndrea Faulds 31e7e3c45bSAndrea Faulds class VulkanModule; 32e7e3c45bSAndrea Faulds 33e7e3c45bSAndrea Faulds // Class to be a thing that can be returned from `mgpuModuleGetFunction`. 34e7e3c45bSAndrea Faulds struct VulkanFunction { 35e7e3c45bSAndrea Faulds VulkanModule *module; 36e7e3c45bSAndrea Faulds std::string name; 37e7e3c45bSAndrea Faulds 38e7e3c45bSAndrea Faulds VulkanFunction(VulkanModule *module, const char *name) 39e7e3c45bSAndrea Faulds : module(module), name(name) {} 40e7e3c45bSAndrea Faulds }; 41e7e3c45bSAndrea Faulds 42e7e3c45bSAndrea Faulds // Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage 43e7e3c45bSAndrea Faulds // allocation of pointers returned from `mgpuModuleGetFunction`. 44e7e3c45bSAndrea Faulds class VulkanModule { 45e7e3c45bSAndrea Faulds public: 46e7e3c45bSAndrea Faulds VulkanModule(const uint8_t *ptr, size_t sizeInBytes) 47e7e3c45bSAndrea Faulds : blob(ptr, ptr + sizeInBytes) {} 48e7e3c45bSAndrea Faulds ~VulkanModule() = default; 49e7e3c45bSAndrea Faulds 50e7e3c45bSAndrea Faulds VulkanFunction *getFunction(const char *name) { 51e7e3c45bSAndrea Faulds return functions.emplace_back(std::make_unique<VulkanFunction>(this, name)) 52e7e3c45bSAndrea Faulds .get(); 53e7e3c45bSAndrea Faulds } 54e7e3c45bSAndrea Faulds 55e7e3c45bSAndrea Faulds uint8_t *blobData() { return blob.data(); } 56e7e3c45bSAndrea Faulds size_t blobSizeInBytes() const { return blob.size(); } 57e7e3c45bSAndrea Faulds 58e7e3c45bSAndrea Faulds private: 59e7e3c45bSAndrea Faulds std::vector<uint8_t> blob; 60e7e3c45bSAndrea Faulds std::vector<std::unique_ptr<VulkanFunction>> functions; 61e7e3c45bSAndrea Faulds }; 62e7e3c45bSAndrea Faulds 63e7e3c45bSAndrea Faulds class VulkanRuntimeManager { 64e7e3c45bSAndrea Faulds public: 65e7e3c45bSAndrea Faulds VulkanRuntimeManager() = default; 66e7e3c45bSAndrea Faulds VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; 67e7e3c45bSAndrea Faulds VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; 68e7e3c45bSAndrea Faulds ~VulkanRuntimeManager() = default; 69e7e3c45bSAndrea Faulds 70e7e3c45bSAndrea Faulds void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, 71e7e3c45bSAndrea Faulds const VulkanHostMemoryBuffer &memBuffer) { 72e7e3c45bSAndrea Faulds std::lock_guard<std::mutex> lock(mutex); 73e7e3c45bSAndrea Faulds vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer); 74e7e3c45bSAndrea Faulds } 75e7e3c45bSAndrea Faulds 76e7e3c45bSAndrea Faulds void setEntryPoint(const char *entryPoint) { 77e7e3c45bSAndrea Faulds std::lock_guard<std::mutex> lock(mutex); 78e7e3c45bSAndrea Faulds vulkanRuntime.setEntryPoint(entryPoint); 79e7e3c45bSAndrea Faulds } 80e7e3c45bSAndrea Faulds 81e7e3c45bSAndrea Faulds void setNumWorkGroups(NumWorkGroups numWorkGroups) { 82e7e3c45bSAndrea Faulds std::lock_guard<std::mutex> lock(mutex); 83e7e3c45bSAndrea Faulds vulkanRuntime.setNumWorkGroups(numWorkGroups); 84e7e3c45bSAndrea Faulds } 85e7e3c45bSAndrea Faulds 86e7e3c45bSAndrea Faulds void setShaderModule(uint8_t *shader, uint32_t size) { 87e7e3c45bSAndrea Faulds std::lock_guard<std::mutex> lock(mutex); 88e7e3c45bSAndrea Faulds vulkanRuntime.setShaderModule(shader, size); 89e7e3c45bSAndrea Faulds } 90e7e3c45bSAndrea Faulds 91e7e3c45bSAndrea Faulds void runOnVulkan() { 92e7e3c45bSAndrea Faulds std::lock_guard<std::mutex> lock(mutex); 93e7e3c45bSAndrea Faulds if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) || 94e7e3c45bSAndrea Faulds failed(vulkanRuntime.updateHostMemoryBuffers()) || 95e7e3c45bSAndrea Faulds failed(vulkanRuntime.destroy())) { 96e7e3c45bSAndrea Faulds std::cerr << "runOnVulkan failed"; 97e7e3c45bSAndrea Faulds } 98e7e3c45bSAndrea Faulds } 99e7e3c45bSAndrea Faulds 100e7e3c45bSAndrea Faulds private: 101e7e3c45bSAndrea Faulds VulkanRuntime vulkanRuntime; 102e7e3c45bSAndrea Faulds std::mutex mutex; 103e7e3c45bSAndrea Faulds }; 104e7e3c45bSAndrea Faulds 105e7e3c45bSAndrea Faulds } // namespace 106e7e3c45bSAndrea Faulds 107e7e3c45bSAndrea Faulds template <typename T, int N> 108e7e3c45bSAndrea Faulds struct MemRefDescriptor { 109e7e3c45bSAndrea Faulds T *allocated; 110e7e3c45bSAndrea Faulds T *aligned; 111e7e3c45bSAndrea Faulds int64_t offset; 112e7e3c45bSAndrea Faulds int64_t sizes[N]; 113e7e3c45bSAndrea Faulds int64_t strides[N]; 114e7e3c45bSAndrea Faulds }; 115e7e3c45bSAndrea Faulds 116e7e3c45bSAndrea Faulds extern "C" { 117e7e3c45bSAndrea Faulds 118e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===// 119e7e3c45bSAndrea Faulds // 120*eb206e9eSAndrea Faulds // Wrappers intended for mlir-runner. Uses of GPU dialect operations get 121e7e3c45bSAndrea Faulds // lowered to calls to these functions by GPUToLLVMConversionPass. 122e7e3c45bSAndrea Faulds // 123e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===// 124e7e3c45bSAndrea Faulds 125e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() { 126e7e3c45bSAndrea Faulds return new VulkanRuntimeManager(); 127e7e3c45bSAndrea Faulds } 128e7e3c45bSAndrea Faulds 129e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) { 130e7e3c45bSAndrea Faulds delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager); 131e7e3c45bSAndrea Faulds } 132e7e3c45bSAndrea Faulds 133e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) { 134e7e3c45bSAndrea Faulds // Currently a no-op as the other operations are synchronous. 135e7e3c45bSAndrea Faulds } 136e7e3c45bSAndrea Faulds 137e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data, 138e7e3c45bSAndrea Faulds size_t gpuBlobSize) { 139e7e3c45bSAndrea Faulds // gpuBlobSize is the size of the data in bytes. 140e7e3c45bSAndrea Faulds return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize); 141e7e3c45bSAndrea Faulds } 142e7e3c45bSAndrea Faulds 143e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) { 144e7e3c45bSAndrea Faulds delete static_cast<VulkanModule *>(vkModule); 145e7e3c45bSAndrea Faulds } 146e7e3c45bSAndrea Faulds 147e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule, 148e7e3c45bSAndrea Faulds const char *name) { 149e7e3c45bSAndrea Faulds if (!vkModule) 150e7e3c45bSAndrea Faulds abort(); 151e7e3c45bSAndrea Faulds return static_cast<VulkanModule *>(vkModule)->getFunction(name); 152e7e3c45bSAndrea Faulds } 153e7e3c45bSAndrea Faulds 154e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 155e7e3c45bSAndrea Faulds mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ, 156e7e3c45bSAndrea Faulds size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/, 157e7e3c45bSAndrea Faulds size_t /*smem*/, void *vkRuntimeManager, void **params, 158e7e3c45bSAndrea Faulds void ** /*extra*/, size_t paramsCount) { 159e7e3c45bSAndrea Faulds auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager); 160e7e3c45bSAndrea Faulds 161e7e3c45bSAndrea Faulds // GpuToLLVMConversionPass with the kernelBarePtrCallConv and 162e7e3c45bSAndrea Faulds // kernelIntersperseSizeCallConv options will set up the params array like: 163e7e3c45bSAndrea Faulds // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... } 164e7e3c45bSAndrea Faulds const size_t paramsPerMemRef = 2; 165e7e3c45bSAndrea Faulds if (paramsCount % paramsPerMemRef != 0) { 166e7e3c45bSAndrea Faulds abort(); // This would indicate a serious calling convention mismatch. 167e7e3c45bSAndrea Faulds } 168e7e3c45bSAndrea Faulds const DescriptorSetIndex setIndex = 0; 169e7e3c45bSAndrea Faulds BindingIndex bindIndex = 0; 170e7e3c45bSAndrea Faulds for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) { 171e7e3c45bSAndrea Faulds void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]); 172e7e3c45bSAndrea Faulds size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]); 173e7e3c45bSAndrea Faulds VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr, 174e7e3c45bSAndrea Faulds static_cast<uint32_t>(memrefBufferSize)}; 175e7e3c45bSAndrea Faulds manager->setResourceData(setIndex, bindIndex, memBuffer); 176e7e3c45bSAndrea Faulds ++bindIndex; 177e7e3c45bSAndrea Faulds } 178e7e3c45bSAndrea Faulds 179e7e3c45bSAndrea Faulds manager->setNumWorkGroups(NumWorkGroups{static_cast<uint32_t>(gridX), 180e7e3c45bSAndrea Faulds static_cast<uint32_t>(gridY), 181e7e3c45bSAndrea Faulds static_cast<uint32_t>(gridZ)}); 182e7e3c45bSAndrea Faulds 183e7e3c45bSAndrea Faulds auto function = static_cast<VulkanFunction *>(vkKernel); 184e7e3c45bSAndrea Faulds // Expected size should be in bytes. 185e7e3c45bSAndrea Faulds manager->setShaderModule( 186e7e3c45bSAndrea Faulds function->module->blobData(), 187e7e3c45bSAndrea Faulds static_cast<uint32_t>(function->module->blobSizeInBytes())); 188e7e3c45bSAndrea Faulds manager->setEntryPoint(function->name.c_str()); 189e7e3c45bSAndrea Faulds 190e7e3c45bSAndrea Faulds manager->runOnVulkan(); 191e7e3c45bSAndrea Faulds } 192e7e3c45bSAndrea Faulds 193e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===// 194e7e3c45bSAndrea Faulds // 195e7e3c45bSAndrea Faulds // Miscellaneous utility functions that can be directly used by tests. 196e7e3c45bSAndrea Faulds // 197e7e3c45bSAndrea Faulds //===----------------------------------------------------------------------===// 198e7e3c45bSAndrea Faulds 199e7e3c45bSAndrea Faulds /// Fills the given 1D float memref with the given float value. 200e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 201e7e3c45bSAndrea Faulds _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT 202e7e3c45bSAndrea Faulds float value) { 203e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0], value); 204e7e3c45bSAndrea Faulds } 205e7e3c45bSAndrea Faulds 206e7e3c45bSAndrea Faulds /// Fills the given 2D float memref with the given float value. 207e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 208e7e3c45bSAndrea Faulds _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT 209e7e3c45bSAndrea Faulds float value) { 210e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 211e7e3c45bSAndrea Faulds } 212e7e3c45bSAndrea Faulds 213e7e3c45bSAndrea Faulds /// Fills the given 3D float memref with the given float value. 214e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 215e7e3c45bSAndrea Faulds _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT 216e7e3c45bSAndrea Faulds float value) { 217e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 218e7e3c45bSAndrea Faulds value); 219e7e3c45bSAndrea Faulds } 220e7e3c45bSAndrea Faulds 221e7e3c45bSAndrea Faulds /// Fills the given 1D int memref with the given int value. 222e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 223e7e3c45bSAndrea Faulds _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT 224e7e3c45bSAndrea Faulds int32_t value) { 225e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0], value); 226e7e3c45bSAndrea Faulds } 227e7e3c45bSAndrea Faulds 228e7e3c45bSAndrea Faulds /// Fills the given 2D int memref with the given int value. 229e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 230e7e3c45bSAndrea Faulds _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT 231e7e3c45bSAndrea Faulds int32_t value) { 232e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 233e7e3c45bSAndrea Faulds } 234e7e3c45bSAndrea Faulds 235e7e3c45bSAndrea Faulds /// Fills the given 3D int memref with the given int value. 236e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 237e7e3c45bSAndrea Faulds _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT 238e7e3c45bSAndrea Faulds int32_t value) { 239e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 240e7e3c45bSAndrea Faulds value); 241e7e3c45bSAndrea Faulds } 242e7e3c45bSAndrea Faulds 243e7e3c45bSAndrea Faulds /// Fills the given 1D int memref with the given int8 value. 244e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 245e7e3c45bSAndrea Faulds _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT 246e7e3c45bSAndrea Faulds int8_t value) { 247e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0], value); 248e7e3c45bSAndrea Faulds } 249e7e3c45bSAndrea Faulds 250e7e3c45bSAndrea Faulds /// Fills the given 2D int memref with the given int8 value. 251e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 252e7e3c45bSAndrea Faulds _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT 253e7e3c45bSAndrea Faulds int8_t value) { 254e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); 255e7e3c45bSAndrea Faulds } 256e7e3c45bSAndrea Faulds 257e7e3c45bSAndrea Faulds /// Fills the given 3D int memref with the given int8 value. 258e7e3c45bSAndrea Faulds VULKAN_WRAPPER_SYMBOL_EXPORT void 259e7e3c45bSAndrea Faulds _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT 260e7e3c45bSAndrea Faulds int8_t value) { 261e7e3c45bSAndrea Faulds std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2], 262e7e3c45bSAndrea Faulds value); 263e7e3c45bSAndrea Faulds } 264e7e3c45bSAndrea Faulds } 265