xref: /llvm-project/offload/src/PluginManager.cpp (revision 13dcc95dcd4999ff99f2de89d881f1aed5b21709)
1 //===-- PluginManager.cpp - Plugin loading and communication API ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Functionality for handling plugins.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PluginManager.h"
14 #include "Shared/Debug.h"
15 #include "Shared/Profile.h"
16 #include "device.h"
17 
18 #include "llvm/Support/Error.h"
19 #include "llvm/Support/ErrorHandling.h"
20 #include <memory>
21 
22 using namespace llvm;
23 using namespace llvm::sys;
24 
25 PluginManager *PM = nullptr;
26 
27 // Every plugin exports this method to create an instance of the plugin type.
28 #define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
29 #include "Shared/Targets.def"
30 
31 void PluginManager::init() {
32   TIMESCOPE();
33   DP("Loading RTLs...\n");
34 
35   // Attempt to create an instance of each supported plugin.
36 #define PLUGIN_TARGET(Name)                                                    \
37   do {                                                                         \
38     Plugins.emplace_back(                                                      \
39         std::unique_ptr<GenericPluginTy>(createPlugin_##Name()));              \
40   } while (false);
41 #include "Shared/Targets.def"
42 
43   DP("RTLs loaded!\n");
44 }
45 
46 void PluginManager::deinit() {
47   TIMESCOPE();
48   DP("Unloading RTLs...\n");
49 
50   for (auto &Plugin : Plugins) {
51     if (!Plugin->is_initialized())
52       continue;
53 
54     if (auto Err = Plugin->deinit()) {
55       [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
56       DP("Failed to deinit plugin: %s\n", InfoMsg.c_str());
57     }
58     Plugin.release();
59   }
60 
61   DP("RTLs unloaded!\n");
62 }
63 
64 bool PluginManager::initializePlugin(GenericPluginTy &Plugin) {
65   if (Plugin.is_initialized())
66     return true;
67 
68   if (auto Err = Plugin.init()) {
69     [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
70     DP("Failed to init plugin: %s\n", InfoMsg.c_str());
71     return false;
72   }
73 
74   DP("Registered plugin %s with %d visible device(s)\n", Plugin.getName(),
75      Plugin.number_of_devices());
76   return true;
77 }
78 
79 bool PluginManager::initializeDevice(GenericPluginTy &Plugin,
80                                      int32_t DeviceId) {
81   if (Plugin.is_device_initialized(DeviceId)) {
82     auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
83     (*ExclusiveDevicesAccessor)[PM->DeviceIds[std::make_pair(&Plugin,
84                                                              DeviceId)]]
85         ->setHasPendingImages(true);
86     return true;
87   }
88 
89   // Initialize the device information for the RTL we are about to use.
90   auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
91 
92   int32_t UserId = ExclusiveDevicesAccessor->size();
93 
94   // Set the device identifier offset in the plugin.
95 #ifdef OMPT_SUPPORT
96   Plugin.set_device_identifier(UserId, DeviceId);
97 #endif
98 
99   auto Device = std::make_unique<DeviceTy>(&Plugin, UserId, DeviceId);
100   if (auto Err = Device->init()) {
101     [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
102     DP("Failed to init device %d: %s\n", DeviceId, InfoMsg.c_str());
103     return false;
104   }
105 
106   ExclusiveDevicesAccessor->push_back(std::move(Device));
107 
108   // We need to map between the plugin's device identifier and the one
109   // that OpenMP will use.
110   PM->DeviceIds[std::make_pair(&Plugin, DeviceId)] = UserId;
111 
112   return true;
113 }
114 
115 void PluginManager::initializeAllDevices() {
116   for (auto &Plugin : plugins()) {
117     if (!initializePlugin(Plugin))
118       continue;
119 
120     for (int32_t DeviceId = 0; DeviceId < Plugin.number_of_devices();
121          ++DeviceId) {
122       initializeDevice(Plugin, DeviceId);
123     }
124   }
125 }
126 
127 // Returns a pointer to the binary descriptor, upgrading from a legacy format if
128 // necessary.
129 __tgt_bin_desc *PluginManager::upgradeLegacyEntries(__tgt_bin_desc *Desc) {
130   struct LegacyEntryTy {
131     void *Address;
132     char *SymbolName;
133     size_t Size;
134     int32_t Flags;
135     int32_t Data;
136   };
137 
138   if (UpgradedDescriptors.contains(Desc))
139     return &UpgradedDescriptors[Desc];
140 
141   if (Desc->HostEntriesBegin == Desc->HostEntriesEnd ||
142       Desc->HostEntriesBegin->Reserved == 0)
143     return Desc;
144 
145   // The new format mandates that each entry starts with eight bytes of zeroes.
146   // This allows us to detect the old format as this is a null pointer.
147   llvm::SmallVector<llvm::offloading::EntryTy, 0> &NewEntries =
148       LegacyEntries.emplace_back();
149   for (LegacyEntryTy &Entry : llvm::make_range(
150            reinterpret_cast<LegacyEntryTy *>(Desc->HostEntriesBegin),
151            reinterpret_cast<LegacyEntryTy *>(Desc->HostEntriesEnd))) {
152     llvm::offloading::EntryTy &NewEntry = NewEntries.emplace_back();
153 
154     NewEntry.Address = Entry.Address;
155     NewEntry.Flags = Entry.Flags;
156     NewEntry.Data = Entry.Data;
157     NewEntry.Size = Entry.Size;
158     NewEntry.SymbolName = Entry.SymbolName;
159   }
160 
161   // Create a new image struct so we can update the entries list.
162   llvm::SmallVector<__tgt_device_image, 0> &NewImages =
163       LegacyImages.emplace_back();
164   for (int32_t Image = 0; Image < Desc->NumDeviceImages; ++Image)
165     NewImages.emplace_back(
166         __tgt_device_image{Desc->DeviceImages[Image].ImageStart,
167                            Desc->DeviceImages[Image].ImageEnd,
168                            NewEntries.begin(), NewEntries.end()});
169 
170   // Create the new binary descriptor containing the newly created memory.
171   __tgt_bin_desc &NewDesc = UpgradedDescriptors[Desc];
172   NewDesc.DeviceImages = NewImages.begin();
173   NewDesc.NumDeviceImages = Desc->NumDeviceImages;
174   NewDesc.HostEntriesBegin = NewEntries.begin();
175   NewDesc.HostEntriesEnd = NewEntries.end();
176 
177   return &NewDesc;
178 }
179 
180 void PluginManager::registerLib(__tgt_bin_desc *Desc) {
181   PM->RTLsMtx.lock();
182 
183   // Upgrade the entries from the legacy implementation if necessary.
184   Desc = upgradeLegacyEntries(Desc);
185 
186   // Add in all the OpenMP requirements associated with this binary.
187   for (llvm::offloading::EntryTy &Entry :
188        llvm::make_range(Desc->HostEntriesBegin, Desc->HostEntriesEnd))
189     if (Entry.Flags == OMP_REGISTER_REQUIRES)
190       PM->addRequirements(Entry.Data);
191 
192   // Extract the exectuable image and extra information if availible.
193   for (int32_t i = 0; i < Desc->NumDeviceImages; ++i)
194     PM->addDeviceImage(*Desc, Desc->DeviceImages[i]);
195 
196   // Register the images with the RTLs that understand them, if any.
197   for (DeviceImageTy &DI : PM->deviceImages()) {
198     // Obtain the image and information that was previously extracted.
199     __tgt_device_image *Img = &DI.getExecutableImage();
200 
201     GenericPluginTy *FoundRTL = nullptr;
202 
203     // Scan the RTLs that have associated images until we find one that supports
204     // the current image.
205     for (auto &R : plugins()) {
206       if (!R.is_plugin_compatible(Img))
207         continue;
208 
209       if (!initializePlugin(R))
210         continue;
211 
212       if (!R.number_of_devices()) {
213         DP("Skipping plugin %s with no visible devices\n", R.getName());
214         continue;
215       }
216 
217       for (int32_t DeviceId = 0; DeviceId < R.number_of_devices(); ++DeviceId) {
218         if (!R.is_device_compatible(DeviceId, Img))
219           continue;
220 
221         DP("Image " DPxMOD " is compatible with RTL %s device %d!\n",
222            DPxPTR(Img->ImageStart), R.getName(), DeviceId);
223 
224         if (!initializeDevice(R, DeviceId))
225           continue;
226 
227         // Initialize (if necessary) translation table for this library.
228         PM->TrlTblMtx.lock();
229         if (!PM->HostEntriesBeginToTransTable.count(Desc->HostEntriesBegin)) {
230           PM->HostEntriesBeginRegistrationOrder.push_back(
231               Desc->HostEntriesBegin);
232           TranslationTable &TT =
233               (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
234           TT.HostTable.EntriesBegin = Desc->HostEntriesBegin;
235           TT.HostTable.EntriesEnd = Desc->HostEntriesEnd;
236         }
237 
238         // Retrieve translation table for this library.
239         TranslationTable &TT =
240             (PM->HostEntriesBeginToTransTable)[Desc->HostEntriesBegin];
241 
242         DP("Registering image " DPxMOD " with RTL %s!\n",
243            DPxPTR(Img->ImageStart), R.getName());
244 
245         auto UserId = PM->DeviceIds[std::make_pair(&R, DeviceId)];
246         if (TT.TargetsTable.size() < static_cast<size_t>(UserId + 1)) {
247           TT.DeviceTables.resize(UserId + 1, {});
248           TT.TargetsImages.resize(UserId + 1, nullptr);
249           TT.TargetsEntries.resize(UserId + 1, {});
250           TT.TargetsTable.resize(UserId + 1, nullptr);
251         }
252 
253         // Register the image for this target type and invalidate the table.
254         TT.TargetsImages[UserId] = Img;
255         TT.TargetsTable[UserId] = nullptr;
256 
257         PM->UsedImages.insert(Img);
258         FoundRTL = &R;
259 
260         PM->TrlTblMtx.unlock();
261       }
262     }
263     if (!FoundRTL)
264       DP("No RTL found for image " DPxMOD "!\n", DPxPTR(Img->ImageStart));
265   }
266   PM->RTLsMtx.unlock();
267 
268   bool UseAutoZeroCopy = Plugins.size() > 0;
269 
270   auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
271   for (const auto &Device : *ExclusiveDevicesAccessor)
272     UseAutoZeroCopy &= Device->useAutoZeroCopy();
273 
274   // Auto Zero-Copy can only be currently triggered when the system is an
275   // homogeneous APU architecture without attached discrete GPUs.
276   // If all devices suggest to use it, change requirment flags to trigger
277   // zero-copy behavior when mapping memory.
278   if (UseAutoZeroCopy)
279     addRequirements(OMPX_REQ_AUTO_ZERO_COPY);
280 
281   DP("Done registering entries!\n");
282 }
283 
284 // Temporary forward declaration, old style CTor/DTor handling is going away.
285 int target(ident_t *Loc, DeviceTy &Device, void *HostPtr,
286            KernelArgsTy &KernelArgs, AsyncInfoTy &AsyncInfo);
287 
288 void PluginManager::unregisterLib(__tgt_bin_desc *Desc) {
289   DP("Unloading target library!\n");
290 
291   Desc = upgradeLegacyEntries(Desc);
292 
293   PM->RTLsMtx.lock();
294   // Find which RTL understands each image, if any.
295   for (DeviceImageTy &DI : PM->deviceImages()) {
296     // Obtain the image and information that was previously extracted.
297     __tgt_device_image *Img = &DI.getExecutableImage();
298 
299     GenericPluginTy *FoundRTL = NULL;
300 
301     // Scan the RTLs that have associated images until we find one that supports
302     // the current image. We only need to scan RTLs that are already being used.
303     for (auto &R : plugins()) {
304       if (R.is_initialized())
305         continue;
306 
307       // Ensure that we do not use any unused images associated with this RTL.
308       if (!UsedImages.contains(Img))
309         continue;
310 
311       FoundRTL = &R;
312 
313       DP("Unregistered image " DPxMOD " from RTL\n", DPxPTR(Img->ImageStart));
314 
315       break;
316     }
317 
318     // if no RTL was found proceed to unregister the next image
319     if (!FoundRTL) {
320       DP("No RTLs in use support the image " DPxMOD "!\n",
321          DPxPTR(Img->ImageStart));
322     }
323   }
324   PM->RTLsMtx.unlock();
325   DP("Done unregistering images!\n");
326 
327   // Remove entries from PM->HostPtrToTableMap
328   PM->TblMapMtx.lock();
329   for (llvm::offloading::EntryTy *Cur = Desc->HostEntriesBegin;
330        Cur < Desc->HostEntriesEnd; ++Cur) {
331     PM->HostPtrToTableMap.erase(Cur->Address);
332   }
333 
334   // Remove translation table for this descriptor.
335   auto TransTable =
336       PM->HostEntriesBeginToTransTable.find(Desc->HostEntriesBegin);
337   if (TransTable != PM->HostEntriesBeginToTransTable.end()) {
338     DP("Removing translation table for descriptor " DPxMOD "\n",
339        DPxPTR(Desc->HostEntriesBegin));
340     PM->HostEntriesBeginToTransTable.erase(TransTable);
341   } else {
342     DP("Translation table for descriptor " DPxMOD " cannot be found, probably "
343        "it has been already removed.\n",
344        DPxPTR(Desc->HostEntriesBegin));
345   }
346 
347   PM->TblMapMtx.unlock();
348 
349   DP("Done unregistering library!\n");
350 }
351 
352 /// Map global data and execute pending ctors
353 static int loadImagesOntoDevice(DeviceTy &Device) {
354   /*
355    * Map global data
356    */
357   int32_t DeviceId = Device.DeviceID;
358   int Rc = OFFLOAD_SUCCESS;
359   {
360     std::lock_guard<decltype(PM->TrlTblMtx)> LG(PM->TrlTblMtx);
361     for (auto *HostEntriesBegin : PM->HostEntriesBeginRegistrationOrder) {
362       TranslationTable *TransTable =
363           &PM->HostEntriesBeginToTransTable[HostEntriesBegin];
364       DP("Trans table %p : %p\n", TransTable->HostTable.EntriesBegin,
365          TransTable->HostTable.EntriesEnd);
366       if (TransTable->HostTable.EntriesBegin ==
367           TransTable->HostTable.EntriesEnd) {
368         // No host entry so no need to proceed
369         continue;
370       }
371 
372       if (TransTable->TargetsTable[DeviceId] != 0) {
373         // Library entries have already been processed
374         continue;
375       }
376 
377       // 1) get image.
378       assert(TransTable->TargetsImages.size() > (size_t)DeviceId &&
379              "Not expecting a device ID outside the table's bounds!");
380       __tgt_device_image *Img = TransTable->TargetsImages[DeviceId];
381       if (!Img) {
382         REPORT("No image loaded for device id %d.\n", DeviceId);
383         Rc = OFFLOAD_FAIL;
384         break;
385       }
386 
387       // 2) Load the image onto the given device.
388       auto BinaryOrErr = Device.loadBinary(Img);
389       if (llvm::Error Err = BinaryOrErr.takeError()) {
390         REPORT("Failed to load image %s\n",
391                llvm::toString(std::move(Err)).c_str());
392         Rc = OFFLOAD_FAIL;
393         break;
394       }
395 
396       // 3) Create the translation table.
397       llvm::SmallVector<llvm::offloading::EntryTy> &DeviceEntries =
398           TransTable->TargetsEntries[DeviceId];
399       for (llvm::offloading::EntryTy &Entry :
400            llvm::make_range(Img->EntriesBegin, Img->EntriesEnd)) {
401         __tgt_device_binary &Binary = *BinaryOrErr;
402 
403         llvm::offloading::EntryTy DeviceEntry = Entry;
404         if (Entry.Size) {
405           if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName,
406                                      &DeviceEntry.Address) != OFFLOAD_SUCCESS)
407             REPORT("Failed to load symbol %s\n", Entry.SymbolName);
408 
409           // If unified memory is active, the corresponding global is a device
410           // reference to the host global. We need to initialize the pointer on
411           // the device to point to the memory on the host.
412           if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
413               (PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
414             if (Device.RTL->data_submit(DeviceId, DeviceEntry.Address,
415                                         Entry.Address,
416                                         Entry.Size) != OFFLOAD_SUCCESS)
417               REPORT("Failed to write symbol for USM %s\n", Entry.SymbolName);
418           }
419         } else if (Entry.Address) {
420           if (Device.RTL->get_function(Binary, Entry.SymbolName,
421                                        &DeviceEntry.Address) != OFFLOAD_SUCCESS)
422             REPORT("Failed to load kernel %s\n", Entry.SymbolName);
423         }
424         DP("Entry point " DPxMOD " maps to%s %s (" DPxMOD ")\n",
425            DPxPTR(Entry.Address), (Entry.Size) ? " global" : "",
426            Entry.SymbolName, DPxPTR(DeviceEntry.Address));
427 
428         DeviceEntries.emplace_back(DeviceEntry);
429       }
430 
431       // Set the storage for the table and get a pointer to it.
432       __tgt_target_table DeviceTable{&DeviceEntries[0],
433                                      &DeviceEntries[0] + DeviceEntries.size()};
434       TransTable->DeviceTables[DeviceId] = DeviceTable;
435       __tgt_target_table *TargetTable = TransTable->TargetsTable[DeviceId] =
436           &TransTable->DeviceTables[DeviceId];
437 
438       // 4) Verify whether the two table sizes match.
439       size_t Hsize =
440           TransTable->HostTable.EntriesEnd - TransTable->HostTable.EntriesBegin;
441       size_t Tsize = TargetTable->EntriesEnd - TargetTable->EntriesBegin;
442 
443       // Invalid image for these host entries!
444       if (Hsize != Tsize) {
445         REPORT(
446             "Host and Target tables mismatch for device id %d [%zx != %zx].\n",
447             DeviceId, Hsize, Tsize);
448         TransTable->TargetsImages[DeviceId] = 0;
449         TransTable->TargetsTable[DeviceId] = 0;
450         Rc = OFFLOAD_FAIL;
451         break;
452       }
453 
454       MappingInfoTy::HDTTMapAccessorTy HDTTMap =
455           Device.getMappingInfo().HostDataToTargetMap.getExclusiveAccessor();
456 
457       __tgt_target_table *HostTable = &TransTable->HostTable;
458       for (llvm::offloading::EntryTy *
459                CurrDeviceEntry = TargetTable->EntriesBegin,
460               *CurrHostEntry = HostTable->EntriesBegin,
461               *EntryDeviceEnd = TargetTable->EntriesEnd;
462            CurrDeviceEntry != EntryDeviceEnd;
463            CurrDeviceEntry++, CurrHostEntry++) {
464         if (CurrDeviceEntry->Size == 0)
465           continue;
466 
467         assert(CurrDeviceEntry->Size == CurrHostEntry->Size &&
468                "data size mismatch");
469 
470         // Fortran may use multiple weak declarations for the same symbol,
471         // therefore we must allow for multiple weak symbols to be loaded from
472         // the fat binary. Treat these mappings as any other "regular"
473         // mapping. Add entry to map.
474         if (Device.getMappingInfo().getTgtPtrBegin(
475                 HDTTMap, CurrHostEntry->Address, CurrHostEntry->Size))
476           continue;
477 
478         void *CurrDeviceEntryAddr = CurrDeviceEntry->Address;
479 
480         // For indirect mapping, follow the indirection and map the actual
481         // target.
482         if (CurrDeviceEntry->Flags & OMP_DECLARE_TARGET_INDIRECT) {
483           AsyncInfoTy AsyncInfo(Device);
484           void *DevPtr;
485           Device.retrieveData(&DevPtr, CurrDeviceEntryAddr, sizeof(void *),
486                               AsyncInfo, /*Entry=*/nullptr, &HDTTMap);
487           if (AsyncInfo.synchronize() != OFFLOAD_SUCCESS)
488             return OFFLOAD_FAIL;
489           CurrDeviceEntryAddr = DevPtr;
490         }
491 
492         DP("Add mapping from host " DPxMOD " to device " DPxMOD " with size %zu"
493            ", name \"%s\"\n",
494            DPxPTR(CurrHostEntry->Address), DPxPTR(CurrDeviceEntry->Address),
495            CurrDeviceEntry->Size, CurrDeviceEntry->SymbolName);
496         HDTTMap->emplace(new HostDataToTargetTy(
497             (uintptr_t)CurrHostEntry->Address /*HstPtrBase*/,
498             (uintptr_t)CurrHostEntry->Address /*HstPtrBegin*/,
499             (uintptr_t)CurrHostEntry->Address +
500                 CurrHostEntry->Size /*HstPtrEnd*/,
501             (uintptr_t)CurrDeviceEntryAddr /*TgtAllocBegin*/,
502             (uintptr_t)CurrDeviceEntryAddr /*TgtPtrBegin*/,
503             false /*UseHoldRefCount*/, CurrHostEntry->SymbolName,
504             true /*IsRefCountINF*/));
505 
506         // Notify about the new mapping.
507         if (Device.notifyDataMapped(CurrHostEntry->Address,
508                                     CurrHostEntry->Size))
509           return OFFLOAD_FAIL;
510       }
511     }
512     Device.setHasPendingImages(false);
513   }
514 
515   if (Rc != OFFLOAD_SUCCESS)
516     return Rc;
517 
518   static Int32Envar DumpOffloadEntries =
519       Int32Envar("OMPTARGET_DUMP_OFFLOAD_ENTRIES", -1);
520   if (DumpOffloadEntries.get() == DeviceId)
521     Device.dumpOffloadEntries();
522 
523   return OFFLOAD_SUCCESS;
524 }
525 
526 Expected<DeviceTy &> PluginManager::getDevice(uint32_t DeviceNo) {
527   DeviceTy *DevicePtr;
528   {
529     auto ExclusiveDevicesAccessor = getExclusiveDevicesAccessor();
530     if (DeviceNo >= ExclusiveDevicesAccessor->size())
531       return createStringError(
532           inconvertibleErrorCode(),
533           "Device number '%i' out of range, only %i devices available",
534           DeviceNo, ExclusiveDevicesAccessor->size());
535 
536     DevicePtr = &*(*ExclusiveDevicesAccessor)[DeviceNo];
537   }
538 
539   // Check whether global data has been mapped for this device
540   if (DevicePtr->hasPendingImages())
541     if (loadImagesOntoDevice(*DevicePtr) != OFFLOAD_SUCCESS)
542       return createStringError(inconvertibleErrorCode(),
543                                "Failed to load images on device '%i'",
544                                DeviceNo);
545   return *DevicePtr;
546 }
547