xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp (revision f15da5fb7845979889b73f3f47bb126617e823cf)
1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction
11 /// selection.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/UniformityAnalysis.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstVisitor.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/KnownBits.h"
25 #include "llvm/Transforms/Utils/Local.h"
26 
27 #define DEBUG_TYPE "amdgpu-late-codegenprepare"
28 
29 using namespace llvm;
30 
31 // Scalar load widening needs running after load-store-vectorizer as that pass
32 // doesn't handle overlapping cases. In addition, this pass enhances the
33 // widening to handle cases where scalar sub-dword loads are naturally aligned
34 // only but not dword aligned.
35 static cl::opt<bool>
36     WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
37                cl::desc("Widen sub-dword constant address space loads in "
38                         "AMDGPULateCodeGenPrepare"),
39                cl::ReallyHidden, cl::init(true));
40 
41 namespace {
42 
43 class AMDGPULateCodeGenPrepare
44     : public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
45   Function &F;
46   const DataLayout &DL;
47   const GCNSubtarget &ST;
48 
49   AssumptionCache *const AC;
50   UniformityInfo &UA;
51 
52   SmallVector<WeakTrackingVH, 8> DeadInsts;
53 
54 public:
55   AMDGPULateCodeGenPrepare(Function &F, const GCNSubtarget &ST,
56                            AssumptionCache *AC, UniformityInfo &UA)
57       : F(F), DL(F.getDataLayout()), ST(ST), AC(AC), UA(UA) {}
58   bool run();
59   bool visitInstruction(Instruction &) { return false; }
60 
61   // Check if the specified value is at least DWORD aligned.
62   bool isDWORDAligned(const Value *V) const {
63     KnownBits Known = computeKnownBits(V, DL, 0, AC);
64     return Known.countMinTrailingZeros() >= 2;
65   }
66 
67   bool canWidenScalarExtLoad(LoadInst &LI) const;
68   bool visitLoadInst(LoadInst &LI);
69 };
70 
71 using ValueToValueMap = DenseMap<const Value *, Value *>;
72 
73 class LiveRegOptimizer {
74 private:
75   Module &Mod;
76   const DataLayout &DL;
77   const GCNSubtarget &ST;
78   /// The scalar type to convert to
79   Type *const ConvertToScalar;
80   /// The set of visited Instructions
81   SmallPtrSet<Instruction *, 4> Visited;
82   /// Map of Value -> Converted Value
83   ValueToValueMap ValMap;
84   /// Map of containing conversions from Optimal Type -> Original Type per BB.
85   DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
86 
87 public:
88   /// Calculate the and \p return  the type to convert to given a problematic \p
89   /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
90   Type *calculateConvertType(Type *OriginalType);
91   /// Convert the virtual register defined by \p V to the compatible vector of
92   /// legal type
93   Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
94   /// Convert the virtual register defined by \p V back to the original type \p
95   /// ConvertType, stripping away the MSBs in cases where there was an imperfect
96   /// fit (e.g. v2i32 -> v7i8)
97   Value *convertFromOptType(Type *ConvertType, Instruction *V,
98                             BasicBlock::iterator &InstPt,
99                             BasicBlock *InsertBlock);
100   /// Check for problematic PHI nodes or cross-bb values based on the value
101   /// defined by \p I, and coerce to legal types if necessary. For problematic
102   /// PHI node, we coerce all incoming values in a single invocation.
103   bool optimizeLiveType(Instruction *I,
104                         SmallVectorImpl<WeakTrackingVH> &DeadInsts);
105 
106   // Whether or not the type should be replaced to avoid inefficient
107   // legalization code
108   bool shouldReplace(Type *ITy) {
109     FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
110     if (!VTy)
111       return false;
112 
113     const auto *TLI = ST.getTargetLowering();
114 
115     Type *EltTy = VTy->getElementType();
116     // If the element size is not less than the convert to scalar size, then we
117     // can't do any bit packing
118     if (!EltTy->isIntegerTy() ||
119         EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
120       return false;
121 
122     // Only coerce illegal types
123     TargetLoweringBase::LegalizeKind LK =
124         TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
125     return LK.first != TargetLoweringBase::TypeLegal;
126   }
127 
128   LiveRegOptimizer(Module &Mod, const GCNSubtarget &ST)
129       : Mod(Mod), DL(Mod.getDataLayout()), ST(ST),
130         ConvertToScalar(Type::getInt32Ty(Mod.getContext())) {}
131 };
132 
133 } // end anonymous namespace
134 
135 bool AMDGPULateCodeGenPrepare::run() {
136   // "Optimize" the virtual regs that cross basic block boundaries. When
137   // building the SelectionDAG, vectors of illegal types that cross basic blocks
138   // will be scalarized and widened, with each scalar living in its
139   // own register. To work around this, this optimization converts the
140   // vectors to equivalent vectors of legal type (which are converted back
141   // before uses in subsequent blocks), to pack the bits into fewer physical
142   // registers (used in CopyToReg/CopyFromReg pairs).
143   LiveRegOptimizer LRO(*F.getParent(), ST);
144 
145   bool Changed = false;
146 
147   bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads();
148 
149   for (auto &BB : reverse(F))
150     for (Instruction &I : make_early_inc_range(reverse(BB))) {
151       Changed |= !HasScalarSubwordLoads && visit(I);
152       Changed |= LRO.optimizeLiveType(&I, DeadInsts);
153     }
154 
155   RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts);
156   return Changed;
157 }
158 
159 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
160   assert(OriginalType->getScalarSizeInBits() <=
161          ConvertToScalar->getScalarSizeInBits());
162 
163   FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
164 
165   TypeSize OriginalSize = DL.getTypeSizeInBits(VTy);
166   TypeSize ConvertScalarSize = DL.getTypeSizeInBits(ConvertToScalar);
167   unsigned ConvertEltCount =
168       (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
169 
170   if (OriginalSize <= ConvertScalarSize)
171     return IntegerType::get(Mod.getContext(), ConvertScalarSize);
172 
173   return VectorType::get(Type::getIntNTy(Mod.getContext(), ConvertScalarSize),
174                          ConvertEltCount, false);
175 }
176 
177 Value *LiveRegOptimizer::convertToOptType(Instruction *V,
178                                           BasicBlock::iterator &InsertPt) {
179   FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
180   Type *NewTy = calculateConvertType(V->getType());
181 
182   TypeSize OriginalSize = DL.getTypeSizeInBits(VTy);
183   TypeSize NewSize = DL.getTypeSizeInBits(NewTy);
184 
185   IRBuilder<> Builder(V->getParent(), InsertPt);
186   // If there is a bitsize match, we can fit the old vector into a new vector of
187   // desired type.
188   if (OriginalSize == NewSize)
189     return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
190 
191   // If there is a bitsize mismatch, we must use a wider vector.
192   assert(NewSize > OriginalSize);
193   uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
194 
195   SmallVector<int, 8> ShuffleMask;
196   uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
197   for (unsigned I = 0; I < OriginalElementCount; I++)
198     ShuffleMask.push_back(I);
199 
200   for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
201     ShuffleMask.push_back(OriginalElementCount);
202 
203   Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
204   return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
205 }
206 
207 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
208                                             BasicBlock::iterator &InsertPt,
209                                             BasicBlock *InsertBB) {
210   FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
211 
212   TypeSize OriginalSize = DL.getTypeSizeInBits(V->getType());
213   TypeSize NewSize = DL.getTypeSizeInBits(NewVTy);
214 
215   IRBuilder<> Builder(InsertBB, InsertPt);
216   // If there is a bitsize match, we simply convert back to the original type.
217   if (OriginalSize == NewSize)
218     return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
219 
220   // If there is a bitsize mismatch, then we must have used a wider value to
221   // hold the bits.
222   assert(OriginalSize > NewSize);
223   // For wide scalars, we can just truncate the value.
224   if (!V->getType()->isVectorTy()) {
225     Instruction *Trunc = cast<Instruction>(
226         Builder.CreateTrunc(V, IntegerType::get(Mod.getContext(), NewSize)));
227     return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
228   }
229 
230   // For wider vectors, we must strip the MSBs to convert back to the original
231   // type.
232   VectorType *ExpandedVT = VectorType::get(
233       Type::getIntNTy(Mod.getContext(), NewVTy->getScalarSizeInBits()),
234       (OriginalSize / NewVTy->getScalarSizeInBits()), false);
235   Instruction *Converted =
236       cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
237 
238   unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
239   SmallVector<int, 8> ShuffleMask(NarrowElementCount);
240   std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
241 
242   return Builder.CreateShuffleVector(Converted, ShuffleMask);
243 }
244 
245 bool LiveRegOptimizer::optimizeLiveType(
246     Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
247   SmallVector<Instruction *, 4> Worklist;
248   SmallPtrSet<PHINode *, 4> PhiNodes;
249   SmallPtrSet<Instruction *, 4> Defs;
250   SmallPtrSet<Instruction *, 4> Uses;
251 
252   Worklist.push_back(cast<Instruction>(I));
253   while (!Worklist.empty()) {
254     Instruction *II = Worklist.pop_back_val();
255 
256     if (!Visited.insert(II).second)
257       continue;
258 
259     if (!shouldReplace(II->getType()))
260       continue;
261 
262     if (PHINode *Phi = dyn_cast<PHINode>(II)) {
263       PhiNodes.insert(Phi);
264       // Collect all the incoming values of problematic PHI nodes.
265       for (Value *V : Phi->incoming_values()) {
266         // Repeat the collection process for newly found PHI nodes.
267         if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
268           if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
269             Worklist.push_back(OpPhi);
270           continue;
271         }
272 
273         Instruction *IncInst = dyn_cast<Instruction>(V);
274         // Other incoming value types (e.g. vector literals) are unhandled
275         if (!IncInst && !isa<ConstantAggregateZero>(V))
276           return false;
277 
278         // Collect all other incoming values for coercion.
279         if (IncInst)
280           Defs.insert(IncInst);
281       }
282     }
283 
284     // Collect all relevant uses.
285     for (User *V : II->users()) {
286       // Repeat the collection process for problematic PHI nodes.
287       if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
288         if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
289           Worklist.push_back(OpPhi);
290         continue;
291       }
292 
293       Instruction *UseInst = cast<Instruction>(V);
294       // Collect all uses of PHINodes and any use the crosses BB boundaries.
295       if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
296         Uses.insert(UseInst);
297         if (!isa<PHINode>(II))
298           Defs.insert(II);
299       }
300     }
301   }
302 
303   // Coerce and track the defs.
304   for (Instruction *D : Defs) {
305     if (!ValMap.contains(D)) {
306       BasicBlock::iterator InsertPt = std::next(D->getIterator());
307       Value *ConvertVal = convertToOptType(D, InsertPt);
308       assert(ConvertVal);
309       ValMap[D] = ConvertVal;
310     }
311   }
312 
313   // Construct new-typed PHI nodes.
314   for (PHINode *Phi : PhiNodes) {
315     ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
316                                   Phi->getNumIncomingValues(),
317                                   Phi->getName() + ".tc", Phi->getIterator());
318   }
319 
320   // Connect all the PHI nodes with their new incoming values.
321   for (PHINode *Phi : PhiNodes) {
322     PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
323     bool MissingIncVal = false;
324     for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
325       Value *IncVal = Phi->getIncomingValue(I);
326       if (isa<ConstantAggregateZero>(IncVal)) {
327         Type *NewType = calculateConvertType(Phi->getType());
328         NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
329                             Phi->getIncomingBlock(I));
330       } else if (Value *Val = ValMap.lookup(IncVal))
331         NewPhi->addIncoming(Val, Phi->getIncomingBlock(I));
332       else
333         MissingIncVal = true;
334     }
335     if (MissingIncVal) {
336       Value *DeadVal = ValMap[Phi];
337       // The coercion chain of the PHI is broken. Delete the Phi
338       // from the ValMap and any connected / user Phis.
339       SmallVector<Value *, 4> PHIWorklist;
340       SmallPtrSet<Value *, 4> VisitedPhis;
341       PHIWorklist.push_back(DeadVal);
342       while (!PHIWorklist.empty()) {
343         Value *NextDeadValue = PHIWorklist.pop_back_val();
344         VisitedPhis.insert(NextDeadValue);
345         auto OriginalPhi =
346             std::find_if(PhiNodes.begin(), PhiNodes.end(),
347                          [this, &NextDeadValue](PHINode *CandPhi) {
348                            return ValMap[CandPhi] == NextDeadValue;
349                          });
350         // This PHI may have already been removed from maps when
351         // unwinding a previous Phi
352         if (OriginalPhi != PhiNodes.end())
353           ValMap.erase(*OriginalPhi);
354 
355         DeadInsts.emplace_back(cast<Instruction>(NextDeadValue));
356 
357         for (User *U : NextDeadValue->users()) {
358           if (!VisitedPhis.contains(cast<PHINode>(U)))
359             PHIWorklist.push_back(U);
360         }
361       }
362     } else {
363       DeadInsts.emplace_back(cast<Instruction>(Phi));
364     }
365   }
366   // Coerce back to the original type and replace the uses.
367   for (Instruction *U : Uses) {
368     // Replace all converted operands for a use.
369     for (auto [OpIdx, Op] : enumerate(U->operands())) {
370       if (ValMap.contains(Op) && ValMap[Op]) {
371         Value *NewVal = nullptr;
372         if (BBUseValMap.contains(U->getParent()) &&
373             BBUseValMap[U->getParent()].contains(ValMap[Op]))
374           NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
375         else {
376           BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
377           // We may pick up ops that were previously converted for users in
378           // other blocks. If there is an originally typed definition of the Op
379           // already in this block, simply reuse it.
380           if (isa<Instruction>(Op) && !isa<PHINode>(Op) &&
381               U->getParent() == cast<Instruction>(Op)->getParent()) {
382             NewVal = Op;
383           } else {
384             NewVal =
385                 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
386                                    InsertPt, U->getParent());
387             BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
388           }
389         }
390         assert(NewVal);
391         U->setOperand(OpIdx, NewVal);
392       }
393     }
394   }
395 
396   return true;
397 }
398 
399 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
400   unsigned AS = LI.getPointerAddressSpace();
401   // Skip non-constant address space.
402   if (AS != AMDGPUAS::CONSTANT_ADDRESS &&
403       AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT)
404     return false;
405   // Skip non-simple loads.
406   if (!LI.isSimple())
407     return false;
408   Type *Ty = LI.getType();
409   // Skip aggregate types.
410   if (Ty->isAggregateType())
411     return false;
412   unsigned TySize = DL.getTypeStoreSize(Ty);
413   // Only handle sub-DWORD loads.
414   if (TySize >= 4)
415     return false;
416   // That load must be at least naturally aligned.
417   if (LI.getAlign() < DL.getABITypeAlign(Ty))
418     return false;
419   // It should be uniform, i.e. a scalar load.
420   return UA.isUniform(&LI);
421 }
422 
423 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
424   if (!WidenLoads)
425     return false;
426 
427   // Skip if that load is already aligned on DWORD at least as it's handled in
428   // SDAG.
429   if (LI.getAlign() >= 4)
430     return false;
431 
432   if (!canWidenScalarExtLoad(LI))
433     return false;
434 
435   int64_t Offset = 0;
436   auto *Base =
437       GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, DL);
438   // If that base is not DWORD aligned, it's not safe to perform the following
439   // transforms.
440   if (!isDWORDAligned(Base))
441     return false;
442 
443   int64_t Adjust = Offset & 0x3;
444   if (Adjust == 0) {
445     // With a zero adjust, the original alignment could be promoted with a
446     // better one.
447     LI.setAlignment(Align(4));
448     return true;
449   }
450 
451   IRBuilder<> IRB(&LI);
452   IRB.SetCurrentDebugLocation(LI.getDebugLoc());
453 
454   unsigned LdBits = DL.getTypeStoreSizeInBits(LI.getType());
455   auto *IntNTy = Type::getIntNTy(LI.getContext(), LdBits);
456 
457   auto *NewPtr = IRB.CreateConstGEP1_64(
458       IRB.getInt8Ty(),
459       IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()),
460       Offset - Adjust);
461 
462   LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4));
463   NewLd->copyMetadata(LI);
464   NewLd->setMetadata(LLVMContext::MD_range, nullptr);
465 
466   unsigned ShAmt = Adjust * 8;
467   Value *NewVal = IRB.CreateBitCast(
468       IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt),
469                       DL.typeSizeEqualsStoreSize(LI.getType()) ? IntNTy
470                                                                : LI.getType()),
471       LI.getType());
472   LI.replaceAllUsesWith(NewVal);
473   DeadInsts.emplace_back(&LI);
474 
475   return true;
476 }
477 
478 PreservedAnalyses
479 AMDGPULateCodeGenPreparePass::run(Function &F, FunctionAnalysisManager &FAM) {
480   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
481 
482   AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(F);
483   UniformityInfo &UI = FAM.getResult<UniformityInfoAnalysis>(F);
484 
485   bool Changed = AMDGPULateCodeGenPrepare(F, ST, &AC, UI).run();
486 
487   if (!Changed)
488     return PreservedAnalyses::all();
489   PreservedAnalyses PA = PreservedAnalyses::none();
490   PA.preserveSet<CFGAnalyses>();
491   return PA;
492 }
493 
494 class AMDGPULateCodeGenPrepareLegacy : public FunctionPass {
495 public:
496   static char ID;
497 
498   AMDGPULateCodeGenPrepareLegacy() : FunctionPass(ID) {}
499 
500   StringRef getPassName() const override {
501     return "AMDGPU IR late optimizations";
502   }
503 
504   void getAnalysisUsage(AnalysisUsage &AU) const override {
505     AU.addRequired<TargetPassConfig>();
506     AU.addRequired<AssumptionCacheTracker>();
507     AU.addRequired<UniformityInfoWrapperPass>();
508     AU.setPreservesAll();
509   }
510 
511   bool runOnFunction(Function &F) override;
512 };
513 
514 bool AMDGPULateCodeGenPrepareLegacy::runOnFunction(Function &F) {
515   if (skipFunction(F))
516     return false;
517 
518   const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
519   const TargetMachine &TM = TPC.getTM<TargetMachine>();
520   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
521 
522   AssumptionCache &AC =
523       getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
524   UniformityInfo &UI =
525       getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
526 
527   return AMDGPULateCodeGenPrepare(F, ST, &AC, UI).run();
528 }
529 
530 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE,
531                       "AMDGPU IR late optimizations", false, false)
532 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
533 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
534 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
535 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE,
536                     "AMDGPU IR late optimizations", false, false)
537 
538 char AMDGPULateCodeGenPrepareLegacy::ID = 0;
539 
540 FunctionPass *llvm::createAMDGPULateCodeGenPrepareLegacyPass() {
541   return new AMDGPULateCodeGenPrepareLegacy();
542 }
543