xref: /llvm-project/offload/liboffload/src/OffloadImpl.cpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1*fd3907ccSCallum Fare //===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===//
2*fd3907ccSCallum Fare //
3*fd3907ccSCallum Fare // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*fd3907ccSCallum Fare // See https://llvm.org/LICENSE.txt for license information.
5*fd3907ccSCallum Fare // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*fd3907ccSCallum Fare //
7*fd3907ccSCallum Fare //===----------------------------------------------------------------------===//
8*fd3907ccSCallum Fare //
9*fd3907ccSCallum Fare // This contains the definitions of the new LLVM/Offload API entry points. See
10*fd3907ccSCallum Fare // new-api/API/README.md for more information.
11*fd3907ccSCallum Fare //
12*fd3907ccSCallum Fare //===----------------------------------------------------------------------===//
13*fd3907ccSCallum Fare 
14*fd3907ccSCallum Fare #include "OffloadImpl.hpp"
15*fd3907ccSCallum Fare #include "Helpers.hpp"
16*fd3907ccSCallum Fare #include "PluginManager.h"
17*fd3907ccSCallum Fare #include "llvm/Support/FormatVariadic.h"
18*fd3907ccSCallum Fare #include <OffloadAPI.h>
19*fd3907ccSCallum Fare 
20*fd3907ccSCallum Fare #include <mutex>
21*fd3907ccSCallum Fare 
22*fd3907ccSCallum Fare using namespace llvm;
23*fd3907ccSCallum Fare using namespace llvm::omp::target::plugin;
24*fd3907ccSCallum Fare 
25*fd3907ccSCallum Fare // Handle type definitions. Ideally these would be 1:1 with the plugins
26*fd3907ccSCallum Fare struct ol_device_handle_t_ {
27*fd3907ccSCallum Fare   int DeviceNum;
28*fd3907ccSCallum Fare   GenericDeviceTy &Device;
29*fd3907ccSCallum Fare   ol_platform_handle_t Platform;
30*fd3907ccSCallum Fare };
31*fd3907ccSCallum Fare 
32*fd3907ccSCallum Fare struct ol_platform_handle_t_ {
33*fd3907ccSCallum Fare   std::unique_ptr<GenericPluginTy> Plugin;
34*fd3907ccSCallum Fare   std::vector<ol_device_handle_t_> Devices;
35*fd3907ccSCallum Fare };
36*fd3907ccSCallum Fare 
37*fd3907ccSCallum Fare using PlatformVecT = SmallVector<ol_platform_handle_t_, 4>;
38*fd3907ccSCallum Fare PlatformVecT &Platforms() {
39*fd3907ccSCallum Fare   static PlatformVecT Platforms;
40*fd3907ccSCallum Fare   return Platforms;
41*fd3907ccSCallum Fare }
42*fd3907ccSCallum Fare 
43*fd3907ccSCallum Fare // TODO: Some plugins expect to be linked into libomptarget which defines these
44*fd3907ccSCallum Fare // symbols to implement ompt callbacks. The least invasive workaround here is to
45*fd3907ccSCallum Fare // define them in libLLVMOffload as false/null so they are never used. In future
46*fd3907ccSCallum Fare // it would be better to allow the plugins to implement callbacks without
47*fd3907ccSCallum Fare // pulling in details from libomptarget.
48*fd3907ccSCallum Fare #ifdef OMPT_SUPPORT
49*fd3907ccSCallum Fare namespace llvm::omp::target {
50*fd3907ccSCallum Fare namespace ompt {
51*fd3907ccSCallum Fare bool Initialized = false;
52*fd3907ccSCallum Fare ompt_get_callback_t lookupCallbackByCode = nullptr;
53*fd3907ccSCallum Fare ompt_function_lookup_t lookupCallbackByName = nullptr;
54*fd3907ccSCallum Fare } // namespace ompt
55*fd3907ccSCallum Fare } // namespace llvm::omp::target
56*fd3907ccSCallum Fare #endif
57*fd3907ccSCallum Fare 
58*fd3907ccSCallum Fare // Every plugin exports this method to create an instance of the plugin type.
59*fd3907ccSCallum Fare #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
60*fd3907ccSCallum Fare #include "Shared/Targets.def"
61*fd3907ccSCallum Fare 
62*fd3907ccSCallum Fare void initPlugins() {
63*fd3907ccSCallum Fare   // Attempt to create an instance of each supported plugin.
64*fd3907ccSCallum Fare #define PLUGIN_TARGET(Name)                                                    \
65*fd3907ccSCallum Fare   do {                                                                         \
66*fd3907ccSCallum Fare     Platforms().emplace_back(ol_platform_handle_t_{                            \
67*fd3907ccSCallum Fare         std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), {}});         \
68*fd3907ccSCallum Fare   } while (false);
69*fd3907ccSCallum Fare #include "Shared/Targets.def"
70*fd3907ccSCallum Fare 
71*fd3907ccSCallum Fare   // Preemptively initialize all devices in the plugin so we can just return
72*fd3907ccSCallum Fare   // them from deviceGet
73*fd3907ccSCallum Fare   for (auto &Platform : Platforms()) {
74*fd3907ccSCallum Fare     auto Err = Platform.Plugin->init();
75*fd3907ccSCallum Fare     [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
76*fd3907ccSCallum Fare     for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
77*fd3907ccSCallum Fare          DevNum++) {
78*fd3907ccSCallum Fare       if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
79*fd3907ccSCallum Fare         Platform.Devices.emplace_back(ol_device_handle_t_{
80*fd3907ccSCallum Fare             DevNum, Platform.Plugin->getDevice(DevNum), &Platform});
81*fd3907ccSCallum Fare       }
82*fd3907ccSCallum Fare     }
83*fd3907ccSCallum Fare   }
84*fd3907ccSCallum Fare 
85*fd3907ccSCallum Fare   offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE");
86*fd3907ccSCallum Fare }
87*fd3907ccSCallum Fare 
88*fd3907ccSCallum Fare // TODO: We can properly reference count here and manage the resources in a more
89*fd3907ccSCallum Fare // clever way
90*fd3907ccSCallum Fare ol_impl_result_t olInit_impl() {
91*fd3907ccSCallum Fare   static std::once_flag InitFlag;
92*fd3907ccSCallum Fare   std::call_once(InitFlag, initPlugins);
93*fd3907ccSCallum Fare 
94*fd3907ccSCallum Fare   return OL_SUCCESS;
95*fd3907ccSCallum Fare }
96*fd3907ccSCallum Fare ol_impl_result_t olShutDown_impl() { return OL_SUCCESS; }
97*fd3907ccSCallum Fare 
98*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms) {
99*fd3907ccSCallum Fare   *NumPlatforms = Platforms().size();
100*fd3907ccSCallum Fare   return OL_SUCCESS;
101*fd3907ccSCallum Fare }
102*fd3907ccSCallum Fare 
103*fd3907ccSCallum Fare ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries,
104*fd3907ccSCallum Fare                                     ol_platform_handle_t *PlatformsOut) {
105*fd3907ccSCallum Fare   if (NumEntries > Platforms().size()) {
106*fd3907ccSCallum Fare     return {OL_ERRC_INVALID_SIZE,
107*fd3907ccSCallum Fare             std::string{formatv("{0} platform(s) available but {1} requested.",
108*fd3907ccSCallum Fare                                 Platforms().size(), NumEntries)}};
109*fd3907ccSCallum Fare   }
110*fd3907ccSCallum Fare 
111*fd3907ccSCallum Fare   for (uint32_t PlatformIndex = 0; PlatformIndex < NumEntries;
112*fd3907ccSCallum Fare        PlatformIndex++) {
113*fd3907ccSCallum Fare     PlatformsOut[PlatformIndex] = &(Platforms())[PlatformIndex];
114*fd3907ccSCallum Fare   }
115*fd3907ccSCallum Fare 
116*fd3907ccSCallum Fare   return OL_SUCCESS;
117*fd3907ccSCallum Fare }
118*fd3907ccSCallum Fare 
119*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
120*fd3907ccSCallum Fare                                              ol_platform_info_t PropName,
121*fd3907ccSCallum Fare                                              size_t PropSize, void *PropValue,
122*fd3907ccSCallum Fare                                              size_t *PropSizeRet) {
123*fd3907ccSCallum Fare   ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
124*fd3907ccSCallum Fare 
125*fd3907ccSCallum Fare   switch (PropName) {
126*fd3907ccSCallum Fare   case OL_PLATFORM_INFO_NAME:
127*fd3907ccSCallum Fare     return ReturnValue(Platform->Plugin->getName());
128*fd3907ccSCallum Fare   case OL_PLATFORM_INFO_VENDOR_NAME:
129*fd3907ccSCallum Fare     // TODO: Implement this
130*fd3907ccSCallum Fare     return ReturnValue("Unknown platform vendor");
131*fd3907ccSCallum Fare   case OL_PLATFORM_INFO_VERSION: {
132*fd3907ccSCallum Fare     return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
133*fd3907ccSCallum Fare                                OL_VERSION_MINOR, OL_VERSION_PATCH)
134*fd3907ccSCallum Fare                            .str()
135*fd3907ccSCallum Fare                            .c_str());
136*fd3907ccSCallum Fare   }
137*fd3907ccSCallum Fare   case OL_PLATFORM_INFO_BACKEND: {
138*fd3907ccSCallum Fare     auto PluginName = Platform->Plugin->getName();
139*fd3907ccSCallum Fare     if (PluginName == StringRef("CUDA")) {
140*fd3907ccSCallum Fare       return ReturnValue(OL_PLATFORM_BACKEND_CUDA);
141*fd3907ccSCallum Fare     } else if (PluginName == StringRef("AMDGPU")) {
142*fd3907ccSCallum Fare       return ReturnValue(OL_PLATFORM_BACKEND_AMDGPU);
143*fd3907ccSCallum Fare     } else {
144*fd3907ccSCallum Fare       return ReturnValue(OL_PLATFORM_BACKEND_UNKNOWN);
145*fd3907ccSCallum Fare     }
146*fd3907ccSCallum Fare   }
147*fd3907ccSCallum Fare   default:
148*fd3907ccSCallum Fare     return OL_ERRC_INVALID_ENUMERATION;
149*fd3907ccSCallum Fare   }
150*fd3907ccSCallum Fare 
151*fd3907ccSCallum Fare   return OL_SUCCESS;
152*fd3907ccSCallum Fare }
153*fd3907ccSCallum Fare 
154*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform,
155*fd3907ccSCallum Fare                                         ol_platform_info_t PropName,
156*fd3907ccSCallum Fare                                         size_t PropSize, void *PropValue) {
157*fd3907ccSCallum Fare   return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
158*fd3907ccSCallum Fare                                      nullptr);
159*fd3907ccSCallum Fare }
160*fd3907ccSCallum Fare 
161*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
162*fd3907ccSCallum Fare                                             ol_platform_info_t PropName,
163*fd3907ccSCallum Fare                                             size_t *PropSizeRet) {
164*fd3907ccSCallum Fare   return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
165*fd3907ccSCallum Fare                                      PropSizeRet);
166*fd3907ccSCallum Fare }
167*fd3907ccSCallum Fare 
168*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform,
169*fd3907ccSCallum Fare                                        uint32_t *pNumDevices) {
170*fd3907ccSCallum Fare   *pNumDevices = static_cast<uint32_t>(Platform->Devices.size());
171*fd3907ccSCallum Fare 
172*fd3907ccSCallum Fare   return OL_SUCCESS;
173*fd3907ccSCallum Fare }
174*fd3907ccSCallum Fare 
175*fd3907ccSCallum Fare ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform,
176*fd3907ccSCallum Fare                                   uint32_t NumEntries,
177*fd3907ccSCallum Fare                                   ol_device_handle_t *Devices) {
178*fd3907ccSCallum Fare   if (NumEntries > Platform->Devices.size()) {
179*fd3907ccSCallum Fare     return OL_ERRC_INVALID_SIZE;
180*fd3907ccSCallum Fare   }
181*fd3907ccSCallum Fare 
182*fd3907ccSCallum Fare   for (uint32_t DeviceIndex = 0; DeviceIndex < NumEntries; DeviceIndex++) {
183*fd3907ccSCallum Fare     Devices[DeviceIndex] = &(Platform->Devices[DeviceIndex]);
184*fd3907ccSCallum Fare   }
185*fd3907ccSCallum Fare 
186*fd3907ccSCallum Fare   return OL_SUCCESS;
187*fd3907ccSCallum Fare }
188*fd3907ccSCallum Fare 
189*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device,
190*fd3907ccSCallum Fare                                            ol_device_info_t PropName,
191*fd3907ccSCallum Fare                                            size_t PropSize, void *PropValue,
192*fd3907ccSCallum Fare                                            size_t *PropSizeRet) {
193*fd3907ccSCallum Fare 
194*fd3907ccSCallum Fare   ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
195*fd3907ccSCallum Fare 
196*fd3907ccSCallum Fare   InfoQueueTy DevInfo;
197*fd3907ccSCallum Fare   if (auto Err = Device->Device.obtainInfoImpl(DevInfo))
198*fd3907ccSCallum Fare     return OL_ERRC_OUT_OF_RESOURCES;
199*fd3907ccSCallum Fare 
200*fd3907ccSCallum Fare   // Find the info if it exists under any of the given names
201*fd3907ccSCallum Fare   auto GetInfo = [&DevInfo](std::vector<std::string> Names) {
202*fd3907ccSCallum Fare     for (auto Name : Names) {
203*fd3907ccSCallum Fare       auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) {
204*fd3907ccSCallum Fare         return Info.Key == Name;
205*fd3907ccSCallum Fare       };
206*fd3907ccSCallum Fare       auto Item = std::find_if(DevInfo.getQueue().begin(),
207*fd3907ccSCallum Fare                                DevInfo.getQueue().end(), InfoKeyMatches);
208*fd3907ccSCallum Fare 
209*fd3907ccSCallum Fare       if (Item != std::end(DevInfo.getQueue())) {
210*fd3907ccSCallum Fare         return Item->Value;
211*fd3907ccSCallum Fare       }
212*fd3907ccSCallum Fare     }
213*fd3907ccSCallum Fare 
214*fd3907ccSCallum Fare     return std::string("");
215*fd3907ccSCallum Fare   };
216*fd3907ccSCallum Fare 
217*fd3907ccSCallum Fare   switch (PropName) {
218*fd3907ccSCallum Fare   case OL_DEVICE_INFO_PLATFORM:
219*fd3907ccSCallum Fare     return ReturnValue(Device->Platform);
220*fd3907ccSCallum Fare   case OL_DEVICE_INFO_TYPE:
221*fd3907ccSCallum Fare     return ReturnValue(OL_DEVICE_TYPE_GPU);
222*fd3907ccSCallum Fare   case OL_DEVICE_INFO_NAME:
223*fd3907ccSCallum Fare     return ReturnValue(GetInfo({"Device Name"}).c_str());
224*fd3907ccSCallum Fare   case OL_DEVICE_INFO_VENDOR:
225*fd3907ccSCallum Fare     return ReturnValue(GetInfo({"Vendor Name"}).c_str());
226*fd3907ccSCallum Fare   case OL_DEVICE_INFO_DRIVER_VERSION:
227*fd3907ccSCallum Fare     return ReturnValue(
228*fd3907ccSCallum Fare         GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str());
229*fd3907ccSCallum Fare   default:
230*fd3907ccSCallum Fare     return OL_ERRC_INVALID_ENUMERATION;
231*fd3907ccSCallum Fare   }
232*fd3907ccSCallum Fare 
233*fd3907ccSCallum Fare   return OL_SUCCESS;
234*fd3907ccSCallum Fare }
235*fd3907ccSCallum Fare 
236*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device,
237*fd3907ccSCallum Fare                                       ol_device_info_t PropName,
238*fd3907ccSCallum Fare                                       size_t PropSize, void *PropValue) {
239*fd3907ccSCallum Fare   return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
240*fd3907ccSCallum Fare                                    nullptr);
241*fd3907ccSCallum Fare }
242*fd3907ccSCallum Fare 
243*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device,
244*fd3907ccSCallum Fare                                           ol_device_info_t PropName,
245*fd3907ccSCallum Fare                                           size_t *PropSizeRet) {
246*fd3907ccSCallum Fare   return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
247*fd3907ccSCallum Fare }
248