1 //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "OffloadWrapper.h" 10 #include "llvm/ADT/ArrayRef.h" 11 #include "llvm/ADT/Triple.h" 12 #include "llvm/IR/Constants.h" 13 #include "llvm/IR/GlobalVariable.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/Object/OffloadBinary.h" 18 #include "llvm/Support/Error.h" 19 #include "llvm/Transforms/Utils/ModuleUtils.h" 20 21 using namespace llvm; 22 23 namespace { 24 /// Magic number that begins the section containing the CUDA fatbinary. 25 constexpr unsigned CudaFatMagic = 0x466243b1; 26 constexpr unsigned HIPFatMagic = 0x48495046; 27 28 /// Copied from clang/CGCudaRuntime.h. 29 enum OffloadEntryKindFlag : uint32_t { 30 /// Mark the entry as a global entry. This indicates the presense of a 31 /// kernel if the size size field is zero and a variable otherwise. 32 OffloadGlobalEntry = 0x0, 33 /// Mark the entry as a managed global variable. 34 OffloadGlobalManagedEntry = 0x1, 35 /// Mark the entry as a surface variable. 36 OffloadGlobalSurfaceEntry = 0x2, 37 /// Mark the entry as a texture variable. 38 OffloadGlobalTextureEntry = 0x3, 39 }; 40 41 IntegerType *getSizeTTy(Module &M) { 42 LLVMContext &C = M.getContext(); 43 switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) { 44 case 4u: 45 return Type::getInt32Ty(C); 46 case 8u: 47 return Type::getInt64Ty(C); 48 } 49 llvm_unreachable("unsupported pointer type size"); 50 } 51 52 // struct __tgt_offload_entry { 53 // void *addr; 54 // char *name; 55 // size_t size; 56 // int32_t flags; 57 // int32_t reserved; 58 // }; 59 StructType *getEntryTy(Module &M) { 60 LLVMContext &C = M.getContext(); 61 StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry"); 62 if (!EntryTy) 63 EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C), 64 Type::getInt8PtrTy(C), getSizeTTy(M), 65 Type::getInt32Ty(C), Type::getInt32Ty(C)); 66 return EntryTy; 67 } 68 69 PointerType *getEntryPtrTy(Module &M) { 70 return PointerType::getUnqual(getEntryTy(M)); 71 } 72 73 // struct __tgt_device_image { 74 // void *ImageStart; 75 // void *ImageEnd; 76 // __tgt_offload_entry *EntriesBegin; 77 // __tgt_offload_entry *EntriesEnd; 78 // }; 79 StructType *getDeviceImageTy(Module &M) { 80 LLVMContext &C = M.getContext(); 81 StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image"); 82 if (!ImageTy) 83 ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C), 84 Type::getInt8PtrTy(C), getEntryPtrTy(M), 85 getEntryPtrTy(M)); 86 return ImageTy; 87 } 88 89 PointerType *getDeviceImagePtrTy(Module &M) { 90 return PointerType::getUnqual(getDeviceImageTy(M)); 91 } 92 93 // struct __tgt_bin_desc { 94 // int32_t NumDeviceImages; 95 // __tgt_device_image *DeviceImages; 96 // __tgt_offload_entry *HostEntriesBegin; 97 // __tgt_offload_entry *HostEntriesEnd; 98 // }; 99 StructType *getBinDescTy(Module &M) { 100 LLVMContext &C = M.getContext(); 101 StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc"); 102 if (!DescTy) 103 DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C), 104 getDeviceImagePtrTy(M), getEntryPtrTy(M), 105 getEntryPtrTy(M)); 106 return DescTy; 107 } 108 109 PointerType *getBinDescPtrTy(Module &M) { 110 return PointerType::getUnqual(getBinDescTy(M)); 111 } 112 113 /// Creates binary descriptor for the given device images. Binary descriptor 114 /// is an object that is passed to the offloading runtime at program startup 115 /// and it describes all device images available in the executable or shared 116 /// library. It is defined as follows 117 /// 118 /// __attribute__((visibility("hidden"))) 119 /// extern __tgt_offload_entry *__start_omp_offloading_entries; 120 /// __attribute__((visibility("hidden"))) 121 /// extern __tgt_offload_entry *__stop_omp_offloading_entries; 122 /// 123 /// static const char Image0[] = { <Bufs.front() contents> }; 124 /// ... 125 /// static const char ImageN[] = { <Bufs.back() contents> }; 126 /// 127 /// static const __tgt_device_image Images[] = { 128 /// { 129 /// Image0, /*ImageStart*/ 130 /// Image0 + sizeof(Image0), /*ImageEnd*/ 131 /// __start_omp_offloading_entries, /*EntriesBegin*/ 132 /// __stop_omp_offloading_entries /*EntriesEnd*/ 133 /// }, 134 /// ... 135 /// { 136 /// ImageN, /*ImageStart*/ 137 /// ImageN + sizeof(ImageN), /*ImageEnd*/ 138 /// __start_omp_offloading_entries, /*EntriesBegin*/ 139 /// __stop_omp_offloading_entries /*EntriesEnd*/ 140 /// } 141 /// }; 142 /// 143 /// static const __tgt_bin_desc BinDesc = { 144 /// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/ 145 /// Images, /*DeviceImages*/ 146 /// __start_omp_offloading_entries, /*HostEntriesBegin*/ 147 /// __stop_omp_offloading_entries /*HostEntriesEnd*/ 148 /// }; 149 /// 150 /// Global variable that represents BinDesc is returned. 151 GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) { 152 LLVMContext &C = M.getContext(); 153 // Create external begin/end symbols for the offload entries table. 154 auto *EntriesB = new GlobalVariable( 155 M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, 156 /*Initializer*/ nullptr, "__start_omp_offloading_entries"); 157 EntriesB->setVisibility(GlobalValue::HiddenVisibility); 158 auto *EntriesE = new GlobalVariable( 159 M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, 160 /*Initializer*/ nullptr, "__stop_omp_offloading_entries"); 161 EntriesE->setVisibility(GlobalValue::HiddenVisibility); 162 163 // We assume that external begin/end symbols that we have created above will 164 // be defined by the linker. But linker will do that only if linker inputs 165 // have section with "omp_offloading_entries" name which is not guaranteed. 166 // So, we just create dummy zero sized object in the offload entries section 167 // to force linker to define those symbols. 168 auto *DummyInit = 169 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u)); 170 auto *DummyEntry = new GlobalVariable( 171 M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit, 172 "__dummy.omp_offloading.entry"); 173 DummyEntry->setSection("omp_offloading_entries"); 174 DummyEntry->setVisibility(GlobalValue::HiddenVisibility); 175 176 auto *Zero = ConstantInt::get(getSizeTTy(M), 0u); 177 Constant *ZeroZero[] = {Zero, Zero}; 178 179 // Create initializer for the images array. 180 SmallVector<Constant *, 4u> ImagesInits; 181 ImagesInits.reserve(Bufs.size()); 182 for (ArrayRef<char> Buf : Bufs) { 183 auto *Data = ConstantDataArray::get(C, Buf); 184 auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true, 185 GlobalVariable::InternalLinkage, Data, 186 ".omp_offloading.device_image"); 187 Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); 188 Image->setSection(".llvm.offloading"); 189 Image->setAlignment(Align(object::OffloadBinary::getAlignment())); 190 191 auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size()); 192 Constant *ZeroSize[] = {Zero, Size}; 193 194 auto *ImageB = 195 ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero); 196 auto *ImageE = 197 ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize); 198 199 ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB, 200 ImageE, EntriesB, EntriesE)); 201 } 202 203 // Then create images array. 204 auto *ImagesData = ConstantArray::get( 205 ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits); 206 207 auto *Images = 208 new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true, 209 GlobalValue::InternalLinkage, ImagesData, 210 ".omp_offloading.device_images"); 211 Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); 212 213 auto *ImagesB = 214 ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero); 215 216 // And finally create the binary descriptor object. 217 auto *DescInit = ConstantStruct::get( 218 getBinDescTy(M), 219 ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB, 220 EntriesB, EntriesE); 221 222 return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true, 223 GlobalValue::InternalLinkage, DescInit, 224 ".omp_offloading.descriptor"); 225 } 226 227 void createRegisterFunction(Module &M, GlobalVariable *BinDesc) { 228 LLVMContext &C = M.getContext(); 229 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 230 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage, 231 ".omp_offloading.descriptor_reg", &M); 232 Func->setSection(".text.startup"); 233 234 // Get __tgt_register_lib function declaration. 235 auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M), 236 /*isVarArg*/ false); 237 FunctionCallee RegFuncC = 238 M.getOrInsertFunction("__tgt_register_lib", RegFuncTy); 239 240 // Construct function body 241 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func)); 242 Builder.CreateCall(RegFuncC, BinDesc); 243 Builder.CreateRetVoid(); 244 245 // Add this function to constructors. 246 // Set priority to 1 so that __tgt_register_lib is executed AFTER 247 // __tgt_register_requires (we want to know what requirements have been 248 // asked for before we load a libomptarget plugin so that by the time the 249 // plugin is loaded it can report how many devices there are which can 250 // satisfy these requirements). 251 appendToGlobalCtors(M, Func, /*Priority*/ 1); 252 } 253 254 void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) { 255 LLVMContext &C = M.getContext(); 256 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 257 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage, 258 ".omp_offloading.descriptor_unreg", &M); 259 Func->setSection(".text.startup"); 260 261 // Get __tgt_unregister_lib function declaration. 262 auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M), 263 /*isVarArg*/ false); 264 FunctionCallee UnRegFuncC = 265 M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy); 266 267 // Construct function body 268 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func)); 269 Builder.CreateCall(UnRegFuncC, BinDesc); 270 Builder.CreateRetVoid(); 271 272 // Add this function to global destructors. 273 // Match priority of __tgt_register_lib 274 appendToGlobalDtors(M, Func, /*Priority*/ 1); 275 } 276 277 // struct fatbin_wrapper { 278 // int32_t magic; 279 // int32_t version; 280 // void *image; 281 // void *reserved; 282 //}; 283 StructType *getFatbinWrapperTy(Module &M) { 284 LLVMContext &C = M.getContext(); 285 StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper"); 286 if (!FatbinTy) 287 FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C), 288 Type::getInt32Ty(C), Type::getInt8PtrTy(C), 289 Type::getInt8PtrTy(C)); 290 return FatbinTy; 291 } 292 293 /// Embed the image \p Image into the module \p M so it can be found by the 294 /// runtime. 295 GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) { 296 LLVMContext &C = M.getContext(); 297 llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C); 298 llvm::Triple Triple = llvm::Triple(M.getTargetTriple()); 299 300 // Create the global string containing the fatbinary. 301 StringRef FatbinConstantSection = 302 IsHIP ? ".hip_fatbin" 303 : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin"); 304 auto *Data = ConstantDataArray::get(C, Image); 305 auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true, 306 GlobalVariable::InternalLinkage, Data, 307 ".fatbin_image"); 308 Fatbin->setSection(FatbinConstantSection); 309 310 // Create the fatbinary wrapper 311 StringRef FatbinWrapperSection = IsHIP ? ".hipFatBinSegment" 312 : Triple.isMacOSX() ? "__NV_CUDA,__fatbin" 313 : ".nvFatBinSegment"; 314 Constant *FatbinWrapper[] = { 315 ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic), 316 ConstantInt::get(Type::getInt32Ty(C), 1), 317 ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy), 318 ConstantPointerNull::get(Type::getInt8PtrTy(C))}; 319 320 Constant *FatbinInitializer = 321 ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper); 322 323 auto *FatbinDesc = 324 new GlobalVariable(M, getFatbinWrapperTy(M), 325 /*isConstant*/ true, GlobalValue::InternalLinkage, 326 FatbinInitializer, ".fatbin_wrapper"); 327 FatbinDesc->setSection(FatbinWrapperSection); 328 FatbinDesc->setAlignment(Align(8)); 329 330 // We create a dummy entry to ensure the linker will define the begin / end 331 // symbols. The CUDA runtime should ignore the null address if we attempt to 332 // register it. 333 auto *DummyInit = 334 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u)); 335 auto *DummyEntry = new GlobalVariable( 336 M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit, 337 IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry"); 338 DummyEntry->setVisibility(GlobalValue::HiddenVisibility); 339 DummyEntry->setSection(IsHIP ? "hip_offloading_entries" 340 : "cuda_offloading_entries"); 341 342 return FatbinDesc; 343 } 344 345 /// Create the register globals function. We will iterate all of the offloading 346 /// entries stored at the begin / end symbols and register them according to 347 /// their type. This creates the following function in IR: 348 /// 349 /// extern struct __tgt_offload_entry __start_cuda_offloading_entries; 350 /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries; 351 /// 352 /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int, 353 /// void *, void *, void *, void *, int *); 354 /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t, 355 /// int64_t, int32_t, int32_t); 356 /// 357 /// void __cudaRegisterTest(void **fatbinHandle) { 358 /// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries; 359 /// entry != &__stop_cuda_offloading_entries; ++entry) { 360 /// if (!entry->size) 361 /// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name, 362 /// entry->name, -1, 0, 0, 0, 0, 0); 363 /// else 364 /// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name, 365 /// 0, entry->size, 0, 0); 366 /// } 367 /// } 368 Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) { 369 LLVMContext &C = M.getContext(); 370 // Get the __cudaRegisterFunction function declaration. 371 auto *RegFuncTy = FunctionType::get( 372 Type::getInt32Ty(C), 373 {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), 374 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), 375 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), 376 Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)}, 377 /*isVarArg*/ false); 378 FunctionCallee RegFunc = M.getOrInsertFunction( 379 IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy); 380 381 // Get the __cudaRegisterVar function declaration. 382 auto *RegVarTy = FunctionType::get( 383 Type::getVoidTy(C), 384 {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), 385 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), 386 getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)}, 387 /*isVarArg*/ false); 388 FunctionCallee RegVar = M.getOrInsertFunction( 389 IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy); 390 391 // Create the references to the start / stop symbols defined by the linker. 392 auto *EntriesB = 393 new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0), 394 /*isConstant*/ true, GlobalValue::ExternalLinkage, 395 /*Initializer*/ nullptr, 396 IsHIP ? "__start_hip_offloading_entries" 397 : "__start_cuda_offloading_entries"); 398 EntriesB->setVisibility(GlobalValue::HiddenVisibility); 399 auto *EntriesE = 400 new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0), 401 /*isConstant*/ true, GlobalValue::ExternalLinkage, 402 /*Initializer*/ nullptr, 403 IsHIP ? "__stop_hip_offloading_entries" 404 : "__stop_cuda_offloading_entries"); 405 EntriesE->setVisibility(GlobalValue::HiddenVisibility); 406 407 auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C), 408 Type::getInt8PtrTy(C)->getPointerTo(), 409 /*isVarArg*/ false); 410 auto *RegGlobalsFn = 411 Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage, 412 IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M); 413 RegGlobalsFn->setSection(".text.startup"); 414 415 // Create the loop to register all the entries. 416 IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn)); 417 auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn); 418 auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn); 419 auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn); 420 auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn); 421 auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn); 422 auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn); 423 auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn); 424 auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn); 425 auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn); 426 427 auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE); 428 Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB); 429 Builder.SetInsertPoint(EntryBB); 430 auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry"); 431 auto *AddrPtr = 432 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 433 {ConstantInt::get(getSizeTTy(M), 0), 434 ConstantInt::get(Type::getInt32Ty(C), 0)}); 435 auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr"); 436 auto *NamePtr = 437 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 438 {ConstantInt::get(getSizeTTy(M), 0), 439 ConstantInt::get(Type::getInt32Ty(C), 1)}); 440 auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name"); 441 auto *SizePtr = 442 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 443 {ConstantInt::get(getSizeTTy(M), 0), 444 ConstantInt::get(Type::getInt32Ty(C), 2)}); 445 auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size"); 446 auto *FlagsPtr = 447 Builder.CreateInBoundsGEP(getEntryTy(M), Entry, 448 {ConstantInt::get(getSizeTTy(M), 0), 449 ConstantInt::get(Type::getInt32Ty(C), 3)}); 450 auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag"); 451 auto *FnCond = 452 Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M))); 453 Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB); 454 455 // Create kernel registration code. 456 Builder.SetInsertPoint(IfThenBB); 457 Builder.CreateCall(RegFunc, 458 {RegGlobalsFn->arg_begin(), Addr, Name, Name, 459 ConstantInt::get(Type::getInt32Ty(C), -1), 460 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 461 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 462 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 463 ConstantPointerNull::get(Type::getInt8PtrTy(C)), 464 ConstantPointerNull::get(Type::getInt32PtrTy(C))}); 465 Builder.CreateBr(IfEndBB); 466 Builder.SetInsertPoint(IfElseBB); 467 468 auto *Switch = Builder.CreateSwitch(Flags, IfEndBB); 469 // Create global variable registration code. 470 Builder.SetInsertPoint(SwGlobalBB); 471 Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name, 472 ConstantInt::get(Type::getInt32Ty(C), 0), Size, 473 ConstantInt::get(Type::getInt32Ty(C), 0), 474 ConstantInt::get(Type::getInt32Ty(C), 0)}); 475 Builder.CreateBr(IfEndBB); 476 Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB); 477 478 // Create managed variable registration code. 479 Builder.SetInsertPoint(SwManagedBB); 480 Builder.CreateBr(IfEndBB); 481 Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB); 482 483 // Create surface variable registration code. 484 Builder.SetInsertPoint(SwSurfaceBB); 485 Builder.CreateBr(IfEndBB); 486 Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB); 487 488 // Create texture variable registration code. 489 Builder.SetInsertPoint(SwTextureBB); 490 Builder.CreateBr(IfEndBB); 491 Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB); 492 493 Builder.SetInsertPoint(IfEndBB); 494 auto *NewEntry = Builder.CreateInBoundsGEP( 495 getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1)); 496 auto *Cmp = Builder.CreateICmpEQ( 497 NewEntry, 498 ConstantExpr::getInBoundsGetElementPtr( 499 ArrayType::get(getEntryTy(M), 0), EntriesE, 500 ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0), 501 ConstantInt::get(getSizeTTy(M), 0)}))); 502 Entry->addIncoming( 503 ConstantExpr::getInBoundsGetElementPtr( 504 ArrayType::get(getEntryTy(M), 0), EntriesB, 505 ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0), 506 ConstantInt::get(getSizeTTy(M), 0)})), 507 &RegGlobalsFn->getEntryBlock()); 508 Entry->addIncoming(NewEntry, IfEndBB); 509 Builder.CreateCondBr(Cmp, ExitBB, EntryBB); 510 Builder.SetInsertPoint(ExitBB); 511 Builder.CreateRetVoid(); 512 513 return RegGlobalsFn; 514 } 515 516 // Create the constructor and destructor to register the fatbinary with the CUDA 517 // runtime. 518 void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc, 519 bool IsHIP) { 520 LLVMContext &C = M.getContext(); 521 auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 522 auto *CtorFunc = 523 Function::Create(CtorFuncTy, GlobalValue::InternalLinkage, 524 IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M); 525 CtorFunc->setSection(".text.startup"); 526 527 auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); 528 auto *DtorFunc = 529 Function::Create(DtorFuncTy, GlobalValue::InternalLinkage, 530 IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M); 531 DtorFunc->setSection(".text.startup"); 532 533 // Get the __cudaRegisterFatBinary function declaration. 534 auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(), 535 Type::getInt8PtrTy(C), 536 /*isVarArg*/ false); 537 FunctionCallee RegFatbin = M.getOrInsertFunction( 538 IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy); 539 // Get the __cudaRegisterFatBinaryEnd function declaration. 540 auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C), 541 Type::getInt8PtrTy(C)->getPointerTo(), 542 /*isVarArg*/ false); 543 FunctionCallee RegFatbinEnd = 544 M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy); 545 // Get the __cudaUnregisterFatBinary function declaration. 546 auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C), 547 Type::getInt8PtrTy(C)->getPointerTo(), 548 /*isVarArg*/ false); 549 FunctionCallee UnregFatbin = M.getOrInsertFunction( 550 IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary", 551 UnregFatTy); 552 553 auto *AtExitTy = 554 FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(), 555 /*isVarArg*/ false); 556 FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy); 557 558 auto *BinaryHandleGlobal = new llvm::GlobalVariable( 559 M, Type::getInt8PtrTy(C)->getPointerTo(), false, 560 llvm::GlobalValue::InternalLinkage, 561 llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()), 562 IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle"); 563 564 // Create the constructor to register this image with the runtime. 565 IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc)); 566 CallInst *Handle = CtorBuilder.CreateCall( 567 RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast( 568 FatbinDesc, Type::getInt8PtrTy(C))); 569 CtorBuilder.CreateAlignedStore( 570 Handle, BinaryHandleGlobal, 571 Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); 572 CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle); 573 if (!IsHIP) 574 CtorBuilder.CreateCall(RegFatbinEnd, Handle); 575 CtorBuilder.CreateCall(AtExit, DtorFunc); 576 CtorBuilder.CreateRetVoid(); 577 578 // Create the destructor to unregister the image with the runtime. We cannot 579 // use a standard global destructor after CUDA 9.2 so this must be called by 580 // `atexit()` intead. 581 IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc)); 582 LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad( 583 Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal, 584 Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); 585 DtorBuilder.CreateCall(UnregFatbin, BinaryHandle); 586 DtorBuilder.CreateRetVoid(); 587 588 // Add this function to constructors. 589 appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1); 590 } 591 592 } // namespace 593 594 Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) { 595 GlobalVariable *Desc = createBinDesc(M, Images); 596 if (!Desc) 597 return createStringError(inconvertibleErrorCode(), 598 "No binary descriptors created."); 599 createRegisterFunction(M, Desc); 600 createUnregisterFunction(M, Desc); 601 return Error::success(); 602 } 603 604 Error wrapCudaBinary(Module &M, ArrayRef<char> Image) { 605 GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false); 606 if (!Desc) 607 return createStringError(inconvertibleErrorCode(), 608 "No fatinbary section created."); 609 610 createRegisterFatbinFunction(M, Desc, /* IsHIP */ false); 611 return Error::success(); 612 } 613 614 Error wrapHIPBinary(Module &M, ArrayRef<char> Image) { 615 GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true); 616 if (!Desc) 617 return createStringError(inconvertibleErrorCode(), 618 "No fatinbary section created."); 619 620 createRegisterFatbinFunction(M, Desc, /* IsHIP */ true); 621 return Error::success(); 622 } 623