1 //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // 10 // Arguments to kernel and device functions are passed via param space, 11 // which imposes certain restrictions: 12 // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces 13 // 14 // Kernel parameters are read-only and accessible only via ld.param 15 // instruction, directly or via a pointer. 16 // 17 // Device function parameters are directly accessible via 18 // ld.param/st.param, but taking the address of one returns a pointer 19 // to a copy created in local space which *can't* be used with 20 // ld.param/st.param. 21 // 22 // Copying a byval struct into local memory in IR allows us to enforce 23 // the param space restrictions, gives the rest of IR a pointer w/o 24 // param space restrictions, and gives us an opportunity to eliminate 25 // the copy. 26 // 27 // Pointer arguments to kernel functions need more work to be lowered: 28 // 29 // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the 30 // global address space. This allows later optimizations to emit 31 // ld.global.*/st.global.* for accessing these pointer arguments. For 32 // example, 33 // 34 // define void @foo(float* %input) { 35 // %v = load float, float* %input, align 4 36 // ... 37 // } 38 // 39 // becomes 40 // 41 // define void @foo(float* %input) { 42 // %input2 = addrspacecast float* %input to float addrspace(1)* 43 // %input3 = addrspacecast float addrspace(1)* %input2 to float* 44 // %v = load float, float* %input3, align 4 45 // ... 46 // } 47 // 48 // Later, NVPTXInferAddressSpaces will optimize it to 49 // 50 // define void @foo(float* %input) { 51 // %input2 = addrspacecast float* %input to float addrspace(1)* 52 // %v = load float, float addrspace(1)* %input2, align 4 53 // ... 54 // } 55 // 56 // 2. Convert byval kernel parameters to pointers in the param address space 57 // (so that NVPTX emits ld/st.param). Convert pointers *within* a byval 58 // kernel parameter to pointers in the global address space. This allows 59 // NVPTX to emit ld/st.global. 60 // 61 // struct S { 62 // int *x; 63 // int *y; 64 // }; 65 // __global__ void foo(S s) { 66 // int *b = s.y; 67 // // use b 68 // } 69 // 70 // "b" points to the global address space. In the IR level, 71 // 72 // define void @foo(ptr byval %input) { 73 // %b_ptr = getelementptr {ptr, ptr}, ptr %input, i64 0, i32 1 74 // %b = load ptr, ptr %b_ptr 75 // ; use %b 76 // } 77 // 78 // becomes 79 // 80 // define void @foo({i32*, i32*}* byval %input) { 81 // %b_param = addrspacecat ptr %input to ptr addrspace(101) 82 // %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1 83 // %b = load ptr, ptr addrspace(101) %b_ptr 84 // %b_global = addrspacecast ptr %b to ptr addrspace(1) 85 // ; use %b_generic 86 // } 87 // 88 // Create a local copy of kernel byval parameters used in a way that *might* mutate 89 // the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters 90 // are undefined behaviour, and don't require local copies. 91 // 92 // define void @foo(ptr byval(%struct.s) align 4 %input) { 93 // store i32 42, ptr %input 94 // ret void 95 // } 96 // 97 // becomes 98 // 99 // define void @foo(ptr byval(%struct.s) align 4 %input) #1 { 100 // %input1 = alloca %struct.s, align 4 101 // %input2 = addrspacecast ptr %input to ptr addrspace(101) 102 // %input3 = load %struct.s, ptr addrspace(101) %input2, align 4 103 // store %struct.s %input3, ptr %input1, align 4 104 // store i32 42, ptr %input1, align 4 105 // ret void 106 // } 107 // 108 // If %input were passed to a device function, or written to memory, 109 // conservatively assume that %input gets mutated, and create a local copy. 110 // 111 // Convert param pointers to grid_constant byval kernel parameters that are 112 // passed into calls (device functions, intrinsics, inline asm), or otherwise 113 // "escape" (into stores/ptrtoints) to the generic address space, using the 114 // `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param 115 // (available for sm70+) 116 // 117 // define void @foo(ptr byval(%struct.s) %input) { 118 // ; %input is a grid_constant 119 // %call = call i32 @escape(ptr %input) 120 // ret void 121 // } 122 // 123 // becomes 124 // 125 // define void @foo(ptr byval(%struct.s) %input) { 126 // %input1 = addrspacecast ptr %input to ptr addrspace(101) 127 // ; the following intrinsic converts pointer to generic. We don't use an addrspacecast 128 // ; to prevent generic -> param -> generic from getting cancelled out 129 // %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1) 130 // %call = call i32 @escape(ptr %input1.gen) 131 // ret void 132 // } 133 // 134 // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't 135 // cancel the addrspacecast pair this pass emits. 136 //===----------------------------------------------------------------------===// 137 138 #include "MCTargetDesc/NVPTXBaseInfo.h" 139 #include "NVPTX.h" 140 #include "NVPTXTargetMachine.h" 141 #include "NVPTXUtilities.h" 142 #include "llvm/ADT/STLExtras.h" 143 #include "llvm/Analysis/PtrUseVisitor.h" 144 #include "llvm/Analysis/ValueTracking.h" 145 #include "llvm/CodeGen/TargetPassConfig.h" 146 #include "llvm/IR/Function.h" 147 #include "llvm/IR/IRBuilder.h" 148 #include "llvm/IR/Instructions.h" 149 #include "llvm/IR/IntrinsicInst.h" 150 #include "llvm/IR/IntrinsicsNVPTX.h" 151 #include "llvm/IR/Type.h" 152 #include "llvm/InitializePasses.h" 153 #include "llvm/Pass.h" 154 #include "llvm/Support/Debug.h" 155 #include "llvm/Support/ErrorHandling.h" 156 #include <numeric> 157 #include <queue> 158 159 #define DEBUG_TYPE "nvptx-lower-args" 160 161 using namespace llvm; 162 163 namespace llvm { 164 void initializeNVPTXLowerArgsPass(PassRegistry &); 165 } 166 167 namespace { 168 class NVPTXLowerArgs : public FunctionPass { 169 bool runOnFunction(Function &F) override; 170 171 bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F); 172 bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F); 173 174 // handle byval parameters 175 void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg); 176 // Knowing Ptr must point to the global address space, this function 177 // addrspacecasts Ptr to global and then back to generic. This allows 178 // NVPTXInferAddressSpaces to fold the global-to-generic cast into 179 // loads/stores that appear later. 180 void markPointerAsGlobal(Value *Ptr); 181 182 public: 183 static char ID; // Pass identification, replacement for typeid 184 NVPTXLowerArgs() : FunctionPass(ID) {} 185 StringRef getPassName() const override { 186 return "Lower pointer arguments of CUDA kernels"; 187 } 188 void getAnalysisUsage(AnalysisUsage &AU) const override { 189 AU.addRequired<TargetPassConfig>(); 190 } 191 }; 192 } // namespace 193 194 char NVPTXLowerArgs::ID = 1; 195 196 INITIALIZE_PASS_BEGIN(NVPTXLowerArgs, "nvptx-lower-args", 197 "Lower arguments (NVPTX)", false, false) 198 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 199 INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args", 200 "Lower arguments (NVPTX)", false, false) 201 202 // ============================================================================= 203 // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d), 204 // and we can't guarantee that the only accesses are loads, 205 // then add the following instructions to the first basic block: 206 // 207 // %temp = alloca %struct.x, align 8 208 // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)* 209 // %tv = load %struct.x addrspace(101)* %tempd 210 // store %struct.x %tv, %struct.x* %temp, align 8 211 // 212 // The above code allocates some space in the stack and copies the incoming 213 // struct from param space to local space. 214 // Then replace all occurrences of %d by %temp. 215 // 216 // In case we know that all users are GEPs or Loads, replace them with the same 217 // ones in parameter AS, so we can access them using ld.param. 218 // ============================================================================= 219 220 // For Loads, replaces the \p OldUse of the pointer with a Use of the same 221 // pointer in parameter AS. 222 // For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to 223 // generic using cvta.param. 224 static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam, 225 bool IsGridConstant) { 226 Instruction *I = dyn_cast<Instruction>(OldUse->getUser()); 227 assert(I && "OldUse must be in an instruction"); 228 struct IP { 229 Use *OldUse; 230 Instruction *OldInstruction; 231 Value *NewParam; 232 }; 233 SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}}; 234 SmallVector<Instruction *> InstructionsToDelete; 235 236 auto CloneInstInParamAS = [HasCvtaParam, 237 IsGridConstant](const IP &I) -> Value * { 238 if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) { 239 LI->setOperand(0, I.NewParam); 240 return LI; 241 } 242 if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) { 243 SmallVector<Value *, 4> Indices(GEP->indices()); 244 auto *NewGEP = GetElementPtrInst::Create( 245 GEP->getSourceElementType(), I.NewParam, Indices, GEP->getName(), 246 GEP->getIterator()); 247 NewGEP->setIsInBounds(GEP->isInBounds()); 248 return NewGEP; 249 } 250 if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) { 251 auto *NewBCType = PointerType::get(BC->getContext(), ADDRESS_SPACE_PARAM); 252 return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType, 253 BC->getName(), BC->getIterator()); 254 } 255 if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) { 256 assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM); 257 (void)ASC; 258 // Just pass through the argument, the old ASC is no longer needed. 259 return I.NewParam; 260 } 261 if (auto *MI = dyn_cast<MemTransferInst>(I.OldInstruction)) { 262 if (MI->getRawSource() == I.OldUse->get()) { 263 // convert to memcpy/memmove from param space. 264 IRBuilder<> Builder(I.OldInstruction); 265 Intrinsic::ID ID = MI->getIntrinsicID(); 266 267 CallInst *B = Builder.CreateMemTransferInst( 268 ID, MI->getRawDest(), MI->getDestAlign(), I.NewParam, 269 MI->getSourceAlign(), MI->getLength(), MI->isVolatile()); 270 for (unsigned I : {0, 1}) 271 if (uint64_t Bytes = MI->getParamDereferenceableBytes(I)) 272 B->addDereferenceableParamAttr(I, Bytes); 273 return B; 274 } 275 // We may be able to handle other cases if the argument is 276 // __grid_constant__ 277 } 278 279 if (HasCvtaParam) { 280 auto GetParamAddrCastToGeneric = 281 [](Value *Addr, Instruction *OriginalUser) -> Value * { 282 PointerType *ReturnTy = 283 PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC); 284 Function *CvtToGen = Intrinsic::getOrInsertDeclaration( 285 OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen, 286 {ReturnTy, PointerType::get(OriginalUser->getContext(), 287 ADDRESS_SPACE_PARAM)}); 288 289 // Cast param address to generic address space 290 Value *CvtToGenCall = 291 CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen", 292 OriginalUser->getIterator()); 293 return CvtToGenCall; 294 }; 295 auto *ParamInGenericAS = 296 GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction); 297 298 // phi/select could use generic arg pointers w/o __grid_constant__ 299 if (auto *PHI = dyn_cast<PHINode>(I.OldInstruction)) { 300 for (auto [Idx, V] : enumerate(PHI->incoming_values())) { 301 if (V.get() == I.OldUse->get()) 302 PHI->setIncomingValue(Idx, ParamInGenericAS); 303 } 304 } 305 if (auto *SI = dyn_cast<SelectInst>(I.OldInstruction)) { 306 if (SI->getTrueValue() == I.OldUse->get()) 307 SI->setTrueValue(ParamInGenericAS); 308 if (SI->getFalseValue() == I.OldUse->get()) 309 SI->setFalseValue(ParamInGenericAS); 310 } 311 312 // Escapes or writes can only use generic param pointers if 313 // __grid_constant__ is in effect. 314 if (IsGridConstant) { 315 if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) { 316 I.OldUse->set(ParamInGenericAS); 317 return CI; 318 } 319 if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) { 320 // byval address is being stored, cast it to generic 321 if (SI->getValueOperand() == I.OldUse->get()) 322 SI->setOperand(0, ParamInGenericAS); 323 return SI; 324 } 325 if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) { 326 if (PI->getPointerOperand() == I.OldUse->get()) 327 PI->setOperand(0, ParamInGenericAS); 328 return PI; 329 } 330 // TODO: iIf we allow stores, we should allow memcpy/memset to 331 // parameter, too. 332 } 333 } 334 335 llvm_unreachable("Unsupported instruction"); 336 }; 337 338 while (!ItemsToConvert.empty()) { 339 IP I = ItemsToConvert.pop_back_val(); 340 Value *NewInst = CloneInstInParamAS(I); 341 342 if (NewInst && NewInst != I.OldInstruction) { 343 // We've created a new instruction. Queue users of the old instruction to 344 // be converted and the instruction itself to be deleted. We can't delete 345 // the old instruction yet, because it's still in use by a load somewhere. 346 for (Use &U : I.OldInstruction->uses()) 347 ItemsToConvert.push_back({&U, cast<Instruction>(U.getUser()), NewInst}); 348 349 InstructionsToDelete.push_back(I.OldInstruction); 350 } 351 } 352 353 // Now we know that all argument loads are using addresses in parameter space 354 // and we can finally remove the old instructions in generic AS. Instructions 355 // scheduled for removal should be processed in reverse order so the ones 356 // closest to the load are deleted first. Otherwise they may still be in use. 357 // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will 358 // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by 359 // the BitCast. 360 for (Instruction *I : llvm::reverse(InstructionsToDelete)) 361 I->eraseFromParent(); 362 } 363 364 // Adjust alignment of arguments passed byval in .param address space. We can 365 // increase alignment of such arguments in a way that ensures that we can 366 // effectively vectorize their loads. We should also traverse all loads from 367 // byval pointer and adjust their alignment, if those were using known offset. 368 // Such alignment changes must be conformed with parameter store and load in 369 // NVPTXTargetLowering::LowerCall. 370 static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS, 371 const NVPTXTargetLowering *TLI) { 372 Function *Func = Arg->getParent(); 373 Type *StructType = Arg->getParamByValType(); 374 const DataLayout &DL = Func->getDataLayout(); 375 376 uint64_t NewArgAlign = 377 TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value(); 378 uint64_t CurArgAlign = 379 Arg->getAttribute(Attribute::Alignment).getValueAsInt(); 380 381 if (CurArgAlign >= NewArgAlign) 382 return; 383 384 LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of " 385 << CurArgAlign << " for " << *Arg << '\n'); 386 387 auto NewAlignAttr = 388 Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign); 389 Arg->removeAttr(Attribute::Alignment); 390 Arg->addAttr(NewAlignAttr); 391 392 struct Load { 393 LoadInst *Inst; 394 uint64_t Offset; 395 }; 396 397 struct LoadContext { 398 Value *InitialVal; 399 uint64_t Offset; 400 }; 401 402 SmallVector<Load> Loads; 403 std::queue<LoadContext> Worklist; 404 Worklist.push({ArgInParamAS, 0}); 405 bool IsGridConstant = isParamGridConstant(*Arg); 406 407 while (!Worklist.empty()) { 408 LoadContext Ctx = Worklist.front(); 409 Worklist.pop(); 410 411 for (User *CurUser : Ctx.InitialVal->users()) { 412 if (auto *I = dyn_cast<LoadInst>(CurUser)) { 413 Loads.push_back({I, Ctx.Offset}); 414 continue; 415 } 416 417 if (auto *I = dyn_cast<BitCastInst>(CurUser)) { 418 Worklist.push({I, Ctx.Offset}); 419 continue; 420 } 421 422 if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) { 423 APInt OffsetAccumulated = 424 APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM)); 425 426 if (!I->accumulateConstantOffset(DL, OffsetAccumulated)) 427 continue; 428 429 uint64_t OffsetLimit = -1; 430 uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit); 431 assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX"); 432 433 Worklist.push({I, Ctx.Offset + Offset}); 434 continue; 435 } 436 437 if (isa<MemTransferInst>(CurUser)) 438 continue; 439 440 // supported for grid_constant 441 if (IsGridConstant && 442 (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) || 443 isa<PtrToIntInst>(CurUser))) 444 continue; 445 446 llvm_unreachable("All users must be one of: load, " 447 "bitcast, getelementptr, call, store, ptrtoint"); 448 } 449 } 450 451 for (Load &CurLoad : Loads) { 452 Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset)); 453 Align CurLoadAlign(CurLoad.Inst->getAlign()); 454 CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign)); 455 } 456 } 457 458 namespace { 459 struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> { 460 using Base = PtrUseVisitor<ArgUseChecker>; 461 462 bool IsGridConstant; 463 // Set of phi/select instructions using the Arg 464 SmallPtrSet<Instruction *, 4> Conditionals; 465 466 ArgUseChecker(const DataLayout &DL, bool IsGridConstant) 467 : PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {} 468 469 PtrInfo visitArgPtr(Argument &A) { 470 assert(A.getType()->isPointerTy()); 471 IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType(A.getType())); 472 IsOffsetKnown = false; 473 Offset = APInt(IntIdxTy->getBitWidth(), 0); 474 PI.reset(); 475 Conditionals.clear(); 476 477 LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n"); 478 // Enqueue the uses of this pointer. 479 enqueueUsers(A); 480 481 // Visit all the uses off the worklist until it is empty. 482 // Note that unlike PtrUseVisitor we intentionally do not track offsets. 483 // We're only interested in how we use the pointer. 484 while (!(Worklist.empty() || PI.isAborted())) { 485 UseToVisit ToVisit = Worklist.pop_back_val(); 486 U = ToVisit.UseAndIsOffsetKnown.getPointer(); 487 Instruction *I = cast<Instruction>(U->getUser()); 488 if (isa<PHINode>(I) || isa<SelectInst>(I)) 489 Conditionals.insert(I); 490 LLVM_DEBUG(dbgs() << "Processing " << *I << "\n"); 491 Base::visit(I); 492 } 493 if (PI.isEscaped()) 494 LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst() 495 << "\n"); 496 else if (PI.isAborted()) 497 LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst() 498 << "\n"); 499 LLVM_DEBUG(dbgs() << "Traversed " << Conditionals.size() 500 << " conditionals\n"); 501 return PI; 502 } 503 504 void visitStoreInst(StoreInst &SI) { 505 // Storing the pointer escapes it. 506 if (U->get() == SI.getValueOperand()) 507 return PI.setEscapedAndAborted(&SI); 508 // Writes to the pointer are UB w/ __grid_constant__, but do not force a 509 // copy. 510 if (!IsGridConstant) 511 return PI.setAborted(&SI); 512 } 513 514 void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) { 515 // ASC to param space are no-ops and do not need a copy 516 if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM) 517 return PI.setEscapedAndAborted(&ASC); 518 Base::visitAddrSpaceCastInst(ASC); 519 } 520 521 void visitPtrToIntInst(PtrToIntInst &I) { 522 if (IsGridConstant) 523 return; 524 Base::visitPtrToIntInst(I); 525 } 526 void visitPHINodeOrSelectInst(Instruction &I) { 527 assert(isa<PHINode>(I) || isa<SelectInst>(I)); 528 } 529 // PHI and select just pass through the pointers. 530 void visitPHINode(PHINode &PN) { enqueueUsers(PN); } 531 void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); } 532 533 void visitMemTransferInst(MemTransferInst &II) { 534 if (*U == II.getRawDest() && !IsGridConstant) 535 PI.setAborted(&II); 536 // memcpy/memmove are OK when the pointer is source. We can convert them to 537 // AS-specific memcpy. 538 } 539 540 void visitMemSetInst(MemSetInst &II) { 541 if (!IsGridConstant) 542 PI.setAborted(&II); 543 } 544 }; // struct ArgUseChecker 545 546 void copyByValParam(Function &F, Argument &Arg) { 547 LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n"); 548 // Otherwise we have to create a temporary copy. 549 BasicBlock::iterator FirstInst = F.getEntryBlock().begin(); 550 Type *StructType = Arg.getParamByValType(); 551 const DataLayout &DL = F.getDataLayout(); 552 AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(), 553 Arg.getName(), FirstInst); 554 // Set the alignment to alignment of the byval parameter. This is because, 555 // later load/stores assume that alignment, and we are going to replace 556 // the use of the byval parameter with this alloca instruction. 557 AllocA->setAlignment(F.getParamAlign(Arg.getArgNo()) 558 .value_or(DL.getPrefTypeAlign(StructType))); 559 Arg.replaceAllUsesWith(AllocA); 560 561 Value *ArgInParam = new AddrSpaceCastInst( 562 &Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM), 563 Arg.getName(), FirstInst); 564 // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX 565 // addrspacecast preserves alignment. Since params are constant, this load 566 // is definitely not volatile. 567 const auto ArgSize = *AllocA->getAllocationSize(DL); 568 IRBuilder<> IRB(&*FirstInst); 569 IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(), 570 ArgSize); 571 } 572 } // namespace 573 574 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM, 575 Argument *Arg) { 576 Function *Func = Arg->getParent(); 577 bool HasCvtaParam = 578 TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func); 579 bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg); 580 const DataLayout &DL = Func->getDataLayout(); 581 BasicBlock::iterator FirstInst = Func->getEntryBlock().begin(); 582 Type *StructType = Arg->getParamByValType(); 583 assert(StructType && "Missing byval type"); 584 585 ArgUseChecker AUC(DL, IsGridConstant); 586 ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg); 587 bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted()); 588 // Easy case, accessing parameter directly is fine. 589 if (ArgUseIsReadOnly && AUC.Conditionals.empty()) { 590 // Convert all loads and intermediate operations to use parameter AS and 591 // skip creation of a local copy of the argument. 592 SmallVector<Use *, 16> UsesToUpdate; 593 for (Use &U : Arg->uses()) 594 UsesToUpdate.push_back(&U); 595 596 Value *ArgInParamAS = new AddrSpaceCastInst( 597 Arg, PointerType::get(StructType->getContext(), ADDRESS_SPACE_PARAM), 598 Arg->getName(), FirstInst); 599 for (Use *U : UsesToUpdate) 600 convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant); 601 LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n"); 602 603 const auto *TLI = 604 cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering()); 605 606 adjustByValArgAlignment(Arg, ArgInParamAS, TLI); 607 608 return; 609 } 610 611 // We can't access byval arg directly and need a pointer. on sm_70+ we have 612 // ability to take a pointer to the argument without making a local copy. 613 // However, we're still not allowed to write to it. If the user specified 614 // `__grid_constant__` for the argument, we'll consider escaped pointer as 615 // read-only. 616 if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) { 617 LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n"); 618 // Replace all argument pointer uses (which might include a device function 619 // call) with a cast to the generic address space using cvta.param 620 // instruction, which avoids a local copy. 621 IRBuilder<> IRB(&Func->getEntryBlock().front()); 622 623 // Cast argument to param address space 624 auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast( 625 Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param")); 626 627 // Cast param address to generic address space. We do not use an 628 // addrspacecast to generic here, because, LLVM considers `Arg` to be in the 629 // generic address space, and a `generic -> param` cast followed by a `param 630 // -> generic` cast will be folded away. The `param -> generic` intrinsic 631 // will be correctly lowered to `cvta.param`. 632 Value *CvtToGenCall = IRB.CreateIntrinsic( 633 IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen, 634 CastToParam, nullptr, CastToParam->getName() + ".gen"); 635 636 Arg->replaceAllUsesWith(CvtToGenCall); 637 638 // Do not replace Arg in the cast to param space 639 CastToParam->setOperand(0, Arg); 640 } else 641 copyByValParam(*Func, *Arg); 642 } 643 644 void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) { 645 if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC) 646 return; 647 648 // Deciding where to emit the addrspacecast pair. 649 BasicBlock::iterator InsertPt; 650 if (Argument *Arg = dyn_cast<Argument>(Ptr)) { 651 // Insert at the functon entry if Ptr is an argument. 652 InsertPt = Arg->getParent()->getEntryBlock().begin(); 653 } else { 654 // Insert right after Ptr if Ptr is an instruction. 655 InsertPt = ++cast<Instruction>(Ptr)->getIterator(); 656 assert(InsertPt != InsertPt->getParent()->end() && 657 "We don't call this function with Ptr being a terminator."); 658 } 659 660 Instruction *PtrInGlobal = new AddrSpaceCastInst( 661 Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL), 662 Ptr->getName(), InsertPt); 663 Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(), 664 Ptr->getName(), InsertPt); 665 // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal. 666 Ptr->replaceAllUsesWith(PtrInGeneric); 667 PtrInGlobal->setOperand(0, Ptr); 668 } 669 670 // ============================================================================= 671 // Main function for this pass. 672 // ============================================================================= 673 bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM, 674 Function &F) { 675 // Copying of byval aggregates + SROA may result in pointers being loaded as 676 // integers, followed by intotoptr. We may want to mark those as global, too, 677 // but only if the loaded integer is used exclusively for conversion to a 678 // pointer with inttoptr. 679 auto HandleIntToPtr = [this](Value &V) { 680 if (llvm::all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); })) { 681 SmallVector<User *, 16> UsersToUpdate(V.users()); 682 for (User *U : UsersToUpdate) 683 markPointerAsGlobal(U); 684 } 685 }; 686 if (TM.getDrvInterface() == NVPTX::CUDA) { 687 // Mark pointers in byval structs as global. 688 for (auto &B : F) { 689 for (auto &I : B) { 690 if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { 691 if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) { 692 Value *UO = getUnderlyingObject(LI->getPointerOperand()); 693 if (Argument *Arg = dyn_cast<Argument>(UO)) { 694 if (Arg->hasByValAttr()) { 695 // LI is a load from a pointer within a byval kernel parameter. 696 if (LI->getType()->isPointerTy()) 697 markPointerAsGlobal(LI); 698 else 699 HandleIntToPtr(*LI); 700 } 701 } 702 } 703 } 704 } 705 } 706 } 707 708 LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n"); 709 for (Argument &Arg : F.args()) { 710 if (Arg.getType()->isPointerTy()) { 711 if (Arg.hasByValAttr()) 712 handleByValParam(TM, &Arg); 713 else if (TM.getDrvInterface() == NVPTX::CUDA) 714 markPointerAsGlobal(&Arg); 715 } else if (Arg.getType()->isIntegerTy() && 716 TM.getDrvInterface() == NVPTX::CUDA) { 717 HandleIntToPtr(Arg); 718 } 719 } 720 return true; 721 } 722 723 // Device functions only need to copy byval args into local memory. 724 bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM, 725 Function &F) { 726 LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n"); 727 for (Argument &Arg : F.args()) 728 if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) 729 handleByValParam(TM, &Arg); 730 return true; 731 } 732 733 bool NVPTXLowerArgs::runOnFunction(Function &F) { 734 auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>(); 735 736 return isKernelFunction(F) ? runOnKernelFunction(TM, F) 737 : runOnDeviceFunction(TM, F); 738 } 739 740 FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); } 741 742 static bool copyFunctionByValArgs(Function &F) { 743 LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName() 744 << "\n"); 745 bool Changed = false; 746 for (Argument &Arg : F.args()) 747 if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() && 748 !(isParamGridConstant(Arg) && isKernelFunction(F))) { 749 copyByValParam(F, Arg); 750 Changed = true; 751 } 752 return Changed; 753 } 754 755 PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F, 756 FunctionAnalysisManager &AM) { 757 return copyFunctionByValArgs(F) ? PreservedAnalyses::none() 758 : PreservedAnalyses::all(); 759 } 760