1 //===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This contains the definitions of the new LLVM/Offload API entry points. See 10 // new-api/API/README.md for more information. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "OffloadImpl.hpp" 15 #include "Helpers.hpp" 16 #include "PluginManager.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include <OffloadAPI.h> 19 20 #include <mutex> 21 22 using namespace llvm; 23 using namespace llvm::omp::target::plugin; 24 25 // Handle type definitions. Ideally these would be 1:1 with the plugins 26 struct ol_device_handle_t_ { 27 int DeviceNum; 28 GenericDeviceTy &Device; 29 ol_platform_handle_t Platform; 30 }; 31 32 struct ol_platform_handle_t_ { 33 std::unique_ptr<GenericPluginTy> Plugin; 34 std::vector<ol_device_handle_t_> Devices; 35 }; 36 37 using PlatformVecT = SmallVector<ol_platform_handle_t_, 4>; 38 PlatformVecT &Platforms() { 39 static PlatformVecT Platforms; 40 return Platforms; 41 } 42 43 // TODO: Some plugins expect to be linked into libomptarget which defines these 44 // symbols to implement ompt callbacks. The least invasive workaround here is to 45 // define them in libLLVMOffload as false/null so they are never used. In future 46 // it would be better to allow the plugins to implement callbacks without 47 // pulling in details from libomptarget. 48 #ifdef OMPT_SUPPORT 49 namespace llvm::omp::target { 50 namespace ompt { 51 bool Initialized = false; 52 ompt_get_callback_t lookupCallbackByCode = nullptr; 53 ompt_function_lookup_t lookupCallbackByName = nullptr; 54 } // namespace ompt 55 } // namespace llvm::omp::target 56 #endif 57 58 // Every plugin exports this method to create an instance of the plugin type. 59 #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); 60 #include "Shared/Targets.def" 61 62 void initPlugins() { 63 // Attempt to create an instance of each supported plugin. 64 #define PLUGIN_TARGET(Name) \ 65 do { \ 66 Platforms().emplace_back(ol_platform_handle_t_{ \ 67 std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), {}}); \ 68 } while (false); 69 #include "Shared/Targets.def" 70 71 // Preemptively initialize all devices in the plugin so we can just return 72 // them from deviceGet 73 for (auto &Platform : Platforms()) { 74 auto Err = Platform.Plugin->init(); 75 [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); 76 for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); 77 DevNum++) { 78 if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { 79 Platform.Devices.emplace_back(ol_device_handle_t_{ 80 DevNum, Platform.Plugin->getDevice(DevNum), &Platform}); 81 } 82 } 83 } 84 85 offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE"); 86 } 87 88 // TODO: We can properly reference count here and manage the resources in a more 89 // clever way 90 ol_impl_result_t olInit_impl() { 91 static std::once_flag InitFlag; 92 std::call_once(InitFlag, initPlugins); 93 94 return OL_SUCCESS; 95 } 96 ol_impl_result_t olShutDown_impl() { return OL_SUCCESS; } 97 98 ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms) { 99 *NumPlatforms = Platforms().size(); 100 return OL_SUCCESS; 101 } 102 103 ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries, 104 ol_platform_handle_t *PlatformsOut) { 105 if (NumEntries > Platforms().size()) { 106 return {OL_ERRC_INVALID_SIZE, 107 std::string{formatv("{0} platform(s) available but {1} requested.", 108 Platforms().size(), NumEntries)}}; 109 } 110 111 for (uint32_t PlatformIndex = 0; PlatformIndex < NumEntries; 112 PlatformIndex++) { 113 PlatformsOut[PlatformIndex] = &(Platforms())[PlatformIndex]; 114 } 115 116 return OL_SUCCESS; 117 } 118 119 ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, 120 ol_platform_info_t PropName, 121 size_t PropSize, void *PropValue, 122 size_t *PropSizeRet) { 123 ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); 124 125 switch (PropName) { 126 case OL_PLATFORM_INFO_NAME: 127 return ReturnValue(Platform->Plugin->getName()); 128 case OL_PLATFORM_INFO_VENDOR_NAME: 129 // TODO: Implement this 130 return ReturnValue("Unknown platform vendor"); 131 case OL_PLATFORM_INFO_VERSION: { 132 return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, 133 OL_VERSION_MINOR, OL_VERSION_PATCH) 134 .str() 135 .c_str()); 136 } 137 case OL_PLATFORM_INFO_BACKEND: { 138 auto PluginName = Platform->Plugin->getName(); 139 if (PluginName == StringRef("CUDA")) { 140 return ReturnValue(OL_PLATFORM_BACKEND_CUDA); 141 } else if (PluginName == StringRef("AMDGPU")) { 142 return ReturnValue(OL_PLATFORM_BACKEND_AMDGPU); 143 } else { 144 return ReturnValue(OL_PLATFORM_BACKEND_UNKNOWN); 145 } 146 } 147 default: 148 return OL_ERRC_INVALID_ENUMERATION; 149 } 150 151 return OL_SUCCESS; 152 } 153 154 ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform, 155 ol_platform_info_t PropName, 156 size_t PropSize, void *PropValue) { 157 return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue, 158 nullptr); 159 } 160 161 ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, 162 ol_platform_info_t PropName, 163 size_t *PropSizeRet) { 164 return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr, 165 PropSizeRet); 166 } 167 168 ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform, 169 uint32_t *pNumDevices) { 170 *pNumDevices = static_cast<uint32_t>(Platform->Devices.size()); 171 172 return OL_SUCCESS; 173 } 174 175 ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform, 176 uint32_t NumEntries, 177 ol_device_handle_t *Devices) { 178 if (NumEntries > Platform->Devices.size()) { 179 return OL_ERRC_INVALID_SIZE; 180 } 181 182 for (uint32_t DeviceIndex = 0; DeviceIndex < NumEntries; DeviceIndex++) { 183 Devices[DeviceIndex] = &(Platform->Devices[DeviceIndex]); 184 } 185 186 return OL_SUCCESS; 187 } 188 189 ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device, 190 ol_device_info_t PropName, 191 size_t PropSize, void *PropValue, 192 size_t *PropSizeRet) { 193 194 ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); 195 196 InfoQueueTy DevInfo; 197 if (auto Err = Device->Device.obtainInfoImpl(DevInfo)) 198 return OL_ERRC_OUT_OF_RESOURCES; 199 200 // Find the info if it exists under any of the given names 201 auto GetInfo = [&DevInfo](std::vector<std::string> Names) { 202 for (auto Name : Names) { 203 auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) { 204 return Info.Key == Name; 205 }; 206 auto Item = std::find_if(DevInfo.getQueue().begin(), 207 DevInfo.getQueue().end(), InfoKeyMatches); 208 209 if (Item != std::end(DevInfo.getQueue())) { 210 return Item->Value; 211 } 212 } 213 214 return std::string(""); 215 }; 216 217 switch (PropName) { 218 case OL_DEVICE_INFO_PLATFORM: 219 return ReturnValue(Device->Platform); 220 case OL_DEVICE_INFO_TYPE: 221 return ReturnValue(OL_DEVICE_TYPE_GPU); 222 case OL_DEVICE_INFO_NAME: 223 return ReturnValue(GetInfo({"Device Name"}).c_str()); 224 case OL_DEVICE_INFO_VENDOR: 225 return ReturnValue(GetInfo({"Vendor Name"}).c_str()); 226 case OL_DEVICE_INFO_DRIVER_VERSION: 227 return ReturnValue( 228 GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str()); 229 default: 230 return OL_ERRC_INVALID_ENUMERATION; 231 } 232 233 return OL_SUCCESS; 234 } 235 236 ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device, 237 ol_device_info_t PropName, 238 size_t PropSize, void *PropValue) { 239 return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, 240 nullptr); 241 } 242 243 ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device, 244 ol_device_info_t PropName, 245 size_t *PropSizeRet) { 246 return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); 247 } 248