1a825fb2cSChristian Sigg //===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===// 2a825fb2cSChristian Sigg // 3a825fb2cSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a825fb2cSChristian Sigg // See https://llvm.org/LICENSE.txt for license information. 5a825fb2cSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a825fb2cSChristian Sigg // 7a825fb2cSChristian Sigg //===----------------------------------------------------------------------===// 8a825fb2cSChristian Sigg // 9a825fb2cSChristian Sigg // Implements C wrappers around the ROCM library for easy linking in ORC jit. 10a825fb2cSChristian Sigg // Also adds some debugging helpers that are helpful when writing MLIR code to 11a825fb2cSChristian Sigg // run on GPUs. 12a825fb2cSChristian Sigg // 13a825fb2cSChristian Sigg //===----------------------------------------------------------------------===// 14a825fb2cSChristian Sigg 15a825fb2cSChristian Sigg #include <cassert> 16a825fb2cSChristian Sigg #include <numeric> 17a825fb2cSChristian Sigg 18a825fb2cSChristian Sigg #include "mlir/ExecutionEngine/CRunnerUtils.h" 19a825fb2cSChristian Sigg #include "llvm/ADT/ArrayRef.h" 20a825fb2cSChristian Sigg 21a825fb2cSChristian Sigg #include "hip/hip_runtime.h" 22a825fb2cSChristian Sigg 23a825fb2cSChristian Sigg #define HIP_REPORT_IF_ERROR(expr) \ 24a825fb2cSChristian Sigg [](hipError_t result) { \ 25a825fb2cSChristian Sigg if (!result) \ 26a825fb2cSChristian Sigg return; \ 27a825fb2cSChristian Sigg const char *name = hipGetErrorName(result); \ 28a825fb2cSChristian Sigg if (!name) \ 29a825fb2cSChristian Sigg name = "<unknown>"; \ 30a825fb2cSChristian Sigg fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ 31a825fb2cSChristian Sigg }(expr) 32a825fb2cSChristian Sigg 3384718d37SKrzysztof Drewniak thread_local static int32_t defaultDevice = 0; 3484718d37SKrzysztof Drewniak 35ebfea261SNishant Patel extern "C" hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) { 36a825fb2cSChristian Sigg hipModule_t module = nullptr; 37a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data)); 38a825fb2cSChristian Sigg return module; 39a825fb2cSChristian Sigg } 40a825fb2cSChristian Sigg 415093413aSFabian Mora extern "C" hipModule_t mgpuModuleLoadJIT(void *data, int optLevel) { 425093413aSFabian Mora assert(false && "This function is not available in HIP."); 435093413aSFabian Mora return nullptr; 445093413aSFabian Mora } 455093413aSFabian Mora 46a825fb2cSChristian Sigg extern "C" void mgpuModuleUnload(hipModule_t module) { 47a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleUnload(module)); 48a825fb2cSChristian Sigg } 49a825fb2cSChristian Sigg 50a825fb2cSChristian Sigg extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module, 51a825fb2cSChristian Sigg const char *name) { 52a825fb2cSChristian Sigg hipFunction_t function = nullptr; 53a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name)); 54a825fb2cSChristian Sigg return function; 55a825fb2cSChristian Sigg } 56a825fb2cSChristian Sigg 57a825fb2cSChristian Sigg // The wrapper uses intptr_t instead of ROCM's unsigned int to match 58a825fb2cSChristian Sigg // the type of MLIR's index type. This avoids the need for casts in the 59a825fb2cSChristian Sigg // generated MLIR code. 60a825fb2cSChristian Sigg extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, 61a825fb2cSChristian Sigg intptr_t gridY, intptr_t gridZ, 62a825fb2cSChristian Sigg intptr_t blockX, intptr_t blockY, 63a825fb2cSChristian Sigg intptr_t blockZ, int32_t smem, 64a825fb2cSChristian Sigg hipStream_t stream, void **params, 65ebfea261SNishant Patel void **extra, size_t /*paramsCount*/) { 66a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 67a825fb2cSChristian Sigg blockX, blockY, blockZ, smem, 68a825fb2cSChristian Sigg stream, params, extra)); 69a825fb2cSChristian Sigg } 70a825fb2cSChristian Sigg 71a825fb2cSChristian Sigg extern "C" hipStream_t mgpuStreamCreate() { 72a825fb2cSChristian Sigg hipStream_t stream = nullptr; 73a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); 74a825fb2cSChristian Sigg return stream; 75a825fb2cSChristian Sigg } 76a825fb2cSChristian Sigg 77a825fb2cSChristian Sigg extern "C" void mgpuStreamDestroy(hipStream_t stream) { 78a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipStreamDestroy(stream)); 79a825fb2cSChristian Sigg } 80a825fb2cSChristian Sigg 81a825fb2cSChristian Sigg extern "C" void mgpuStreamSynchronize(hipStream_t stream) { 82a825fb2cSChristian Sigg return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream)); 83a825fb2cSChristian Sigg } 84a825fb2cSChristian Sigg 85a825fb2cSChristian Sigg extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) { 86a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0)); 87a825fb2cSChristian Sigg } 88a825fb2cSChristian Sigg 89a825fb2cSChristian Sigg extern "C" hipEvent_t mgpuEventCreate() { 90a825fb2cSChristian Sigg hipEvent_t event = nullptr; 91a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming)); 92a825fb2cSChristian Sigg return event; 93a825fb2cSChristian Sigg } 94a825fb2cSChristian Sigg 95a825fb2cSChristian Sigg extern "C" void mgpuEventDestroy(hipEvent_t event) { 96a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventDestroy(event)); 97a825fb2cSChristian Sigg } 98a825fb2cSChristian Sigg 99a825fb2cSChristian Sigg extern "C" void mgpuEventSynchronize(hipEvent_t event) { 100a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventSynchronize(event)); 101a825fb2cSChristian Sigg } 102a825fb2cSChristian Sigg 103a825fb2cSChristian Sigg extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) { 104a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventRecord(event, stream)); 105a825fb2cSChristian Sigg } 106a825fb2cSChristian Sigg 1071002a1d0SNishant Patel extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/, 1081002a1d0SNishant Patel bool /*isHostShared*/) { 109a825fb2cSChristian Sigg void *ptr; 110a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes)); 111a825fb2cSChristian Sigg return ptr; 112a825fb2cSChristian Sigg } 113a825fb2cSChristian Sigg 114a825fb2cSChristian Sigg extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) { 115a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipFree(ptr)); 116a825fb2cSChristian Sigg } 117a825fb2cSChristian Sigg 118361458b1SLoren Maggiore extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, 119a825fb2cSChristian Sigg hipStream_t stream) { 120a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR( 121a825fb2cSChristian Sigg hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream)); 122a825fb2cSChristian Sigg } 123a825fb2cSChristian Sigg 124361458b1SLoren Maggiore extern "C" void mgpuMemset32(void *dst, int value, size_t count, 125361458b1SLoren Maggiore hipStream_t stream) { 126361458b1SLoren Maggiore HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst), 127361458b1SLoren Maggiore value, count, stream)); 128361458b1SLoren Maggiore } 129*9f8f1d98SUmang Yadav 130*9f8f1d98SUmang Yadav extern "C" void mgpuMemset16(void *dst, int short value, size_t count, 131*9f8f1d98SUmang Yadav hipStream_t stream) { 132*9f8f1d98SUmang Yadav HIP_REPORT_IF_ERROR(hipMemsetD16Async(reinterpret_cast<hipDeviceptr_t>(dst), 133*9f8f1d98SUmang Yadav value, count, stream)); 134*9f8f1d98SUmang Yadav } 135*9f8f1d98SUmang Yadav 136a825fb2cSChristian Sigg /// Helper functions for writing mlir example code 137a825fb2cSChristian Sigg 138a825fb2cSChristian Sigg // Allows to register byte array with the ROCM runtime. Helpful until we have 139a825fb2cSChristian Sigg // transfer functions implemented. 140a825fb2cSChristian Sigg extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { 141a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0)); 142a825fb2cSChristian Sigg } 143a825fb2cSChristian Sigg 144a825fb2cSChristian Sigg // Allows to register a MemRef with the ROCm runtime. Helpful until we have 145a825fb2cSChristian Sigg // transfer functions implemented. 146a825fb2cSChristian Sigg extern "C" void 147a825fb2cSChristian Sigg mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, 148a825fb2cSChristian Sigg int64_t elementSizeBytes) { 149a825fb2cSChristian Sigg 150a825fb2cSChristian Sigg llvm::SmallVector<int64_t, 4> denseStrides(rank); 151a825fb2cSChristian Sigg llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank); 152a825fb2cSChristian Sigg llvm::ArrayRef<int64_t> strides(sizes.end(), rank); 153a825fb2cSChristian Sigg 154a825fb2cSChristian Sigg std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), 155a825fb2cSChristian Sigg std::multiplies<int64_t>()); 156a825fb2cSChristian Sigg auto sizeBytes = denseStrides.front() * elementSizeBytes; 157a825fb2cSChristian Sigg 158a825fb2cSChristian Sigg // Only densely packed tensors are currently supported. 159a825fb2cSChristian Sigg std::rotate(denseStrides.begin(), denseStrides.begin() + 1, 160a825fb2cSChristian Sigg denseStrides.end()); 161a825fb2cSChristian Sigg denseStrides.back() = 1; 162984b800aSserge-sans-paille assert(strides == llvm::ArrayRef(denseStrides)); 163a825fb2cSChristian Sigg 164a825fb2cSChristian Sigg auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; 165a825fb2cSChristian Sigg mgpuMemHostRegister(ptr, sizeBytes); 166a825fb2cSChristian Sigg } 167a825fb2cSChristian Sigg 1688f7c8a6eSmax // Allows to unregister byte array with the ROCM runtime. Helpful until we have 1698f7c8a6eSmax // transfer functions implemented. 1708f7c8a6eSmax extern "C" void mgpuMemHostUnregister(void *ptr) { 1718f7c8a6eSmax HIP_REPORT_IF_ERROR(hipHostUnregister(ptr)); 1728f7c8a6eSmax } 1738f7c8a6eSmax 1748f7c8a6eSmax // Allows to unregister a MemRef with the ROCm runtime. Helpful until we have 1758f7c8a6eSmax // transfer functions implemented. 1768f7c8a6eSmax extern "C" void 1778f7c8a6eSmax mgpuMemHostUnregisterMemRef(int64_t rank, 1788f7c8a6eSmax StridedMemRefType<char, 1> *descriptor, 1798f7c8a6eSmax int64_t elementSizeBytes) { 1808f7c8a6eSmax auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; 1818f7c8a6eSmax mgpuMemHostUnregister(ptr); 1828f7c8a6eSmax } 1838f7c8a6eSmax 184a825fb2cSChristian Sigg template <typename T> 185a825fb2cSChristian Sigg void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) { 186a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipSetDevice(0)); 187a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR( 188a825fb2cSChristian Sigg hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0)); 189a825fb2cSChristian Sigg } 190a825fb2cSChristian Sigg 191a825fb2cSChristian Sigg extern "C" StridedMemRefType<float, 1> 192a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset, 193a825fb2cSChristian Sigg int64_t size, int64_t stride) { 194a825fb2cSChristian Sigg float *devicePtr = nullptr; 195a825fb2cSChristian Sigg mgpuMemGetDevicePointer(aligned, &devicePtr); 196a825fb2cSChristian Sigg return {devicePtr, devicePtr, offset, {size}, {stride}}; 197a825fb2cSChristian Sigg } 198a825fb2cSChristian Sigg 199a825fb2cSChristian Sigg extern "C" StridedMemRefType<int32_t, 1> 200a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned, 201a825fb2cSChristian Sigg int64_t offset, int64_t size, int64_t stride) { 202a825fb2cSChristian Sigg int32_t *devicePtr = nullptr; 203a825fb2cSChristian Sigg mgpuMemGetDevicePointer(aligned, &devicePtr); 204a825fb2cSChristian Sigg return {devicePtr, devicePtr, offset, {size}, {stride}}; 205a825fb2cSChristian Sigg } 20684718d37SKrzysztof Drewniak 20784718d37SKrzysztof Drewniak extern "C" void mgpuSetDefaultDevice(int32_t device) { 20884718d37SKrzysztof Drewniak defaultDevice = device; 20984718d37SKrzysztof Drewniak HIP_REPORT_IF_ERROR(hipSetDevice(device)); 21084718d37SKrzysztof Drewniak } 211