xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp (revision ded956440739ae326a99cbaef18ce4362e972679)
1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction
11 /// selection.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/UniformityAnalysis.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstVisitor.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/KnownBits.h"
26 #include "llvm/Transforms/Utils/Local.h"
27 
28 #define DEBUG_TYPE "amdgpu-late-codegenprepare"
29 
30 using namespace llvm;
31 
32 // Scalar load widening needs running after load-store-vectorizer as that pass
33 // doesn't handle overlapping cases. In addition, this pass enhances the
34 // widening to handle cases where scalar sub-dword loads are naturally aligned
35 // only but not dword aligned.
36 static cl::opt<bool>
37     WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
38                cl::desc("Widen sub-dword constant address space loads in "
39                         "AMDGPULateCodeGenPrepare"),
40                cl::ReallyHidden, cl::init(true));
41 
42 namespace {
43 
44 class AMDGPULateCodeGenPrepare
45     : public FunctionPass,
46       public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
47   Module *Mod = nullptr;
48   const DataLayout *DL = nullptr;
49 
50   AssumptionCache *AC = nullptr;
51   UniformityInfo *UA = nullptr;
52 
53 public:
54   static char ID;
55 
56   AMDGPULateCodeGenPrepare() : FunctionPass(ID) {}
57 
58   StringRef getPassName() const override {
59     return "AMDGPU IR late optimizations";
60   }
61 
62   void getAnalysisUsage(AnalysisUsage &AU) const override {
63     AU.addRequired<TargetPassConfig>();
64     AU.addRequired<AssumptionCacheTracker>();
65     AU.addRequired<UniformityInfoWrapperPass>();
66     AU.setPreservesAll();
67   }
68 
69   bool doInitialization(Module &M) override;
70   bool runOnFunction(Function &F) override;
71 
72   bool visitInstruction(Instruction &) { return false; }
73 
74   // Check if the specified value is at least DWORD aligned.
75   bool isDWORDAligned(const Value *V) const {
76     KnownBits Known = computeKnownBits(V, *DL, 0, AC);
77     return Known.countMinTrailingZeros() >= 2;
78   }
79 
80   bool canWidenScalarExtLoad(LoadInst &LI) const;
81   bool visitLoadInst(LoadInst &LI);
82 };
83 
84 using ValueToValueMap = DenseMap<const Value *, Value *>;
85 
86 class LiveRegOptimizer {
87 private:
88   Module *Mod = nullptr;
89   const DataLayout *DL = nullptr;
90   const GCNSubtarget *ST;
91   /// The scalar type to convert to
92   Type *ConvertToScalar;
93   /// The set of visited Instructions
94   SmallPtrSet<Instruction *, 4> Visited;
95   /// The set of Instructions to be deleted
96   SmallPtrSet<Instruction *, 4> DeadInstrs;
97   /// Map of Value -> Converted Value
98   ValueToValueMap ValMap;
99   /// Map of containing conversions from Optimal Type -> Original Type per BB.
100   DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101 
102 public:
103   /// Calculate the and \p return  the type to convert to given a problematic \p
104   /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105   Type *calculateConvertType(Type *OriginalType);
106   /// Convert the virtual register defined by \p V to the compatible vector of
107   /// legal type
108   Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
109   /// Convert the virtual register defined by \p V back to the original type \p
110   /// ConvertType, stripping away the MSBs in cases where there was an imperfect
111   /// fit (e.g. v2i32 -> v7i8)
112   Value *convertFromOptType(Type *ConvertType, Instruction *V,
113                             BasicBlock::iterator &InstPt,
114                             BasicBlock *InsertBlock);
115   /// Check for problematic PHI nodes or cross-bb values based on the value
116   /// defined by \p I, and coerce to legal types if necessary. For problematic
117   /// PHI node, we coerce all incoming values in a single invocation.
118   bool optimizeLiveType(Instruction *I);
119 
120   /// Remove all instructions that have become dead (i.e. all the re-typed PHIs)
121   void removeDeadInstrs();
122 
123   // Whether or not the type should be replaced to avoid inefficient
124   // legalization code
125   bool shouldReplace(Type *ITy) {
126     FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
127     if (!VTy)
128       return false;
129 
130     auto TLI = ST->getTargetLowering();
131 
132     Type *EltTy = VTy->getElementType();
133     // If the element size is not less than the convert to scalar size, then we
134     // can't do any bit packing
135     if (!EltTy->isIntegerTy() ||
136         EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
137       return false;
138 
139     // Only coerce illegal types
140     TargetLoweringBase::LegalizeKind LK =
141         TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
142     return LK.first != TargetLoweringBase::TypeLegal;
143   }
144 
145   LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
146     DL = &Mod->getDataLayout();
147     ConvertToScalar = Type::getInt32Ty(Mod->getContext());
148   }
149 };
150 
151 } // end anonymous namespace
152 
153 bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) {
154   Mod = &M;
155   DL = &Mod->getDataLayout();
156   return false;
157 }
158 
159 bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
160   if (skipFunction(F))
161     return false;
162 
163   const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
164   const TargetMachine &TM = TPC.getTM<TargetMachine>();
165   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
166   if (ST.hasScalarSubwordLoads())
167     return false;
168 
169   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
170   UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
171 
172   // "Optimize" the virtual regs that cross basic block boundaries. When
173   // building the SelectionDAG, vectors of illegal types that cross basic blocks
174   // will be scalarized and widened, with each scalar living in its
175   // own register. To work around this, this optimization converts the
176   // vectors to equivalent vectors of legal type (which are converted back
177   // before uses in subsequent blocks), to pack the bits into fewer physical
178   // registers (used in CopyToReg/CopyFromReg pairs).
179   LiveRegOptimizer LRO(Mod, &ST);
180 
181   bool Changed = false;
182 
183   for (auto &BB : F)
184     for (Instruction &I : make_early_inc_range(BB)) {
185       Changed |= visit(I);
186       Changed |= LRO.optimizeLiveType(&I);
187     }
188 
189   LRO.removeDeadInstrs();
190   return Changed;
191 }
192 
193 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
194   assert(OriginalType->getScalarSizeInBits() <=
195          ConvertToScalar->getScalarSizeInBits());
196 
197   FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
198 
199   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
200   TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
201   unsigned ConvertEltCount =
202       (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
203 
204   if (OriginalSize <= ConvertScalarSize)
205     return IntegerType::get(Mod->getContext(), ConvertScalarSize);
206 
207   return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
208                          ConvertEltCount, false);
209 }
210 
211 Value *LiveRegOptimizer::convertToOptType(Instruction *V,
212                                           BasicBlock::iterator &InsertPt) {
213   FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
214   Type *NewTy = calculateConvertType(V->getType());
215 
216   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
217   TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
218 
219   IRBuilder<> Builder(V->getParent(), InsertPt);
220   // If there is a bitsize match, we can fit the old vector into a new vector of
221   // desired type.
222   if (OriginalSize == NewSize)
223     return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
224 
225   // If there is a bitsize mismatch, we must use a wider vector.
226   assert(NewSize > OriginalSize);
227   uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
228 
229   SmallVector<int, 8> ShuffleMask;
230   uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
231   for (unsigned I = 0; I < OriginalElementCount; I++)
232     ShuffleMask.push_back(I);
233 
234   for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
235     ShuffleMask.push_back(OriginalElementCount);
236 
237   Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
238   return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
239 }
240 
241 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
242                                             BasicBlock::iterator &InsertPt,
243                                             BasicBlock *InsertBB) {
244   FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
245 
246   TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
247   TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
248 
249   IRBuilder<> Builder(InsertBB, InsertPt);
250   // If there is a bitsize match, we simply convert back to the original type.
251   if (OriginalSize == NewSize)
252     return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
253 
254   // If there is a bitsize mismatch, then we must have used a wider value to
255   // hold the bits.
256   assert(OriginalSize > NewSize);
257   // For wide scalars, we can just truncate the value.
258   if (!V->getType()->isVectorTy()) {
259     Instruction *Trunc = cast<Instruction>(
260         Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
261     return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
262   }
263 
264   // For wider vectors, we must strip the MSBs to convert back to the original
265   // type.
266   VectorType *ExpandedVT = VectorType::get(
267       Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
268       (OriginalSize / NewVTy->getScalarSizeInBits()), false);
269   Instruction *Converted =
270       cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
271 
272   unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
273   SmallVector<int, 8> ShuffleMask(NarrowElementCount);
274   std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
275 
276   return Builder.CreateShuffleVector(Converted, ShuffleMask);
277 }
278 
279 bool LiveRegOptimizer::optimizeLiveType(Instruction *I) {
280   SmallVector<Instruction *, 4> Worklist;
281   SmallPtrSet<PHINode *, 4> PhiNodes;
282   SmallPtrSet<Instruction *, 4> Defs;
283   SmallPtrSet<Instruction *, 4> Uses;
284 
285   Worklist.push_back(cast<Instruction>(I));
286   while (!Worklist.empty()) {
287     Instruction *II = Worklist.pop_back_val();
288 
289     if (!Visited.insert(II).second)
290       continue;
291 
292     if (!shouldReplace(II->getType()))
293       continue;
294 
295     if (PHINode *Phi = dyn_cast<PHINode>(II)) {
296       PhiNodes.insert(Phi);
297       // Collect all the incoming values of problematic PHI nodes.
298       for (Value *V : Phi->incoming_values()) {
299         // Repeat the collection process for newly found PHI nodes.
300         if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
301           if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
302             Worklist.push_back(OpPhi);
303           continue;
304         }
305 
306         Instruction *IncInst = dyn_cast<Instruction>(V);
307         // Other incoming value types (e.g. vector literals) are unhandled
308         if (!IncInst && !isa<ConstantAggregateZero>(V))
309           return false;
310 
311         // Collect all other incoming values for coercion.
312         if (IncInst)
313           Defs.insert(IncInst);
314       }
315     }
316 
317     // Collect all relevant uses.
318     for (User *V : II->users()) {
319       // Repeat the collection process for problematic PHI nodes.
320       if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
321         if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
322           Worklist.push_back(OpPhi);
323         continue;
324       }
325 
326       Instruction *UseInst = cast<Instruction>(V);
327       // Collect all uses of PHINodes and any use the crosses BB boundaries.
328       if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
329         Uses.insert(UseInst);
330         if (!Defs.count(II) && !isa<PHINode>(II)) {
331           Defs.insert(II);
332         }
333       }
334     }
335   }
336 
337   // Coerce and track the defs.
338   for (Instruction *D : Defs) {
339     if (!ValMap.contains(D)) {
340       BasicBlock::iterator InsertPt = std::next(D->getIterator());
341       Value *ConvertVal = convertToOptType(D, InsertPt);
342       assert(ConvertVal);
343       ValMap[D] = ConvertVal;
344     }
345   }
346 
347   // Construct new-typed PHI nodes.
348   for (PHINode *Phi : PhiNodes) {
349     ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
350                                   Phi->getNumIncomingValues(),
351                                   Phi->getName() + ".tc", Phi->getIterator());
352   }
353 
354   // Connect all the PHI nodes with their new incoming values.
355   for (PHINode *Phi : PhiNodes) {
356     PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
357     bool MissingIncVal = false;
358     for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
359       Value *IncVal = Phi->getIncomingValue(I);
360       if (isa<ConstantAggregateZero>(IncVal)) {
361         Type *NewType = calculateConvertType(Phi->getType());
362         NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
363                             Phi->getIncomingBlock(I));
364       } else if (ValMap.contains(IncVal))
365         NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
366       else
367         MissingIncVal = true;
368     }
369     DeadInstrs.insert(MissingIncVal ? cast<Instruction>(ValMap[Phi]) : Phi);
370   }
371   // Coerce back to the original type and replace the uses.
372   for (Instruction *U : Uses) {
373     // Replace all converted operands for a use.
374     for (auto [OpIdx, Op] : enumerate(U->operands())) {
375       if (ValMap.contains(Op)) {
376         Value *NewVal = nullptr;
377         if (BBUseValMap.contains(U->getParent()) &&
378             BBUseValMap[U->getParent()].contains(ValMap[Op]))
379           NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
380         else {
381           BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
382           NewVal =
383               convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
384                                  InsertPt, U->getParent());
385           BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
386         }
387         assert(NewVal);
388         U->setOperand(OpIdx, NewVal);
389       }
390     }
391   }
392 
393   return true;
394 }
395 
396 void LiveRegOptimizer::removeDeadInstrs() {
397   // Remove instrs that have been marked dead after type-coercion.
398   for (auto *I : DeadInstrs) {
399     I->replaceAllUsesWith(PoisonValue::get(I->getType()));
400     I->eraseFromParent();
401   }
402 }
403 
404 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
405   unsigned AS = LI.getPointerAddressSpace();
406   // Skip non-constant address space.
407   if (AS != AMDGPUAS::CONSTANT_ADDRESS &&
408       AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT)
409     return false;
410   // Skip non-simple loads.
411   if (!LI.isSimple())
412     return false;
413   Type *Ty = LI.getType();
414   // Skip aggregate types.
415   if (Ty->isAggregateType())
416     return false;
417   unsigned TySize = DL->getTypeStoreSize(Ty);
418   // Only handle sub-DWORD loads.
419   if (TySize >= 4)
420     return false;
421   // That load must be at least naturally aligned.
422   if (LI.getAlign() < DL->getABITypeAlign(Ty))
423     return false;
424   // It should be uniform, i.e. a scalar load.
425   return UA->isUniform(&LI);
426 }
427 
428 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
429   if (!WidenLoads)
430     return false;
431 
432   // Skip if that load is already aligned on DWORD at least as it's handled in
433   // SDAG.
434   if (LI.getAlign() >= 4)
435     return false;
436 
437   if (!canWidenScalarExtLoad(LI))
438     return false;
439 
440   int64_t Offset = 0;
441   auto *Base =
442       GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL);
443   // If that base is not DWORD aligned, it's not safe to perform the following
444   // transforms.
445   if (!isDWORDAligned(Base))
446     return false;
447 
448   int64_t Adjust = Offset & 0x3;
449   if (Adjust == 0) {
450     // With a zero adjust, the original alignment could be promoted with a
451     // better one.
452     LI.setAlignment(Align(4));
453     return true;
454   }
455 
456   IRBuilder<> IRB(&LI);
457   IRB.SetCurrentDebugLocation(LI.getDebugLoc());
458 
459   unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType());
460   auto IntNTy = Type::getIntNTy(LI.getContext(), LdBits);
461 
462   auto *NewPtr = IRB.CreateConstGEP1_64(
463       IRB.getInt8Ty(),
464       IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()),
465       Offset - Adjust);
466 
467   LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4));
468   NewLd->copyMetadata(LI);
469   NewLd->setMetadata(LLVMContext::MD_range, nullptr);
470 
471   unsigned ShAmt = Adjust * 8;
472   auto *NewVal = IRB.CreateBitCast(
473       IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType());
474   LI.replaceAllUsesWith(NewVal);
475   RecursivelyDeleteTriviallyDeadInstructions(&LI);
476 
477   return true;
478 }
479 
480 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
481                       "AMDGPU IR late optimizations", false, false)
482 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
483 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
484 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
485 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
486                     "AMDGPU IR late optimizations", false, false)
487 
488 char AMDGPULateCodeGenPrepare::ID = 0;
489 
490 FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() {
491   return new AMDGPULateCodeGenPrepare();
492 }
493