xref: /openbsd-src/gnu/llvm/clang/tools/clang-linker-wrapper/OffloadWrapper.cpp (revision 12c855180aad702bbcca06e0398d774beeafb155)
1*12c85518Srobert //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
2*12c85518Srobert //
3*12c85518Srobert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*12c85518Srobert // See https://llvm.org/LICENSE.txt for license information.
5*12c85518Srobert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*12c85518Srobert //
7*12c85518Srobert //===----------------------------------------------------------------------===//
8*12c85518Srobert 
9*12c85518Srobert #include "OffloadWrapper.h"
10*12c85518Srobert #include "llvm/ADT/ArrayRef.h"
11*12c85518Srobert #include "llvm/ADT/Triple.h"
12*12c85518Srobert #include "llvm/IR/Constants.h"
13*12c85518Srobert #include "llvm/IR/GlobalVariable.h"
14*12c85518Srobert #include "llvm/IR/IRBuilder.h"
15*12c85518Srobert #include "llvm/IR/LLVMContext.h"
16*12c85518Srobert #include "llvm/IR/Module.h"
17*12c85518Srobert #include "llvm/Object/OffloadBinary.h"
18*12c85518Srobert #include "llvm/Support/Error.h"
19*12c85518Srobert #include "llvm/Transforms/Utils/ModuleUtils.h"
20*12c85518Srobert 
21*12c85518Srobert using namespace llvm;
22*12c85518Srobert 
23*12c85518Srobert namespace {
24*12c85518Srobert /// Magic number that begins the section containing the CUDA fatbinary.
25*12c85518Srobert constexpr unsigned CudaFatMagic = 0x466243b1;
26*12c85518Srobert constexpr unsigned HIPFatMagic = 0x48495046;
27*12c85518Srobert 
28*12c85518Srobert /// Copied from clang/CGCudaRuntime.h.
29*12c85518Srobert enum OffloadEntryKindFlag : uint32_t {
30*12c85518Srobert   /// Mark the entry as a global entry. This indicates the presense of a
31*12c85518Srobert   /// kernel if the size size field is zero and a variable otherwise.
32*12c85518Srobert   OffloadGlobalEntry = 0x0,
33*12c85518Srobert   /// Mark the entry as a managed global variable.
34*12c85518Srobert   OffloadGlobalManagedEntry = 0x1,
35*12c85518Srobert   /// Mark the entry as a surface variable.
36*12c85518Srobert   OffloadGlobalSurfaceEntry = 0x2,
37*12c85518Srobert   /// Mark the entry as a texture variable.
38*12c85518Srobert   OffloadGlobalTextureEntry = 0x3,
39*12c85518Srobert };
40*12c85518Srobert 
getSizeTTy(Module & M)41*12c85518Srobert IntegerType *getSizeTTy(Module &M) {
42*12c85518Srobert   LLVMContext &C = M.getContext();
43*12c85518Srobert   switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) {
44*12c85518Srobert   case 4u:
45*12c85518Srobert     return Type::getInt32Ty(C);
46*12c85518Srobert   case 8u:
47*12c85518Srobert     return Type::getInt64Ty(C);
48*12c85518Srobert   }
49*12c85518Srobert   llvm_unreachable("unsupported pointer type size");
50*12c85518Srobert }
51*12c85518Srobert 
52*12c85518Srobert // struct __tgt_offload_entry {
53*12c85518Srobert //   void *addr;
54*12c85518Srobert //   char *name;
55*12c85518Srobert //   size_t size;
56*12c85518Srobert //   int32_t flags;
57*12c85518Srobert //   int32_t reserved;
58*12c85518Srobert // };
getEntryTy(Module & M)59*12c85518Srobert StructType *getEntryTy(Module &M) {
60*12c85518Srobert   LLVMContext &C = M.getContext();
61*12c85518Srobert   StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry");
62*12c85518Srobert   if (!EntryTy)
63*12c85518Srobert     EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C),
64*12c85518Srobert                                  Type::getInt8PtrTy(C), getSizeTTy(M),
65*12c85518Srobert                                  Type::getInt32Ty(C), Type::getInt32Ty(C));
66*12c85518Srobert   return EntryTy;
67*12c85518Srobert }
68*12c85518Srobert 
getEntryPtrTy(Module & M)69*12c85518Srobert PointerType *getEntryPtrTy(Module &M) {
70*12c85518Srobert   return PointerType::getUnqual(getEntryTy(M));
71*12c85518Srobert }
72*12c85518Srobert 
73*12c85518Srobert // struct __tgt_device_image {
74*12c85518Srobert //   void *ImageStart;
75*12c85518Srobert //   void *ImageEnd;
76*12c85518Srobert //   __tgt_offload_entry *EntriesBegin;
77*12c85518Srobert //   __tgt_offload_entry *EntriesEnd;
78*12c85518Srobert // };
getDeviceImageTy(Module & M)79*12c85518Srobert StructType *getDeviceImageTy(Module &M) {
80*12c85518Srobert   LLVMContext &C = M.getContext();
81*12c85518Srobert   StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
82*12c85518Srobert   if (!ImageTy)
83*12c85518Srobert     ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C),
84*12c85518Srobert                                  Type::getInt8PtrTy(C), getEntryPtrTy(M),
85*12c85518Srobert                                  getEntryPtrTy(M));
86*12c85518Srobert   return ImageTy;
87*12c85518Srobert }
88*12c85518Srobert 
getDeviceImagePtrTy(Module & M)89*12c85518Srobert PointerType *getDeviceImagePtrTy(Module &M) {
90*12c85518Srobert   return PointerType::getUnqual(getDeviceImageTy(M));
91*12c85518Srobert }
92*12c85518Srobert 
93*12c85518Srobert // struct __tgt_bin_desc {
94*12c85518Srobert //   int32_t NumDeviceImages;
95*12c85518Srobert //   __tgt_device_image *DeviceImages;
96*12c85518Srobert //   __tgt_offload_entry *HostEntriesBegin;
97*12c85518Srobert //   __tgt_offload_entry *HostEntriesEnd;
98*12c85518Srobert // };
getBinDescTy(Module & M)99*12c85518Srobert StructType *getBinDescTy(Module &M) {
100*12c85518Srobert   LLVMContext &C = M.getContext();
101*12c85518Srobert   StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
102*12c85518Srobert   if (!DescTy)
103*12c85518Srobert     DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C),
104*12c85518Srobert                                 getDeviceImagePtrTy(M), getEntryPtrTy(M),
105*12c85518Srobert                                 getEntryPtrTy(M));
106*12c85518Srobert   return DescTy;
107*12c85518Srobert }
108*12c85518Srobert 
getBinDescPtrTy(Module & M)109*12c85518Srobert PointerType *getBinDescPtrTy(Module &M) {
110*12c85518Srobert   return PointerType::getUnqual(getBinDescTy(M));
111*12c85518Srobert }
112*12c85518Srobert 
113*12c85518Srobert /// Creates binary descriptor for the given device images. Binary descriptor
114*12c85518Srobert /// is an object that is passed to the offloading runtime at program startup
115*12c85518Srobert /// and it describes all device images available in the executable or shared
116*12c85518Srobert /// library. It is defined as follows
117*12c85518Srobert ///
118*12c85518Srobert /// __attribute__((visibility("hidden")))
119*12c85518Srobert /// extern __tgt_offload_entry *__start_omp_offloading_entries;
120*12c85518Srobert /// __attribute__((visibility("hidden")))
121*12c85518Srobert /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
122*12c85518Srobert ///
123*12c85518Srobert /// static const char Image0[] = { <Bufs.front() contents> };
124*12c85518Srobert ///  ...
125*12c85518Srobert /// static const char ImageN[] = { <Bufs.back() contents> };
126*12c85518Srobert ///
127*12c85518Srobert /// static const __tgt_device_image Images[] = {
128*12c85518Srobert ///   {
129*12c85518Srobert ///     Image0,                            /*ImageStart*/
130*12c85518Srobert ///     Image0 + sizeof(Image0),           /*ImageEnd*/
131*12c85518Srobert ///     __start_omp_offloading_entries,    /*EntriesBegin*/
132*12c85518Srobert ///     __stop_omp_offloading_entries      /*EntriesEnd*/
133*12c85518Srobert ///   },
134*12c85518Srobert ///   ...
135*12c85518Srobert ///   {
136*12c85518Srobert ///     ImageN,                            /*ImageStart*/
137*12c85518Srobert ///     ImageN + sizeof(ImageN),           /*ImageEnd*/
138*12c85518Srobert ///     __start_omp_offloading_entries,    /*EntriesBegin*/
139*12c85518Srobert ///     __stop_omp_offloading_entries      /*EntriesEnd*/
140*12c85518Srobert ///   }
141*12c85518Srobert /// };
142*12c85518Srobert ///
143*12c85518Srobert /// static const __tgt_bin_desc BinDesc = {
144*12c85518Srobert ///   sizeof(Images) / sizeof(Images[0]),  /*NumDeviceImages*/
145*12c85518Srobert ///   Images,                              /*DeviceImages*/
146*12c85518Srobert ///   __start_omp_offloading_entries,      /*HostEntriesBegin*/
147*12c85518Srobert ///   __stop_omp_offloading_entries        /*HostEntriesEnd*/
148*12c85518Srobert /// };
149*12c85518Srobert ///
150*12c85518Srobert /// Global variable that represents BinDesc is returned.
createBinDesc(Module & M,ArrayRef<ArrayRef<char>> Bufs)151*12c85518Srobert GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
152*12c85518Srobert   LLVMContext &C = M.getContext();
153*12c85518Srobert   // Create external begin/end symbols for the offload entries table.
154*12c85518Srobert   auto *EntriesB = new GlobalVariable(
155*12c85518Srobert       M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
156*12c85518Srobert       /*Initializer*/ nullptr, "__start_omp_offloading_entries");
157*12c85518Srobert   EntriesB->setVisibility(GlobalValue::HiddenVisibility);
158*12c85518Srobert   auto *EntriesE = new GlobalVariable(
159*12c85518Srobert       M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
160*12c85518Srobert       /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
161*12c85518Srobert   EntriesE->setVisibility(GlobalValue::HiddenVisibility);
162*12c85518Srobert 
163*12c85518Srobert   // We assume that external begin/end symbols that we have created above will
164*12c85518Srobert   // be defined by the linker. But linker will do that only if linker inputs
165*12c85518Srobert   // have section with "omp_offloading_entries" name which is not guaranteed.
166*12c85518Srobert   // So, we just create dummy zero sized object in the offload entries section
167*12c85518Srobert   // to force linker to define those symbols.
168*12c85518Srobert   auto *DummyInit =
169*12c85518Srobert       ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
170*12c85518Srobert   auto *DummyEntry = new GlobalVariable(
171*12c85518Srobert       M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
172*12c85518Srobert       "__dummy.omp_offloading.entry");
173*12c85518Srobert   DummyEntry->setSection("omp_offloading_entries");
174*12c85518Srobert   DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
175*12c85518Srobert 
176*12c85518Srobert   auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
177*12c85518Srobert   Constant *ZeroZero[] = {Zero, Zero};
178*12c85518Srobert 
179*12c85518Srobert   // Create initializer for the images array.
180*12c85518Srobert   SmallVector<Constant *, 4u> ImagesInits;
181*12c85518Srobert   ImagesInits.reserve(Bufs.size());
182*12c85518Srobert   for (ArrayRef<char> Buf : Bufs) {
183*12c85518Srobert     auto *Data = ConstantDataArray::get(C, Buf);
184*12c85518Srobert     auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
185*12c85518Srobert                                      GlobalVariable::InternalLinkage, Data,
186*12c85518Srobert                                      ".omp_offloading.device_image");
187*12c85518Srobert     Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
188*12c85518Srobert     Image->setSection(".llvm.offloading");
189*12c85518Srobert     Image->setAlignment(Align(object::OffloadBinary::getAlignment()));
190*12c85518Srobert 
191*12c85518Srobert     auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size());
192*12c85518Srobert     Constant *ZeroSize[] = {Zero, Size};
193*12c85518Srobert 
194*12c85518Srobert     auto *ImageB =
195*12c85518Srobert         ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero);
196*12c85518Srobert     auto *ImageE =
197*12c85518Srobert         ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
198*12c85518Srobert 
199*12c85518Srobert     ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
200*12c85518Srobert                                               ImageE, EntriesB, EntriesE));
201*12c85518Srobert   }
202*12c85518Srobert 
203*12c85518Srobert   // Then create images array.
204*12c85518Srobert   auto *ImagesData = ConstantArray::get(
205*12c85518Srobert       ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
206*12c85518Srobert 
207*12c85518Srobert   auto *Images =
208*12c85518Srobert       new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
209*12c85518Srobert                          GlobalValue::InternalLinkage, ImagesData,
210*12c85518Srobert                          ".omp_offloading.device_images");
211*12c85518Srobert   Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
212*12c85518Srobert 
213*12c85518Srobert   auto *ImagesB =
214*12c85518Srobert       ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
215*12c85518Srobert 
216*12c85518Srobert   // And finally create the binary descriptor object.
217*12c85518Srobert   auto *DescInit = ConstantStruct::get(
218*12c85518Srobert       getBinDescTy(M),
219*12c85518Srobert       ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
220*12c85518Srobert       EntriesB, EntriesE);
221*12c85518Srobert 
222*12c85518Srobert   return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
223*12c85518Srobert                             GlobalValue::InternalLinkage, DescInit,
224*12c85518Srobert                             ".omp_offloading.descriptor");
225*12c85518Srobert }
226*12c85518Srobert 
createRegisterFunction(Module & M,GlobalVariable * BinDesc)227*12c85518Srobert void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
228*12c85518Srobert   LLVMContext &C = M.getContext();
229*12c85518Srobert   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
230*12c85518Srobert   auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
231*12c85518Srobert                                 ".omp_offloading.descriptor_reg", &M);
232*12c85518Srobert   Func->setSection(".text.startup");
233*12c85518Srobert 
234*12c85518Srobert   // Get __tgt_register_lib function declaration.
235*12c85518Srobert   auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
236*12c85518Srobert                                       /*isVarArg*/ false);
237*12c85518Srobert   FunctionCallee RegFuncC =
238*12c85518Srobert       M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
239*12c85518Srobert 
240*12c85518Srobert   // Construct function body
241*12c85518Srobert   IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
242*12c85518Srobert   Builder.CreateCall(RegFuncC, BinDesc);
243*12c85518Srobert   Builder.CreateRetVoid();
244*12c85518Srobert 
245*12c85518Srobert   // Add this function to constructors.
246*12c85518Srobert   // Set priority to 1 so that __tgt_register_lib is executed AFTER
247*12c85518Srobert   // __tgt_register_requires (we want to know what requirements have been
248*12c85518Srobert   // asked for before we load a libomptarget plugin so that by the time the
249*12c85518Srobert   // plugin is loaded it can report how many devices there are which can
250*12c85518Srobert   // satisfy these requirements).
251*12c85518Srobert   appendToGlobalCtors(M, Func, /*Priority*/ 1);
252*12c85518Srobert }
253*12c85518Srobert 
createUnregisterFunction(Module & M,GlobalVariable * BinDesc)254*12c85518Srobert void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
255*12c85518Srobert   LLVMContext &C = M.getContext();
256*12c85518Srobert   auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
257*12c85518Srobert   auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
258*12c85518Srobert                                 ".omp_offloading.descriptor_unreg", &M);
259*12c85518Srobert   Func->setSection(".text.startup");
260*12c85518Srobert 
261*12c85518Srobert   // Get __tgt_unregister_lib function declaration.
262*12c85518Srobert   auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
263*12c85518Srobert                                         /*isVarArg*/ false);
264*12c85518Srobert   FunctionCallee UnRegFuncC =
265*12c85518Srobert       M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
266*12c85518Srobert 
267*12c85518Srobert   // Construct function body
268*12c85518Srobert   IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
269*12c85518Srobert   Builder.CreateCall(UnRegFuncC, BinDesc);
270*12c85518Srobert   Builder.CreateRetVoid();
271*12c85518Srobert 
272*12c85518Srobert   // Add this function to global destructors.
273*12c85518Srobert   // Match priority of __tgt_register_lib
274*12c85518Srobert   appendToGlobalDtors(M, Func, /*Priority*/ 1);
275*12c85518Srobert }
276*12c85518Srobert 
277*12c85518Srobert // struct fatbin_wrapper {
278*12c85518Srobert //  int32_t magic;
279*12c85518Srobert //  int32_t version;
280*12c85518Srobert //  void *image;
281*12c85518Srobert //  void *reserved;
282*12c85518Srobert //};
getFatbinWrapperTy(Module & M)283*12c85518Srobert StructType *getFatbinWrapperTy(Module &M) {
284*12c85518Srobert   LLVMContext &C = M.getContext();
285*12c85518Srobert   StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
286*12c85518Srobert   if (!FatbinTy)
287*12c85518Srobert     FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
288*12c85518Srobert                                   Type::getInt32Ty(C), Type::getInt8PtrTy(C),
289*12c85518Srobert                                   Type::getInt8PtrTy(C));
290*12c85518Srobert   return FatbinTy;
291*12c85518Srobert }
292*12c85518Srobert 
293*12c85518Srobert /// Embed the image \p Image into the module \p M so it can be found by the
294*12c85518Srobert /// runtime.
createFatbinDesc(Module & M,ArrayRef<char> Image,bool IsHIP)295*12c85518Srobert GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
296*12c85518Srobert   LLVMContext &C = M.getContext();
297*12c85518Srobert   llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
298*12c85518Srobert   llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
299*12c85518Srobert 
300*12c85518Srobert   // Create the global string containing the fatbinary.
301*12c85518Srobert   StringRef FatbinConstantSection =
302*12c85518Srobert       IsHIP ? ".hip_fatbin"
303*12c85518Srobert             : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
304*12c85518Srobert   auto *Data = ConstantDataArray::get(C, Image);
305*12c85518Srobert   auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
306*12c85518Srobert                                     GlobalVariable::InternalLinkage, Data,
307*12c85518Srobert                                     ".fatbin_image");
308*12c85518Srobert   Fatbin->setSection(FatbinConstantSection);
309*12c85518Srobert 
310*12c85518Srobert   // Create the fatbinary wrapper
311*12c85518Srobert   StringRef FatbinWrapperSection = IsHIP               ? ".hipFatBinSegment"
312*12c85518Srobert                                    : Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
313*12c85518Srobert                                                        : ".nvFatBinSegment";
314*12c85518Srobert   Constant *FatbinWrapper[] = {
315*12c85518Srobert       ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
316*12c85518Srobert       ConstantInt::get(Type::getInt32Ty(C), 1),
317*12c85518Srobert       ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
318*12c85518Srobert       ConstantPointerNull::get(Type::getInt8PtrTy(C))};
319*12c85518Srobert 
320*12c85518Srobert   Constant *FatbinInitializer =
321*12c85518Srobert       ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
322*12c85518Srobert 
323*12c85518Srobert   auto *FatbinDesc =
324*12c85518Srobert       new GlobalVariable(M, getFatbinWrapperTy(M),
325*12c85518Srobert                          /*isConstant*/ true, GlobalValue::InternalLinkage,
326*12c85518Srobert                          FatbinInitializer, ".fatbin_wrapper");
327*12c85518Srobert   FatbinDesc->setSection(FatbinWrapperSection);
328*12c85518Srobert   FatbinDesc->setAlignment(Align(8));
329*12c85518Srobert 
330*12c85518Srobert   // We create a dummy entry to ensure the linker will define the begin / end
331*12c85518Srobert   // symbols. The CUDA runtime should ignore the null address if we attempt to
332*12c85518Srobert   // register it.
333*12c85518Srobert   auto *DummyInit =
334*12c85518Srobert       ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
335*12c85518Srobert   auto *DummyEntry = new GlobalVariable(
336*12c85518Srobert       M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
337*12c85518Srobert       IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry");
338*12c85518Srobert   DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
339*12c85518Srobert   DummyEntry->setSection(IsHIP ? "hip_offloading_entries"
340*12c85518Srobert                                : "cuda_offloading_entries");
341*12c85518Srobert 
342*12c85518Srobert   return FatbinDesc;
343*12c85518Srobert }
344*12c85518Srobert 
345*12c85518Srobert /// Create the register globals function. We will iterate all of the offloading
346*12c85518Srobert /// entries stored at the begin / end symbols and register them according to
347*12c85518Srobert /// their type. This creates the following function in IR:
348*12c85518Srobert ///
349*12c85518Srobert /// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
350*12c85518Srobert /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
351*12c85518Srobert ///
352*12c85518Srobert /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
353*12c85518Srobert ///                                    void *, void *, void *, void *, int *);
354*12c85518Srobert /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
355*12c85518Srobert ///                               int64_t, int32_t, int32_t);
356*12c85518Srobert ///
357*12c85518Srobert /// void __cudaRegisterTest(void **fatbinHandle) {
358*12c85518Srobert ///   for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
359*12c85518Srobert ///        entry != &__stop_cuda_offloading_entries; ++entry) {
360*12c85518Srobert ///     if (!entry->size)
361*12c85518Srobert ///       __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
362*12c85518Srobert ///                              entry->name, -1, 0, 0, 0, 0, 0);
363*12c85518Srobert ///     else
364*12c85518Srobert ///       __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
365*12c85518Srobert ///                         0, entry->size, 0, 0);
366*12c85518Srobert ///   }
367*12c85518Srobert /// }
createRegisterGlobalsFunction(Module & M,bool IsHIP)368*12c85518Srobert Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
369*12c85518Srobert   LLVMContext &C = M.getContext();
370*12c85518Srobert   // Get the __cudaRegisterFunction function declaration.
371*12c85518Srobert   auto *RegFuncTy = FunctionType::get(
372*12c85518Srobert       Type::getInt32Ty(C),
373*12c85518Srobert       {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
374*12c85518Srobert        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
375*12c85518Srobert        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
376*12c85518Srobert        Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
377*12c85518Srobert       /*isVarArg*/ false);
378*12c85518Srobert   FunctionCallee RegFunc = M.getOrInsertFunction(
379*12c85518Srobert       IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
380*12c85518Srobert 
381*12c85518Srobert   // Get the __cudaRegisterVar function declaration.
382*12c85518Srobert   auto *RegVarTy = FunctionType::get(
383*12c85518Srobert       Type::getVoidTy(C),
384*12c85518Srobert       {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
385*12c85518Srobert        Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
386*12c85518Srobert        getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
387*12c85518Srobert       /*isVarArg*/ false);
388*12c85518Srobert   FunctionCallee RegVar = M.getOrInsertFunction(
389*12c85518Srobert       IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
390*12c85518Srobert 
391*12c85518Srobert   // Create the references to the start / stop symbols defined by the linker.
392*12c85518Srobert   auto *EntriesB =
393*12c85518Srobert       new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
394*12c85518Srobert                          /*isConstant*/ true, GlobalValue::ExternalLinkage,
395*12c85518Srobert                          /*Initializer*/ nullptr,
396*12c85518Srobert                          IsHIP ? "__start_hip_offloading_entries"
397*12c85518Srobert                                : "__start_cuda_offloading_entries");
398*12c85518Srobert   EntriesB->setVisibility(GlobalValue::HiddenVisibility);
399*12c85518Srobert   auto *EntriesE =
400*12c85518Srobert       new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
401*12c85518Srobert                          /*isConstant*/ true, GlobalValue::ExternalLinkage,
402*12c85518Srobert                          /*Initializer*/ nullptr,
403*12c85518Srobert                          IsHIP ? "__stop_hip_offloading_entries"
404*12c85518Srobert                                : "__stop_cuda_offloading_entries");
405*12c85518Srobert   EntriesE->setVisibility(GlobalValue::HiddenVisibility);
406*12c85518Srobert 
407*12c85518Srobert   auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
408*12c85518Srobert                                          Type::getInt8PtrTy(C)->getPointerTo(),
409*12c85518Srobert                                          /*isVarArg*/ false);
410*12c85518Srobert   auto *RegGlobalsFn =
411*12c85518Srobert       Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
412*12c85518Srobert                        IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
413*12c85518Srobert   RegGlobalsFn->setSection(".text.startup");
414*12c85518Srobert 
415*12c85518Srobert   // Create the loop to register all the entries.
416*12c85518Srobert   IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
417*12c85518Srobert   auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
418*12c85518Srobert   auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
419*12c85518Srobert   auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
420*12c85518Srobert   auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
421*12c85518Srobert   auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
422*12c85518Srobert   auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
423*12c85518Srobert   auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
424*12c85518Srobert   auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
425*12c85518Srobert   auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
426*12c85518Srobert 
427*12c85518Srobert   auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
428*12c85518Srobert   Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
429*12c85518Srobert   Builder.SetInsertPoint(EntryBB);
430*12c85518Srobert   auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
431*12c85518Srobert   auto *AddrPtr =
432*12c85518Srobert       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
433*12c85518Srobert                                 {ConstantInt::get(getSizeTTy(M), 0),
434*12c85518Srobert                                  ConstantInt::get(Type::getInt32Ty(C), 0)});
435*12c85518Srobert   auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
436*12c85518Srobert   auto *NamePtr =
437*12c85518Srobert       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
438*12c85518Srobert                                 {ConstantInt::get(getSizeTTy(M), 0),
439*12c85518Srobert                                  ConstantInt::get(Type::getInt32Ty(C), 1)});
440*12c85518Srobert   auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
441*12c85518Srobert   auto *SizePtr =
442*12c85518Srobert       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
443*12c85518Srobert                                 {ConstantInt::get(getSizeTTy(M), 0),
444*12c85518Srobert                                  ConstantInt::get(Type::getInt32Ty(C), 2)});
445*12c85518Srobert   auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
446*12c85518Srobert   auto *FlagsPtr =
447*12c85518Srobert       Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
448*12c85518Srobert                                 {ConstantInt::get(getSizeTTy(M), 0),
449*12c85518Srobert                                  ConstantInt::get(Type::getInt32Ty(C), 3)});
450*12c85518Srobert   auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag");
451*12c85518Srobert   auto *FnCond =
452*12c85518Srobert       Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
453*12c85518Srobert   Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
454*12c85518Srobert 
455*12c85518Srobert   // Create kernel registration code.
456*12c85518Srobert   Builder.SetInsertPoint(IfThenBB);
457*12c85518Srobert   Builder.CreateCall(RegFunc,
458*12c85518Srobert                      {RegGlobalsFn->arg_begin(), Addr, Name, Name,
459*12c85518Srobert                       ConstantInt::get(Type::getInt32Ty(C), -1),
460*12c85518Srobert                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
461*12c85518Srobert                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
462*12c85518Srobert                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
463*12c85518Srobert                       ConstantPointerNull::get(Type::getInt8PtrTy(C)),
464*12c85518Srobert                       ConstantPointerNull::get(Type::getInt32PtrTy(C))});
465*12c85518Srobert   Builder.CreateBr(IfEndBB);
466*12c85518Srobert   Builder.SetInsertPoint(IfElseBB);
467*12c85518Srobert 
468*12c85518Srobert   auto *Switch = Builder.CreateSwitch(Flags, IfEndBB);
469*12c85518Srobert   // Create global variable registration code.
470*12c85518Srobert   Builder.SetInsertPoint(SwGlobalBB);
471*12c85518Srobert   Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
472*12c85518Srobert                               ConstantInt::get(Type::getInt32Ty(C), 0), Size,
473*12c85518Srobert                               ConstantInt::get(Type::getInt32Ty(C), 0),
474*12c85518Srobert                               ConstantInt::get(Type::getInt32Ty(C), 0)});
475*12c85518Srobert   Builder.CreateBr(IfEndBB);
476*12c85518Srobert   Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB);
477*12c85518Srobert 
478*12c85518Srobert   // Create managed variable registration code.
479*12c85518Srobert   Builder.SetInsertPoint(SwManagedBB);
480*12c85518Srobert   Builder.CreateBr(IfEndBB);
481*12c85518Srobert   Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB);
482*12c85518Srobert 
483*12c85518Srobert   // Create surface variable registration code.
484*12c85518Srobert   Builder.SetInsertPoint(SwSurfaceBB);
485*12c85518Srobert   Builder.CreateBr(IfEndBB);
486*12c85518Srobert   Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB);
487*12c85518Srobert 
488*12c85518Srobert   // Create texture variable registration code.
489*12c85518Srobert   Builder.SetInsertPoint(SwTextureBB);
490*12c85518Srobert   Builder.CreateBr(IfEndBB);
491*12c85518Srobert   Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB);
492*12c85518Srobert 
493*12c85518Srobert   Builder.SetInsertPoint(IfEndBB);
494*12c85518Srobert   auto *NewEntry = Builder.CreateInBoundsGEP(
495*12c85518Srobert       getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
496*12c85518Srobert   auto *Cmp = Builder.CreateICmpEQ(
497*12c85518Srobert       NewEntry,
498*12c85518Srobert       ConstantExpr::getInBoundsGetElementPtr(
499*12c85518Srobert           ArrayType::get(getEntryTy(M), 0), EntriesE,
500*12c85518Srobert           ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
501*12c85518Srobert                                 ConstantInt::get(getSizeTTy(M), 0)})));
502*12c85518Srobert   Entry->addIncoming(
503*12c85518Srobert       ConstantExpr::getInBoundsGetElementPtr(
504*12c85518Srobert           ArrayType::get(getEntryTy(M), 0), EntriesB,
505*12c85518Srobert           ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
506*12c85518Srobert                                 ConstantInt::get(getSizeTTy(M), 0)})),
507*12c85518Srobert       &RegGlobalsFn->getEntryBlock());
508*12c85518Srobert   Entry->addIncoming(NewEntry, IfEndBB);
509*12c85518Srobert   Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
510*12c85518Srobert   Builder.SetInsertPoint(ExitBB);
511*12c85518Srobert   Builder.CreateRetVoid();
512*12c85518Srobert 
513*12c85518Srobert   return RegGlobalsFn;
514*12c85518Srobert }
515*12c85518Srobert 
516*12c85518Srobert // Create the constructor and destructor to register the fatbinary with the CUDA
517*12c85518Srobert // runtime.
createRegisterFatbinFunction(Module & M,GlobalVariable * FatbinDesc,bool IsHIP)518*12c85518Srobert void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
519*12c85518Srobert                                   bool IsHIP) {
520*12c85518Srobert   LLVMContext &C = M.getContext();
521*12c85518Srobert   auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
522*12c85518Srobert   auto *CtorFunc =
523*12c85518Srobert       Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
524*12c85518Srobert                        IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M);
525*12c85518Srobert   CtorFunc->setSection(".text.startup");
526*12c85518Srobert 
527*12c85518Srobert   auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
528*12c85518Srobert   auto *DtorFunc =
529*12c85518Srobert       Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
530*12c85518Srobert                        IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M);
531*12c85518Srobert   DtorFunc->setSection(".text.startup");
532*12c85518Srobert 
533*12c85518Srobert   // Get the __cudaRegisterFatBinary function declaration.
534*12c85518Srobert   auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
535*12c85518Srobert                                      Type::getInt8PtrTy(C),
536*12c85518Srobert                                      /*isVarArg*/ false);
537*12c85518Srobert   FunctionCallee RegFatbin = M.getOrInsertFunction(
538*12c85518Srobert       IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
539*12c85518Srobert   // Get the __cudaRegisterFatBinaryEnd function declaration.
540*12c85518Srobert   auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
541*12c85518Srobert                                         Type::getInt8PtrTy(C)->getPointerTo(),
542*12c85518Srobert                                         /*isVarArg*/ false);
543*12c85518Srobert   FunctionCallee RegFatbinEnd =
544*12c85518Srobert       M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
545*12c85518Srobert   // Get the __cudaUnregisterFatBinary function declaration.
546*12c85518Srobert   auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
547*12c85518Srobert                                        Type::getInt8PtrTy(C)->getPointerTo(),
548*12c85518Srobert                                        /*isVarArg*/ false);
549*12c85518Srobert   FunctionCallee UnregFatbin = M.getOrInsertFunction(
550*12c85518Srobert       IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
551*12c85518Srobert       UnregFatTy);
552*12c85518Srobert 
553*12c85518Srobert   auto *AtExitTy =
554*12c85518Srobert       FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
555*12c85518Srobert                         /*isVarArg*/ false);
556*12c85518Srobert   FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
557*12c85518Srobert 
558*12c85518Srobert   auto *BinaryHandleGlobal = new llvm::GlobalVariable(
559*12c85518Srobert       M, Type::getInt8PtrTy(C)->getPointerTo(), false,
560*12c85518Srobert       llvm::GlobalValue::InternalLinkage,
561*12c85518Srobert       llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
562*12c85518Srobert       IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle");
563*12c85518Srobert 
564*12c85518Srobert   // Create the constructor to register this image with the runtime.
565*12c85518Srobert   IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
566*12c85518Srobert   CallInst *Handle = CtorBuilder.CreateCall(
567*12c85518Srobert       RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
568*12c85518Srobert                      FatbinDesc, Type::getInt8PtrTy(C)));
569*12c85518Srobert   CtorBuilder.CreateAlignedStore(
570*12c85518Srobert       Handle, BinaryHandleGlobal,
571*12c85518Srobert       Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
572*12c85518Srobert   CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle);
573*12c85518Srobert   if (!IsHIP)
574*12c85518Srobert     CtorBuilder.CreateCall(RegFatbinEnd, Handle);
575*12c85518Srobert   CtorBuilder.CreateCall(AtExit, DtorFunc);
576*12c85518Srobert   CtorBuilder.CreateRetVoid();
577*12c85518Srobert 
578*12c85518Srobert   // Create the destructor to unregister the image with the runtime. We cannot
579*12c85518Srobert   // use a standard global destructor after CUDA 9.2 so this must be called by
580*12c85518Srobert   // `atexit()` intead.
581*12c85518Srobert   IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
582*12c85518Srobert   LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
583*12c85518Srobert       Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
584*12c85518Srobert       Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
585*12c85518Srobert   DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
586*12c85518Srobert   DtorBuilder.CreateRetVoid();
587*12c85518Srobert 
588*12c85518Srobert   // Add this function to constructors.
589*12c85518Srobert   appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
590*12c85518Srobert }
591*12c85518Srobert 
592*12c85518Srobert } // namespace
593*12c85518Srobert 
wrapOpenMPBinaries(Module & M,ArrayRef<ArrayRef<char>> Images)594*12c85518Srobert Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
595*12c85518Srobert   GlobalVariable *Desc = createBinDesc(M, Images);
596*12c85518Srobert   if (!Desc)
597*12c85518Srobert     return createStringError(inconvertibleErrorCode(),
598*12c85518Srobert                              "No binary descriptors created.");
599*12c85518Srobert   createRegisterFunction(M, Desc);
600*12c85518Srobert   createUnregisterFunction(M, Desc);
601*12c85518Srobert   return Error::success();
602*12c85518Srobert }
603*12c85518Srobert 
wrapCudaBinary(Module & M,ArrayRef<char> Image)604*12c85518Srobert Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
605*12c85518Srobert   GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false);
606*12c85518Srobert   if (!Desc)
607*12c85518Srobert     return createStringError(inconvertibleErrorCode(),
608*12c85518Srobert                              "No fatinbary section created.");
609*12c85518Srobert 
610*12c85518Srobert   createRegisterFatbinFunction(M, Desc, /* IsHIP */ false);
611*12c85518Srobert   return Error::success();
612*12c85518Srobert }
613*12c85518Srobert 
wrapHIPBinary(Module & M,ArrayRef<char> Image)614*12c85518Srobert Error wrapHIPBinary(Module &M, ArrayRef<char> Image) {
615*12c85518Srobert   GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true);
616*12c85518Srobert   if (!Desc)
617*12c85518Srobert     return createStringError(inconvertibleErrorCode(),
618*12c85518Srobert                              "No fatinbary section created.");
619*12c85518Srobert 
620*12c85518Srobert   createRegisterFatbinFunction(M, Desc, /* IsHIP */ true);
621*12c85518Srobert   return Error::success();
622*12c85518Srobert }
623