xref: /llvm-project/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp (revision 9f8f1d9890fcbae96b92fa0bee5a6f9e1f953ebd)
1a825fb2cSChristian Sigg //===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===//
2a825fb2cSChristian Sigg //
3a825fb2cSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a825fb2cSChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5a825fb2cSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a825fb2cSChristian Sigg //
7a825fb2cSChristian Sigg //===----------------------------------------------------------------------===//
8a825fb2cSChristian Sigg //
9a825fb2cSChristian Sigg // Implements C wrappers around the ROCM library for easy linking in ORC jit.
10a825fb2cSChristian Sigg // Also adds some debugging helpers that are helpful when writing MLIR code to
11a825fb2cSChristian Sigg // run on GPUs.
12a825fb2cSChristian Sigg //
13a825fb2cSChristian Sigg //===----------------------------------------------------------------------===//
14a825fb2cSChristian Sigg 
15a825fb2cSChristian Sigg #include <cassert>
16a825fb2cSChristian Sigg #include <numeric>
17a825fb2cSChristian Sigg 
18a825fb2cSChristian Sigg #include "mlir/ExecutionEngine/CRunnerUtils.h"
19a825fb2cSChristian Sigg #include "llvm/ADT/ArrayRef.h"
20a825fb2cSChristian Sigg 
21a825fb2cSChristian Sigg #include "hip/hip_runtime.h"
22a825fb2cSChristian Sigg 
23a825fb2cSChristian Sigg #define HIP_REPORT_IF_ERROR(expr)                                              \
24a825fb2cSChristian Sigg   [](hipError_t result) {                                                      \
25a825fb2cSChristian Sigg     if (!result)                                                               \
26a825fb2cSChristian Sigg       return;                                                                  \
27a825fb2cSChristian Sigg     const char *name = hipGetErrorName(result);                                \
28a825fb2cSChristian Sigg     if (!name)                                                                 \
29a825fb2cSChristian Sigg       name = "<unknown>";                                                      \
30a825fb2cSChristian Sigg     fprintf(stderr, "'%s' failed with '%s'\n", #expr, name);                   \
31a825fb2cSChristian Sigg   }(expr)
32a825fb2cSChristian Sigg 
3384718d37SKrzysztof Drewniak thread_local static int32_t defaultDevice = 0;
3484718d37SKrzysztof Drewniak 
35ebfea261SNishant Patel extern "C" hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
36a825fb2cSChristian Sigg   hipModule_t module = nullptr;
37a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
38a825fb2cSChristian Sigg   return module;
39a825fb2cSChristian Sigg }
40a825fb2cSChristian Sigg 
415093413aSFabian Mora extern "C" hipModule_t mgpuModuleLoadJIT(void *data, int optLevel) {
425093413aSFabian Mora   assert(false && "This function is not available in HIP.");
435093413aSFabian Mora   return nullptr;
445093413aSFabian Mora }
455093413aSFabian Mora 
46a825fb2cSChristian Sigg extern "C" void mgpuModuleUnload(hipModule_t module) {
47a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleUnload(module));
48a825fb2cSChristian Sigg }
49a825fb2cSChristian Sigg 
50a825fb2cSChristian Sigg extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
51a825fb2cSChristian Sigg                                                const char *name) {
52a825fb2cSChristian Sigg   hipFunction_t function = nullptr;
53a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
54a825fb2cSChristian Sigg   return function;
55a825fb2cSChristian Sigg }
56a825fb2cSChristian Sigg 
57a825fb2cSChristian Sigg // The wrapper uses intptr_t instead of ROCM's unsigned int to match
58a825fb2cSChristian Sigg // the type of MLIR's index type. This avoids the need for casts in the
59a825fb2cSChristian Sigg // generated MLIR code.
60a825fb2cSChristian Sigg extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
61a825fb2cSChristian Sigg                                  intptr_t gridY, intptr_t gridZ,
62a825fb2cSChristian Sigg                                  intptr_t blockX, intptr_t blockY,
63a825fb2cSChristian Sigg                                  intptr_t blockZ, int32_t smem,
64a825fb2cSChristian Sigg                                  hipStream_t stream, void **params,
65ebfea261SNishant Patel                                  void **extra, size_t /*paramsCount*/) {
66a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
67a825fb2cSChristian Sigg                                             blockX, blockY, blockZ, smem,
68a825fb2cSChristian Sigg                                             stream, params, extra));
69a825fb2cSChristian Sigg }
70a825fb2cSChristian Sigg 
71a825fb2cSChristian Sigg extern "C" hipStream_t mgpuStreamCreate() {
72a825fb2cSChristian Sigg   hipStream_t stream = nullptr;
73a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
74a825fb2cSChristian Sigg   return stream;
75a825fb2cSChristian Sigg }
76a825fb2cSChristian Sigg 
77a825fb2cSChristian Sigg extern "C" void mgpuStreamDestroy(hipStream_t stream) {
78a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipStreamDestroy(stream));
79a825fb2cSChristian Sigg }
80a825fb2cSChristian Sigg 
81a825fb2cSChristian Sigg extern "C" void mgpuStreamSynchronize(hipStream_t stream) {
82a825fb2cSChristian Sigg   return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream));
83a825fb2cSChristian Sigg }
84a825fb2cSChristian Sigg 
85a825fb2cSChristian Sigg extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) {
86a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0));
87a825fb2cSChristian Sigg }
88a825fb2cSChristian Sigg 
89a825fb2cSChristian Sigg extern "C" hipEvent_t mgpuEventCreate() {
90a825fb2cSChristian Sigg   hipEvent_t event = nullptr;
91a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming));
92a825fb2cSChristian Sigg   return event;
93a825fb2cSChristian Sigg }
94a825fb2cSChristian Sigg 
95a825fb2cSChristian Sigg extern "C" void mgpuEventDestroy(hipEvent_t event) {
96a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventDestroy(event));
97a825fb2cSChristian Sigg }
98a825fb2cSChristian Sigg 
99a825fb2cSChristian Sigg extern "C" void mgpuEventSynchronize(hipEvent_t event) {
100a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventSynchronize(event));
101a825fb2cSChristian Sigg }
102a825fb2cSChristian Sigg 
103a825fb2cSChristian Sigg extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
104a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventRecord(event, stream));
105a825fb2cSChristian Sigg }
106a825fb2cSChristian Sigg 
1071002a1d0SNishant Patel extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/,
1081002a1d0SNishant Patel                               bool /*isHostShared*/) {
109a825fb2cSChristian Sigg   void *ptr;
110a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
111a825fb2cSChristian Sigg   return ptr;
112a825fb2cSChristian Sigg }
113a825fb2cSChristian Sigg 
114a825fb2cSChristian Sigg extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
115a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipFree(ptr));
116a825fb2cSChristian Sigg }
117a825fb2cSChristian Sigg 
118361458b1SLoren Maggiore extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
119a825fb2cSChristian Sigg                            hipStream_t stream) {
120a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(
121a825fb2cSChristian Sigg       hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
122a825fb2cSChristian Sigg }
123a825fb2cSChristian Sigg 
124361458b1SLoren Maggiore extern "C" void mgpuMemset32(void *dst, int value, size_t count,
125361458b1SLoren Maggiore                              hipStream_t stream) {
126361458b1SLoren Maggiore   HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
127361458b1SLoren Maggiore                                         value, count, stream));
128361458b1SLoren Maggiore }
129*9f8f1d98SUmang Yadav 
130*9f8f1d98SUmang Yadav extern "C" void mgpuMemset16(void *dst, int short value, size_t count,
131*9f8f1d98SUmang Yadav                              hipStream_t stream) {
132*9f8f1d98SUmang Yadav   HIP_REPORT_IF_ERROR(hipMemsetD16Async(reinterpret_cast<hipDeviceptr_t>(dst),
133*9f8f1d98SUmang Yadav                                         value, count, stream));
134*9f8f1d98SUmang Yadav }
135*9f8f1d98SUmang Yadav 
136a825fb2cSChristian Sigg /// Helper functions for writing mlir example code
137a825fb2cSChristian Sigg 
138a825fb2cSChristian Sigg // Allows to register byte array with the ROCM runtime. Helpful until we have
139a825fb2cSChristian Sigg // transfer functions implemented.
140a825fb2cSChristian Sigg extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
141a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
142a825fb2cSChristian Sigg }
143a825fb2cSChristian Sigg 
144a825fb2cSChristian Sigg // Allows to register a MemRef with the ROCm runtime. Helpful until we have
145a825fb2cSChristian Sigg // transfer functions implemented.
146a825fb2cSChristian Sigg extern "C" void
147a825fb2cSChristian Sigg mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
148a825fb2cSChristian Sigg                           int64_t elementSizeBytes) {
149a825fb2cSChristian Sigg 
150a825fb2cSChristian Sigg   llvm::SmallVector<int64_t, 4> denseStrides(rank);
151a825fb2cSChristian Sigg   llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
152a825fb2cSChristian Sigg   llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
153a825fb2cSChristian Sigg 
154a825fb2cSChristian Sigg   std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
155a825fb2cSChristian Sigg                    std::multiplies<int64_t>());
156a825fb2cSChristian Sigg   auto sizeBytes = denseStrides.front() * elementSizeBytes;
157a825fb2cSChristian Sigg 
158a825fb2cSChristian Sigg   // Only densely packed tensors are currently supported.
159a825fb2cSChristian Sigg   std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
160a825fb2cSChristian Sigg               denseStrides.end());
161a825fb2cSChristian Sigg   denseStrides.back() = 1;
162984b800aSserge-sans-paille   assert(strides == llvm::ArrayRef(denseStrides));
163a825fb2cSChristian Sigg 
164a825fb2cSChristian Sigg   auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
165a825fb2cSChristian Sigg   mgpuMemHostRegister(ptr, sizeBytes);
166a825fb2cSChristian Sigg }
167a825fb2cSChristian Sigg 
1688f7c8a6eSmax // Allows to unregister byte array with the ROCM runtime. Helpful until we have
1698f7c8a6eSmax // transfer functions implemented.
1708f7c8a6eSmax extern "C" void mgpuMemHostUnregister(void *ptr) {
1718f7c8a6eSmax   HIP_REPORT_IF_ERROR(hipHostUnregister(ptr));
1728f7c8a6eSmax }
1738f7c8a6eSmax 
1748f7c8a6eSmax // Allows to unregister a MemRef with the ROCm runtime. Helpful until we have
1758f7c8a6eSmax // transfer functions implemented.
1768f7c8a6eSmax extern "C" void
1778f7c8a6eSmax mgpuMemHostUnregisterMemRef(int64_t rank,
1788f7c8a6eSmax                             StridedMemRefType<char, 1> *descriptor,
1798f7c8a6eSmax                             int64_t elementSizeBytes) {
1808f7c8a6eSmax   auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
1818f7c8a6eSmax   mgpuMemHostUnregister(ptr);
1828f7c8a6eSmax }
1838f7c8a6eSmax 
184a825fb2cSChristian Sigg template <typename T>
185a825fb2cSChristian Sigg void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
186a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipSetDevice(0));
187a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(
188a825fb2cSChristian Sigg       hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0));
189a825fb2cSChristian Sigg }
190a825fb2cSChristian Sigg 
191a825fb2cSChristian Sigg extern "C" StridedMemRefType<float, 1>
192a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset,
193a825fb2cSChristian Sigg                               int64_t size, int64_t stride) {
194a825fb2cSChristian Sigg   float *devicePtr = nullptr;
195a825fb2cSChristian Sigg   mgpuMemGetDevicePointer(aligned, &devicePtr);
196a825fb2cSChristian Sigg   return {devicePtr, devicePtr, offset, {size}, {stride}};
197a825fb2cSChristian Sigg }
198a825fb2cSChristian Sigg 
199a825fb2cSChristian Sigg extern "C" StridedMemRefType<int32_t, 1>
200a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned,
201a825fb2cSChristian Sigg                               int64_t offset, int64_t size, int64_t stride) {
202a825fb2cSChristian Sigg   int32_t *devicePtr = nullptr;
203a825fb2cSChristian Sigg   mgpuMemGetDevicePointer(aligned, &devicePtr);
204a825fb2cSChristian Sigg   return {devicePtr, devicePtr, offset, {size}, {stride}};
205a825fb2cSChristian Sigg }
20684718d37SKrzysztof Drewniak 
20784718d37SKrzysztof Drewniak extern "C" void mgpuSetDefaultDevice(int32_t device) {
20884718d37SKrzysztof Drewniak   defaultDevice = device;
20984718d37SKrzysztof Drewniak   HIP_REPORT_IF_ERROR(hipSetDevice(device));
21084718d37SKrzysztof Drewniak }
211