xref: /openbsd-src/gnu/llvm/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp (revision d415bd752c734aee168c4ee86ff32e8cc249eb16)
1 //===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass replaces all the uses of LDS within non-kernel functions by
10 // corresponding pointer counter-parts.
11 //
12 // The main motivation behind this pass is - to *avoid* subsequent LDS lowering
13 // pass from directly packing LDS (assume large LDS) into a struct type which
14 // would otherwise cause allocating huge memory for struct instance within every
15 // kernel.
16 //
17 // Brief sketch of the algorithm implemented in this pass is as below:
18 //
19 //   1. Collect all the LDS defined in the module which qualify for pointer
20 //      replacement, say it is, LDSGlobals set.
21 //
22 //   2. Collect all the reachable callees for each kernel defined in the module,
23 //      say it is, KernelToCallees map.
24 //
25 //   3. FOR (each global GV from LDSGlobals set) DO
26 //        LDSUsedNonKernels = Collect all non-kernel functions which use GV.
27 //        FOR (each kernel K in KernelToCallees map) DO
28 //           ReachableCallees = KernelToCallees[K]
29 //           ReachableAndLDSUsedCallees =
30 //              SetIntersect(LDSUsedNonKernels, ReachableCallees)
31 //           IF (ReachableAndLDSUsedCallees is not empty) THEN
32 //             Pointer = Create a pointer to point-to GV if not created.
33 //             Initialize Pointer to point-to GV within kernel K.
34 //           ENDIF
35 //        ENDFOR
36 //        Replace all uses of GV within non kernel functions by Pointer.
37 //      ENFOR
38 //
39 // LLVM IR example:
40 //
41 //    Input IR:
42 //
43 //    @lds = internal addrspace(3) global [4 x i32] undef, align 16
44 //
45 //    define internal void @f0() {
46 //    entry:
47 //      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
48 //             i32 0, i32 0
49 //      ret void
50 //    }
51 //
52 //    define protected amdgpu_kernel void @k0() {
53 //    entry:
54 //      call void @f0()
55 //      ret void
56 //    }
57 //
58 //    Output IR:
59 //
60 //    @lds = internal addrspace(3) global [4 x i32] undef, align 16
61 //    @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
62 //
63 //    define internal void @f0() {
64 //    entry:
65 //      %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
66 //      %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
67 //      %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
68 //      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
69 //             i32 0, i32 0
70 //      ret void
71 //    }
72 //
73 //    define protected amdgpu_kernel void @k0() {
74 //    entry:
75 //      store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
76 //            i16 addrspace(3)* @lds.ptr, align 2
77 //      call void @f0()
78 //      ret void
79 //    }
80 //
81 //===----------------------------------------------------------------------===//
82 
83 #include "AMDGPU.h"
84 #include "GCNSubtarget.h"
85 #include "Utils/AMDGPUBaseInfo.h"
86 #include "Utils/AMDGPUMemoryUtils.h"
87 #include "llvm/ADT/DenseMap.h"
88 #include "llvm/ADT/STLExtras.h"
89 #include "llvm/ADT/SetOperations.h"
90 #include "llvm/Analysis/CallGraph.h"
91 #include "llvm/CodeGen/TargetPassConfig.h"
92 #include "llvm/IR/Constants.h"
93 #include "llvm/IR/DerivedTypes.h"
94 #include "llvm/IR/IRBuilder.h"
95 #include "llvm/IR/InlineAsm.h"
96 #include "llvm/IR/Instructions.h"
97 #include "llvm/IR/IntrinsicsAMDGPU.h"
98 #include "llvm/IR/ReplaceConstant.h"
99 #include "llvm/InitializePasses.h"
100 #include "llvm/Pass.h"
101 #include "llvm/Support/Debug.h"
102 #include "llvm/Target/TargetMachine.h"
103 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
104 #include "llvm/Transforms/Utils/ModuleUtils.h"
105 #include <algorithm>
106 #include <vector>
107 
108 #define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"
109 
110 using namespace llvm;
111 
112 namespace {
113 
114 namespace AMDGPU {
115 /// Collect all the instructions where user \p U belongs to. \p U could be
116 /// instruction itself or it could be a constant expression which is used within
117 /// an instruction. If \p CollectKernelInsts is true, collect instructions only
118 /// from kernels, otherwise collect instructions only from non-kernel functions.
119 DenseMap<Function *, SmallPtrSet<Instruction *, 8>>
120 getFunctionToInstsMap(User *U, bool CollectKernelInsts);
121 
122 SmallPtrSet<Function *, 8> collectNonKernelAccessorsOfLDS(GlobalVariable *GV);
123 
124 } // namespace AMDGPU
125 
126 class ReplaceLDSUseImpl {
127   Module &M;
128   LLVMContext &Ctx;
129   const DataLayout &DL;
130   Constant *LDSMemBaseAddr;
131 
132   DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
133   DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
134   DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
135   DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
136   DenseMap<Function *, BasicBlock *> KernelToInitBB;
137   DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
138       FunctionToLDSToReplaceInst;
139 
140   // Collect LDS which requires their uses to be replaced by pointer.
collectLDSRequiringPointerReplace()141   std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
142     // Collect LDS which requires module lowering.
143     std::vector<GlobalVariable *> LDSGlobals =
144         llvm::AMDGPU::findLDSVariablesToLower(M, nullptr);
145 
146     // Remove LDS which don't qualify for replacement.
147     llvm::erase_if(LDSGlobals, [&](GlobalVariable *GV) {
148       return shouldIgnorePointerReplacement(GV);
149     });
150 
151     return LDSGlobals;
152   }
153 
154   // Returns true if uses of given LDS global within non-kernel functions should
155   // be keep as it is without pointer replacement.
shouldIgnorePointerReplacement(GlobalVariable * GV)156   bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
157     // LDS whose size is very small and doesn't exceed pointer size is not worth
158     // replacing.
159     if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
160       return true;
161 
162     // LDS which is not used from non-kernel function scope or it is used from
163     // global scope does not qualify for replacement.
164     LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
165     return LDSToNonKernels[GV].empty();
166 
167     // FIXME: When GV is used within all (or within most of the kernels), then
168     // it does not make sense to create a pointer for it.
169   }
170 
171   // Insert new global LDS pointer which points to LDS.
createLDSPointer(GlobalVariable * GV)172   GlobalVariable *createLDSPointer(GlobalVariable *GV) {
173     // LDS pointer which points to LDS is already created? Return it.
174     auto PointerEntry = LDSToPointer.insert(std::pair(GV, nullptr));
175     if (!PointerEntry.second)
176       return PointerEntry.first->second;
177 
178     // We need to create new LDS pointer which points to LDS.
179     //
180     // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
181     // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
182     auto *I16Ty = Type::getInt16Ty(Ctx);
183     GlobalVariable *LDSPointer = new GlobalVariable(
184         M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
185         GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
186         AMDGPUAS::LOCAL_ADDRESS);
187 
188     LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
189     LDSPointer->setAlignment(llvm::AMDGPU::getAlign(DL, LDSPointer));
190 
191     // Mark that an associated LDS pointer is created for LDS.
192     LDSToPointer[GV] = LDSPointer;
193 
194     return LDSPointer;
195   }
196 
197   // Split entry basic block in such a way that only lane 0 of each wave does
198   // the LDS pointer initialization, and return newly created basic block.
activateLaneZero(Function * K)199   BasicBlock *activateLaneZero(Function *K) {
200     // If the entry basic block of kernel K is already split, then return
201     // newly created basic block.
202     auto BasicBlockEntry = KernelToInitBB.insert(std::pair(K, nullptr));
203     if (!BasicBlockEntry.second)
204       return BasicBlockEntry.first->second;
205 
206     // Split entry basic block of kernel K.
207     auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt()));
208     IRBuilder<> Builder(EI);
209 
210     Value *Mbcnt =
211         Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
212                                 {Builder.getInt32(-1), Builder.getInt32(0)});
213     Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
214     Instruction *WB = cast<Instruction>(
215         Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));
216 
217     BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();
218 
219     // Mark that the entry basic block of kernel K is split.
220     KernelToInitBB[K] = NBB;
221 
222     return NBB;
223   }
224 
225   // Within given kernel, initialize given LDS pointer to point to given LDS.
initializeLDSPointer(Function * K,GlobalVariable * GV,GlobalVariable * LDSPointer)226   void initializeLDSPointer(Function *K, GlobalVariable *GV,
227                             GlobalVariable *LDSPointer) {
228     // If LDS pointer is already initialized within K, then nothing to do.
229     auto PointerEntry = KernelToLDSPointers.insert(
230         std::pair(K, SmallPtrSet<GlobalVariable *, 8>()));
231     if (!PointerEntry.second)
232       if (PointerEntry.first->second.contains(LDSPointer))
233         return;
234 
235     // Insert instructions at EI which initialize LDS pointer to point-to LDS
236     // within kernel K.
237     //
238     // That is, convert pointer type of GV to i16, and then store this converted
239     // i16 value within LDSPointer which is of type i16*.
240     auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
241     IRBuilder<> Builder(EI);
242     Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
243                         LDSPointer);
244 
245     // Mark that LDS pointer is initialized within kernel K.
246     KernelToLDSPointers[K].insert(LDSPointer);
247   }
248 
249   // We have created an LDS pointer for LDS, and initialized it to point-to LDS
250   // within all relevant kernels. Now replace all the uses of LDS within
251   // non-kernel functions by LDS pointer.
replaceLDSUseByPointer(GlobalVariable * GV,GlobalVariable * LDSPointer)252   void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
253     SmallVector<User *, 8> LDSUsers(GV->users());
254     for (auto *U : LDSUsers) {
255       // When `U` is a constant expression, it is possible that same constant
256       // expression exists within multiple instructions, and within multiple
257       // non-kernel functions. Collect all those non-kernel functions and all
258       // those instructions within which `U` exist.
259       auto FunctionToInsts =
260           AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);
261 
262       for (const auto &FunctionToInst : FunctionToInsts) {
263         Function *F = FunctionToInst.first;
264         auto &Insts = FunctionToInst.second;
265         for (auto *I : Insts) {
266           // If `U` is a constant expression, then we need to break the
267           // associated instruction into a set of separate instructions by
268           // converting constant expressions into instructions.
269           SmallPtrSet<Instruction *, 8> UserInsts;
270 
271           if (U == I) {
272             // `U` is an instruction, conversion from constant expression to
273             // set of instructions is *not* required.
274             UserInsts.insert(I);
275           } else {
276             // `U` is a constant expression, convert it into corresponding set
277             // of instructions.
278             auto *CE = cast<ConstantExpr>(U);
279             convertConstantExprsToInstructions(I, CE, &UserInsts);
280           }
281 
282           // Go through all the user instructions, if LDS exist within them as
283           // an operand, then replace it by replace instruction.
284           for (auto *II : UserInsts) {
285             auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
286             II->replaceUsesOfWith(GV, ReplaceInst);
287           }
288         }
289       }
290     }
291   }
292 
293   // Create a set of replacement instructions which together replace LDS within
294   // non-kernel function F by accessing LDS indirectly using LDS pointer.
getReplacementInst(Function * F,GlobalVariable * GV,GlobalVariable * LDSPointer)295   Value *getReplacementInst(Function *F, GlobalVariable *GV,
296                             GlobalVariable *LDSPointer) {
297     // If the instruction which replaces LDS within F is already created, then
298     // return it.
299     auto LDSEntry = FunctionToLDSToReplaceInst.insert(
300         std::pair(F, DenseMap<GlobalVariable *, Value *>()));
301     if (!LDSEntry.second) {
302       auto ReplaceInstEntry =
303           LDSEntry.first->second.insert(std::pair(GV, nullptr));
304       if (!ReplaceInstEntry.second)
305         return ReplaceInstEntry.first->second;
306     }
307 
308     // Get the instruction insertion point within the beginning of the entry
309     // block of current non-kernel function.
310     auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
311     IRBuilder<> Builder(EI);
312 
313     // Insert required set of instructions which replace LDS within F.
314     auto *V = Builder.CreateBitCast(
315         Builder.CreateGEP(
316             Builder.getInt8Ty(), LDSMemBaseAddr,
317             Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
318         GV->getType());
319 
320     // Mark that the replacement instruction which replace LDS within F is
321     // created.
322     FunctionToLDSToReplaceInst[F][GV] = V;
323 
324     return V;
325   }
326 
327 public:
ReplaceLDSUseImpl(Module & M)328   ReplaceLDSUseImpl(Module &M)
329       : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
330     LDSMemBaseAddr = Constant::getIntegerValue(
331         PointerType::get(Type::getInt8Ty(M.getContext()),
332                          AMDGPUAS::LOCAL_ADDRESS),
333         APInt(32, 0));
334   }
335 
336   // Entry-point function which interface ReplaceLDSUseImpl with outside of the
337   // class.
338   bool replaceLDSUse();
339 
340 private:
341   // For a given LDS from collected LDS globals set, replace its non-kernel
342   // function scope uses by pointer.
343   bool replaceLDSUse(GlobalVariable *GV);
344 };
345 
346 // For given LDS from collected LDS globals set, replace its non-kernel function
347 // scope uses by pointer.
replaceLDSUse(GlobalVariable * GV)348 bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
349   // Holds all those non-kernel functions within which LDS is being accessed.
350   SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];
351 
352   // The LDS pointer which points to LDS and replaces all the uses of LDS.
353   GlobalVariable *LDSPointer = nullptr;
354 
355   // Traverse through each kernel K, check and if required, initialize the
356   // LDS pointer to point to LDS within K.
357   for (const auto &KernelToCallee : KernelToCallees) {
358     Function *K = KernelToCallee.first;
359     SmallPtrSet<Function *, 8> Callees = KernelToCallee.second;
360 
361     // Compute reachable and LDS used callees for kernel K.
362     set_intersect(Callees, LDSAccessors);
363 
364     // None of the LDS accessing non-kernel functions are reachable from
365     // kernel K. Hence, no need to initialize LDS pointer within kernel K.
366     if (Callees.empty())
367       continue;
368 
369     // We have found reachable and LDS used callees for kernel K, and we need to
370     // initialize LDS pointer within kernel K, and we need to replace LDS use
371     // within those callees by LDS pointer.
372     //
373     // But, first check if LDS pointer is already created, if not create one.
374     LDSPointer = createLDSPointer(GV);
375 
376     // Initialize LDS pointer to point to LDS within kernel K.
377     initializeLDSPointer(K, GV, LDSPointer);
378   }
379 
380   // We have not found reachable and LDS used callees for any of the kernels,
381   // and hence we have not created LDS pointer.
382   if (!LDSPointer)
383     return false;
384 
385   // We have created an LDS pointer for LDS, and initialized it to point-to LDS
386   // within all relevant kernels. Now replace all the uses of LDS within
387   // non-kernel functions by LDS pointer.
388   replaceLDSUseByPointer(GV, LDSPointer);
389 
390   return true;
391 }
392 
393 namespace AMDGPU {
394 
395 // An helper class for collecting all reachable callees for each kernel defined
396 // within the module.
397 class CollectReachableCallees {
398   Module &M;
399   CallGraph CG;
400   SmallPtrSet<CallGraphNode *, 8> AddressTakenFunctions;
401 
402   // Collect all address taken functions within the module.
collectAddressTakenFunctions()403   void collectAddressTakenFunctions() {
404     auto *ECNode = CG.getExternalCallingNode();
405 
406     for (const auto &GI : *ECNode) {
407       auto *CGN = GI.second;
408       auto *F = CGN->getFunction();
409       if (!F || F->isDeclaration() || llvm::AMDGPU::isKernelCC(F))
410         continue;
411       AddressTakenFunctions.insert(CGN);
412     }
413   }
414 
415   // For given kernel, collect all its reachable non-kernel functions.
collectReachableCallees(Function * K)416   SmallPtrSet<Function *, 8> collectReachableCallees(Function *K) {
417     SmallPtrSet<Function *, 8> ReachableCallees;
418 
419     // Call graph node which represents this kernel.
420     auto *KCGN = CG[K];
421 
422     // Go through all call graph nodes reachable from the node representing this
423     // kernel, visit all their call sites, if the call site is direct, add
424     // corresponding callee to reachable callee set, if it is indirect, resolve
425     // the indirect call site to potential reachable callees, add them to
426     // reachable callee set, and repeat the process for the newly added
427     // potential callee nodes.
428     //
429     // FIXME: Need to handle bit-casted function pointers.
430     //
431     SmallVector<CallGraphNode *, 8> CGNStack(depth_first(KCGN));
432     SmallPtrSet<CallGraphNode *, 8> VisitedCGNodes;
433     while (!CGNStack.empty()) {
434       auto *CGN = CGNStack.pop_back_val();
435 
436       if (!VisitedCGNodes.insert(CGN).second)
437         continue;
438 
439       // Ignore call graph node which does not have associated function or
440       // associated function is not a definition.
441       if (!CGN->getFunction() || CGN->getFunction()->isDeclaration())
442         continue;
443 
444       for (const auto &GI : *CGN) {
445         auto *RCB = cast<CallBase>(*GI.first);
446         auto *RCGN = GI.second;
447 
448         if (auto *DCallee = RCGN->getFunction()) {
449           ReachableCallees.insert(DCallee);
450         } else if (RCB->isIndirectCall()) {
451           auto *RCBFTy = RCB->getFunctionType();
452           for (auto *ACGN : AddressTakenFunctions) {
453             auto *ACallee = ACGN->getFunction();
454             if (ACallee->getFunctionType() == RCBFTy) {
455               ReachableCallees.insert(ACallee);
456               CGNStack.append(df_begin(ACGN), df_end(ACGN));
457             }
458           }
459         }
460       }
461     }
462 
463     return ReachableCallees;
464   }
465 
466 public:
CollectReachableCallees(Module & M)467   explicit CollectReachableCallees(Module &M) : M(M), CG(CallGraph(M)) {
468     // Collect address taken functions.
469     collectAddressTakenFunctions();
470   }
471 
collectReachableCallees(DenseMap<Function *,SmallPtrSet<Function *,8>> & KernelToCallees)472   void collectReachableCallees(
473       DenseMap<Function *, SmallPtrSet<Function *, 8>> &KernelToCallees) {
474     // Collect reachable callee set for each kernel defined in the module.
475     for (Function &F : M.functions()) {
476       if (!llvm::AMDGPU::isKernelCC(&F))
477         continue;
478       Function *K = &F;
479       KernelToCallees[K] = collectReachableCallees(K);
480     }
481   }
482 };
483 
484 /// Collect reachable callees for each kernel defined in the module \p M and
485 /// return collected callees at \p KernelToCallees.
collectReachableCallees(Module & M,DenseMap<Function *,SmallPtrSet<Function *,8>> & KernelToCallees)486 void collectReachableCallees(
487     Module &M,
488     DenseMap<Function *, SmallPtrSet<Function *, 8>> &KernelToCallees) {
489   CollectReachableCallees CRC{M};
490   CRC.collectReachableCallees(KernelToCallees);
491 }
492 
493 /// For the given LDS global \p GV, visit all its users and collect all
494 /// non-kernel functions within which \p GV is used and return collected list of
495 /// such non-kernel functions.
collectNonKernelAccessorsOfLDS(GlobalVariable * GV)496 SmallPtrSet<Function *, 8> collectNonKernelAccessorsOfLDS(GlobalVariable *GV) {
497   SmallPtrSet<Function *, 8> LDSAccessors;
498   SmallVector<User *, 8> UserStack(GV->users());
499   SmallPtrSet<User *, 8> VisitedUsers;
500 
501   while (!UserStack.empty()) {
502     auto *U = UserStack.pop_back_val();
503 
504     // `U` is already visited? continue to next one.
505     if (!VisitedUsers.insert(U).second)
506       continue;
507 
508     // `U` is a global variable which is initialized with LDS. Ignore LDS.
509     if (isa<GlobalValue>(U))
510       return SmallPtrSet<Function *, 8>();
511 
512     // Recursively explore constant users.
513     if (isa<Constant>(U)) {
514       append_range(UserStack, U->users());
515       continue;
516     }
517 
518     // `U` should be an instruction, if it belongs to a non-kernel function F,
519     // then collect F.
520     Function *F = cast<Instruction>(U)->getFunction();
521     if (!llvm::AMDGPU::isKernelCC(F))
522       LDSAccessors.insert(F);
523   }
524 
525   return LDSAccessors;
526 }
527 
528 DenseMap<Function *, SmallPtrSet<Instruction *, 8>>
getFunctionToInstsMap(User * U,bool CollectKernelInsts)529 getFunctionToInstsMap(User *U, bool CollectKernelInsts) {
530   DenseMap<Function *, SmallPtrSet<Instruction *, 8>> FunctionToInsts;
531   SmallVector<User *, 8> UserStack;
532   SmallPtrSet<User *, 8> VisitedUsers;
533 
534   UserStack.push_back(U);
535 
536   while (!UserStack.empty()) {
537     auto *UU = UserStack.pop_back_val();
538 
539     if (!VisitedUsers.insert(UU).second)
540       continue;
541 
542     if (isa<GlobalValue>(UU))
543       continue;
544 
545     if (isa<Constant>(UU)) {
546       append_range(UserStack, UU->users());
547       continue;
548     }
549 
550     auto *I = cast<Instruction>(UU);
551     Function *F = I->getFunction();
552     if (CollectKernelInsts) {
553       if (!llvm::AMDGPU::isKernelCC(F)) {
554         continue;
555       }
556     } else {
557       if (llvm::AMDGPU::isKernelCC(F)) {
558         continue;
559       }
560     }
561 
562     FunctionToInsts.insert(std::pair(F, SmallPtrSet<Instruction *, 8>()));
563     FunctionToInsts[F].insert(I);
564   }
565 
566   return FunctionToInsts;
567 }
568 
569 } // namespace AMDGPU
570 
571 // Entry-point function which interface ReplaceLDSUseImpl with outside of the
572 // class.
replaceLDSUse()573 bool ReplaceLDSUseImpl::replaceLDSUse() {
574   // Collect LDS which requires their uses to be replaced by pointer.
575   std::vector<GlobalVariable *> LDSGlobals =
576       collectLDSRequiringPointerReplace();
577 
578   // No LDS to pointer-replace. Nothing to do.
579   if (LDSGlobals.empty())
580     return false;
581 
582   // Collect reachable callee set for each kernel defined in the module.
583   AMDGPU::collectReachableCallees(M, KernelToCallees);
584 
585   if (KernelToCallees.empty()) {
586     // Either module does not have any kernel definitions, or none of the kernel
587     // has a call to non-kernel functions, or we could not resolve any of the
588     // call sites to proper non-kernel functions, because of the situations like
589     // inline asm calls. Nothing to replace.
590     return false;
591   }
592 
593   // For every LDS from collected LDS globals set, replace its non-kernel
594   // function scope use by pointer.
595   bool Changed = false;
596   for (auto *GV : LDSGlobals)
597     Changed |= replaceLDSUse(GV);
598 
599   return Changed;
600 }
601 
602 class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
603 public:
604   static char ID;
605 
AMDGPUReplaceLDSUseWithPointer()606   AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
607     initializeAMDGPUReplaceLDSUseWithPointerPass(
608         *PassRegistry::getPassRegistry());
609   }
610 
611   bool runOnModule(Module &M) override;
612 
getAnalysisUsage(AnalysisUsage & AU) const613   void getAnalysisUsage(AnalysisUsage &AU) const override {
614     AU.addRequired<TargetPassConfig>();
615   }
616 };
617 
618 } // namespace
619 
620 char AMDGPUReplaceLDSUseWithPointer::ID = 0;
621 char &llvm::AMDGPUReplaceLDSUseWithPointerID =
622     AMDGPUReplaceLDSUseWithPointer::ID;
623 
624 INITIALIZE_PASS_BEGIN(
625     AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
626     "Replace within non-kernel function use of LDS with pointer",
627     false /*only look at the cfg*/, false /*analysis pass*/)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)628 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
629 INITIALIZE_PASS_END(
630     AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
631     "Replace within non-kernel function use of LDS with pointer",
632     false /*only look at the cfg*/, false /*analysis pass*/)
633 
634 bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
635   ReplaceLDSUseImpl LDSUseReplacer{M};
636   return LDSUseReplacer.replaceLDSUse();
637 }
638 
createAMDGPUReplaceLDSUseWithPointerPass()639 ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
640   return new AMDGPUReplaceLDSUseWithPointer();
641 }
642 
643 PreservedAnalyses
run(Module & M,ModuleAnalysisManager & AM)644 AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
645   ReplaceLDSUseImpl LDSUseReplacer{M};
646   LDSUseReplacer.replaceLDSUse();
647   return PreservedAnalyses::all();
648 }
649