1*fd3907ccSCallum Fare //===- ol_impl.cpp - Implementation of the new LLVM/Offload API ------===// 2*fd3907ccSCallum Fare // 3*fd3907ccSCallum Fare // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*fd3907ccSCallum Fare // See https://llvm.org/LICENSE.txt for license information. 5*fd3907ccSCallum Fare // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*fd3907ccSCallum Fare // 7*fd3907ccSCallum Fare //===----------------------------------------------------------------------===// 8*fd3907ccSCallum Fare // 9*fd3907ccSCallum Fare // This contains the definitions of the new LLVM/Offload API entry points. See 10*fd3907ccSCallum Fare // new-api/API/README.md for more information. 11*fd3907ccSCallum Fare // 12*fd3907ccSCallum Fare //===----------------------------------------------------------------------===// 13*fd3907ccSCallum Fare 14*fd3907ccSCallum Fare #include "OffloadImpl.hpp" 15*fd3907ccSCallum Fare #include "Helpers.hpp" 16*fd3907ccSCallum Fare #include "PluginManager.h" 17*fd3907ccSCallum Fare #include "llvm/Support/FormatVariadic.h" 18*fd3907ccSCallum Fare #include <OffloadAPI.h> 19*fd3907ccSCallum Fare 20*fd3907ccSCallum Fare #include <mutex> 21*fd3907ccSCallum Fare 22*fd3907ccSCallum Fare using namespace llvm; 23*fd3907ccSCallum Fare using namespace llvm::omp::target::plugin; 24*fd3907ccSCallum Fare 25*fd3907ccSCallum Fare // Handle type definitions. Ideally these would be 1:1 with the plugins 26*fd3907ccSCallum Fare struct ol_device_handle_t_ { 27*fd3907ccSCallum Fare int DeviceNum; 28*fd3907ccSCallum Fare GenericDeviceTy &Device; 29*fd3907ccSCallum Fare ol_platform_handle_t Platform; 30*fd3907ccSCallum Fare }; 31*fd3907ccSCallum Fare 32*fd3907ccSCallum Fare struct ol_platform_handle_t_ { 33*fd3907ccSCallum Fare std::unique_ptr<GenericPluginTy> Plugin; 34*fd3907ccSCallum Fare std::vector<ol_device_handle_t_> Devices; 35*fd3907ccSCallum Fare }; 36*fd3907ccSCallum Fare 37*fd3907ccSCallum Fare using PlatformVecT = SmallVector<ol_platform_handle_t_, 4>; 38*fd3907ccSCallum Fare PlatformVecT &Platforms() { 39*fd3907ccSCallum Fare static PlatformVecT Platforms; 40*fd3907ccSCallum Fare return Platforms; 41*fd3907ccSCallum Fare } 42*fd3907ccSCallum Fare 43*fd3907ccSCallum Fare // TODO: Some plugins expect to be linked into libomptarget which defines these 44*fd3907ccSCallum Fare // symbols to implement ompt callbacks. The least invasive workaround here is to 45*fd3907ccSCallum Fare // define them in libLLVMOffload as false/null so they are never used. In future 46*fd3907ccSCallum Fare // it would be better to allow the plugins to implement callbacks without 47*fd3907ccSCallum Fare // pulling in details from libomptarget. 48*fd3907ccSCallum Fare #ifdef OMPT_SUPPORT 49*fd3907ccSCallum Fare namespace llvm::omp::target { 50*fd3907ccSCallum Fare namespace ompt { 51*fd3907ccSCallum Fare bool Initialized = false; 52*fd3907ccSCallum Fare ompt_get_callback_t lookupCallbackByCode = nullptr; 53*fd3907ccSCallum Fare ompt_function_lookup_t lookupCallbackByName = nullptr; 54*fd3907ccSCallum Fare } // namespace ompt 55*fd3907ccSCallum Fare } // namespace llvm::omp::target 56*fd3907ccSCallum Fare #endif 57*fd3907ccSCallum Fare 58*fd3907ccSCallum Fare // Every plugin exports this method to create an instance of the plugin type. 59*fd3907ccSCallum Fare #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name(); 60*fd3907ccSCallum Fare #include "Shared/Targets.def" 61*fd3907ccSCallum Fare 62*fd3907ccSCallum Fare void initPlugins() { 63*fd3907ccSCallum Fare // Attempt to create an instance of each supported plugin. 64*fd3907ccSCallum Fare #define PLUGIN_TARGET(Name) \ 65*fd3907ccSCallum Fare do { \ 66*fd3907ccSCallum Fare Platforms().emplace_back(ol_platform_handle_t_{ \ 67*fd3907ccSCallum Fare std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), {}}); \ 68*fd3907ccSCallum Fare } while (false); 69*fd3907ccSCallum Fare #include "Shared/Targets.def" 70*fd3907ccSCallum Fare 71*fd3907ccSCallum Fare // Preemptively initialize all devices in the plugin so we can just return 72*fd3907ccSCallum Fare // them from deviceGet 73*fd3907ccSCallum Fare for (auto &Platform : Platforms()) { 74*fd3907ccSCallum Fare auto Err = Platform.Plugin->init(); 75*fd3907ccSCallum Fare [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); 76*fd3907ccSCallum Fare for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); 77*fd3907ccSCallum Fare DevNum++) { 78*fd3907ccSCallum Fare if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { 79*fd3907ccSCallum Fare Platform.Devices.emplace_back(ol_device_handle_t_{ 80*fd3907ccSCallum Fare DevNum, Platform.Plugin->getDevice(DevNum), &Platform}); 81*fd3907ccSCallum Fare } 82*fd3907ccSCallum Fare } 83*fd3907ccSCallum Fare } 84*fd3907ccSCallum Fare 85*fd3907ccSCallum Fare offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE"); 86*fd3907ccSCallum Fare } 87*fd3907ccSCallum Fare 88*fd3907ccSCallum Fare // TODO: We can properly reference count here and manage the resources in a more 89*fd3907ccSCallum Fare // clever way 90*fd3907ccSCallum Fare ol_impl_result_t olInit_impl() { 91*fd3907ccSCallum Fare static std::once_flag InitFlag; 92*fd3907ccSCallum Fare std::call_once(InitFlag, initPlugins); 93*fd3907ccSCallum Fare 94*fd3907ccSCallum Fare return OL_SUCCESS; 95*fd3907ccSCallum Fare } 96*fd3907ccSCallum Fare ol_impl_result_t olShutDown_impl() { return OL_SUCCESS; } 97*fd3907ccSCallum Fare 98*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformCount_impl(uint32_t *NumPlatforms) { 99*fd3907ccSCallum Fare *NumPlatforms = Platforms().size(); 100*fd3907ccSCallum Fare return OL_SUCCESS; 101*fd3907ccSCallum Fare } 102*fd3907ccSCallum Fare 103*fd3907ccSCallum Fare ol_impl_result_t olGetPlatform_impl(uint32_t NumEntries, 104*fd3907ccSCallum Fare ol_platform_handle_t *PlatformsOut) { 105*fd3907ccSCallum Fare if (NumEntries > Platforms().size()) { 106*fd3907ccSCallum Fare return {OL_ERRC_INVALID_SIZE, 107*fd3907ccSCallum Fare std::string{formatv("{0} platform(s) available but {1} requested.", 108*fd3907ccSCallum Fare Platforms().size(), NumEntries)}}; 109*fd3907ccSCallum Fare } 110*fd3907ccSCallum Fare 111*fd3907ccSCallum Fare for (uint32_t PlatformIndex = 0; PlatformIndex < NumEntries; 112*fd3907ccSCallum Fare PlatformIndex++) { 113*fd3907ccSCallum Fare PlatformsOut[PlatformIndex] = &(Platforms())[PlatformIndex]; 114*fd3907ccSCallum Fare } 115*fd3907ccSCallum Fare 116*fd3907ccSCallum Fare return OL_SUCCESS; 117*fd3907ccSCallum Fare } 118*fd3907ccSCallum Fare 119*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, 120*fd3907ccSCallum Fare ol_platform_info_t PropName, 121*fd3907ccSCallum Fare size_t PropSize, void *PropValue, 122*fd3907ccSCallum Fare size_t *PropSizeRet) { 123*fd3907ccSCallum Fare ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); 124*fd3907ccSCallum Fare 125*fd3907ccSCallum Fare switch (PropName) { 126*fd3907ccSCallum Fare case OL_PLATFORM_INFO_NAME: 127*fd3907ccSCallum Fare return ReturnValue(Platform->Plugin->getName()); 128*fd3907ccSCallum Fare case OL_PLATFORM_INFO_VENDOR_NAME: 129*fd3907ccSCallum Fare // TODO: Implement this 130*fd3907ccSCallum Fare return ReturnValue("Unknown platform vendor"); 131*fd3907ccSCallum Fare case OL_PLATFORM_INFO_VERSION: { 132*fd3907ccSCallum Fare return ReturnValue(formatv("v{0}.{1}.{2}", OL_VERSION_MAJOR, 133*fd3907ccSCallum Fare OL_VERSION_MINOR, OL_VERSION_PATCH) 134*fd3907ccSCallum Fare .str() 135*fd3907ccSCallum Fare .c_str()); 136*fd3907ccSCallum Fare } 137*fd3907ccSCallum Fare case OL_PLATFORM_INFO_BACKEND: { 138*fd3907ccSCallum Fare auto PluginName = Platform->Plugin->getName(); 139*fd3907ccSCallum Fare if (PluginName == StringRef("CUDA")) { 140*fd3907ccSCallum Fare return ReturnValue(OL_PLATFORM_BACKEND_CUDA); 141*fd3907ccSCallum Fare } else if (PluginName == StringRef("AMDGPU")) { 142*fd3907ccSCallum Fare return ReturnValue(OL_PLATFORM_BACKEND_AMDGPU); 143*fd3907ccSCallum Fare } else { 144*fd3907ccSCallum Fare return ReturnValue(OL_PLATFORM_BACKEND_UNKNOWN); 145*fd3907ccSCallum Fare } 146*fd3907ccSCallum Fare } 147*fd3907ccSCallum Fare default: 148*fd3907ccSCallum Fare return OL_ERRC_INVALID_ENUMERATION; 149*fd3907ccSCallum Fare } 150*fd3907ccSCallum Fare 151*fd3907ccSCallum Fare return OL_SUCCESS; 152*fd3907ccSCallum Fare } 153*fd3907ccSCallum Fare 154*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformInfo_impl(ol_platform_handle_t Platform, 155*fd3907ccSCallum Fare ol_platform_info_t PropName, 156*fd3907ccSCallum Fare size_t PropSize, void *PropValue) { 157*fd3907ccSCallum Fare return olGetPlatformInfoImplDetail(Platform, PropName, PropSize, PropValue, 158*fd3907ccSCallum Fare nullptr); 159*fd3907ccSCallum Fare } 160*fd3907ccSCallum Fare 161*fd3907ccSCallum Fare ol_impl_result_t olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, 162*fd3907ccSCallum Fare ol_platform_info_t PropName, 163*fd3907ccSCallum Fare size_t *PropSizeRet) { 164*fd3907ccSCallum Fare return olGetPlatformInfoImplDetail(Platform, PropName, 0, nullptr, 165*fd3907ccSCallum Fare PropSizeRet); 166*fd3907ccSCallum Fare } 167*fd3907ccSCallum Fare 168*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceCount_impl(ol_platform_handle_t Platform, 169*fd3907ccSCallum Fare uint32_t *pNumDevices) { 170*fd3907ccSCallum Fare *pNumDevices = static_cast<uint32_t>(Platform->Devices.size()); 171*fd3907ccSCallum Fare 172*fd3907ccSCallum Fare return OL_SUCCESS; 173*fd3907ccSCallum Fare } 174*fd3907ccSCallum Fare 175*fd3907ccSCallum Fare ol_impl_result_t olGetDevice_impl(ol_platform_handle_t Platform, 176*fd3907ccSCallum Fare uint32_t NumEntries, 177*fd3907ccSCallum Fare ol_device_handle_t *Devices) { 178*fd3907ccSCallum Fare if (NumEntries > Platform->Devices.size()) { 179*fd3907ccSCallum Fare return OL_ERRC_INVALID_SIZE; 180*fd3907ccSCallum Fare } 181*fd3907ccSCallum Fare 182*fd3907ccSCallum Fare for (uint32_t DeviceIndex = 0; DeviceIndex < NumEntries; DeviceIndex++) { 183*fd3907ccSCallum Fare Devices[DeviceIndex] = &(Platform->Devices[DeviceIndex]); 184*fd3907ccSCallum Fare } 185*fd3907ccSCallum Fare 186*fd3907ccSCallum Fare return OL_SUCCESS; 187*fd3907ccSCallum Fare } 188*fd3907ccSCallum Fare 189*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceInfoImplDetail(ol_device_handle_t Device, 190*fd3907ccSCallum Fare ol_device_info_t PropName, 191*fd3907ccSCallum Fare size_t PropSize, void *PropValue, 192*fd3907ccSCallum Fare size_t *PropSizeRet) { 193*fd3907ccSCallum Fare 194*fd3907ccSCallum Fare ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); 195*fd3907ccSCallum Fare 196*fd3907ccSCallum Fare InfoQueueTy DevInfo; 197*fd3907ccSCallum Fare if (auto Err = Device->Device.obtainInfoImpl(DevInfo)) 198*fd3907ccSCallum Fare return OL_ERRC_OUT_OF_RESOURCES; 199*fd3907ccSCallum Fare 200*fd3907ccSCallum Fare // Find the info if it exists under any of the given names 201*fd3907ccSCallum Fare auto GetInfo = [&DevInfo](std::vector<std::string> Names) { 202*fd3907ccSCallum Fare for (auto Name : Names) { 203*fd3907ccSCallum Fare auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) { 204*fd3907ccSCallum Fare return Info.Key == Name; 205*fd3907ccSCallum Fare }; 206*fd3907ccSCallum Fare auto Item = std::find_if(DevInfo.getQueue().begin(), 207*fd3907ccSCallum Fare DevInfo.getQueue().end(), InfoKeyMatches); 208*fd3907ccSCallum Fare 209*fd3907ccSCallum Fare if (Item != std::end(DevInfo.getQueue())) { 210*fd3907ccSCallum Fare return Item->Value; 211*fd3907ccSCallum Fare } 212*fd3907ccSCallum Fare } 213*fd3907ccSCallum Fare 214*fd3907ccSCallum Fare return std::string(""); 215*fd3907ccSCallum Fare }; 216*fd3907ccSCallum Fare 217*fd3907ccSCallum Fare switch (PropName) { 218*fd3907ccSCallum Fare case OL_DEVICE_INFO_PLATFORM: 219*fd3907ccSCallum Fare return ReturnValue(Device->Platform); 220*fd3907ccSCallum Fare case OL_DEVICE_INFO_TYPE: 221*fd3907ccSCallum Fare return ReturnValue(OL_DEVICE_TYPE_GPU); 222*fd3907ccSCallum Fare case OL_DEVICE_INFO_NAME: 223*fd3907ccSCallum Fare return ReturnValue(GetInfo({"Device Name"}).c_str()); 224*fd3907ccSCallum Fare case OL_DEVICE_INFO_VENDOR: 225*fd3907ccSCallum Fare return ReturnValue(GetInfo({"Vendor Name"}).c_str()); 226*fd3907ccSCallum Fare case OL_DEVICE_INFO_DRIVER_VERSION: 227*fd3907ccSCallum Fare return ReturnValue( 228*fd3907ccSCallum Fare GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str()); 229*fd3907ccSCallum Fare default: 230*fd3907ccSCallum Fare return OL_ERRC_INVALID_ENUMERATION; 231*fd3907ccSCallum Fare } 232*fd3907ccSCallum Fare 233*fd3907ccSCallum Fare return OL_SUCCESS; 234*fd3907ccSCallum Fare } 235*fd3907ccSCallum Fare 236*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceInfo_impl(ol_device_handle_t Device, 237*fd3907ccSCallum Fare ol_device_info_t PropName, 238*fd3907ccSCallum Fare size_t PropSize, void *PropValue) { 239*fd3907ccSCallum Fare return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, 240*fd3907ccSCallum Fare nullptr); 241*fd3907ccSCallum Fare } 242*fd3907ccSCallum Fare 243*fd3907ccSCallum Fare ol_impl_result_t olGetDeviceInfoSize_impl(ol_device_handle_t Device, 244*fd3907ccSCallum Fare ol_device_info_t PropName, 245*fd3907ccSCallum Fare size_t *PropSizeRet) { 246*fd3907ccSCallum Fare return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); 247*fd3907ccSCallum Fare } 248