xref: /llvm-project/offload/liboffload/src/OffloadImpl.cpp (revision fd3907ccb583df99e9c19d2fe84e4e7c52d75de9)
1 //===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This contains the definitions of the new LLVM/Offload API entry points. See
10 // new-api/API/README.md for more information.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OffloadImpl.hpp"
15 #include "Helpers.hpp"
16 #include "PluginManager.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include <OffloadAPI.h>
19 
20 #include <mutex>
21 
22 using namespace llvm;
23 using namespace llvm::omp::target::plugin;
24 
25 // Handle type definitions. Ideally these would be 1:1 with the plugins
26 struct ol_device_handle_t_ {
27   int DeviceNum;
28   GenericDeviceTy &Device;
29   ol_platform_handle_t Platform;
30 };
31 
32 struct ol_platform_handle_t_ {
33   std::unique_ptr<GenericPluginTy> Plugin;
34   std::vector<ol_device_handle_t_> Devices;
35 };
36 
37 using PlatformVecT = SmallVector<ol_platform_handle_t_, 4>;
38 PlatformVecT &Platforms() {
39   static PlatformVecT Platforms;
40   return Platforms;
41 }
42 
43 // TODO: Some plugins expect to be linked into libomptarget which defines these
44 // symbols to implement ompt callbacks. The least invasive workaround here is to
45 // define them in libLLVMOffload as false/null so they are never used. In future
46 // it would be better to allow the plugins to implement callbacks without
47 // pulling in details from libomptarget.
48 #ifdef OMPT_SUPPORT
49 namespace llvm::omp::target {
50 namespace ompt {
51 bool Initialized = false;
52 ompt_get_callback_t lookupCallbackByCode = nullptr;
53 ompt_function_lookup_t lookupCallbackByName = nullptr;
54 } // namespace ompt
55 } // namespace llvm::omp::target
56 #endif
57 
58 // Every plugin exports this method to create an instance of the plugin type.
59 #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
60 #include "Shared/Targets.def"
61 
62 void initPlugins() {
63   // Attempt to create an instance of each supported plugin.
64 #define PLUGIN_TARGET(Name)                                                    \
65   do {                                                                         \
66     Platforms().emplace_back(ol_platform_handle_t_{                            \
67         std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), {}});         \
68   } while (false);
69 #include "Shared/Targets.def"
70 
71   // Preemptively initialize all devices in the plugin so we can just return
72   // them from deviceGet
73   for (auto &Platform : Platforms()) {
74     auto Err = Platform.Plugin->init();
75     [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
76     for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
77          DevNum++) {
78       if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
79         Platform.Devices.emplace_back(ol_device_handle_t_{
80             DevNum, Platform.Plugin->getDevice(DevNum), &Platform});
81       }
82     }
83   }
84 
85   offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE");
86 }
87 
88 // TODO: We can properly reference count here and manage the resources in a more
89 // clever way
90 ol_impl_result_t olInit_impl() {
91   static std::once_flag InitFlag;
92   std::call_once(InitFlag, initPlugins);
93 
94   return OL_SUCCESS;
95 }
96 ol_impl_result_t olShutDown_impl() { return OL_SUCCESS; }
97 
98 ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms) {
99   *NumPlatforms = Platforms().size();
100   return OL_SUCCESS;
101 }
102 
103 ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries,
104                                     ol_platform_handle_t *PlatformsOut) {
105   if (NumEntries > Platforms().size()) {
106     return {OL_ERRC_INVALID_SIZE,
107             std::string{formatv("{0} platform(s) available but {1} requested.",
108                                 Platforms().size(), NumEntries)}};
109   }
110 
111   for (uint32_t PlatformIndex = 0; PlatformIndex < NumEntries;
112        PlatformIndex++) {
113     PlatformsOut[PlatformIndex] = &(Platforms())[PlatformIndex];
114   }
115 
116   return OL_SUCCESS;
117 }
118 
119 ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
120                                              ol_platform_info_t PropName,
121                                              size_t PropSize, void *PropValue,
122                                              size_t *PropSizeRet) {
123   ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
124 
125   switch (PropName) {
126   case OL_PLATFORM_INFO_NAME:
127     return ReturnValue(Platform->Plugin->getName());
128   case OL_PLATFORM_INFO_VENDOR_NAME:
129     // TODO: Implement this
130     return ReturnValue("Unknown platform vendor");
131   case OL_PLATFORM_INFO_VERSION: {
132     return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR,
133                                OL_VERSION_MINOR, OL_VERSION_PATCH)
134                            .str()
135                            .c_str());
136   }
137   case OL_PLATFORM_INFO_BACKEND: {
138     auto PluginName = Platform->Plugin->getName();
139     if (PluginName == StringRef("CUDA")) {
140       return ReturnValue(OL_PLATFORM_BACKEND_CUDA);
141     } else if (PluginName == StringRef("AMDGPU")) {
142       return ReturnValue(OL_PLATFORM_BACKEND_AMDGPU);
143     } else {
144       return ReturnValue(OL_PLATFORM_BACKEND_UNKNOWN);
145     }
146   }
147   default:
148     return OL_ERRC_INVALID_ENUMERATION;
149   }
150 
151   return OL_SUCCESS;
152 }
153 
154 ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform,
155                                         ol_platform_info_t PropName,
156                                         size_t PropSize, void *PropValue) {
157   return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue,
158                                      nullptr);
159 }
160 
161 ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
162                                             ol_platform_info_t PropName,
163                                             size_t *PropSizeRet) {
164   return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr,
165                                      PropSizeRet);
166 }
167 
168 ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform,
169                                        uint32_t *pNumDevices) {
170   *pNumDevices = static_cast<uint32_t>(Platform->Devices.size());
171 
172   return OL_SUCCESS;
173 }
174 
175 ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform,
176                                   uint32_t NumEntries,
177                                   ol_device_handle_t *Devices) {
178   if (NumEntries > Platform->Devices.size()) {
179     return OL_ERRC_INVALID_SIZE;
180   }
181 
182   for (uint32_t DeviceIndex = 0; DeviceIndex < NumEntries; DeviceIndex++) {
183     Devices[DeviceIndex] = &(Platform->Devices[DeviceIndex]);
184   }
185 
186   return OL_SUCCESS;
187 }
188 
189 ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device,
190                                            ol_device_info_t PropName,
191                                            size_t PropSize, void *PropValue,
192                                            size_t *PropSizeRet) {
193 
194   ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
195 
196   InfoQueueTy DevInfo;
197   if (auto Err = Device->Device.obtainInfoImpl(DevInfo))
198     return OL_ERRC_OUT_OF_RESOURCES;
199 
200   // Find the info if it exists under any of the given names
201   auto GetInfo = [&DevInfo](std::vector<std::string> Names) {
202     for (auto Name : Names) {
203       auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) {
204         return Info.Key == Name;
205       };
206       auto Item = std::find_if(DevInfo.getQueue().begin(),
207                                DevInfo.getQueue().end(), InfoKeyMatches);
208 
209       if (Item != std::end(DevInfo.getQueue())) {
210         return Item->Value;
211       }
212     }
213 
214     return std::string("");
215   };
216 
217   switch (PropName) {
218   case OL_DEVICE_INFO_PLATFORM:
219     return ReturnValue(Device->Platform);
220   case OL_DEVICE_INFO_TYPE:
221     return ReturnValue(OL_DEVICE_TYPE_GPU);
222   case OL_DEVICE_INFO_NAME:
223     return ReturnValue(GetInfo({"Device Name"}).c_str());
224   case OL_DEVICE_INFO_VENDOR:
225     return ReturnValue(GetInfo({"Vendor Name"}).c_str());
226   case OL_DEVICE_INFO_DRIVER_VERSION:
227     return ReturnValue(
228         GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str());
229   default:
230     return OL_ERRC_INVALID_ENUMERATION;
231   }
232 
233   return OL_SUCCESS;
234 }
235 
236 ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device,
237                                       ol_device_info_t PropName,
238                                       size_t PropSize, void *PropValue) {
239   return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
240                                    nullptr);
241 }
242 
243 ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device,
244                                           ol_device_info_t PropName,
245                                           size_t *PropSizeRet) {
246   return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
247 }
248