xref: /llvm-project/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp (revision 7fa19e6f4b87623b0ca1a23bf6b6293c1b5e5799)
1*7fa19e6fSNishant Patel //===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
2*7fa19e6fSNishant Patel //
3*7fa19e6fSNishant Patel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*7fa19e6fSNishant Patel // See https://llvm.org/LICENSE.txt for license information.
5*7fa19e6fSNishant Patel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*7fa19e6fSNishant Patel //
7*7fa19e6fSNishant Patel //===----------------------------------------------------------------------===//
8*7fa19e6fSNishant Patel //
9*7fa19e6fSNishant Patel // Implements wrappers around the sycl runtime library with C linkage
10*7fa19e6fSNishant Patel //
11*7fa19e6fSNishant Patel //===----------------------------------------------------------------------===//
12*7fa19e6fSNishant Patel 
13*7fa19e6fSNishant Patel #include <CL/sycl.hpp>
14*7fa19e6fSNishant Patel #include <level_zero/ze_api.h>
15*7fa19e6fSNishant Patel #include <sycl/ext/oneapi/backend/level_zero.hpp>
16*7fa19e6fSNishant Patel 
17*7fa19e6fSNishant Patel #ifdef _WIN32
18*7fa19e6fSNishant Patel #define SYCL_RUNTIME_EXPORT __declspec(dllexport)
19*7fa19e6fSNishant Patel #else
20*7fa19e6fSNishant Patel #define SYCL_RUNTIME_EXPORT
21*7fa19e6fSNishant Patel #endif // _WIN32
22*7fa19e6fSNishant Patel 
23*7fa19e6fSNishant Patel namespace {
24*7fa19e6fSNishant Patel 
25*7fa19e6fSNishant Patel template <typename F>
catchAll(F && func)26*7fa19e6fSNishant Patel auto catchAll(F &&func) {
27*7fa19e6fSNishant Patel   try {
28*7fa19e6fSNishant Patel     return func();
29*7fa19e6fSNishant Patel   } catch (const std::exception &e) {
30*7fa19e6fSNishant Patel     fprintf(stdout, "An exception was thrown: %s\n", e.what());
31*7fa19e6fSNishant Patel     fflush(stdout);
32*7fa19e6fSNishant Patel     abort();
33*7fa19e6fSNishant Patel   } catch (...) {
34*7fa19e6fSNishant Patel     fprintf(stdout, "An unknown exception was thrown\n");
35*7fa19e6fSNishant Patel     fflush(stdout);
36*7fa19e6fSNishant Patel     abort();
37*7fa19e6fSNishant Patel   }
38*7fa19e6fSNishant Patel }
39*7fa19e6fSNishant Patel 
40*7fa19e6fSNishant Patel #define L0_SAFE_CALL(call)                                                     \
41*7fa19e6fSNishant Patel   {                                                                            \
42*7fa19e6fSNishant Patel     ze_result_t status = (call);                                               \
43*7fa19e6fSNishant Patel     if (status != ZE_RESULT_SUCCESS) {                                         \
44*7fa19e6fSNishant Patel       fprintf(stdout, "L0 error %d\n", status);                                \
45*7fa19e6fSNishant Patel       fflush(stdout);                                                          \
46*7fa19e6fSNishant Patel       abort();                                                                 \
47*7fa19e6fSNishant Patel     }                                                                          \
48*7fa19e6fSNishant Patel   }
49*7fa19e6fSNishant Patel 
50*7fa19e6fSNishant Patel } // namespace
51*7fa19e6fSNishant Patel 
getDefaultDevice()52*7fa19e6fSNishant Patel static sycl::device getDefaultDevice() {
53*7fa19e6fSNishant Patel   static sycl::device syclDevice;
54*7fa19e6fSNishant Patel   static bool isDeviceInitialised = false;
55*7fa19e6fSNishant Patel   if (!isDeviceInitialised) {
56*7fa19e6fSNishant Patel     auto platformList = sycl::platform::get_platforms();
57*7fa19e6fSNishant Patel     for (const auto &platform : platformList) {
58*7fa19e6fSNishant Patel       auto platformName = platform.get_info<sycl::info::platform::name>();
59*7fa19e6fSNishant Patel       bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
60*7fa19e6fSNishant Patel       if (!isLevelZero)
61*7fa19e6fSNishant Patel         continue;
62*7fa19e6fSNishant Patel 
63*7fa19e6fSNishant Patel       syclDevice = platform.get_devices()[0];
64*7fa19e6fSNishant Patel       isDeviceInitialised = true;
65*7fa19e6fSNishant Patel       return syclDevice;
66*7fa19e6fSNishant Patel     }
67*7fa19e6fSNishant Patel     throw std::runtime_error("getDefaultDevice failed");
68*7fa19e6fSNishant Patel   } else
69*7fa19e6fSNishant Patel     return syclDevice;
70*7fa19e6fSNishant Patel }
71*7fa19e6fSNishant Patel 
getDefaultContext()72*7fa19e6fSNishant Patel static sycl::context getDefaultContext() {
73*7fa19e6fSNishant Patel   static sycl::context syclContext{getDefaultDevice()};
74*7fa19e6fSNishant Patel   return syclContext;
75*7fa19e6fSNishant Patel }
76*7fa19e6fSNishant Patel 
allocDeviceMemory(sycl::queue * queue,size_t size,bool isShared)77*7fa19e6fSNishant Patel static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
78*7fa19e6fSNishant Patel   void *memPtr = nullptr;
79*7fa19e6fSNishant Patel   if (isShared) {
80*7fa19e6fSNishant Patel     memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
81*7fa19e6fSNishant Patel                                         getDefaultContext());
82*7fa19e6fSNishant Patel   } else {
83*7fa19e6fSNishant Patel     memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
84*7fa19e6fSNishant Patel                                         getDefaultContext());
85*7fa19e6fSNishant Patel   }
86*7fa19e6fSNishant Patel   if (memPtr == nullptr) {
87*7fa19e6fSNishant Patel     throw std::runtime_error("mem allocation failed!");
88*7fa19e6fSNishant Patel   }
89*7fa19e6fSNishant Patel   return memPtr;
90*7fa19e6fSNishant Patel }
91*7fa19e6fSNishant Patel 
deallocDeviceMemory(sycl::queue * queue,void * ptr)92*7fa19e6fSNishant Patel static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
93*7fa19e6fSNishant Patel   sycl::free(ptr, *queue);
94*7fa19e6fSNishant Patel }
95*7fa19e6fSNishant Patel 
loadModule(const void * data,size_t dataSize)96*7fa19e6fSNishant Patel static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
97*7fa19e6fSNishant Patel   assert(data);
98*7fa19e6fSNishant Patel   ze_module_handle_t zeModule;
99*7fa19e6fSNishant Patel   ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
100*7fa19e6fSNishant Patel                            nullptr,
101*7fa19e6fSNishant Patel                            ZE_MODULE_FORMAT_IL_SPIRV,
102*7fa19e6fSNishant Patel                            dataSize,
103*7fa19e6fSNishant Patel                            (const uint8_t *)data,
104*7fa19e6fSNishant Patel                            nullptr,
105*7fa19e6fSNishant Patel                            nullptr};
106*7fa19e6fSNishant Patel   auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
107*7fa19e6fSNishant Patel       getDefaultDevice());
108*7fa19e6fSNishant Patel   auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
109*7fa19e6fSNishant Patel       getDefaultContext());
110*7fa19e6fSNishant Patel   L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
111*7fa19e6fSNishant Patel   return zeModule;
112*7fa19e6fSNishant Patel }
113*7fa19e6fSNishant Patel 
getKernel(ze_module_handle_t zeModule,const char * name)114*7fa19e6fSNishant Patel static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
115*7fa19e6fSNishant Patel   assert(zeModule);
116*7fa19e6fSNishant Patel   assert(name);
117*7fa19e6fSNishant Patel   ze_kernel_handle_t zeKernel;
118*7fa19e6fSNishant Patel   ze_kernel_desc_t desc = {};
119*7fa19e6fSNishant Patel   desc.pKernelName = name;
120*7fa19e6fSNishant Patel 
121*7fa19e6fSNishant Patel   L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
122*7fa19e6fSNishant Patel   sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
123*7fa19e6fSNishant Patel       sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
124*7fa19e6fSNishant Patel                                sycl::bundle_state::executable>(
125*7fa19e6fSNishant Patel           {zeModule}, getDefaultContext());
126*7fa19e6fSNishant Patel 
127*7fa19e6fSNishant Patel   auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
128*7fa19e6fSNishant Patel       {kernelBundle, zeKernel}, getDefaultContext());
129*7fa19e6fSNishant Patel   return new sycl::kernel(kernel);
130*7fa19e6fSNishant Patel }
131*7fa19e6fSNishant Patel 
launchKernel(sycl::queue * queue,sycl::kernel * kernel,size_t gridX,size_t gridY,size_t gridZ,size_t blockX,size_t blockY,size_t blockZ,size_t sharedMemBytes,void ** params,size_t paramsCount)132*7fa19e6fSNishant Patel static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
133*7fa19e6fSNishant Patel                          size_t gridY, size_t gridZ, size_t blockX,
134*7fa19e6fSNishant Patel                          size_t blockY, size_t blockZ, size_t sharedMemBytes,
135*7fa19e6fSNishant Patel                          void **params, size_t paramsCount) {
136*7fa19e6fSNishant Patel   auto syclGlobalRange =
137*7fa19e6fSNishant Patel       sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
138*7fa19e6fSNishant Patel   auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
139*7fa19e6fSNishant Patel   sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
140*7fa19e6fSNishant Patel 
141*7fa19e6fSNishant Patel   queue->submit([&](sycl::handler &cgh) {
142*7fa19e6fSNishant Patel     for (size_t i = 0; i < paramsCount; i++) {
143*7fa19e6fSNishant Patel       cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
144*7fa19e6fSNishant Patel     }
145*7fa19e6fSNishant Patel     cgh.parallel_for(syclNdRange, *kernel);
146*7fa19e6fSNishant Patel   });
147*7fa19e6fSNishant Patel }
148*7fa19e6fSNishant Patel 
149*7fa19e6fSNishant Patel // Wrappers
150*7fa19e6fSNishant Patel 
mgpuStreamCreate()151*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
152*7fa19e6fSNishant Patel 
153*7fa19e6fSNishant Patel   return catchAll([&]() {
154*7fa19e6fSNishant Patel     sycl::queue *queue =
155*7fa19e6fSNishant Patel         new sycl::queue(getDefaultContext(), getDefaultDevice());
156*7fa19e6fSNishant Patel     return queue;
157*7fa19e6fSNishant Patel   });
158*7fa19e6fSNishant Patel }
159*7fa19e6fSNishant Patel 
mgpuStreamDestroy(sycl::queue * queue)160*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
161*7fa19e6fSNishant Patel   catchAll([&]() { delete queue; });
162*7fa19e6fSNishant Patel }
163*7fa19e6fSNishant Patel 
164*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT void *
mgpuMemAlloc(uint64_t size,sycl::queue * queue,bool isShared)165*7fa19e6fSNishant Patel mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
166*7fa19e6fSNishant Patel   return catchAll([&]() {
167*7fa19e6fSNishant Patel     return allocDeviceMemory(queue, static_cast<size_t>(size), true);
168*7fa19e6fSNishant Patel   });
169*7fa19e6fSNishant Patel }
170*7fa19e6fSNishant Patel 
mgpuMemFree(void * ptr,sycl::queue * queue)171*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
172*7fa19e6fSNishant Patel   catchAll([&]() {
173*7fa19e6fSNishant Patel     if (ptr) {
174*7fa19e6fSNishant Patel       deallocDeviceMemory(queue, ptr);
175*7fa19e6fSNishant Patel     }
176*7fa19e6fSNishant Patel   });
177*7fa19e6fSNishant Patel }
178*7fa19e6fSNishant Patel 
179*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
mgpuModuleLoad(const void * data,size_t gpuBlobSize)180*7fa19e6fSNishant Patel mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
181*7fa19e6fSNishant Patel   return catchAll([&]() { return loadModule(data, gpuBlobSize); });
182*7fa19e6fSNishant Patel }
183*7fa19e6fSNishant Patel 
184*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
mgpuModuleGetFunction(ze_module_handle_t module,const char * name)185*7fa19e6fSNishant Patel mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
186*7fa19e6fSNishant Patel   return catchAll([&]() { return getKernel(module, name); });
187*7fa19e6fSNishant Patel }
188*7fa19e6fSNishant Patel 
189*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT void
mgpuLaunchKernel(sycl::kernel * kernel,size_t gridX,size_t gridY,size_t gridZ,size_t blockX,size_t blockY,size_t blockZ,size_t sharedMemBytes,sycl::queue * queue,void ** params,void **,size_t paramsCount)190*7fa19e6fSNishant Patel mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
191*7fa19e6fSNishant Patel                  size_t blockX, size_t blockY, size_t blockZ,
192*7fa19e6fSNishant Patel                  size_t sharedMemBytes, sycl::queue *queue, void **params,
193*7fa19e6fSNishant Patel                  void ** /*extra*/, size_t paramsCount) {
194*7fa19e6fSNishant Patel   return catchAll([&]() {
195*7fa19e6fSNishant Patel     launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ,
196*7fa19e6fSNishant Patel                  sharedMemBytes, params, paramsCount);
197*7fa19e6fSNishant Patel   });
198*7fa19e6fSNishant Patel }
199*7fa19e6fSNishant Patel 
mgpuStreamSynchronize(sycl::queue * queue)200*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) {
201*7fa19e6fSNishant Patel 
202*7fa19e6fSNishant Patel   catchAll([&]() { queue->wait(); });
203*7fa19e6fSNishant Patel }
204*7fa19e6fSNishant Patel 
205*7fa19e6fSNishant Patel extern "C" SYCL_RUNTIME_EXPORT void
mgpuModuleUnload(ze_module_handle_t module)206*7fa19e6fSNishant Patel mgpuModuleUnload(ze_module_handle_t module) {
207*7fa19e6fSNishant Patel 
208*7fa19e6fSNishant Patel   catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
209*7fa19e6fSNishant Patel }
210