xref: /openbsd-src/gnu/llvm/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp (revision 12c855180aad702bbcca06e0398d774beeafb155)
1 //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "OffloadWrapper.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/ADT/Triple.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/GlobalVariable.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/Object/OffloadBinary.h"
18 #include "llvm/Support/Error.h"
19 #include "llvm/Transforms/Utils/ModuleUtils.h"
20 
21 using namespace llvm;
22 
23 namespace {
24 /// Magic number that begins the section containing the CUDA fatbinary.
25 constexpr unsigned CudaFatMagic = 0x466243b1;
26 constexpr unsigned HIPFatMagic = 0x48495046;
27 
28 /// Copied from clang/CGCudaRuntime.h.
29 enum OffloadEntryKindFlag : uint32_t {
30   /// Mark the entry as a global entry. This indicates the presense of a
31   /// kernel if the size size field is zero and a variable otherwise.
32   OffloadGlobalEntry = 0x0,
33   /// Mark the entry as a managed global variable.
34   OffloadGlobalManagedEntry = 0x1,
35   /// Mark the entry as a surface variable.
36   OffloadGlobalSurfaceEntry = 0x2,
37   /// Mark the entry as a texture variable.
38   OffloadGlobalTextureEntry = 0x3,
39 };
40 
getSizeTTy(Module & M)41 IntegerType *getSizeTTy(Module &M) {
42   LLVMContext &C = M.getContext();
43   switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) {
44   case 4u:
45     return Type::getInt32Ty(C);
46   case 8u:
47     return Type::getInt64Ty(C);
48   }
49   llvm_unreachable("unsupported pointer type size");
50 }
51 
52 // struct __tgt_offload_entry {
53 //   void *addr;
54 //   char *name;
55 //   size_t size;
56 //   int32_t flags;
57 //   int32_t reserved;
58 // };
getEntryTy(Module & M)59 StructType *getEntryTy(Module &M) {
60   LLVMContext &C = M.getContext();
61   StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry");
62   if (!EntryTy)
63     EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C),
64                                  Type::getInt8PtrTy(C), getSizeTTy(M),
65                                  Type::getInt32Ty(C), Type::getInt32Ty(C));
66   return EntryTy;
67 }
68 
getEntryPtrTy(Module & M)69 PointerType *getEntryPtrTy(Module &M) {
70   return PointerType::getUnqual(getEntryTy(M));
71 }
72 
73 // struct __tgt_device_image {
74 //   void *ImageStart;
75 //   void *ImageEnd;
76 //   __tgt_offload_entry *EntriesBegin;
77 //   __tgt_offload_entry *EntriesEnd;
78 // };
getDeviceImageTy(Module & M)79 StructType *getDeviceImageTy(Module &M) {
80   LLVMContext &C = M.getContext();
81   StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
82   if (!ImageTy)
83     ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C),
84                                  Type::getInt8PtrTy(C), getEntryPtrTy(M),
85                                  getEntryPtrTy(M));
86   return ImageTy;
87 }
88 
getDeviceImagePtrTy(Module & M)89 PointerType *getDeviceImagePtrTy(Module &M) {
90   return PointerType::getUnqual(getDeviceImageTy(M));
91 }
92 
93 // struct __tgt_bin_desc {
94 //   int32_t NumDeviceImages;
95 //   __tgt_device_image *DeviceImages;
96 //   __tgt_offload_entry *HostEntriesBegin;
97 //   __tgt_offload_entry *HostEntriesEnd;
98 // };
getBinDescTy(Module & M)99 StructType *getBinDescTy(Module &M) {
100   LLVMContext &C = M.getContext();
101   StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
102   if (!DescTy)
103     DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C),
104                                 getDeviceImagePtrTy(M), getEntryPtrTy(M),
105                                 getEntryPtrTy(M));
106   return DescTy;
107 }
108 
getBinDescPtrTy(Module & M)109 PointerType *getBinDescPtrTy(Module &M) {
110   return PointerType::getUnqual(getBinDescTy(M));
111 }
112 
113 /// Creates binary descriptor for the given device images. Binary descriptor
114 /// is an object that is passed to the offloading runtime at program startup
115 /// and it describes all device images available in the executable or shared
116 /// library. It is defined as follows
117 ///
118 /// __attribute__((visibility("hidden")))
119 /// extern __tgt_offload_entry *__start_omp_offloading_entries;
120 /// __attribute__((visibility("hidden")))
121 /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
122 ///
123 /// static const char Image0[] = { <Bufs.front() contents> };
124 ///  ...
125 /// static const char ImageN[] = { <Bufs.back() contents> };
126 ///
127 /// static const __tgt_device_image Images[] = {
128 ///   {
129 ///     Image0,                            /*ImageStart*/
130 ///     Image0 + sizeof(Image0),           /*ImageEnd*/
131 ///     __start_omp_offloading_entries,    /*EntriesBegin*/
132 ///     __stop_omp_offloading_entries      /*EntriesEnd*/
133 ///   },
134 ///   ...
135 ///   {
136 ///     ImageN,                            /*ImageStart*/
137 ///     ImageN + sizeof(ImageN),           /*ImageEnd*/
138 ///     __start_omp_offloading_entries,    /*EntriesBegin*/
139 ///     __stop_omp_offloading_entries      /*EntriesEnd*/
140 ///   }
141 /// };
142 ///
143 /// static const __tgt_bin_desc BinDesc = {
144 ///   sizeof(Images) / sizeof(Images[0]),  /*NumDeviceImages*/
145 ///   Images,                              /*DeviceImages*/
146 ///   __start_omp_offloading_entries,      /*HostEntriesBegin*/
147 ///   __stop_omp_offloading_entries        /*HostEntriesEnd*/
148 /// };
149 ///
150 /// Global variable that represents BinDesc is returned.
createBinDesc(Module & M,ArrayRef<ArrayRef<char>> Bufs)151 GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
152   LLVMContext &C = M.getContext();
153   // Create external begin/end symbols for the offload entries table.
154   auto *EntriesB = new GlobalVariable(
155       M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
156       /*Initializer*/ nullptr, "__start_omp_offloading_entries");
157   EntriesB->setVisibility(GlobalValue::HiddenVisibility);
158   auto *EntriesE = new GlobalVariable(
159       M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
160       /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
161   EntriesE->setVisibility(GlobalValue::HiddenVisibility);
162 
163   // We assume that external begin/end symbols that we have created above will
164   // be defined by the linker. But linker will do that only if linker inputs
165   // have section with "omp_offloading_entries" name which is not guaranteed.
166   // So, we just create dummy zero sized object in the offload entries section
167   // to force linker to define those symbols.
168   auto *DummyInit =
169       ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
170   auto *DummyEntry = new GlobalVariable(
171       M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
172       "__dummy.omp_offloading.entry");
173   DummyEntry->setSection("omp_offloading_entries");
174   DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
175 
176   auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
177   Constant *ZeroZero[] = {Zero, Zero};
178 
179   // Create initializer for the images array.
180   SmallVector<Constant *, 4u> ImagesInits;
181   ImagesInits.reserve(Bufs.size());
182   for (ArrayRef<char> Buf : Bufs) {
183     auto *Data = ConstantDataArray::get(C, Buf);
184     auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
185                                      GlobalVariable::InternalLinkage, Data,
186                                      ".omp_offloading.device_image");
187     Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
188     Image->setSection(".llvm.offloading");
189     Image->setAlignment(Align(object::OffloadBinary::getAlignment()));
190 
191     auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size());
192     Constant *ZeroSize[] = {Zero, Size};
193 
194     auto *ImageB =
195         ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero);
196     auto *ImageE =
197         ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
198 
199     ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
200                                               ImageE, EntriesB, EntriesE));
201   }
202 
203   // Then create images array.
204   auto *ImagesData = ConstantArray::get(
205       ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
206 
207   auto *Images =
208       new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
209                          GlobalValue::InternalLinkage, ImagesData,
210                          ".omp_offloading.device_images");
211   Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
212 
213   auto *ImagesB =
214       ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
215 
216   // And finally create the binary descriptor object.
217   auto *DescInit = ConstantStruct::get(
218       getBinDescTy(M),
219       ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
220       EntriesB, EntriesE);
221 
222   return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
223                             GlobalValue::InternalLinkage, DescInit,
224                             ".omp_offloading.descriptor");
225 }
226 
createRegisterFunction(Module & M,GlobalVariable * BinDesc)227 void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
228   LLVMContext &C = M.getContext();
229   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
230   auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
231                                 ".omp_offloading.descriptor_reg", &M);
232   Func->setSection(".text.startup");
233 
234   // Get __tgt_register_lib function declaration.
235   auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
236                                       /*isVarArg*/ false);
237   FunctionCallee RegFuncC =
238       M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
239 
240   // Construct function body
241   IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
242   Builder.CreateCall(RegFuncC, BinDesc);
243   Builder.CreateRetVoid();
244 
245   // Add this function to constructors.
246   // Set priority to 1 so that __tgt_register_lib is executed AFTER
247   // __tgt_register_requires (we want to know what requirements have been
248   // asked for before we load a libomptarget plugin so that by the time the
249   // plugin is loaded it can report how many devices there are which can
250   // satisfy these requirements).
251   appendToGlobalCtors(M, Func, /*Priority*/ 1);
252 }
253 
createUnregisterFunction(Module & M,GlobalVariable * BinDesc)254 void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
255   LLVMContext &C = M.getContext();
256   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
257   auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
258                                 ".omp_offloading.descriptor_unreg", &M);
259   Func->setSection(".text.startup");
260 
261   // Get __tgt_unregister_lib function declaration.
262   auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
263                                         /*isVarArg*/ false);
264   FunctionCallee UnRegFuncC =
265       M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
266 
267   // Construct function body
268   IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
269   Builder.CreateCall(UnRegFuncC, BinDesc);
270   Builder.CreateRetVoid();
271 
272   // Add this function to global destructors.
273   // Match priority of __tgt_register_lib
274   appendToGlobalDtors(M, Func, /*Priority*/ 1);
275 }
276 
277 // struct fatbin_wrapper {
278 //  int32_t magic;
279 //  int32_t version;
280 //  void *image;
281 //  void *reserved;
282 //};
getFatbinWrapperTy(Module & M)283 StructType *getFatbinWrapperTy(Module &M) {
284   LLVMContext &C = M.getContext();
285   StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
286   if (!FatbinTy)
287     FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
288                                   Type::getInt32Ty(C), Type::getInt8PtrTy(C),
289                                   Type::getInt8PtrTy(C));
290   return FatbinTy;
291 }
292 
293 /// Embed the image \p Image into the module \p M so it can be found by the
294 /// runtime.
createFatbinDesc(Module & M,ArrayRef<char> Image,bool IsHIP)295 GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
296   LLVMContext &C = M.getContext();
297   llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
298   llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
299 
300   // Create the global string containing the fatbinary.
301   StringRef FatbinConstantSection =
302       IsHIP ? ".hip_fatbin"
303             : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
304   auto *Data = ConstantDataArray::get(C, Image);
305   auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
306                                     GlobalVariable::InternalLinkage, Data,
307                                     ".fatbin_image");
308   Fatbin->setSection(FatbinConstantSection);
309 
310   // Create the fatbinary wrapper
311   StringRef FatbinWrapperSection = IsHIP               ? ".hipFatBinSegment"
312                                    : Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
313                                                        : ".nvFatBinSegment";
314   Constant *FatbinWrapper[] = {
315       ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
316       ConstantInt::get(Type::getInt32Ty(C), 1),
317       ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
318       ConstantPointerNull::get(Type::getInt8PtrTy(C))};
319 
320   Constant *FatbinInitializer =
321       ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
322 
323   auto *FatbinDesc =
324       new GlobalVariable(M, getFatbinWrapperTy(M),
325                          /*isConstant*/ true, GlobalValue::InternalLinkage,
326                          FatbinInitializer, ".fatbin_wrapper");
327   FatbinDesc->setSection(FatbinWrapperSection);
328   FatbinDesc->setAlignment(Align(8));
329 
330   // We create a dummy entry to ensure the linker will define the begin / end
331   // symbols. The CUDA runtime should ignore the null address if we attempt to
332   // register it.
333   auto *DummyInit =
334       ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
335   auto *DummyEntry = new GlobalVariable(
336       M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
337       IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry");
338   DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
339   DummyEntry->setSection(IsHIP ? "hip_offloading_entries"
340                                : "cuda_offloading_entries");
341 
342   return FatbinDesc;
343 }
344 
345 /// Create the register globals function. We will iterate all of the offloading
346 /// entries stored at the begin / end symbols and register them according to
347 /// their type. This creates the following function in IR:
348 ///
349 /// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
350 /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
351 ///
352 /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
353 ///                                    void *, void *, void *, void *, int *);
354 /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
355 ///                               int64_t, int32_t, int32_t);
356 ///
357 /// void __cudaRegisterTest(void **fatbinHandle) {
358 ///   for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
359 ///        entry != &__stop_cuda_offloading_entries; ++entry) {
360 ///     if (!entry->size)
361 ///       __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
362 ///                              entry->name, -1, 0, 0, 0, 0, 0);
363 ///     else
364 ///       __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
365 ///                         0, entry->size, 0, 0);
366 ///   }
367 /// }
createRegisterGlobalsFunction(Module & M,bool IsHIP)368 Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
369   LLVMContext &C = M.getContext();
370   // Get the __cudaRegisterFunction function declaration.
371   auto *RegFuncTy = FunctionType::get(
372       Type::getInt32Ty(C),
373       {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
374        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
375        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
376        Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
377       /*isVarArg*/ false);
378   FunctionCallee RegFunc = M.getOrInsertFunction(
379       IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
380 
381   // Get the __cudaRegisterVar function declaration.
382   auto *RegVarTy = FunctionType::get(
383       Type::getVoidTy(C),
384       {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
385        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
386        getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
387       /*isVarArg*/ false);
388   FunctionCallee RegVar = M.getOrInsertFunction(
389       IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
390 
391   // Create the references to the start / stop symbols defined by the linker.
392   auto *EntriesB =
393       new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
394                          /*isConstant*/ true, GlobalValue::ExternalLinkage,
395                          /*Initializer*/ nullptr,
396                          IsHIP ? "__start_hip_offloading_entries"
397                                : "__start_cuda_offloading_entries");
398   EntriesB->setVisibility(GlobalValue::HiddenVisibility);
399   auto *EntriesE =
400       new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
401                          /*isConstant*/ true, GlobalValue::ExternalLinkage,
402                          /*Initializer*/ nullptr,
403                          IsHIP ? "__stop_hip_offloading_entries"
404                                : "__stop_cuda_offloading_entries");
405   EntriesE->setVisibility(GlobalValue::HiddenVisibility);
406 
407   auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
408                                          Type::getInt8PtrTy(C)->getPointerTo(),
409                                          /*isVarArg*/ false);
410   auto *RegGlobalsFn =
411       Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
412                        IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
413   RegGlobalsFn->setSection(".text.startup");
414 
415   // Create the loop to register all the entries.
416   IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
417   auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
418   auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
419   auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
420   auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
421   auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
422   auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
423   auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
424   auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
425   auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
426 
427   auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
428   Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
429   Builder.SetInsertPoint(EntryBB);
430   auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
431   auto *AddrPtr =
432       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
433                                 {ConstantInt::get(getSizeTTy(M), 0),
434                                  ConstantInt::get(Type::getInt32Ty(C), 0)});
435   auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
436   auto *NamePtr =
437       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
438                                 {ConstantInt::get(getSizeTTy(M), 0),
439                                  ConstantInt::get(Type::getInt32Ty(C), 1)});
440   auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
441   auto *SizePtr =
442       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
443                                 {ConstantInt::get(getSizeTTy(M), 0),
444                                  ConstantInt::get(Type::getInt32Ty(C), 2)});
445   auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
446   auto *FlagsPtr =
447       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
448                                 {ConstantInt::get(getSizeTTy(M), 0),
449                                  ConstantInt::get(Type::getInt32Ty(C), 3)});
450   auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag");
451   auto *FnCond =
452       Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
453   Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
454 
455   // Create kernel registration code.
456   Builder.SetInsertPoint(IfThenBB);
457   Builder.CreateCall(RegFunc,
458                      {RegGlobalsFn->arg_begin(), Addr, Name, Name,
459                       ConstantInt::get(Type::getInt32Ty(C), -1),
460                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
461                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
462                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
463                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
464                       ConstantPointerNull::get(Type::getInt32PtrTy(C))});
465   Builder.CreateBr(IfEndBB);
466   Builder.SetInsertPoint(IfElseBB);
467 
468   auto *Switch = Builder.CreateSwitch(Flags, IfEndBB);
469   // Create global variable registration code.
470   Builder.SetInsertPoint(SwGlobalBB);
471   Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
472                               ConstantInt::get(Type::getInt32Ty(C), 0), Size,
473                               ConstantInt::get(Type::getInt32Ty(C), 0),
474                               ConstantInt::get(Type::getInt32Ty(C), 0)});
475   Builder.CreateBr(IfEndBB);
476   Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB);
477 
478   // Create managed variable registration code.
479   Builder.SetInsertPoint(SwManagedBB);
480   Builder.CreateBr(IfEndBB);
481   Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB);
482 
483   // Create surface variable registration code.
484   Builder.SetInsertPoint(SwSurfaceBB);
485   Builder.CreateBr(IfEndBB);
486   Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB);
487 
488   // Create texture variable registration code.
489   Builder.SetInsertPoint(SwTextureBB);
490   Builder.CreateBr(IfEndBB);
491   Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB);
492 
493   Builder.SetInsertPoint(IfEndBB);
494   auto *NewEntry = Builder.CreateInBoundsGEP(
495       getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
496   auto *Cmp = Builder.CreateICmpEQ(
497       NewEntry,
498       ConstantExpr::getInBoundsGetElementPtr(
499           ArrayType::get(getEntryTy(M), 0), EntriesE,
500           ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
501                                 ConstantInt::get(getSizeTTy(M), 0)})));
502   Entry->addIncoming(
503       ConstantExpr::getInBoundsGetElementPtr(
504           ArrayType::get(getEntryTy(M), 0), EntriesB,
505           ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
506                                 ConstantInt::get(getSizeTTy(M), 0)})),
507       &RegGlobalsFn->getEntryBlock());
508   Entry->addIncoming(NewEntry, IfEndBB);
509   Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
510   Builder.SetInsertPoint(ExitBB);
511   Builder.CreateRetVoid();
512 
513   return RegGlobalsFn;
514 }
515 
516 // Create the constructor and destructor to register the fatbinary with the CUDA
517 // runtime.
createRegisterFatbinFunction(Module & M,GlobalVariable * FatbinDesc,bool IsHIP)518 void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
519                                   bool IsHIP) {
520   LLVMContext &C = M.getContext();
521   auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
522   auto *CtorFunc =
523       Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
524                        IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M);
525   CtorFunc->setSection(".text.startup");
526 
527   auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
528   auto *DtorFunc =
529       Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
530                        IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M);
531   DtorFunc->setSection(".text.startup");
532 
533   // Get the __cudaRegisterFatBinary function declaration.
534   auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
535                                      Type::getInt8PtrTy(C),
536                                      /*isVarArg*/ false);
537   FunctionCallee RegFatbin = M.getOrInsertFunction(
538       IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
539   // Get the __cudaRegisterFatBinaryEnd function declaration.
540   auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
541                                         Type::getInt8PtrTy(C)->getPointerTo(),
542                                         /*isVarArg*/ false);
543   FunctionCallee RegFatbinEnd =
544       M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
545   // Get the __cudaUnregisterFatBinary function declaration.
546   auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
547                                        Type::getInt8PtrTy(C)->getPointerTo(),
548                                        /*isVarArg*/ false);
549   FunctionCallee UnregFatbin = M.getOrInsertFunction(
550       IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
551       UnregFatTy);
552 
553   auto *AtExitTy =
554       FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
555                         /*isVarArg*/ false);
556   FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
557 
558   auto *BinaryHandleGlobal = new llvm::GlobalVariable(
559       M, Type::getInt8PtrTy(C)->getPointerTo(), false,
560       llvm::GlobalValue::InternalLinkage,
561       llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
562       IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle");
563 
564   // Create the constructor to register this image with the runtime.
565   IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
566   CallInst *Handle = CtorBuilder.CreateCall(
567       RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
568                      FatbinDesc, Type::getInt8PtrTy(C)));
569   CtorBuilder.CreateAlignedStore(
570       Handle, BinaryHandleGlobal,
571       Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
572   CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle);
573   if (!IsHIP)
574     CtorBuilder.CreateCall(RegFatbinEnd, Handle);
575   CtorBuilder.CreateCall(AtExit, DtorFunc);
576   CtorBuilder.CreateRetVoid();
577 
578   // Create the destructor to unregister the image with the runtime. We cannot
579   // use a standard global destructor after CUDA 9.2 so this must be called by
580   // `atexit()` intead.
581   IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
582   LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
583       Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
584       Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
585   DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
586   DtorBuilder.CreateRetVoid();
587 
588   // Add this function to constructors.
589   appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
590 }
591 
592 } // namespace
593 
wrapOpenMPBinaries(Module & M,ArrayRef<ArrayRef<char>> Images)594 Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
595   GlobalVariable *Desc = createBinDesc(M, Images);
596   if (!Desc)
597     return createStringError(inconvertibleErrorCode(),
598                              "No binary descriptors created.");
599   createRegisterFunction(M, Desc);
600   createUnregisterFunction(M, Desc);
601   return Error::success();
602 }
603 
wrapCudaBinary(Module & M,ArrayRef<char> Image)604 Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
605   GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false);
606   if (!Desc)
607     return createStringError(inconvertibleErrorCode(),
608                              "No fatinbary section created.");
609 
610   createRegisterFatbinFunction(M, Desc, /* IsHIP */ false);
611   return Error::success();
612 }
613 
wrapHIPBinary(Module & M,ArrayRef<char> Image)614 Error wrapHIPBinary(Module &M, ArrayRef<char> Image) {
615   GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true);
616   if (!Desc)
617     return createStringError(inconvertibleErrorCode(),
618                              "No fatinbary section created.");
619 
620   createRegisterFatbinFunction(M, Desc, /* IsHIP */ true);
621   return Error::success();
622 }
623