xref: /llvm-project/offload/plugins-nextgen/cuda/src/rtl.cpp (revision 134401deea5e86d646bb99fab39c182cfa8e5292)
1 //===----RTLs/cuda/src/rtl.cpp - Target RTLs Implementation ------- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RTL NextGen for CUDA machine
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <cassert>
14 #include <cstddef>
15 #include <cuda.h>
16 #include <string>
17 #include <unordered_map>
18 
19 #include "Shared/APITypes.h"
20 #include "Shared/Debug.h"
21 #include "Shared/Environment.h"
22 
23 #include "GlobalHandler.h"
24 #include "OpenMP/OMPT/Callback.h"
25 #include "PluginInterface.h"
26 #include "Utils/ELF.h"
27 
28 #include "llvm/BinaryFormat/ELF.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
31 #include "llvm/Support/Error.h"
32 #include "llvm/Support/FileOutputBuffer.h"
33 #include "llvm/Support/FileSystem.h"
34 #include "llvm/Support/Program.h"
35 
36 namespace llvm {
37 namespace omp {
38 namespace target {
39 namespace plugin {
40 
41 /// Forward declarations for all specialized data structures.
42 struct CUDAKernelTy;
43 struct CUDADeviceTy;
44 struct CUDAPluginTy;
45 
46 #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11000))
47 /// Forward declarations for all Virtual Memory Management
48 /// related data structures and functions. This is necessary
49 /// for older cuda versions.
50 typedef void *CUmemGenericAllocationHandle;
51 typedef void *CUmemAllocationProp;
52 typedef void *CUmemAccessDesc;
53 typedef void *CUmemAllocationGranularity_flags;
54 CUresult cuMemAddressReserve(CUdeviceptr *ptr, size_t size, size_t alignment,
55                              CUdeviceptr addr, unsigned long long flags) {}
56 CUresult cuMemMap(CUdeviceptr ptr, size_t size, size_t offset,
57                   CUmemGenericAllocationHandle handle,
58                   unsigned long long flags) {}
59 CUresult cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size,
60                      const CUmemAllocationProp *prop,
61                      unsigned long long flags) {}
62 CUresult cuMemSetAccess(CUdeviceptr ptr, size_t size,
63                         const CUmemAccessDesc *desc, size_t count) {}
64 CUresult
65 cuMemGetAllocationGranularity(size_t *granularity,
66                               const CUmemAllocationProp *prop,
67                               CUmemAllocationGranularity_flags option) {}
68 #endif
69 
70 #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11020))
71 // Forward declarations of asynchronous memory management functions. This is
72 // necessary for older versions of CUDA.
73 CUresult cuMemAllocAsync(CUdeviceptr *ptr, size_t, CUstream) { *ptr = 0; }
74 
75 CUresult cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream) {}
76 #endif
77 
78 /// Class implementing the CUDA device images properties.
79 struct CUDADeviceImageTy : public DeviceImageTy {
80   /// Create the CUDA image with the id and the target image pointer.
81   CUDADeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
82                     const __tgt_device_image *TgtImage)
83       : DeviceImageTy(ImageId, Device, TgtImage), Module(nullptr) {}
84 
85   /// Load the image as a CUDA module.
86   Error loadModule() {
87     assert(!Module && "Module already loaded");
88 
89     CUresult Res = cuModuleLoadDataEx(&Module, getStart(), 0, nullptr, nullptr);
90     if (auto Err = Plugin::check(Res, "Error in cuModuleLoadDataEx: %s"))
91       return Err;
92 
93     return Plugin::success();
94   }
95 
96   /// Unload the CUDA module corresponding to the image.
97   Error unloadModule() {
98     assert(Module && "Module not loaded");
99 
100     CUresult Res = cuModuleUnload(Module);
101     if (auto Err = Plugin::check(Res, "Error in cuModuleUnload: %s"))
102       return Err;
103 
104     Module = nullptr;
105 
106     return Plugin::success();
107   }
108 
109   /// Getter of the CUDA module.
110   CUmodule getModule() const { return Module; }
111 
112 private:
113   /// The CUDA module that loaded the image.
114   CUmodule Module;
115 };
116 
117 /// Class implementing the CUDA kernel functionalities which derives from the
118 /// generic kernel class.
119 struct CUDAKernelTy : public GenericKernelTy {
120   /// Create a CUDA kernel with a name and an execution mode.
121   CUDAKernelTy(const char *Name) : GenericKernelTy(Name), Func(nullptr) {}
122 
123   /// Initialize the CUDA kernel.
124   Error initImpl(GenericDeviceTy &GenericDevice,
125                  DeviceImageTy &Image) override {
126     CUresult Res;
127     CUDADeviceImageTy &CUDAImage = static_cast<CUDADeviceImageTy &>(Image);
128 
129     // Retrieve the function pointer of the kernel.
130     Res = cuModuleGetFunction(&Func, CUDAImage.getModule(), getName());
131     if (auto Err = Plugin::check(Res, "Error in cuModuleGetFunction('%s'): %s",
132                                  getName()))
133       return Err;
134 
135     // Check that the function pointer is valid.
136     if (!Func)
137       return Plugin::error("Invalid function for kernel %s", getName());
138 
139     int MaxThreads;
140     Res = cuFuncGetAttribute(&MaxThreads,
141                              CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, Func);
142     if (auto Err = Plugin::check(Res, "Error in cuFuncGetAttribute: %s"))
143       return Err;
144 
145     // The maximum number of threads cannot exceed the maximum of the kernel.
146     MaxNumThreads = std::min(MaxNumThreads, (uint32_t)MaxThreads);
147 
148     return Plugin::success();
149   }
150 
151   /// Launch the CUDA kernel function.
152   Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
153                    uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
154                    KernelLaunchParamsTy LaunchParams,
155                    AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
156 
157 private:
158   /// The CUDA kernel function to execute.
159   CUfunction Func;
160 };
161 
162 /// Class wrapping a CUDA stream reference. These are the objects handled by the
163 /// Stream Manager for the CUDA plugin.
164 struct CUDAStreamRef final : public GenericDeviceResourceRef {
165   /// The underlying handle type for streams.
166   using HandleTy = CUstream;
167 
168   /// Create an empty reference to an invalid stream.
169   CUDAStreamRef() : Stream(nullptr) {}
170 
171   /// Create a reference to an existing stream.
172   CUDAStreamRef(HandleTy Stream) : Stream(Stream) {}
173 
174   /// Create a new stream and save the reference. The reference must be empty
175   /// before calling to this function.
176   Error create(GenericDeviceTy &Device) override {
177     if (Stream)
178       return Plugin::error("Creating an existing stream");
179 
180     CUresult Res = cuStreamCreate(&Stream, CU_STREAM_NON_BLOCKING);
181     if (auto Err = Plugin::check(Res, "Error in cuStreamCreate: %s"))
182       return Err;
183 
184     return Plugin::success();
185   }
186 
187   /// Destroy the referenced stream and invalidate the reference. The reference
188   /// must be to a valid stream before calling to this function.
189   Error destroy(GenericDeviceTy &Device) override {
190     if (!Stream)
191       return Plugin::error("Destroying an invalid stream");
192 
193     CUresult Res = cuStreamDestroy(Stream);
194     if (auto Err = Plugin::check(Res, "Error in cuStreamDestroy: %s"))
195       return Err;
196 
197     Stream = nullptr;
198     return Plugin::success();
199   }
200 
201   /// Get the underlying CUDA stream.
202   operator HandleTy() const { return Stream; }
203 
204 private:
205   /// The reference to the CUDA stream.
206   HandleTy Stream;
207 };
208 
209 /// Class wrapping a CUDA event reference. These are the objects handled by the
210 /// Event Manager for the CUDA plugin.
211 struct CUDAEventRef final : public GenericDeviceResourceRef {
212   /// The underlying handle type for events.
213   using HandleTy = CUevent;
214 
215   /// Create an empty reference to an invalid event.
216   CUDAEventRef() : Event(nullptr) {}
217 
218   /// Create a reference to an existing event.
219   CUDAEventRef(HandleTy Event) : Event(Event) {}
220 
221   /// Create a new event and save the reference. The reference must be empty
222   /// before calling to this function.
223   Error create(GenericDeviceTy &Device) override {
224     if (Event)
225       return Plugin::error("Creating an existing event");
226 
227     CUresult Res = cuEventCreate(&Event, CU_EVENT_DEFAULT);
228     if (auto Err = Plugin::check(Res, "Error in cuEventCreate: %s"))
229       return Err;
230 
231     return Plugin::success();
232   }
233 
234   /// Destroy the referenced event and invalidate the reference. The reference
235   /// must be to a valid event before calling to this function.
236   Error destroy(GenericDeviceTy &Device) override {
237     if (!Event)
238       return Plugin::error("Destroying an invalid event");
239 
240     CUresult Res = cuEventDestroy(Event);
241     if (auto Err = Plugin::check(Res, "Error in cuEventDestroy: %s"))
242       return Err;
243 
244     Event = nullptr;
245     return Plugin::success();
246   }
247 
248   /// Get the underlying CUevent.
249   operator HandleTy() const { return Event; }
250 
251 private:
252   /// The reference to the CUDA event.
253   HandleTy Event;
254 };
255 
256 /// Class implementing the CUDA device functionalities which derives from the
257 /// generic device class.
258 struct CUDADeviceTy : public GenericDeviceTy {
259   // Create a CUDA device with a device id and the default CUDA grid values.
260   CUDADeviceTy(GenericPluginTy &Plugin, int32_t DeviceId, int32_t NumDevices)
261       : GenericDeviceTy(Plugin, DeviceId, NumDevices, NVPTXGridValues),
262         CUDAStreamManager(*this), CUDAEventManager(*this) {}
263 
264   ~CUDADeviceTy() {}
265 
266   /// Initialize the device, its resources and get its properties.
267   Error initImpl(GenericPluginTy &Plugin) override {
268     CUresult Res = cuDeviceGet(&Device, DeviceId);
269     if (auto Err = Plugin::check(Res, "Error in cuDeviceGet: %s"))
270       return Err;
271 
272     // Query the current flags of the primary context and set its flags if
273     // it is inactive.
274     unsigned int FormerPrimaryCtxFlags = 0;
275     int FormerPrimaryCtxIsActive = 0;
276     Res = cuDevicePrimaryCtxGetState(Device, &FormerPrimaryCtxFlags,
277                                      &FormerPrimaryCtxIsActive);
278     if (auto Err =
279             Plugin::check(Res, "Error in cuDevicePrimaryCtxGetState: %s"))
280       return Err;
281 
282     if (FormerPrimaryCtxIsActive) {
283       INFO(OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId,
284            "The primary context is active, no change to its flags\n");
285       if ((FormerPrimaryCtxFlags & CU_CTX_SCHED_MASK) !=
286           CU_CTX_SCHED_BLOCKING_SYNC)
287         INFO(OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId,
288              "Warning: The current flags are not CU_CTX_SCHED_BLOCKING_SYNC\n");
289     } else {
290       INFO(OMP_INFOTYPE_PLUGIN_KERNEL, DeviceId,
291            "The primary context is inactive, set its flags to "
292            "CU_CTX_SCHED_BLOCKING_SYNC\n");
293       Res = cuDevicePrimaryCtxSetFlags(Device, CU_CTX_SCHED_BLOCKING_SYNC);
294       if (auto Err =
295               Plugin::check(Res, "Error in cuDevicePrimaryCtxSetFlags: %s"))
296         return Err;
297     }
298 
299     // Retain the per device primary context and save it to use whenever this
300     // device is selected.
301     Res = cuDevicePrimaryCtxRetain(&Context, Device);
302     if (auto Err = Plugin::check(Res, "Error in cuDevicePrimaryCtxRetain: %s"))
303       return Err;
304 
305     if (auto Err = setContext())
306       return Err;
307 
308     // Initialize stream pool.
309     if (auto Err = CUDAStreamManager.init(OMPX_InitialNumStreams))
310       return Err;
311 
312     // Initialize event pool.
313     if (auto Err = CUDAEventManager.init(OMPX_InitialNumEvents))
314       return Err;
315 
316     // Query attributes to determine number of threads/block and blocks/grid.
317     if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X,
318                                  GridValues.GV_Max_Teams))
319       return Err;
320 
321     if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X,
322                                  GridValues.GV_Max_WG_Size))
323       return Err;
324 
325     if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_WARP_SIZE,
326                                  GridValues.GV_Warp_Size))
327       return Err;
328 
329     if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
330                                  ComputeCapability.Major))
331       return Err;
332 
333     if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
334                                  ComputeCapability.Minor))
335       return Err;
336 
337     uint32_t NumMuliprocessors = 0;
338     uint32_t MaxThreadsPerSM = 0;
339     uint32_t WarpSize = 0;
340     if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
341                                  NumMuliprocessors))
342       return Err;
343     if (auto Err =
344             getDeviceAttr(CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR,
345                           MaxThreadsPerSM))
346       return Err;
347     if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_WARP_SIZE, WarpSize))
348       return Err;
349     HardwareParallelism = NumMuliprocessors * (MaxThreadsPerSM / WarpSize);
350 
351     return Plugin::success();
352   }
353 
354   /// Deinitialize the device and release its resources.
355   Error deinitImpl() override {
356     if (Context) {
357       if (auto Err = setContext())
358         return Err;
359     }
360 
361     // Deinitialize the stream manager.
362     if (auto Err = CUDAStreamManager.deinit())
363       return Err;
364 
365     if (auto Err = CUDAEventManager.deinit())
366       return Err;
367 
368     // Close modules if necessary.
369     if (!LoadedImages.empty()) {
370       assert(Context && "Invalid CUDA context");
371 
372       // Each image has its own module.
373       for (DeviceImageTy *Image : LoadedImages) {
374         CUDADeviceImageTy &CUDAImage = static_cast<CUDADeviceImageTy &>(*Image);
375 
376         // Unload the module of the image.
377         if (auto Err = CUDAImage.unloadModule())
378           return Err;
379       }
380     }
381 
382     if (Context) {
383       CUresult Res = cuDevicePrimaryCtxRelease(Device);
384       if (auto Err =
385               Plugin::check(Res, "Error in cuDevicePrimaryCtxRelease: %s"))
386         return Err;
387     }
388 
389     // Invalidate context and device references.
390     Context = nullptr;
391     Device = CU_DEVICE_INVALID;
392 
393     return Plugin::success();
394   }
395 
396   virtual Error callGlobalConstructors(GenericPluginTy &Plugin,
397                                        DeviceImageTy &Image) override {
398     // Check for the presense of global destructors at initialization time. This
399     // is required when the image may be deallocated before destructors are run.
400     GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
401     if (Handler.isSymbolInImage(*this, Image, "nvptx$device$fini"))
402       Image.setPendingGlobalDtors();
403 
404     return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/true);
405   }
406 
407   virtual Error callGlobalDestructors(GenericPluginTy &Plugin,
408                                       DeviceImageTy &Image) override {
409     if (Image.hasPendingGlobalDtors())
410       return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
411     return Plugin::success();
412   }
413 
414   Expected<std::unique_ptr<MemoryBuffer>>
415   doJITPostProcessing(std::unique_ptr<MemoryBuffer> MB) const override {
416     // TODO: We should be able to use the 'nvidia-ptxjitcompiler' interface to
417     //       avoid the call to 'ptxas'.
418     SmallString<128> PTXInputFilePath;
419     std::error_code EC = sys::fs::createTemporaryFile("nvptx-pre-link-jit", "s",
420                                                       PTXInputFilePath);
421     if (EC)
422       return Plugin::error("Failed to create temporary file for ptxas");
423 
424     // Write the file's contents to the output file.
425     Expected<std::unique_ptr<FileOutputBuffer>> OutputOrErr =
426         FileOutputBuffer::create(PTXInputFilePath, MB->getBuffer().size());
427     if (!OutputOrErr)
428       return OutputOrErr.takeError();
429     std::unique_ptr<FileOutputBuffer> Output = std::move(*OutputOrErr);
430     llvm::copy(MB->getBuffer(), Output->getBufferStart());
431     if (Error E = Output->commit())
432       return std::move(E);
433 
434     SmallString<128> PTXOutputFilePath;
435     EC = sys::fs::createTemporaryFile("nvptx-post-link-jit", "cubin",
436                                       PTXOutputFilePath);
437     if (EC)
438       return Plugin::error("Failed to create temporary file for ptxas");
439 
440     // Try to find `ptxas` in the path to compile the PTX to a binary.
441     const auto ErrorOrPath = sys::findProgramByName("ptxas");
442     if (!ErrorOrPath)
443       return Plugin::error("Failed to find 'ptxas' on the PATH.");
444 
445     std::string Arch = getComputeUnitKind();
446     StringRef Args[] = {*ErrorOrPath,
447                         "-m64",
448                         "-O2",
449                         "--gpu-name",
450                         Arch,
451                         "--output-file",
452                         PTXOutputFilePath,
453                         PTXInputFilePath};
454 
455     std::string ErrMsg;
456     if (sys::ExecuteAndWait(*ErrorOrPath, Args, std::nullopt, {}, 0, 0,
457                             &ErrMsg))
458       return Plugin::error("Running 'ptxas' failed: %s\n", ErrMsg.c_str());
459 
460     auto BufferOrErr = MemoryBuffer::getFileOrSTDIN(PTXOutputFilePath.data());
461     if (!BufferOrErr)
462       return Plugin::error("Failed to open temporary file for ptxas");
463 
464     // Clean up the temporary files afterwards.
465     if (sys::fs::remove(PTXOutputFilePath))
466       return Plugin::error("Failed to remove temporary file for ptxas");
467     if (sys::fs::remove(PTXInputFilePath))
468       return Plugin::error("Failed to remove temporary file for ptxas");
469 
470     return std::move(*BufferOrErr);
471   }
472 
473   /// Allocate and construct a CUDA kernel.
474   Expected<GenericKernelTy &> constructKernel(const char *Name) override {
475     // Allocate and construct the CUDA kernel.
476     CUDAKernelTy *CUDAKernel = Plugin.allocate<CUDAKernelTy>();
477     if (!CUDAKernel)
478       return Plugin::error("Failed to allocate memory for CUDA kernel");
479 
480     new (CUDAKernel) CUDAKernelTy(Name);
481 
482     return *CUDAKernel;
483   }
484 
485   /// Set the current context to this device's context.
486   Error setContext() override {
487     CUresult Res = cuCtxSetCurrent(Context);
488     return Plugin::check(Res, "Error in cuCtxSetCurrent: %s");
489   }
490 
491   /// NVIDIA returns the product of the SM count and the number of warps that
492   /// fit if the maximum number of threads were scheduled on each SM.
493   uint64_t getHardwareParallelism() const override {
494     return HardwareParallelism;
495   }
496 
497   /// We want to set up the RPC server for host services to the GPU if it is
498   /// availible.
499   bool shouldSetupRPCServer() const override { return true; }
500 
501   /// The RPC interface should have enough space for all availible parallelism.
502   uint64_t requestedRPCPortCount() const override {
503     return getHardwareParallelism();
504   }
505 
506   /// Get the stream of the asynchronous info sructure or get a new one.
507   Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
508     // Get the stream (if any) from the async info.
509     Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
510     if (!Stream) {
511       // There was no stream; get an idle one.
512       if (auto Err = CUDAStreamManager.getResource(Stream))
513         return Err;
514 
515       // Modify the async info's stream.
516       AsyncInfoWrapper.setQueueAs<CUstream>(Stream);
517     }
518     return Plugin::success();
519   }
520 
521   /// Getters of CUDA references.
522   CUcontext getCUDAContext() const { return Context; }
523   CUdevice getCUDADevice() const { return Device; }
524 
525   /// Load the binary image into the device and allocate an image object.
526   Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
527                                            int32_t ImageId) override {
528     if (auto Err = setContext())
529       return std::move(Err);
530 
531     // Allocate and initialize the image object.
532     CUDADeviceImageTy *CUDAImage = Plugin.allocate<CUDADeviceImageTy>();
533     new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage);
534 
535     // Load the CUDA module.
536     if (auto Err = CUDAImage->loadModule())
537       return std::move(Err);
538 
539     return CUDAImage;
540   }
541 
542   /// Allocate memory on the device or related to the device.
543   void *allocate(size_t Size, void *, TargetAllocTy Kind) override {
544     if (Size == 0)
545       return nullptr;
546 
547     if (auto Err = setContext()) {
548       REPORT("Failure to alloc memory: %s\n", toString(std::move(Err)).data());
549       return nullptr;
550     }
551 
552     void *MemAlloc = nullptr;
553     CUdeviceptr DevicePtr;
554     CUresult Res;
555 
556     switch (Kind) {
557     case TARGET_ALLOC_DEFAULT:
558     case TARGET_ALLOC_DEVICE:
559       Res = cuMemAlloc(&DevicePtr, Size);
560       MemAlloc = (void *)DevicePtr;
561       break;
562     case TARGET_ALLOC_HOST:
563       Res = cuMemAllocHost(&MemAlloc, Size);
564       break;
565     case TARGET_ALLOC_SHARED:
566       Res = cuMemAllocManaged(&DevicePtr, Size, CU_MEM_ATTACH_GLOBAL);
567       MemAlloc = (void *)DevicePtr;
568       break;
569     case TARGET_ALLOC_DEVICE_NON_BLOCKING: {
570       CUstream Stream;
571       if ((Res = cuStreamCreate(&Stream, CU_STREAM_NON_BLOCKING)))
572         break;
573       if ((Res = cuMemAllocAsync(&DevicePtr, Size, Stream)))
574         break;
575       cuStreamSynchronize(Stream);
576       Res = cuStreamDestroy(Stream);
577       MemAlloc = (void *)DevicePtr;
578     }
579     }
580 
581     if (auto Err =
582             Plugin::check(Res, "Error in cuMemAlloc[Host|Managed]: %s")) {
583       REPORT("Failure to alloc memory: %s\n", toString(std::move(Err)).data());
584       return nullptr;
585     }
586     return MemAlloc;
587   }
588 
589   /// Deallocate memory on the device or related to the device.
590   int free(void *TgtPtr, TargetAllocTy Kind) override {
591     if (TgtPtr == nullptr)
592       return OFFLOAD_SUCCESS;
593 
594     if (auto Err = setContext()) {
595       REPORT("Failure to free memory: %s\n", toString(std::move(Err)).data());
596       return OFFLOAD_FAIL;
597     }
598 
599     CUresult Res;
600     switch (Kind) {
601     case TARGET_ALLOC_DEFAULT:
602     case TARGET_ALLOC_DEVICE:
603     case TARGET_ALLOC_SHARED:
604       Res = cuMemFree((CUdeviceptr)TgtPtr);
605       break;
606     case TARGET_ALLOC_HOST:
607       Res = cuMemFreeHost(TgtPtr);
608       break;
609     case TARGET_ALLOC_DEVICE_NON_BLOCKING: {
610       CUstream Stream;
611       if ((Res = cuStreamCreate(&Stream, CU_STREAM_NON_BLOCKING)))
612         break;
613       cuMemFreeAsync(reinterpret_cast<CUdeviceptr>(TgtPtr), Stream);
614       cuStreamSynchronize(Stream);
615       if ((Res = cuStreamDestroy(Stream)))
616         break;
617     }
618     }
619 
620     if (auto Err = Plugin::check(Res, "Error in cuMemFree[Host]: %s")) {
621       REPORT("Failure to free memory: %s\n", toString(std::move(Err)).data());
622       return OFFLOAD_FAIL;
623     }
624     return OFFLOAD_SUCCESS;
625   }
626 
627   /// Synchronize current thread with the pending operations on the async info.
628   Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
629     CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
630     CUresult Res;
631     Res = cuStreamSynchronize(Stream);
632 
633     // Once the stream is synchronized, return it to stream pool and reset
634     // AsyncInfo. This is to make sure the synchronization only works for its
635     // own tasks.
636     AsyncInfo.Queue = nullptr;
637     if (auto Err = CUDAStreamManager.returnResource(Stream))
638       return Err;
639 
640     return Plugin::check(Res, "Error in cuStreamSynchronize: %s");
641   }
642 
643   /// CUDA support VA management
644   bool supportVAManagement() const override {
645 #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11000))
646     return true;
647 #else
648     return false;
649 #endif
650   }
651 
652   /// Allocates \p RSize bytes (rounded up to page size) and hints the cuda
653   /// driver to map it to \p VAddr. The obtained address is stored in \p Addr.
654   /// At return \p RSize contains the actual size
655   Error memoryVAMap(void **Addr, void *VAddr, size_t *RSize) override {
656     CUdeviceptr DVAddr = reinterpret_cast<CUdeviceptr>(VAddr);
657     auto IHandle = DeviceMMaps.find(DVAddr);
658     size_t Size = *RSize;
659 
660     if (Size == 0)
661       return Plugin::error("Memory Map Size must be larger than 0");
662 
663     // Check if we have already mapped this address
664     if (IHandle != DeviceMMaps.end())
665       return Plugin::error("Address already memory mapped");
666 
667     CUmemAllocationProp Prop = {};
668     size_t Granularity = 0;
669 
670     size_t Free, Total;
671     CUresult Res = cuMemGetInfo(&Free, &Total);
672     if (auto Err = Plugin::check(Res, "Error in cuMemGetInfo: %s"))
673       return Err;
674 
675     if (Size >= Free) {
676       *Addr = nullptr;
677       return Plugin::error(
678           "Canot map memory size larger than the available device memory");
679     }
680 
681     // currently NVidia only supports pinned device types
682     Prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
683     Prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
684 
685     Prop.location.id = DeviceId;
686     cuMemGetAllocationGranularity(&Granularity, &Prop,
687                                   CU_MEM_ALLOC_GRANULARITY_MINIMUM);
688     if (auto Err =
689             Plugin::check(Res, "Error in cuMemGetAllocationGranularity: %s"))
690       return Err;
691 
692     if (Granularity == 0)
693       return Plugin::error("Wrong device Page size");
694 
695     // Ceil to page size.
696     Size = utils::roundUp(Size, Granularity);
697 
698     // Create a handler of our allocation
699     CUmemGenericAllocationHandle AHandle;
700     Res = cuMemCreate(&AHandle, Size, &Prop, 0);
701     if (auto Err = Plugin::check(Res, "Error in cuMemCreate: %s"))
702       return Err;
703 
704     CUdeviceptr DevPtr = 0;
705     Res = cuMemAddressReserve(&DevPtr, Size, 0, DVAddr, 0);
706     if (auto Err = Plugin::check(Res, "Error in cuMemAddressReserve: %s"))
707       return Err;
708 
709     Res = cuMemMap(DevPtr, Size, 0, AHandle, 0);
710     if (auto Err = Plugin::check(Res, "Error in cuMemMap: %s"))
711       return Err;
712 
713     CUmemAccessDesc ADesc = {};
714     ADesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
715     ADesc.location.id = DeviceId;
716     ADesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
717 
718     // Sets address
719     Res = cuMemSetAccess(DevPtr, Size, &ADesc, 1);
720     if (auto Err = Plugin::check(Res, "Error in cuMemSetAccess: %s"))
721       return Err;
722 
723     *Addr = reinterpret_cast<void *>(DevPtr);
724     *RSize = Size;
725     DeviceMMaps.insert({DevPtr, AHandle});
726     return Plugin::success();
727   }
728 
729   /// De-allocates device memory and Unmaps the Virtual Addr
730   Error memoryVAUnMap(void *VAddr, size_t Size) override {
731     CUdeviceptr DVAddr = reinterpret_cast<CUdeviceptr>(VAddr);
732     auto IHandle = DeviceMMaps.find(DVAddr);
733     // Mapping does not exist
734     if (IHandle == DeviceMMaps.end()) {
735       return Plugin::error("Addr is not MemoryMapped");
736     }
737 
738     if (IHandle == DeviceMMaps.end())
739       return Plugin::error("Addr is not MemoryMapped");
740 
741     CUmemGenericAllocationHandle &AllocHandle = IHandle->second;
742 
743     CUresult Res = cuMemUnmap(DVAddr, Size);
744     if (auto Err = Plugin::check(Res, "Error in cuMemUnmap: %s"))
745       return Err;
746 
747     Res = cuMemRelease(AllocHandle);
748     if (auto Err = Plugin::check(Res, "Error in cuMemRelease: %s"))
749       return Err;
750 
751     Res = cuMemAddressFree(DVAddr, Size);
752     if (auto Err = Plugin::check(Res, "Error in cuMemAddressFree: %s"))
753       return Err;
754 
755     DeviceMMaps.erase(IHandle);
756     return Plugin::success();
757   }
758 
759   /// Query for the completion of the pending operations on the async info.
760   Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
761     CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
762     CUresult Res = cuStreamQuery(Stream);
763 
764     // Not ready streams must be considered as successful operations.
765     if (Res == CUDA_ERROR_NOT_READY)
766       return Plugin::success();
767 
768     // Once the stream is synchronized and the operations completed (or an error
769     // occurs), return it to stream pool and reset AsyncInfo. This is to make
770     // sure the synchronization only works for its own tasks.
771     AsyncInfo.Queue = nullptr;
772     if (auto Err = CUDAStreamManager.returnResource(Stream))
773       return Err;
774 
775     return Plugin::check(Res, "Error in cuStreamQuery: %s");
776   }
777 
778   Expected<void *> dataLockImpl(void *HstPtr, int64_t Size) override {
779     // TODO: Register the buffer as CUDA host memory.
780     return HstPtr;
781   }
782 
783   Error dataUnlockImpl(void *HstPtr) override { return Plugin::success(); }
784 
785   Expected<bool> isPinnedPtrImpl(void *HstPtr, void *&BaseHstPtr,
786                                  void *&BaseDevAccessiblePtr,
787                                  size_t &BaseSize) const override {
788     // TODO: Implement pinning feature for CUDA.
789     return false;
790   }
791 
792   /// Submit data to the device (host to device transfer).
793   Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
794                        AsyncInfoWrapperTy &AsyncInfoWrapper) override {
795     if (auto Err = setContext())
796       return Err;
797 
798     CUstream Stream;
799     if (auto Err = getStream(AsyncInfoWrapper, Stream))
800       return Err;
801 
802     CUresult Res = cuMemcpyHtoDAsync((CUdeviceptr)TgtPtr, HstPtr, Size, Stream);
803     return Plugin::check(Res, "Error in cuMemcpyHtoDAsync: %s");
804   }
805 
806   /// Retrieve data from the device (device to host transfer).
807   Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size,
808                          AsyncInfoWrapperTy &AsyncInfoWrapper) override {
809     if (auto Err = setContext())
810       return Err;
811 
812     CUstream Stream;
813     if (auto Err = getStream(AsyncInfoWrapper, Stream))
814       return Err;
815 
816     CUresult Res = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
817     return Plugin::check(Res, "Error in cuMemcpyDtoHAsync: %s");
818   }
819 
820   /// Exchange data between two devices directly. We may use peer access if
821   /// the CUDA devices and driver allow them.
822   Error dataExchangeImpl(const void *SrcPtr, GenericDeviceTy &DstGenericDevice,
823                          void *DstPtr, int64_t Size,
824                          AsyncInfoWrapperTy &AsyncInfoWrapper) override;
825 
826   /// Initialize the async info for interoperability purposes.
827   Error initAsyncInfoImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) override {
828     if (auto Err = setContext())
829       return Err;
830 
831     CUstream Stream;
832     if (auto Err = getStream(AsyncInfoWrapper, Stream))
833       return Err;
834 
835     return Plugin::success();
836   }
837 
838   /// Initialize the device info for interoperability purposes.
839   Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) override {
840     assert(Context && "Context is null");
841     assert(Device != CU_DEVICE_INVALID && "Invalid CUDA device");
842 
843     if (auto Err = setContext())
844       return Err;
845 
846     if (!DeviceInfo->Context)
847       DeviceInfo->Context = Context;
848 
849     if (!DeviceInfo->Device)
850       DeviceInfo->Device = reinterpret_cast<void *>(Device);
851 
852     return Plugin::success();
853   }
854 
855   /// Create an event.
856   Error createEventImpl(void **EventPtrStorage) override {
857     CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);
858     return CUDAEventManager.getResource(*Event);
859   }
860 
861   /// Destroy a previously created event.
862   Error destroyEventImpl(void *EventPtr) override {
863     CUevent Event = reinterpret_cast<CUevent>(EventPtr);
864     return CUDAEventManager.returnResource(Event);
865   }
866 
867   /// Record the event.
868   Error recordEventImpl(void *EventPtr,
869                         AsyncInfoWrapperTy &AsyncInfoWrapper) override {
870     CUevent Event = reinterpret_cast<CUevent>(EventPtr);
871 
872     CUstream Stream;
873     if (auto Err = getStream(AsyncInfoWrapper, Stream))
874       return Err;
875 
876     CUresult Res = cuEventRecord(Event, Stream);
877     return Plugin::check(Res, "Error in cuEventRecord: %s");
878   }
879 
880   /// Make the stream wait on the event.
881   Error waitEventImpl(void *EventPtr,
882                       AsyncInfoWrapperTy &AsyncInfoWrapper) override {
883     CUevent Event = reinterpret_cast<CUevent>(EventPtr);
884 
885     CUstream Stream;
886     if (auto Err = getStream(AsyncInfoWrapper, Stream))
887       return Err;
888 
889     // Do not use CU_EVENT_WAIT_DEFAULT here as it is only available from
890     // specific CUDA version, and defined as 0x0. In previous version, per CUDA
891     // API document, that argument has to be 0x0.
892     CUresult Res = cuStreamWaitEvent(Stream, Event, 0);
893     return Plugin::check(Res, "Error in cuStreamWaitEvent: %s");
894   }
895 
896   /// Synchronize the current thread with the event.
897   Error syncEventImpl(void *EventPtr) override {
898     CUevent Event = reinterpret_cast<CUevent>(EventPtr);
899     CUresult Res = cuEventSynchronize(Event);
900     return Plugin::check(Res, "Error in cuEventSynchronize: %s");
901   }
902 
903   /// Print information about the device.
904   Error obtainInfoImpl(InfoQueueTy &Info) override {
905     char TmpChar[1000];
906     const char *TmpCharPtr;
907     size_t TmpSt;
908     int TmpInt;
909 
910     CUresult Res = cuDriverGetVersion(&TmpInt);
911     if (Res == CUDA_SUCCESS)
912       Info.add("CUDA Driver Version", TmpInt);
913 
914     Info.add("CUDA OpenMP Device Number", DeviceId);
915 
916     Res = cuDeviceGetName(TmpChar, 1000, Device);
917     if (Res == CUDA_SUCCESS)
918       Info.add("Device Name", TmpChar);
919 
920     Res = cuDeviceTotalMem(&TmpSt, Device);
921     if (Res == CUDA_SUCCESS)
922       Info.add("Global Memory Size", TmpSt, "bytes");
923 
924     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, TmpInt);
925     if (Res == CUDA_SUCCESS)
926       Info.add("Number of Multiprocessors", TmpInt);
927 
928     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_GPU_OVERLAP, TmpInt);
929     if (Res == CUDA_SUCCESS)
930       Info.add("Concurrent Copy and Execution", (bool)TmpInt);
931 
932     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY, TmpInt);
933     if (Res == CUDA_SUCCESS)
934       Info.add("Total Constant Memory", TmpInt, "bytes");
935 
936     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
937                            TmpInt);
938     if (Res == CUDA_SUCCESS)
939       Info.add("Max Shared Memory per Block", TmpInt, "bytes");
940 
941     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, TmpInt);
942     if (Res == CUDA_SUCCESS)
943       Info.add("Registers per Block", TmpInt);
944 
945     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_WARP_SIZE, TmpInt);
946     if (Res == CUDA_SUCCESS)
947       Info.add("Warp Size", TmpInt);
948 
949     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, TmpInt);
950     if (Res == CUDA_SUCCESS)
951       Info.add("Maximum Threads per Block", TmpInt);
952 
953     Info.add("Maximum Block Dimensions", "");
954     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, TmpInt);
955     if (Res == CUDA_SUCCESS)
956       Info.add<InfoLevel2>("x", TmpInt);
957     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, TmpInt);
958     if (Res == CUDA_SUCCESS)
959       Info.add<InfoLevel2>("y", TmpInt);
960     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, TmpInt);
961     if (Res == CUDA_SUCCESS)
962       Info.add<InfoLevel2>("z", TmpInt);
963 
964     Info.add("Maximum Grid Dimensions", "");
965     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, TmpInt);
966     if (Res == CUDA_SUCCESS)
967       Info.add<InfoLevel2>("x", TmpInt);
968     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, TmpInt);
969     if (Res == CUDA_SUCCESS)
970       Info.add<InfoLevel2>("y", TmpInt);
971     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, TmpInt);
972     if (Res == CUDA_SUCCESS)
973       Info.add<InfoLevel2>("z", TmpInt);
974 
975     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_PITCH, TmpInt);
976     if (Res == CUDA_SUCCESS)
977       Info.add("Maximum Memory Pitch", TmpInt, "bytes");
978 
979     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT, TmpInt);
980     if (Res == CUDA_SUCCESS)
981       Info.add("Texture Alignment", TmpInt, "bytes");
982 
983     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_CLOCK_RATE, TmpInt);
984     if (Res == CUDA_SUCCESS)
985       Info.add("Clock Rate", TmpInt, "kHz");
986 
987     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT, TmpInt);
988     if (Res == CUDA_SUCCESS)
989       Info.add("Execution Timeout", (bool)TmpInt);
990 
991     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_INTEGRATED, TmpInt);
992     if (Res == CUDA_SUCCESS)
993       Info.add("Integrated Device", (bool)TmpInt);
994 
995     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY, TmpInt);
996     if (Res == CUDA_SUCCESS)
997       Info.add("Can Map Host Memory", (bool)TmpInt);
998 
999     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, TmpInt);
1000     if (Res == CUDA_SUCCESS) {
1001       if (TmpInt == CU_COMPUTEMODE_DEFAULT)
1002         TmpCharPtr = "Default";
1003       else if (TmpInt == CU_COMPUTEMODE_PROHIBITED)
1004         TmpCharPtr = "Prohibited";
1005       else if (TmpInt == CU_COMPUTEMODE_EXCLUSIVE_PROCESS)
1006         TmpCharPtr = "Exclusive process";
1007       else
1008         TmpCharPtr = "Unknown";
1009       Info.add("Compute Mode", TmpCharPtr);
1010     }
1011 
1012     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS, TmpInt);
1013     if (Res == CUDA_SUCCESS)
1014       Info.add("Concurrent Kernels", (bool)TmpInt);
1015 
1016     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_ECC_ENABLED, TmpInt);
1017     if (Res == CUDA_SUCCESS)
1018       Info.add("ECC Enabled", (bool)TmpInt);
1019 
1020     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, TmpInt);
1021     if (Res == CUDA_SUCCESS)
1022       Info.add("Memory Clock Rate", TmpInt, "kHz");
1023 
1024     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, TmpInt);
1025     if (Res == CUDA_SUCCESS)
1026       Info.add("Memory Bus Width", TmpInt, "bits");
1027 
1028     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, TmpInt);
1029     if (Res == CUDA_SUCCESS)
1030       Info.add("L2 Cache Size", TmpInt, "bytes");
1031 
1032     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR,
1033                            TmpInt);
1034     if (Res == CUDA_SUCCESS)
1035       Info.add("Max Threads Per SMP", TmpInt);
1036 
1037     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT, TmpInt);
1038     if (Res == CUDA_SUCCESS)
1039       Info.add("Async Engines", TmpInt);
1040 
1041     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, TmpInt);
1042     if (Res == CUDA_SUCCESS)
1043       Info.add("Unified Addressing", (bool)TmpInt);
1044 
1045     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY, TmpInt);
1046     if (Res == CUDA_SUCCESS)
1047       Info.add("Managed Memory", (bool)TmpInt);
1048 
1049     Res =
1050         getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, TmpInt);
1051     if (Res == CUDA_SUCCESS)
1052       Info.add("Concurrent Managed Memory", (bool)TmpInt);
1053 
1054     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED,
1055                            TmpInt);
1056     if (Res == CUDA_SUCCESS)
1057       Info.add("Preemption Supported", (bool)TmpInt);
1058 
1059     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH, TmpInt);
1060     if (Res == CUDA_SUCCESS)
1061       Info.add("Cooperative Launch", (bool)TmpInt);
1062 
1063     Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD, TmpInt);
1064     if (Res == CUDA_SUCCESS)
1065       Info.add("Multi-Device Boars", (bool)TmpInt);
1066 
1067     Info.add("Compute Capabilities", ComputeCapability.str());
1068 
1069     return Plugin::success();
1070   }
1071 
1072   virtual bool shouldSetupDeviceMemoryPool() const override {
1073     /// We use the CUDA malloc for now.
1074     return false;
1075   }
1076 
1077   /// Getters and setters for stack and heap sizes.
1078   Error getDeviceStackSize(uint64_t &Value) override {
1079     return getCtxLimit(CU_LIMIT_STACK_SIZE, Value);
1080   }
1081   Error setDeviceStackSize(uint64_t Value) override {
1082     return setCtxLimit(CU_LIMIT_STACK_SIZE, Value);
1083   }
1084   Error getDeviceHeapSize(uint64_t &Value) override {
1085     return getCtxLimit(CU_LIMIT_MALLOC_HEAP_SIZE, Value);
1086   }
1087   Error setDeviceHeapSize(uint64_t Value) override {
1088     return setCtxLimit(CU_LIMIT_MALLOC_HEAP_SIZE, Value);
1089   }
1090   Error getDeviceMemorySize(uint64_t &Value) override {
1091     CUresult Res = cuDeviceTotalMem(&Value, Device);
1092     return Plugin::check(Res, "Error in getDeviceMemorySize %s");
1093   }
1094 
1095   /// CUDA-specific functions for getting and setting context limits.
1096   Error setCtxLimit(CUlimit Kind, uint64_t Value) {
1097     CUresult Res = cuCtxSetLimit(Kind, Value);
1098     return Plugin::check(Res, "Error in cuCtxSetLimit: %s");
1099   }
1100   Error getCtxLimit(CUlimit Kind, uint64_t &Value) {
1101     CUresult Res = cuCtxGetLimit(&Value, Kind);
1102     return Plugin::check(Res, "Error in cuCtxGetLimit: %s");
1103   }
1104 
1105   /// CUDA-specific function to get device attributes.
1106   Error getDeviceAttr(uint32_t Kind, uint32_t &Value) {
1107     // TODO: Warn if the new value is larger than the old.
1108     CUresult Res =
1109         cuDeviceGetAttribute((int *)&Value, (CUdevice_attribute)Kind, Device);
1110     return Plugin::check(Res, "Error in cuDeviceGetAttribute: %s");
1111   }
1112 
1113   CUresult getDeviceAttrRaw(uint32_t Kind, int &Value) {
1114     return cuDeviceGetAttribute(&Value, (CUdevice_attribute)Kind, Device);
1115   }
1116 
1117   /// See GenericDeviceTy::getComputeUnitKind().
1118   std::string getComputeUnitKind() const override {
1119     return ComputeCapability.str();
1120   }
1121 
1122   /// Returns the clock frequency for the given NVPTX device.
1123   uint64_t getClockFrequency() const override { return 1000000000; }
1124 
1125 private:
1126   using CUDAStreamManagerTy = GenericDeviceResourceManagerTy<CUDAStreamRef>;
1127   using CUDAEventManagerTy = GenericDeviceResourceManagerTy<CUDAEventRef>;
1128 
1129   Error callGlobalCtorDtorCommon(GenericPluginTy &Plugin, DeviceImageTy &Image,
1130                                  bool IsCtor) {
1131     const char *KernelName = IsCtor ? "nvptx$device$init" : "nvptx$device$fini";
1132     // Perform a quick check for the named kernel in the image. The kernel
1133     // should be created by the 'nvptx-lower-ctor-dtor' pass.
1134     GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
1135     if (IsCtor && !Handler.isSymbolInImage(*this, Image, KernelName))
1136       return Plugin::success();
1137 
1138     // The Nvidia backend cannot handle creating the ctor / dtor array
1139     // automatically so we must create it ourselves. The backend will emit
1140     // several globals that contain function pointers we can call. These are
1141     // prefixed with a known name due to Nvidia's lack of section support.
1142     auto ELFObjOrErr = Handler.getELFObjectFile(Image);
1143     if (!ELFObjOrErr)
1144       return ELFObjOrErr.takeError();
1145 
1146     // Search for all symbols that contain a constructor or destructor.
1147     SmallVector<std::pair<StringRef, uint16_t>> Funcs;
1148     for (ELFSymbolRef Sym : (*ELFObjOrErr)->symbols()) {
1149       auto NameOrErr = Sym.getName();
1150       if (!NameOrErr)
1151         return NameOrErr.takeError();
1152 
1153       if (!NameOrErr->starts_with(IsCtor ? "__init_array_object_"
1154                                          : "__fini_array_object_"))
1155         continue;
1156 
1157       uint16_t Priority;
1158       if (NameOrErr->rsplit('_').second.getAsInteger(10, Priority))
1159         return Plugin::error("Invalid priority for constructor or destructor");
1160 
1161       Funcs.emplace_back(*NameOrErr, Priority);
1162     }
1163 
1164     // Sort the created array to be in priority order.
1165     llvm::sort(Funcs, [=](auto X, auto Y) { return X.second < Y.second; });
1166 
1167     // Allocate a buffer to store all of the known constructor / destructor
1168     // functions in so we can iterate them on the device.
1169     void *Buffer =
1170         allocate(Funcs.size() * sizeof(void *), nullptr, TARGET_ALLOC_DEVICE);
1171     if (!Buffer)
1172       return Plugin::error("Failed to allocate memory for global buffer");
1173 
1174     auto *GlobalPtrStart = reinterpret_cast<uintptr_t *>(Buffer);
1175     auto *GlobalPtrStop = reinterpret_cast<uintptr_t *>(Buffer) + Funcs.size();
1176 
1177     SmallVector<void *> FunctionPtrs(Funcs.size());
1178     std::size_t Idx = 0;
1179     for (auto [Name, Priority] : Funcs) {
1180       GlobalTy FunctionAddr(Name.str(), sizeof(void *), &FunctionPtrs[Idx++]);
1181       if (auto Err = Handler.readGlobalFromDevice(*this, Image, FunctionAddr))
1182         return Err;
1183     }
1184 
1185     // Copy the local buffer to the device.
1186     if (auto Err = dataSubmit(GlobalPtrStart, FunctionPtrs.data(),
1187                               FunctionPtrs.size() * sizeof(void *), nullptr))
1188       return Err;
1189 
1190     // Copy the created buffer to the appropriate symbols so the kernel can
1191     // iterate through them.
1192     GlobalTy StartGlobal(IsCtor ? "__init_array_start" : "__fini_array_start",
1193                          sizeof(void *), &GlobalPtrStart);
1194     if (auto Err = Handler.writeGlobalToDevice(*this, Image, StartGlobal))
1195       return Err;
1196 
1197     GlobalTy StopGlobal(IsCtor ? "__init_array_end" : "__fini_array_end",
1198                         sizeof(void *), &GlobalPtrStop);
1199     if (auto Err = Handler.writeGlobalToDevice(*this, Image, StopGlobal))
1200       return Err;
1201 
1202     CUDAKernelTy CUDAKernel(KernelName);
1203 
1204     if (auto Err = CUDAKernel.init(*this, Image))
1205       return Err;
1206 
1207     AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
1208 
1209     KernelArgsTy KernelArgs = {};
1210     uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
1211     if (auto Err = CUDAKernel.launchImpl(
1212             *this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
1213             KernelLaunchParamsTy{}, AsyncInfoWrapper))
1214       return Err;
1215 
1216     Error Err = Plugin::success();
1217     AsyncInfoWrapper.finalize(Err);
1218 
1219     if (free(Buffer, TARGET_ALLOC_DEVICE) != OFFLOAD_SUCCESS)
1220       return Plugin::error("Failed to free memory for global buffer");
1221 
1222     return Err;
1223   }
1224 
1225   /// Stream manager for CUDA streams.
1226   CUDAStreamManagerTy CUDAStreamManager;
1227 
1228   /// Event manager for CUDA events.
1229   CUDAEventManagerTy CUDAEventManager;
1230 
1231   /// The device's context. This context should be set before performing
1232   /// operations on the device.
1233   CUcontext Context = nullptr;
1234 
1235   /// The CUDA device handler.
1236   CUdevice Device = CU_DEVICE_INVALID;
1237 
1238   /// The memory mapped addresses and their handles
1239   std::unordered_map<CUdeviceptr, CUmemGenericAllocationHandle> DeviceMMaps;
1240 
1241   /// The compute capability of the corresponding CUDA device.
1242   struct ComputeCapabilityTy {
1243     uint32_t Major;
1244     uint32_t Minor;
1245     std::string str() const {
1246       return "sm_" + std::to_string(Major * 10 + Minor);
1247     }
1248   } ComputeCapability;
1249 
1250   /// The maximum number of warps that can be resident on all the SMs
1251   /// simultaneously.
1252   uint32_t HardwareParallelism = 0;
1253 };
1254 
1255 Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
1256                                uint32_t NumThreads[3], uint32_t NumBlocks[3],
1257                                KernelArgsTy &KernelArgs,
1258                                KernelLaunchParamsTy LaunchParams,
1259                                AsyncInfoWrapperTy &AsyncInfoWrapper) const {
1260   CUDADeviceTy &CUDADevice = static_cast<CUDADeviceTy &>(GenericDevice);
1261 
1262   CUstream Stream;
1263   if (auto Err = CUDADevice.getStream(AsyncInfoWrapper, Stream))
1264     return Err;
1265 
1266   uint32_t MaxDynCGroupMem =
1267       std::max(KernelArgs.DynCGroupMem, GenericDevice.getDynamicMemorySize());
1268 
1269   void *Config[] = {CU_LAUNCH_PARAM_BUFFER_POINTER, LaunchParams.Data,
1270                     CU_LAUNCH_PARAM_BUFFER_SIZE,
1271                     reinterpret_cast<void *>(&LaunchParams.Size),
1272                     CU_LAUNCH_PARAM_END};
1273 
1274   // If we are running an RPC server we want to wake up the server thread
1275   // whenever there is a kernel running and let it sleep otherwise.
1276   if (GenericDevice.getRPCServer())
1277     GenericDevice.Plugin.getRPCServer().Thread->notify();
1278 
1279   CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
1280                                 NumThreads[0], NumThreads[1], NumThreads[2],
1281                                 MaxDynCGroupMem, Stream, nullptr, Config);
1282 
1283   // Register a callback to indicate when the kernel is complete.
1284   if (GenericDevice.getRPCServer())
1285     cuLaunchHostFunc(
1286         Stream,
1287         [](void *Data) {
1288           GenericPluginTy &Plugin = *reinterpret_cast<GenericPluginTy *>(Data);
1289           Plugin.getRPCServer().Thread->finish();
1290         },
1291         &GenericDevice.Plugin);
1292 
1293   return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName());
1294 }
1295 
1296 /// Class implementing the CUDA-specific functionalities of the global handler.
1297 class CUDAGlobalHandlerTy final : public GenericGlobalHandlerTy {
1298 public:
1299   /// Get the metadata of a global from the device. The name and size of the
1300   /// global is read from DeviceGlobal and the address of the global is written
1301   /// to DeviceGlobal.
1302   Error getGlobalMetadataFromDevice(GenericDeviceTy &Device,
1303                                     DeviceImageTy &Image,
1304                                     GlobalTy &DeviceGlobal) override {
1305     CUDADeviceImageTy &CUDAImage = static_cast<CUDADeviceImageTy &>(Image);
1306 
1307     const char *GlobalName = DeviceGlobal.getName().data();
1308 
1309     size_t CUSize;
1310     CUdeviceptr CUPtr;
1311     CUresult Res =
1312         cuModuleGetGlobal(&CUPtr, &CUSize, CUDAImage.getModule(), GlobalName);
1313     if (auto Err = Plugin::check(Res, "Error in cuModuleGetGlobal for '%s': %s",
1314                                  GlobalName))
1315       return Err;
1316 
1317     if (CUSize != DeviceGlobal.getSize())
1318       return Plugin::error(
1319           "Failed to load global '%s' due to size mismatch (%zu != %zu)",
1320           GlobalName, CUSize, (size_t)DeviceGlobal.getSize());
1321 
1322     DeviceGlobal.setPtr(reinterpret_cast<void *>(CUPtr));
1323     return Plugin::success();
1324   }
1325 };
1326 
1327 /// Class implementing the CUDA-specific functionalities of the plugin.
1328 struct CUDAPluginTy final : public GenericPluginTy {
1329   /// Create a CUDA plugin.
1330   CUDAPluginTy() : GenericPluginTy(getTripleArch()) {}
1331 
1332   /// This class should not be copied.
1333   CUDAPluginTy(const CUDAPluginTy &) = delete;
1334   CUDAPluginTy(CUDAPluginTy &&) = delete;
1335 
1336   /// Initialize the plugin and return the number of devices.
1337   Expected<int32_t> initImpl() override {
1338     CUresult Res = cuInit(0);
1339     if (Res == CUDA_ERROR_INVALID_HANDLE) {
1340       // Cannot call cuGetErrorString if dlsym failed.
1341       DP("Failed to load CUDA shared library\n");
1342       return 0;
1343     }
1344 
1345     if (Res == CUDA_ERROR_NO_DEVICE) {
1346       // Do not initialize if there are no devices.
1347       DP("There are no devices supporting CUDA.\n");
1348       return 0;
1349     }
1350 
1351     if (auto Err = Plugin::check(Res, "Error in cuInit: %s"))
1352       return std::move(Err);
1353 
1354     // Get the number of devices.
1355     int NumDevices;
1356     Res = cuDeviceGetCount(&NumDevices);
1357     if (auto Err = Plugin::check(Res, "Error in cuDeviceGetCount: %s"))
1358       return std::move(Err);
1359 
1360     // Do not initialize if there are no devices.
1361     if (NumDevices == 0)
1362       DP("There are no devices supporting CUDA.\n");
1363 
1364     return NumDevices;
1365   }
1366 
1367   /// Deinitialize the plugin.
1368   Error deinitImpl() override { return Plugin::success(); }
1369 
1370   /// Creates a CUDA device to use for offloading.
1371   GenericDeviceTy *createDevice(GenericPluginTy &Plugin, int32_t DeviceId,
1372                                 int32_t NumDevices) override {
1373     return new CUDADeviceTy(Plugin, DeviceId, NumDevices);
1374   }
1375 
1376   /// Creates a CUDA global handler.
1377   GenericGlobalHandlerTy *createGlobalHandler() override {
1378     return new CUDAGlobalHandlerTy();
1379   }
1380 
1381   /// Get the ELF code for recognizing the compatible image binary.
1382   uint16_t getMagicElfBits() const override { return ELF::EM_CUDA; }
1383 
1384   Triple::ArchType getTripleArch() const override {
1385     // TODO: I think we can drop the support for 32-bit NVPTX devices.
1386     return Triple::nvptx64;
1387   }
1388 
1389   const char *getName() const override { return GETNAME(TARGET_NAME); }
1390 
1391   /// Check whether the image is compatible with a CUDA device.
1392   Expected<bool> isELFCompatible(uint32_t DeviceId,
1393                                  StringRef Image) const override {
1394     auto ElfOrErr =
1395         ELF64LEObjectFile::create(MemoryBufferRef(Image, /*Identifier=*/""),
1396                                   /*InitContent=*/false);
1397     if (!ElfOrErr)
1398       return ElfOrErr.takeError();
1399 
1400     // Get the numeric value for the image's `sm_` value.
1401     auto SM = ElfOrErr->getPlatformFlags() & ELF::EF_CUDA_SM;
1402 
1403     CUdevice Device;
1404     CUresult Res = cuDeviceGet(&Device, DeviceId);
1405     if (auto Err = Plugin::check(Res, "Error in cuDeviceGet: %s"))
1406       return std::move(Err);
1407 
1408     int32_t Major, Minor;
1409     Res = cuDeviceGetAttribute(
1410         &Major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, Device);
1411     if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
1412       return std::move(Err);
1413 
1414     Res = cuDeviceGetAttribute(
1415         &Minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, Device);
1416     if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
1417       return std::move(Err);
1418 
1419     int32_t ImageMajor = SM / 10;
1420     int32_t ImageMinor = SM % 10;
1421 
1422     // A cubin generated for a certain compute capability is supported to
1423     // run on any GPU with the same major revision and same or higher minor
1424     // revision.
1425     return Major == ImageMajor && Minor >= ImageMinor;
1426   }
1427 };
1428 
1429 Error CUDADeviceTy::dataExchangeImpl(const void *SrcPtr,
1430                                      GenericDeviceTy &DstGenericDevice,
1431                                      void *DstPtr, int64_t Size,
1432                                      AsyncInfoWrapperTy &AsyncInfoWrapper) {
1433   if (auto Err = setContext())
1434     return Err;
1435 
1436   CUDADeviceTy &DstDevice = static_cast<CUDADeviceTy &>(DstGenericDevice);
1437 
1438   CUresult Res;
1439   int32_t DstDeviceId = DstDevice.DeviceId;
1440   CUdeviceptr CUSrcPtr = (CUdeviceptr)SrcPtr;
1441   CUdeviceptr CUDstPtr = (CUdeviceptr)DstPtr;
1442 
1443   int CanAccessPeer = 0;
1444   if (DeviceId != DstDeviceId) {
1445     // Make sure the lock is released before performing the copies.
1446     std::lock_guard<std::mutex> Lock(PeerAccessesLock);
1447 
1448     switch (PeerAccesses[DstDeviceId]) {
1449     case PeerAccessState::AVAILABLE:
1450       CanAccessPeer = 1;
1451       break;
1452     case PeerAccessState::UNAVAILABLE:
1453       CanAccessPeer = 0;
1454       break;
1455     case PeerAccessState::PENDING:
1456       // Check whether the source device can access the destination device.
1457       Res = cuDeviceCanAccessPeer(&CanAccessPeer, Device, DstDevice.Device);
1458       if (auto Err = Plugin::check(Res, "Error in cuDeviceCanAccessPeer: %s"))
1459         return Err;
1460 
1461       if (CanAccessPeer) {
1462         Res = cuCtxEnablePeerAccess(DstDevice.Context, 0);
1463         if (Res == CUDA_ERROR_TOO_MANY_PEERS) {
1464           // Resources may be exhausted due to many P2P links.
1465           CanAccessPeer = 0;
1466           DP("Too many P2P so fall back to D2D memcpy");
1467         } else if (auto Err =
1468                        Plugin::check(Res, "Error in cuCtxEnablePeerAccess: %s"))
1469           return Err;
1470       }
1471       PeerAccesses[DstDeviceId] = (CanAccessPeer)
1472                                       ? PeerAccessState::AVAILABLE
1473                                       : PeerAccessState::UNAVAILABLE;
1474     }
1475   }
1476 
1477   CUstream Stream;
1478   if (auto Err = getStream(AsyncInfoWrapper, Stream))
1479     return Err;
1480 
1481   if (CanAccessPeer) {
1482     // TODO: Should we fallback to D2D if peer access fails?
1483     Res = cuMemcpyPeerAsync(CUDstPtr, Context, CUSrcPtr, DstDevice.Context,
1484                             Size, Stream);
1485     return Plugin::check(Res, "Error in cuMemcpyPeerAsync: %s");
1486   }
1487 
1488   // Fallback to D2D copy.
1489   Res = cuMemcpyDtoDAsync(CUDstPtr, CUSrcPtr, Size, Stream);
1490   return Plugin::check(Res, "Error in cuMemcpyDtoDAsync: %s");
1491 }
1492 
1493 template <typename... ArgsTy>
1494 static Error Plugin::check(int32_t Code, const char *ErrFmt, ArgsTy... Args) {
1495   CUresult ResultCode = static_cast<CUresult>(Code);
1496   if (ResultCode == CUDA_SUCCESS)
1497     return Error::success();
1498 
1499   const char *Desc = "Unknown error";
1500   CUresult Ret = cuGetErrorString(ResultCode, &Desc);
1501   if (Ret != CUDA_SUCCESS)
1502     REPORT("Unrecognized " GETNAME(TARGET_NAME) " error code %d\n", Code);
1503 
1504   return createStringError<ArgsTy..., const char *>(inconvertibleErrorCode(),
1505                                                     ErrFmt, Args..., Desc);
1506 }
1507 
1508 } // namespace plugin
1509 } // namespace target
1510 } // namespace omp
1511 } // namespace llvm
1512 
1513 extern "C" {
1514 llvm::omp::target::plugin::GenericPluginTy *createPlugin_cuda() {
1515   return new llvm::omp::target::plugin::CUDAPluginTy();
1516 }
1517 }
1518