xref: /llvm-project/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp (revision 7fa19e6f4b87623b0ca1a23bf6b6293c1b5e5799)
1 //===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements wrappers around the sycl runtime library with C linkage
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <CL/sycl.hpp>
14 #include <level_zero/ze_api.h>
15 #include <sycl/ext/oneapi/backend/level_zero.hpp>
16 
17 #ifdef _WIN32
18 #define SYCL_RUNTIME_EXPORT __declspec(dllexport)
19 #else
20 #define SYCL_RUNTIME_EXPORT
21 #endif // _WIN32
22 
23 namespace {
24 
25 template <typename F>
catchAll(F && func)26 auto catchAll(F &&func) {
27   try {
28     return func();
29   } catch (const std::exception &e) {
30     fprintf(stdout, "An exception was thrown: %s\n", e.what());
31     fflush(stdout);
32     abort();
33   } catch (...) {
34     fprintf(stdout, "An unknown exception was thrown\n");
35     fflush(stdout);
36     abort();
37   }
38 }
39 
40 #define L0_SAFE_CALL(call)                                                     \
41   {                                                                            \
42     ze_result_t status = (call);                                               \
43     if (status != ZE_RESULT_SUCCESS) {                                         \
44       fprintf(stdout, "L0 error %d\n", status);                                \
45       fflush(stdout);                                                          \
46       abort();                                                                 \
47     }                                                                          \
48   }
49 
50 } // namespace
51 
getDefaultDevice()52 static sycl::device getDefaultDevice() {
53   static sycl::device syclDevice;
54   static bool isDeviceInitialised = false;
55   if (!isDeviceInitialised) {
56     auto platformList = sycl::platform::get_platforms();
57     for (const auto &platform : platformList) {
58       auto platformName = platform.get_info<sycl::info::platform::name>();
59       bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
60       if (!isLevelZero)
61         continue;
62 
63       syclDevice = platform.get_devices()[0];
64       isDeviceInitialised = true;
65       return syclDevice;
66     }
67     throw std::runtime_error("getDefaultDevice failed");
68   } else
69     return syclDevice;
70 }
71 
getDefaultContext()72 static sycl::context getDefaultContext() {
73   static sycl::context syclContext{getDefaultDevice()};
74   return syclContext;
75 }
76 
allocDeviceMemory(sycl::queue * queue,size_t size,bool isShared)77 static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
78   void *memPtr = nullptr;
79   if (isShared) {
80     memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
81                                         getDefaultContext());
82   } else {
83     memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
84                                         getDefaultContext());
85   }
86   if (memPtr == nullptr) {
87     throw std::runtime_error("mem allocation failed!");
88   }
89   return memPtr;
90 }
91 
deallocDeviceMemory(sycl::queue * queue,void * ptr)92 static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
93   sycl::free(ptr, *queue);
94 }
95 
loadModule(const void * data,size_t dataSize)96 static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
97   assert(data);
98   ze_module_handle_t zeModule;
99   ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
100                            nullptr,
101                            ZE_MODULE_FORMAT_IL_SPIRV,
102                            dataSize,
103                            (const uint8_t *)data,
104                            nullptr,
105                            nullptr};
106   auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
107       getDefaultDevice());
108   auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
109       getDefaultContext());
110   L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
111   return zeModule;
112 }
113 
getKernel(ze_module_handle_t zeModule,const char * name)114 static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
115   assert(zeModule);
116   assert(name);
117   ze_kernel_handle_t zeKernel;
118   ze_kernel_desc_t desc = {};
119   desc.pKernelName = name;
120 
121   L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
122   sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
123       sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
124                                sycl::bundle_state::executable>(
125           {zeModule}, getDefaultContext());
126 
127   auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
128       {kernelBundle, zeKernel}, getDefaultContext());
129   return new sycl::kernel(kernel);
130 }
131 
launchKernel(sycl::queue * queue,sycl::kernel * kernel,size_t gridX,size_t gridY,size_t gridZ,size_t blockX,size_t blockY,size_t blockZ,size_t sharedMemBytes,void ** params,size_t paramsCount)132 static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
133                          size_t gridY, size_t gridZ, size_t blockX,
134                          size_t blockY, size_t blockZ, size_t sharedMemBytes,
135                          void **params, size_t paramsCount) {
136   auto syclGlobalRange =
137       sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
138   auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
139   sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
140 
141   queue->submit([&](sycl::handler &cgh) {
142     for (size_t i = 0; i < paramsCount; i++) {
143       cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
144     }
145     cgh.parallel_for(syclNdRange, *kernel);
146   });
147 }
148 
149 // Wrappers
150 
mgpuStreamCreate()151 extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
152 
153   return catchAll([&]() {
154     sycl::queue *queue =
155         new sycl::queue(getDefaultContext(), getDefaultDevice());
156     return queue;
157   });
158 }
159 
mgpuStreamDestroy(sycl::queue * queue)160 extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
161   catchAll([&]() { delete queue; });
162 }
163 
164 extern "C" SYCL_RUNTIME_EXPORT void *
mgpuMemAlloc(uint64_t size,sycl::queue * queue,bool isShared)165 mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
166   return catchAll([&]() {
167     return allocDeviceMemory(queue, static_cast<size_t>(size), true);
168   });
169 }
170 
mgpuMemFree(void * ptr,sycl::queue * queue)171 extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
172   catchAll([&]() {
173     if (ptr) {
174       deallocDeviceMemory(queue, ptr);
175     }
176   });
177 }
178 
179 extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
mgpuModuleLoad(const void * data,size_t gpuBlobSize)180 mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
181   return catchAll([&]() { return loadModule(data, gpuBlobSize); });
182 }
183 
184 extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
mgpuModuleGetFunction(ze_module_handle_t module,const char * name)185 mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
186   return catchAll([&]() { return getKernel(module, name); });
187 }
188 
189 extern "C" SYCL_RUNTIME_EXPORT void
mgpuLaunchKernel(sycl::kernel * kernel,size_t gridX,size_t gridY,size_t gridZ,size_t blockX,size_t blockY,size_t blockZ,size_t sharedMemBytes,sycl::queue * queue,void ** params,void **,size_t paramsCount)190 mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
191                  size_t blockX, size_t blockY, size_t blockZ,
192                  size_t sharedMemBytes, sycl::queue *queue, void **params,
193                  void ** /*extra*/, size_t paramsCount) {
194   return catchAll([&]() {
195     launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ,
196                  sharedMemBytes, params, paramsCount);
197   });
198 }
199 
mgpuStreamSynchronize(sycl::queue * queue)200 extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) {
201 
202   catchAll([&]() { queue->wait(); });
203 }
204 
205 extern "C" SYCL_RUNTIME_EXPORT void
mgpuModuleUnload(ze_module_handle_t module)206 mgpuModuleUnload(ze_module_handle_t module) {
207 
208   catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
209 }
210