xref: /llvm-project/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (revision 416f1c465db62d829283f6902ef35e027e127aa7)
1 //===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 // Arguments to kernel and device functions are passed via param space,
11 // which imposes certain restrictions:
12 // http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
13 //
14 // Kernel parameters are read-only and accessible only via ld.param
15 // instruction, directly or via a pointer.
16 //
17 // Device function parameters are directly accessible via
18 // ld.param/st.param, but taking the address of one returns a pointer
19 // to a copy created in local space which *can't* be used with
20 // ld.param/st.param.
21 //
22 // Copying a byval struct into local memory in IR allows us to enforce
23 // the param space restrictions, gives the rest of IR a pointer w/o
24 // param space restrictions, and gives us an opportunity to eliminate
25 // the copy.
26 //
27 // Pointer arguments to kernel functions need more work to be lowered:
28 //
29 // 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
30 //    global address space. This allows later optimizations to emit
31 //    ld.global.*/st.global.* for accessing these pointer arguments. For
32 //    example,
33 //
34 //    define void @foo(float* %input) {
35 //      %v = load float, float* %input, align 4
36 //      ...
37 //    }
38 //
39 //    becomes
40 //
41 //    define void @foo(float* %input) {
42 //      %input2 = addrspacecast float* %input to float addrspace(1)*
43 //      %input3 = addrspacecast float addrspace(1)* %input2 to float*
44 //      %v = load float, float* %input3, align 4
45 //      ...
46 //    }
47 //
48 //    Later, NVPTXInferAddressSpaces will optimize it to
49 //
50 //    define void @foo(float* %input) {
51 //      %input2 = addrspacecast float* %input to float addrspace(1)*
52 //      %v = load float, float addrspace(1)* %input2, align 4
53 //      ...
54 //    }
55 //
56 // 2. Convert byval kernel parameters to pointers in the param address space
57 //    (so that NVPTX emits ld/st.param).  Convert pointers *within* a byval
58 //    kernel parameter to pointers in the global address space. This allows
59 //    NVPTX to emit ld/st.global.
60 //
61 //    struct S {
62 //      int *x;
63 //      int *y;
64 //    };
65 //    __global__ void foo(S s) {
66 //      int *b = s.y;
67 //      // use b
68 //    }
69 //
70 //    "b" points to the global address space. In the IR level,
71 //
72 //    define void @foo(ptr byval %input) {
73 //      %b_ptr = getelementptr {ptr, ptr}, ptr %input, i64 0, i32 1
74 //      %b = load ptr, ptr %b_ptr
75 //      ; use %b
76 //    }
77 //
78 //    becomes
79 //
80 //    define void @foo({i32*, i32*}* byval %input) {
81 //      %b_param = addrspacecat ptr %input to ptr addrspace(101)
82 //      %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83 //      %b = load ptr, ptr addrspace(101) %b_ptr
84 //      %b_global = addrspacecast ptr %b to ptr addrspace(1)
85 //      ; use %b_generic
86 //    }
87 //
88 //    Create a local copy of kernel byval parameters used in a way that *might* mutate
89 //    the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90 //    are undefined behaviour, and don't require local copies.
91 //
92 //    define void @foo(ptr byval(%struct.s) align 4 %input) {
93 //       store i32 42, ptr %input
94 //       ret void
95 //    }
96 //
97 //    becomes
98 //
99 //    define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
100 //      %input1 = alloca %struct.s, align 4
101 //      %input2 = addrspacecast ptr %input to ptr addrspace(101)
102 //      %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
103 //      store %struct.s %input3, ptr %input1, align 4
104 //      store i32 42, ptr %input1, align 4
105 //      ret void
106 //    }
107 //
108 //    If %input were passed to a device function, or written to memory,
109 //    conservatively assume that %input gets mutated, and create a local copy.
110 //
111 //    Convert param pointers to grid_constant byval kernel parameters that are
112 //    passed into calls (device functions, intrinsics, inline asm), or otherwise
113 //    "escape" (into stores/ptrtoints) to the generic address space, using the
114 //    `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
115 //    (available for sm70+)
116 //
117 //    define void @foo(ptr byval(%struct.s) %input) {
118 //      ; %input is a grid_constant
119 //      %call = call i32 @escape(ptr %input)
120 //      ret void
121 //    }
122 //
123 //    becomes
124 //
125 //    define void @foo(ptr byval(%struct.s) %input) {
126 //      %input1 = addrspacecast ptr %input to ptr addrspace(101)
127 //      ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128 //      ; to prevent generic -> param -> generic from getting cancelled out
129 //      %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130 //      %call = call i32 @escape(ptr %input1.gen)
131 //      ret void
132 //    }
133 //
134 // TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
135 // cancel the addrspacecast pair this pass emits.
136 //===----------------------------------------------------------------------===//
137 
138 #include "MCTargetDesc/NVPTXBaseInfo.h"
139 #include "NVPTX.h"
140 #include "NVPTXTargetMachine.h"
141 #include "NVPTXUtilities.h"
142 #include "llvm/ADT/STLExtras.h"
143 #include "llvm/Analysis/PtrUseVisitor.h"
144 #include "llvm/Analysis/ValueTracking.h"
145 #include "llvm/CodeGen/TargetPassConfig.h"
146 #include "llvm/IR/Function.h"
147 #include "llvm/IR/IRBuilder.h"
148 #include "llvm/IR/Instructions.h"
149 #include "llvm/IR/IntrinsicInst.h"
150 #include "llvm/IR/IntrinsicsNVPTX.h"
151 #include "llvm/IR/Type.h"
152 #include "llvm/InitializePasses.h"
153 #include "llvm/Pass.h"
154 #include "llvm/Support/Debug.h"
155 #include "llvm/Support/ErrorHandling.h"
156 #include <numeric>
157 #include <queue>
158 
159 #define DEBUG_TYPE "nvptx-lower-args"
160 
161 using namespace llvm;
162 
163 namespace llvm {
164 void initializeNVPTXLowerArgsPass(PassRegistry &);
165 }
166 
167 namespace {
168 class NVPTXLowerArgs : public FunctionPass {
169   bool runOnFunction(Function &F) override;
170 
171   bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F);
172   bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F);
173 
174   // handle byval parameters
175   void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg);
176   // Knowing Ptr must point to the global address space, this function
177   // addrspacecasts Ptr to global and then back to generic. This allows
178   // NVPTXInferAddressSpaces to fold the global-to-generic cast into
179   // loads/stores that appear later.
180   void markPointerAsGlobal(Value *Ptr);
181 
182 public:
183   static char ID; // Pass identification, replacement for typeid
184   NVPTXLowerArgs() : FunctionPass(ID) {}
185   StringRef getPassName() const override {
186     return "Lower pointer arguments of CUDA kernels";
187   }
188   void getAnalysisUsage(AnalysisUsage &AU) const override {
189     AU.addRequired<TargetPassConfig>();
190   }
191 };
192 } // namespace
193 
194 char NVPTXLowerArgs::ID = 1;
195 
196 INITIALIZE_PASS_BEGIN(NVPTXLowerArgs, "nvptx-lower-args",
197                       "Lower arguments (NVPTX)", false, false)
198 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
199 INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
200                     "Lower arguments (NVPTX)", false, false)
201 
202 // =============================================================================
203 // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
204 // and we can't guarantee that the only accesses are loads,
205 // then add the following instructions to the first basic block:
206 //
207 // %temp = alloca %struct.x, align 8
208 // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
209 // %tv = load %struct.x addrspace(101)* %tempd
210 // store %struct.x %tv, %struct.x* %temp, align 8
211 //
212 // The above code allocates some space in the stack and copies the incoming
213 // struct from param space to local space.
214 // Then replace all occurrences of %d by %temp.
215 //
216 // In case we know that all users are GEPs or Loads, replace them with the same
217 // ones in parameter AS, so we can access them using ld.param.
218 // =============================================================================
219 
220 // For Loads, replaces the \p OldUse of the pointer with a Use of the same
221 // pointer in parameter AS.
222 // For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
223 // generic using cvta.param.
224 static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
225                              bool IsGridConstant) {
226   Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
227   assert(I && "OldUse must be in an instruction");
228   struct IP {
229     Use *OldUse;
230     Instruction *OldInstruction;
231     Value *NewParam;
232   };
233   SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
234   SmallVector<Instruction *> InstructionsToDelete;
235 
236   auto CloneInstInParamAS = [HasCvtaParam,
237                              IsGridConstant](const IP &I) -> Value * {
238     if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
239       LI->setOperand(0, I.NewParam);
240       return LI;
241     }
242     if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
243       SmallVector<Value *, 4> Indices(GEP->indices());
244       auto *NewGEP = GetElementPtrInst::Create(
245           GEP->getSourceElementType(), I.NewParam, Indices, GEP->getName(),
246           GEP->getIterator());
247       NewGEP->setIsInBounds(GEP->isInBounds());
248       return NewGEP;
249     }
250     if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
251       auto *NewBCType = PointerType::get(BC->getContext(), ADDRESS_SPACE_PARAM);
252       return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
253                                  BC->getName(), BC->getIterator());
254     }
255     if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
256       assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
257       (void)ASC;
258       // Just pass through the argument, the old ASC is no longer needed.
259       return I.NewParam;
260     }
261     if (auto *MI = dyn_cast<MemTransferInst>(I.OldInstruction)) {
262       if (MI->getRawSource() == I.OldUse->get()) {
263         // convert to memcpy/memmove from param space.
264         IRBuilder<> Builder(I.OldInstruction);
265         Intrinsic::ID ID = MI->getIntrinsicID();
266 
267         CallInst *B = Builder.CreateMemTransferInst(
268             ID, MI->getRawDest(), MI->getDestAlign(), I.NewParam,
269             MI->getSourceAlign(), MI->getLength(), MI->isVolatile());
270         for (unsigned I : {0, 1})
271           if (uint64_t Bytes = MI->getParamDereferenceableBytes(I))
272             B->addDereferenceableParamAttr(I, Bytes);
273         return B;
274       }
275       // We may be able to handle other cases if the argument is
276       // __grid_constant__
277     }
278 
279     if (HasCvtaParam) {
280       auto GetParamAddrCastToGeneric =
281           [](Value *Addr, Instruction *OriginalUser) -> Value * {
282         PointerType *ReturnTy =
283             PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
284         Function *CvtToGen = Intrinsic::getOrInsertDeclaration(
285             OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
286             {ReturnTy, PointerType::get(OriginalUser->getContext(),
287                                         ADDRESS_SPACE_PARAM)});
288 
289         // Cast param address to generic address space
290         Value *CvtToGenCall =
291             CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
292                              OriginalUser->getIterator());
293         return CvtToGenCall;
294       };
295       auto *ParamInGenericAS =
296           GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
297 
298       // phi/select could use generic arg pointers w/o __grid_constant__
299       if (auto *PHI = dyn_cast<PHINode>(I.OldInstruction)) {
300         for (auto [Idx, V] : enumerate(PHI->incoming_values())) {
301           if (V.get() == I.OldUse->get())
302             PHI->setIncomingValue(Idx, ParamInGenericAS);
303         }
304       }
305       if (auto *SI = dyn_cast<SelectInst>(I.OldInstruction)) {
306         if (SI->getTrueValue() == I.OldUse->get())
307           SI->setTrueValue(ParamInGenericAS);
308         if (SI->getFalseValue() == I.OldUse->get())
309           SI->setFalseValue(ParamInGenericAS);
310       }
311 
312       // Escapes or writes can only use generic param pointers if
313       // __grid_constant__ is in effect.
314       if (IsGridConstant) {
315         if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
316           I.OldUse->set(ParamInGenericAS);
317           return CI;
318         }
319         if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
320           // byval address is being stored, cast it to generic
321           if (SI->getValueOperand() == I.OldUse->get())
322             SI->setOperand(0, ParamInGenericAS);
323           return SI;
324         }
325         if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
326           if (PI->getPointerOperand() == I.OldUse->get())
327             PI->setOperand(0, ParamInGenericAS);
328           return PI;
329         }
330         // TODO: iIf we allow stores, we should allow memcpy/memset to
331         // parameter, too.
332       }
333     }
334 
335     llvm_unreachable("Unsupported instruction");
336   };
337 
338   while (!ItemsToConvert.empty()) {
339     IP I = ItemsToConvert.pop_back_val();
340     Value *NewInst = CloneInstInParamAS(I);
341 
342     if (NewInst && NewInst != I.OldInstruction) {
343       // We've created a new instruction. Queue users of the old instruction to
344       // be converted and the instruction itself to be deleted. We can't delete
345       // the old instruction yet, because it's still in use by a load somewhere.
346       for (Use &U : I.OldInstruction->uses())
347         ItemsToConvert.push_back({&U, cast<Instruction>(U.getUser()), NewInst});
348 
349       InstructionsToDelete.push_back(I.OldInstruction);
350     }
351   }
352 
353   // Now we know that all argument loads are using addresses in parameter space
354   // and we can finally remove the old instructions in generic AS.  Instructions
355   // scheduled for removal should be processed in reverse order so the ones
356   // closest to the load are deleted first. Otherwise they may still be in use.
357   // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
358   // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
359   // the BitCast.
360   for (Instruction *I : llvm::reverse(InstructionsToDelete))
361     I->eraseFromParent();
362 }
363 
364 // Adjust alignment of arguments passed byval in .param address space. We can
365 // increase alignment of such arguments in a way that ensures that we can
366 // effectively vectorize their loads. We should also traverse all loads from
367 // byval pointer and adjust their alignment, if those were using known offset.
368 // Such alignment changes must be conformed with parameter store and load in
369 // NVPTXTargetLowering::LowerCall.
370 static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
371                                     const NVPTXTargetLowering *TLI) {
372   Function *Func = Arg->getParent();
373   Type *StructType = Arg->getParamByValType();
374   const DataLayout &DL = Func->getDataLayout();
375 
376   uint64_t NewArgAlign =
377       TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
378   uint64_t CurArgAlign =
379       Arg->getAttribute(Attribute::Alignment).getValueAsInt();
380 
381   if (CurArgAlign >= NewArgAlign)
382     return;
383 
384   LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
385                     << CurArgAlign << " for " << *Arg << '\n');
386 
387   auto NewAlignAttr =
388       Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
389   Arg->removeAttr(Attribute::Alignment);
390   Arg->addAttr(NewAlignAttr);
391 
392   struct Load {
393     LoadInst *Inst;
394     uint64_t Offset;
395   };
396 
397   struct LoadContext {
398     Value *InitialVal;
399     uint64_t Offset;
400   };
401 
402   SmallVector<Load> Loads;
403   std::queue<LoadContext> Worklist;
404   Worklist.push({ArgInParamAS, 0});
405   bool IsGridConstant = isParamGridConstant(*Arg);
406 
407   while (!Worklist.empty()) {
408     LoadContext Ctx = Worklist.front();
409     Worklist.pop();
410 
411     for (User *CurUser : Ctx.InitialVal->users()) {
412       if (auto *I = dyn_cast<LoadInst>(CurUser)) {
413         Loads.push_back({I, Ctx.Offset});
414         continue;
415       }
416 
417       if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
418         Worklist.push({I, Ctx.Offset});
419         continue;
420       }
421 
422       if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
423         APInt OffsetAccumulated =
424             APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
425 
426         if (!I->accumulateConstantOffset(DL, OffsetAccumulated))
427           continue;
428 
429         uint64_t OffsetLimit = -1;
430         uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit);
431         assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
432 
433         Worklist.push({I, Ctx.Offset + Offset});
434         continue;
435       }
436 
437       if (isa<MemTransferInst>(CurUser))
438         continue;
439 
440       // supported for grid_constant
441       if (IsGridConstant &&
442           (isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
443            isa<PtrToIntInst>(CurUser)))
444         continue;
445 
446       llvm_unreachable("All users must be one of: load, "
447                        "bitcast, getelementptr, call, store, ptrtoint");
448     }
449   }
450 
451   for (Load &CurLoad : Loads) {
452     Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
453     Align CurLoadAlign(CurLoad.Inst->getAlign());
454     CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
455   }
456 }
457 
458 namespace {
459 struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
460   using Base = PtrUseVisitor<ArgUseChecker>;
461 
462   bool IsGridConstant;
463   // Set of phi/select instructions using the Arg
464   SmallPtrSet<Instruction *, 4> Conditionals;
465 
466   ArgUseChecker(const DataLayout &DL, bool IsGridConstant)
467       : PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}
468 
469   PtrInfo visitArgPtr(Argument &A) {
470     assert(A.getType()->isPointerTy());
471     IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType(A.getType()));
472     IsOffsetKnown = false;
473     Offset = APInt(IntIdxTy->getBitWidth(), 0);
474     PI.reset();
475     Conditionals.clear();
476 
477     LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n");
478     // Enqueue the uses of this pointer.
479     enqueueUsers(A);
480 
481     // Visit all the uses off the worklist until it is empty.
482     // Note that unlike PtrUseVisitor we intentionally do not track offsets.
483     // We're only interested in how we use the pointer.
484     while (!(Worklist.empty() || PI.isAborted())) {
485       UseToVisit ToVisit = Worklist.pop_back_val();
486       U = ToVisit.UseAndIsOffsetKnown.getPointer();
487       Instruction *I = cast<Instruction>(U->getUser());
488       if (isa<PHINode>(I) || isa<SelectInst>(I))
489         Conditionals.insert(I);
490       LLVM_DEBUG(dbgs() << "Processing " << *I << "\n");
491       Base::visit(I);
492     }
493     if (PI.isEscaped())
494       LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst()
495                         << "\n");
496     else if (PI.isAborted())
497       LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst()
498                         << "\n");
499     LLVM_DEBUG(dbgs() << "Traversed " << Conditionals.size()
500                       << " conditionals\n");
501     return PI;
502   }
503 
504   void visitStoreInst(StoreInst &SI) {
505     // Storing the pointer escapes it.
506     if (U->get() == SI.getValueOperand())
507       return PI.setEscapedAndAborted(&SI);
508     // Writes to the pointer are UB w/ __grid_constant__, but do not force a
509     // copy.
510     if (!IsGridConstant)
511       return PI.setAborted(&SI);
512   }
513 
514   void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
515     // ASC to param space are no-ops and do not need a copy
516     if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM)
517       return PI.setEscapedAndAborted(&ASC);
518     Base::visitAddrSpaceCastInst(ASC);
519   }
520 
521   void visitPtrToIntInst(PtrToIntInst &I) {
522     if (IsGridConstant)
523       return;
524     Base::visitPtrToIntInst(I);
525   }
526   void visitPHINodeOrSelectInst(Instruction &I) {
527     assert(isa<PHINode>(I) || isa<SelectInst>(I));
528   }
529   // PHI and select just pass through the pointers.
530   void visitPHINode(PHINode &PN) { enqueueUsers(PN); }
531   void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); }
532 
533   void visitMemTransferInst(MemTransferInst &II) {
534     if (*U == II.getRawDest() && !IsGridConstant)
535       PI.setAborted(&II);
536     // memcpy/memmove are OK when the pointer is source. We can convert them to
537     // AS-specific memcpy.
538   }
539 
540   void visitMemSetInst(MemSetInst &II) {
541     if (!IsGridConstant)
542       PI.setAborted(&II);
543   }
544 }; // struct ArgUseChecker
545 
546 void copyByValParam(Function &F, Argument &Arg) {
547   LLVM_DEBUG(dbgs() << "Creating a local copy of " << Arg << "\n");
548   // Otherwise we have to create a temporary copy.
549   BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
550   Type *StructType = Arg.getParamByValType();
551   const DataLayout &DL = F.getDataLayout();
552   AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
553                                       Arg.getName(), FirstInst);
554   // Set the alignment to alignment of the byval parameter. This is because,
555   // later load/stores assume that alignment, and we are going to replace
556   // the use of the byval parameter with this alloca instruction.
557   AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
558                            .value_or(DL.getPrefTypeAlign(StructType)));
559   Arg.replaceAllUsesWith(AllocA);
560 
561   Value *ArgInParam = new AddrSpaceCastInst(
562       &Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
563       Arg.getName(), FirstInst);
564   // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
565   // addrspacecast preserves alignment.  Since params are constant, this load
566   // is definitely not volatile.
567   const auto ArgSize = *AllocA->getAllocationSize(DL);
568   IRBuilder<> IRB(&*FirstInst);
569   IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
570                    ArgSize);
571 }
572 } // namespace
573 
574 void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
575                                       Argument *Arg) {
576   Function *Func = Arg->getParent();
577   bool HasCvtaParam =
578       TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func);
579   bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
580   const DataLayout &DL = Func->getDataLayout();
581   BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
582   Type *StructType = Arg->getParamByValType();
583   assert(StructType && "Missing byval type");
584 
585   ArgUseChecker AUC(DL, IsGridConstant);
586   ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
587   bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
588   // Easy case, accessing parameter directly is fine.
589   if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
590     // Convert all loads and intermediate operations to use parameter AS and
591     // skip creation of a local copy of the argument.
592     SmallVector<Use *, 16> UsesToUpdate;
593     for (Use &U : Arg->uses())
594       UsesToUpdate.push_back(&U);
595 
596     Value *ArgInParamAS = new AddrSpaceCastInst(
597         Arg, PointerType::get(StructType->getContext(), ADDRESS_SPACE_PARAM),
598         Arg->getName(), FirstInst);
599     for (Use *U : UsesToUpdate)
600       convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
601     LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
602 
603     const auto *TLI =
604         cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
605 
606     adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
607 
608     return;
609   }
610 
611   // We can't access byval arg directly and need a pointer. on sm_70+ we have
612   // ability to take a pointer to the argument without making a local copy.
613   // However, we're still not allowed to write to it. If the user specified
614   // `__grid_constant__` for the argument, we'll consider escaped pointer as
615   // read-only.
616   if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
617     LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
618     // Replace all argument pointer uses (which might include a device function
619     // call) with a cast to the generic address space using cvta.param
620     // instruction, which avoids a local copy.
621     IRBuilder<> IRB(&Func->getEntryBlock().front());
622 
623     // Cast argument to param address space
624     auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
625         Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
626 
627     // Cast param address to generic address space. We do not use an
628     // addrspacecast to generic here, because, LLVM considers `Arg` to be in the
629     // generic address space, and a `generic -> param` cast followed by a `param
630     // -> generic` cast will be folded away. The `param -> generic` intrinsic
631     // will be correctly lowered to `cvta.param`.
632     Value *CvtToGenCall = IRB.CreateIntrinsic(
633         IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
634         CastToParam, nullptr, CastToParam->getName() + ".gen");
635 
636     Arg->replaceAllUsesWith(CvtToGenCall);
637 
638     // Do not replace Arg in the cast to param space
639     CastToParam->setOperand(0, Arg);
640   } else
641     copyByValParam(*Func, *Arg);
642 }
643 
644 void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
645   if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
646     return;
647 
648   // Deciding where to emit the addrspacecast pair.
649   BasicBlock::iterator InsertPt;
650   if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
651     // Insert at the functon entry if Ptr is an argument.
652     InsertPt = Arg->getParent()->getEntryBlock().begin();
653   } else {
654     // Insert right after Ptr if Ptr is an instruction.
655     InsertPt = ++cast<Instruction>(Ptr)->getIterator();
656     assert(InsertPt != InsertPt->getParent()->end() &&
657            "We don't call this function with Ptr being a terminator.");
658   }
659 
660   Instruction *PtrInGlobal = new AddrSpaceCastInst(
661       Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL),
662       Ptr->getName(), InsertPt);
663   Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
664                                               Ptr->getName(), InsertPt);
665   // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
666   Ptr->replaceAllUsesWith(PtrInGeneric);
667   PtrInGlobal->setOperand(0, Ptr);
668 }
669 
670 // =============================================================================
671 // Main function for this pass.
672 // =============================================================================
673 bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM,
674                                          Function &F) {
675   // Copying of byval aggregates + SROA may result in pointers being loaded as
676   // integers, followed by intotoptr. We may want to mark those as global, too,
677   // but only if the loaded integer is used exclusively for conversion to a
678   // pointer with inttoptr.
679   auto HandleIntToPtr = [this](Value &V) {
680     if (llvm::all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); })) {
681       SmallVector<User *, 16> UsersToUpdate(V.users());
682       for (User *U : UsersToUpdate)
683         markPointerAsGlobal(U);
684     }
685   };
686   if (TM.getDrvInterface() == NVPTX::CUDA) {
687     // Mark pointers in byval structs as global.
688     for (auto &B : F) {
689       for (auto &I : B) {
690         if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
691           if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
692             Value *UO = getUnderlyingObject(LI->getPointerOperand());
693             if (Argument *Arg = dyn_cast<Argument>(UO)) {
694               if (Arg->hasByValAttr()) {
695                 // LI is a load from a pointer within a byval kernel parameter.
696                 if (LI->getType()->isPointerTy())
697                   markPointerAsGlobal(LI);
698                 else
699                   HandleIntToPtr(*LI);
700               }
701             }
702           }
703         }
704       }
705     }
706   }
707 
708   LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
709   for (Argument &Arg : F.args()) {
710     if (Arg.getType()->isPointerTy()) {
711       if (Arg.hasByValAttr())
712         handleByValParam(TM, &Arg);
713       else if (TM.getDrvInterface() == NVPTX::CUDA)
714         markPointerAsGlobal(&Arg);
715     } else if (Arg.getType()->isIntegerTy() &&
716                TM.getDrvInterface() == NVPTX::CUDA) {
717       HandleIntToPtr(Arg);
718     }
719   }
720   return true;
721 }
722 
723 // Device functions only need to copy byval args into local memory.
724 bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM,
725                                          Function &F) {
726   LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
727   for (Argument &Arg : F.args())
728     if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
729       handleByValParam(TM, &Arg);
730   return true;
731 }
732 
733 bool NVPTXLowerArgs::runOnFunction(Function &F) {
734   auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>();
735 
736   return isKernelFunction(F) ? runOnKernelFunction(TM, F)
737                              : runOnDeviceFunction(TM, F);
738 }
739 
740 FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }
741 
742 static bool copyFunctionByValArgs(Function &F) {
743   LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
744                     << "\n");
745   bool Changed = false;
746   for (Argument &Arg : F.args())
747     if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
748         !(isParamGridConstant(Arg) && isKernelFunction(F))) {
749       copyByValParam(F, Arg);
750       Changed = true;
751     }
752   return Changed;
753 }
754 
755 PreservedAnalyses NVPTXCopyByValArgsPass::run(Function &F,
756                                               FunctionAnalysisManager &AM) {
757   return copyFunctionByValArgs(F) ? PreservedAnalyses::none()
758                                   : PreservedAnalyses::all();
759 }
760