xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUPromoteKernelArguments.cpp (revision 349cc55c9796c4596a5b9904cd3281af295f878f)
1*349cc55cSDimitry Andric //===-- AMDGPUPromoteKernelArguments.cpp ----------------------------------===//
2*349cc55cSDimitry Andric //
3*349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*349cc55cSDimitry Andric //
7*349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
8*349cc55cSDimitry Andric //
9*349cc55cSDimitry Andric /// \file This pass recursively promotes generic pointer arguments of a kernel
10*349cc55cSDimitry Andric /// into the global address space.
11*349cc55cSDimitry Andric ///
12*349cc55cSDimitry Andric /// The pass walks kernel's pointer arguments, then loads from them. If a loaded
13*349cc55cSDimitry Andric /// value is a pointer and loaded pointer is unmodified in the kernel before the
14*349cc55cSDimitry Andric /// load, then promote loaded pointer to global. Then recursively continue.
15*349cc55cSDimitry Andric //
16*349cc55cSDimitry Andric //===----------------------------------------------------------------------===//
17*349cc55cSDimitry Andric 
18*349cc55cSDimitry Andric #include "AMDGPU.h"
19*349cc55cSDimitry Andric #include "llvm/ADT/SmallVector.h"
20*349cc55cSDimitry Andric #include "llvm/Analysis/MemorySSA.h"
21*349cc55cSDimitry Andric #include "llvm/IR/IRBuilder.h"
22*349cc55cSDimitry Andric #include "llvm/InitializePasses.h"
23*349cc55cSDimitry Andric 
24*349cc55cSDimitry Andric #define DEBUG_TYPE "amdgpu-promote-kernel-arguments"
25*349cc55cSDimitry Andric 
26*349cc55cSDimitry Andric using namespace llvm;
27*349cc55cSDimitry Andric 
28*349cc55cSDimitry Andric namespace {
29*349cc55cSDimitry Andric 
30*349cc55cSDimitry Andric class AMDGPUPromoteKernelArguments : public FunctionPass {
31*349cc55cSDimitry Andric   MemorySSA *MSSA;
32*349cc55cSDimitry Andric 
33*349cc55cSDimitry Andric   Instruction *ArgCastInsertPt;
34*349cc55cSDimitry Andric 
35*349cc55cSDimitry Andric   SmallVector<Value *> Ptrs;
36*349cc55cSDimitry Andric 
37*349cc55cSDimitry Andric   void enqueueUsers(Value *Ptr);
38*349cc55cSDimitry Andric 
39*349cc55cSDimitry Andric   bool promotePointer(Value *Ptr);
40*349cc55cSDimitry Andric 
41*349cc55cSDimitry Andric public:
42*349cc55cSDimitry Andric   static char ID;
43*349cc55cSDimitry Andric 
44*349cc55cSDimitry Andric   AMDGPUPromoteKernelArguments() : FunctionPass(ID) {}
45*349cc55cSDimitry Andric 
46*349cc55cSDimitry Andric   bool run(Function &F, MemorySSA &MSSA);
47*349cc55cSDimitry Andric 
48*349cc55cSDimitry Andric   bool runOnFunction(Function &F) override;
49*349cc55cSDimitry Andric 
50*349cc55cSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
51*349cc55cSDimitry Andric     AU.addRequired<MemorySSAWrapperPass>();
52*349cc55cSDimitry Andric     AU.setPreservesAll();
53*349cc55cSDimitry Andric   }
54*349cc55cSDimitry Andric };
55*349cc55cSDimitry Andric 
56*349cc55cSDimitry Andric } // end anonymous namespace
57*349cc55cSDimitry Andric 
58*349cc55cSDimitry Andric void AMDGPUPromoteKernelArguments::enqueueUsers(Value *Ptr) {
59*349cc55cSDimitry Andric   SmallVector<User *> PtrUsers(Ptr->users());
60*349cc55cSDimitry Andric 
61*349cc55cSDimitry Andric   while (!PtrUsers.empty()) {
62*349cc55cSDimitry Andric     Instruction *U = dyn_cast<Instruction>(PtrUsers.pop_back_val());
63*349cc55cSDimitry Andric     if (!U)
64*349cc55cSDimitry Andric       continue;
65*349cc55cSDimitry Andric 
66*349cc55cSDimitry Andric     switch (U->getOpcode()) {
67*349cc55cSDimitry Andric     default:
68*349cc55cSDimitry Andric       break;
69*349cc55cSDimitry Andric     case Instruction::Load: {
70*349cc55cSDimitry Andric       LoadInst *LD = cast<LoadInst>(U);
71*349cc55cSDimitry Andric       PointerType *PT = dyn_cast<PointerType>(LD->getType());
72*349cc55cSDimitry Andric       if (!PT ||
73*349cc55cSDimitry Andric           (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS &&
74*349cc55cSDimitry Andric            PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS &&
75*349cc55cSDimitry Andric            PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS) ||
76*349cc55cSDimitry Andric           LD->getPointerOperand()->stripInBoundsOffsets() != Ptr)
77*349cc55cSDimitry Andric         break;
78*349cc55cSDimitry Andric       const MemoryAccess *MA = MSSA->getWalker()->getClobberingMemoryAccess(LD);
79*349cc55cSDimitry Andric       // TODO: This load poprobably can be promoted to constant address space.
80*349cc55cSDimitry Andric       if (MSSA->isLiveOnEntryDef(MA))
81*349cc55cSDimitry Andric         Ptrs.push_back(LD);
82*349cc55cSDimitry Andric       break;
83*349cc55cSDimitry Andric     }
84*349cc55cSDimitry Andric     case Instruction::GetElementPtr:
85*349cc55cSDimitry Andric     case Instruction::AddrSpaceCast:
86*349cc55cSDimitry Andric     case Instruction::BitCast:
87*349cc55cSDimitry Andric       if (U->getOperand(0)->stripInBoundsOffsets() == Ptr)
88*349cc55cSDimitry Andric         PtrUsers.append(U->user_begin(), U->user_end());
89*349cc55cSDimitry Andric       break;
90*349cc55cSDimitry Andric     }
91*349cc55cSDimitry Andric   }
92*349cc55cSDimitry Andric }
93*349cc55cSDimitry Andric 
94*349cc55cSDimitry Andric bool AMDGPUPromoteKernelArguments::promotePointer(Value *Ptr) {
95*349cc55cSDimitry Andric   enqueueUsers(Ptr);
96*349cc55cSDimitry Andric 
97*349cc55cSDimitry Andric   PointerType *PT = cast<PointerType>(Ptr->getType());
98*349cc55cSDimitry Andric   if (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS)
99*349cc55cSDimitry Andric     return false;
100*349cc55cSDimitry Andric 
101*349cc55cSDimitry Andric   bool IsArg = isa<Argument>(Ptr);
102*349cc55cSDimitry Andric   IRBuilder<> B(IsArg ? ArgCastInsertPt
103*349cc55cSDimitry Andric                       : &*std::next(cast<Instruction>(Ptr)->getIterator()));
104*349cc55cSDimitry Andric 
105*349cc55cSDimitry Andric   // Cast pointer to global address space and back to flat and let
106*349cc55cSDimitry Andric   // Infer Address Spaces pass to do all necessary rewriting.
107*349cc55cSDimitry Andric   PointerType *NewPT =
108*349cc55cSDimitry Andric       PointerType::getWithSamePointeeType(PT, AMDGPUAS::GLOBAL_ADDRESS);
109*349cc55cSDimitry Andric   Value *Cast =
110*349cc55cSDimitry Andric       B.CreateAddrSpaceCast(Ptr, NewPT, Twine(Ptr->getName(), ".global"));
111*349cc55cSDimitry Andric   Value *CastBack =
112*349cc55cSDimitry Andric       B.CreateAddrSpaceCast(Cast, PT, Twine(Ptr->getName(), ".flat"));
113*349cc55cSDimitry Andric   Ptr->replaceUsesWithIf(CastBack,
114*349cc55cSDimitry Andric                          [Cast](Use &U) { return U.getUser() != Cast; });
115*349cc55cSDimitry Andric 
116*349cc55cSDimitry Andric   return true;
117*349cc55cSDimitry Andric }
118*349cc55cSDimitry Andric 
119*349cc55cSDimitry Andric // skip allocas
120*349cc55cSDimitry Andric static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
121*349cc55cSDimitry Andric   BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
122*349cc55cSDimitry Andric   for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
123*349cc55cSDimitry Andric     AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
124*349cc55cSDimitry Andric 
125*349cc55cSDimitry Andric     // If this is a dynamic alloca, the value may depend on the loaded kernargs,
126*349cc55cSDimitry Andric     // so loads will need to be inserted before it.
127*349cc55cSDimitry Andric     if (!AI || !AI->isStaticAlloca())
128*349cc55cSDimitry Andric       break;
129*349cc55cSDimitry Andric   }
130*349cc55cSDimitry Andric 
131*349cc55cSDimitry Andric   return InsPt;
132*349cc55cSDimitry Andric }
133*349cc55cSDimitry Andric 
134*349cc55cSDimitry Andric bool AMDGPUPromoteKernelArguments::run(Function &F, MemorySSA &MSSA) {
135*349cc55cSDimitry Andric   if (skipFunction(F))
136*349cc55cSDimitry Andric     return false;
137*349cc55cSDimitry Andric 
138*349cc55cSDimitry Andric   CallingConv::ID CC = F.getCallingConv();
139*349cc55cSDimitry Andric   if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
140*349cc55cSDimitry Andric     return false;
141*349cc55cSDimitry Andric 
142*349cc55cSDimitry Andric   ArgCastInsertPt = &*getInsertPt(*F.begin());
143*349cc55cSDimitry Andric   this->MSSA = &MSSA;
144*349cc55cSDimitry Andric 
145*349cc55cSDimitry Andric   for (Argument &Arg : F.args()) {
146*349cc55cSDimitry Andric     if (Arg.use_empty())
147*349cc55cSDimitry Andric       continue;
148*349cc55cSDimitry Andric 
149*349cc55cSDimitry Andric     PointerType *PT = dyn_cast<PointerType>(Arg.getType());
150*349cc55cSDimitry Andric     if (!PT || (PT->getAddressSpace() != AMDGPUAS::FLAT_ADDRESS &&
151*349cc55cSDimitry Andric                 PT->getAddressSpace() != AMDGPUAS::GLOBAL_ADDRESS &&
152*349cc55cSDimitry Andric                 PT->getAddressSpace() != AMDGPUAS::CONSTANT_ADDRESS))
153*349cc55cSDimitry Andric       continue;
154*349cc55cSDimitry Andric 
155*349cc55cSDimitry Andric     Ptrs.push_back(&Arg);
156*349cc55cSDimitry Andric   }
157*349cc55cSDimitry Andric 
158*349cc55cSDimitry Andric   bool Changed = false;
159*349cc55cSDimitry Andric   while (!Ptrs.empty()) {
160*349cc55cSDimitry Andric     Value *Ptr = Ptrs.pop_back_val();
161*349cc55cSDimitry Andric     Changed |= promotePointer(Ptr);
162*349cc55cSDimitry Andric   }
163*349cc55cSDimitry Andric 
164*349cc55cSDimitry Andric   return Changed;
165*349cc55cSDimitry Andric }
166*349cc55cSDimitry Andric 
167*349cc55cSDimitry Andric bool AMDGPUPromoteKernelArguments::runOnFunction(Function &F) {
168*349cc55cSDimitry Andric   MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA();
169*349cc55cSDimitry Andric   return run(F, MSSA);
170*349cc55cSDimitry Andric }
171*349cc55cSDimitry Andric 
172*349cc55cSDimitry Andric INITIALIZE_PASS_BEGIN(AMDGPUPromoteKernelArguments, DEBUG_TYPE,
173*349cc55cSDimitry Andric                       "AMDGPU Promote Kernel Arguments", false, false)
174*349cc55cSDimitry Andric INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
175*349cc55cSDimitry Andric INITIALIZE_PASS_END(AMDGPUPromoteKernelArguments, DEBUG_TYPE,
176*349cc55cSDimitry Andric                     "AMDGPU Promote Kernel Arguments", false, false)
177*349cc55cSDimitry Andric 
178*349cc55cSDimitry Andric char AMDGPUPromoteKernelArguments::ID = 0;
179*349cc55cSDimitry Andric 
180*349cc55cSDimitry Andric FunctionPass *llvm::createAMDGPUPromoteKernelArgumentsPass() {
181*349cc55cSDimitry Andric   return new AMDGPUPromoteKernelArguments();
182*349cc55cSDimitry Andric }
183*349cc55cSDimitry Andric 
184*349cc55cSDimitry Andric PreservedAnalyses
185*349cc55cSDimitry Andric AMDGPUPromoteKernelArgumentsPass::run(Function &F,
186*349cc55cSDimitry Andric                                       FunctionAnalysisManager &AM) {
187*349cc55cSDimitry Andric   MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA();
188*349cc55cSDimitry Andric   if (AMDGPUPromoteKernelArguments().run(F, MSSA)) {
189*349cc55cSDimitry Andric     PreservedAnalyses PA;
190*349cc55cSDimitry Andric     PA.preserveSet<CFGAnalyses>();
191*349cc55cSDimitry Andric     PA.preserve<MemorySSAAnalysis>();
192*349cc55cSDimitry Andric     return PA;
193*349cc55cSDimitry Andric   }
194*349cc55cSDimitry Andric   return PreservedAnalyses::all();
195*349cc55cSDimitry Andric }
196