xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp (revision 0cb80c4f00689ca00a85e1f38bc6ae9dd0bf980e)
1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction
11 /// selection.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/UniformityAnalysis.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstVisitor.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/KnownBits.h"
26 #include "llvm/Transforms/Utils/Local.h"
27 
28 #define DEBUG_TYPE "amdgpu-late-codegenprepare"
29 
30 using namespace llvm;
31 
32 // Scalar load widening needs running after load-store-vectorizer as that pass
33 // doesn't handle overlapping cases. In addition, this pass enhances the
34 // widening to handle cases where scalar sub-dword loads are naturally aligned
35 // only but not dword aligned.
36 static cl::opt<bool>
37     WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
38                cl::desc("Widen sub-dword constant address space loads in "
39                         "AMDGPULateCodeGenPrepare"),
40                cl::ReallyHidden, cl::init(true));
41 
42 namespace {
43 
44 class AMDGPULateCodeGenPrepare
45     : public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
46   Module *Mod = nullptr;
47   const DataLayout *DL = nullptr;
48   const GCNSubtarget &ST;
49 
50   AssumptionCache *AC = nullptr;
51   UniformityInfo *UA = nullptr;
52 
53   SmallVector<WeakTrackingVH, 8> DeadInsts;
54 
55 public:
56   AMDGPULateCodeGenPrepare(Module &M, const GCNSubtarget &ST,
57                            AssumptionCache *AC, UniformityInfo *UA)
58       : Mod(&M), DL(&M.getDataLayout()), ST(ST), AC(AC), UA(UA) {}
59   bool run(Function &F);
60   bool visitInstruction(Instruction &) { return false; }
61 
62   // Check if the specified value is at least DWORD aligned.
63   bool isDWORDAligned(const Value *V) const {
64     KnownBits Known = computeKnownBits(V, *DL, 0, AC);
65     return Known.countMinTrailingZeros() >= 2;
66   }
67 
68   bool canWidenScalarExtLoad(LoadInst &LI) const;
69   bool visitLoadInst(LoadInst &LI);
70 };
71 
72 using ValueToValueMap = DenseMap<const Value *, Value *>;
73 
74 class LiveRegOptimizer {
75 private:
76   Module *Mod = nullptr;
77   const DataLayout *DL = nullptr;
78   const GCNSubtarget *ST;
79   /// The scalar type to convert to
80   Type *ConvertToScalar;
81   /// The set of visited Instructions
82   SmallPtrSet<Instruction *, 4> Visited;
83   /// Map of Value -> Converted Value
84   ValueToValueMap ValMap;
85   /// Map of containing conversions from Optimal Type -> Original Type per BB.
86   DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
87 
88 public:
89   /// Calculate the and \p return  the type to convert to given a problematic \p
90   /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
91   Type *calculateConvertType(Type *OriginalType);
92   /// Convert the virtual register defined by \p V to the compatible vector of
93   /// legal type
94   Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
95   /// Convert the virtual register defined by \p V back to the original type \p
96   /// ConvertType, stripping away the MSBs in cases where there was an imperfect
97   /// fit (e.g. v2i32 -> v7i8)
98   Value *convertFromOptType(Type *ConvertType, Instruction *V,
99                             BasicBlock::iterator &InstPt,
100                             BasicBlock *InsertBlock);
101   /// Check for problematic PHI nodes or cross-bb values based on the value
102   /// defined by \p I, and coerce to legal types if necessary. For problematic
103   /// PHI node, we coerce all incoming values in a single invocation.
104   bool optimizeLiveType(Instruction *I,
105                         SmallVectorImpl<WeakTrackingVH> &DeadInsts);
106 
107   // Whether or not the type should be replaced to avoid inefficient
108   // legalization code
109   bool shouldReplace(Type *ITy) {
110     FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
111     if (!VTy)
112       return false;
113 
114     const auto *TLI = ST->getTargetLowering();
115 
116     Type *EltTy = VTy->getElementType();
117     // If the element size is not less than the convert to scalar size, then we
118     // can't do any bit packing
119     if (!EltTy->isIntegerTy() ||
120         EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
121       return false;
122 
123     // Only coerce illegal types
124     TargetLoweringBase::LegalizeKind LK =
125         TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
126     return LK.first != TargetLoweringBase::TypeLegal;
127   }
128 
129   LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
130     DL = &Mod->getDataLayout();
131     ConvertToScalar = Type::getInt32Ty(Mod->getContext());
132   }
133 };
134 
135 } // end anonymous namespace
136 
137 bool AMDGPULateCodeGenPrepare::run(Function &F) {
138   // "Optimize" the virtual regs that cross basic block boundaries. When
139   // building the SelectionDAG, vectors of illegal types that cross basic blocks
140   // will be scalarized and widened, with each scalar living in its
141   // own register. To work around this, this optimization converts the
142   // vectors to equivalent vectors of legal type (which are converted back
143   // before uses in subsequent blocks), to pack the bits into fewer physical
144   // registers (used in CopyToReg/CopyFromReg pairs).
145   LiveRegOptimizer LRO(Mod, &ST);
146 
147   bool Changed = false;
148 
149   bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads();
150 
151   for (auto &BB : reverse(F))
152     for (Instruction &I : make_early_inc_range(reverse(BB))) {
153       Changed |= !HasScalarSubwordLoads && visit(I);
154       Changed |= LRO.optimizeLiveType(&I, DeadInsts);
155     }
156 
157   RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts);
158   return Changed;
159 }
160 
161 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
162   assert(OriginalType->getScalarSizeInBits() <=
163          ConvertToScalar->getScalarSizeInBits());
164 
165   FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
166 
167   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
168   TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
169   unsigned ConvertEltCount =
170       (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
171 
172   if (OriginalSize <= ConvertScalarSize)
173     return IntegerType::get(Mod->getContext(), ConvertScalarSize);
174 
175   return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
176                          ConvertEltCount, false);
177 }
178 
179 Value *LiveRegOptimizer::convertToOptType(Instruction *V,
180                                           BasicBlock::iterator &InsertPt) {
181   FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
182   Type *NewTy = calculateConvertType(V->getType());
183 
184   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
185   TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
186 
187   IRBuilder<> Builder(V->getParent(), InsertPt);
188   // If there is a bitsize match, we can fit the old vector into a new vector of
189   // desired type.
190   if (OriginalSize == NewSize)
191     return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
192 
193   // If there is a bitsize mismatch, we must use a wider vector.
194   assert(NewSize > OriginalSize);
195   uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
196 
197   SmallVector<int, 8> ShuffleMask;
198   uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
199   for (unsigned I = 0; I < OriginalElementCount; I++)
200     ShuffleMask.push_back(I);
201 
202   for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
203     ShuffleMask.push_back(OriginalElementCount);
204 
205   Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
206   return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
207 }
208 
209 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
210                                             BasicBlock::iterator &InsertPt,
211                                             BasicBlock *InsertBB) {
212   FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
213 
214   TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
215   TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
216 
217   IRBuilder<> Builder(InsertBB, InsertPt);
218   // If there is a bitsize match, we simply convert back to the original type.
219   if (OriginalSize == NewSize)
220     return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
221 
222   // If there is a bitsize mismatch, then we must have used a wider value to
223   // hold the bits.
224   assert(OriginalSize > NewSize);
225   // For wide scalars, we can just truncate the value.
226   if (!V->getType()->isVectorTy()) {
227     Instruction *Trunc = cast<Instruction>(
228         Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
229     return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
230   }
231 
232   // For wider vectors, we must strip the MSBs to convert back to the original
233   // type.
234   VectorType *ExpandedVT = VectorType::get(
235       Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
236       (OriginalSize / NewVTy->getScalarSizeInBits()), false);
237   Instruction *Converted =
238       cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
239 
240   unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
241   SmallVector<int, 8> ShuffleMask(NarrowElementCount);
242   std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
243 
244   return Builder.CreateShuffleVector(Converted, ShuffleMask);
245 }
246 
247 bool LiveRegOptimizer::optimizeLiveType(
248     Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
249   SmallVector<Instruction *, 4> Worklist;
250   SmallPtrSet<PHINode *, 4> PhiNodes;
251   SmallPtrSet<Instruction *, 4> Defs;
252   SmallPtrSet<Instruction *, 4> Uses;
253 
254   Worklist.push_back(cast<Instruction>(I));
255   while (!Worklist.empty()) {
256     Instruction *II = Worklist.pop_back_val();
257 
258     if (!Visited.insert(II).second)
259       continue;
260 
261     if (!shouldReplace(II->getType()))
262       continue;
263 
264     if (PHINode *Phi = dyn_cast<PHINode>(II)) {
265       PhiNodes.insert(Phi);
266       // Collect all the incoming values of problematic PHI nodes.
267       for (Value *V : Phi->incoming_values()) {
268         // Repeat the collection process for newly found PHI nodes.
269         if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
270           if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
271             Worklist.push_back(OpPhi);
272           continue;
273         }
274 
275         Instruction *IncInst = dyn_cast<Instruction>(V);
276         // Other incoming value types (e.g. vector literals) are unhandled
277         if (!IncInst && !isa<ConstantAggregateZero>(V))
278           return false;
279 
280         // Collect all other incoming values for coercion.
281         if (IncInst)
282           Defs.insert(IncInst);
283       }
284     }
285 
286     // Collect all relevant uses.
287     for (User *V : II->users()) {
288       // Repeat the collection process for problematic PHI nodes.
289       if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
290         if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
291           Worklist.push_back(OpPhi);
292         continue;
293       }
294 
295       Instruction *UseInst = cast<Instruction>(V);
296       // Collect all uses of PHINodes and any use the crosses BB boundaries.
297       if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
298         Uses.insert(UseInst);
299         if (!isa<PHINode>(II))
300           Defs.insert(II);
301       }
302     }
303   }
304 
305   // Coerce and track the defs.
306   for (Instruction *D : Defs) {
307     if (!ValMap.contains(D)) {
308       BasicBlock::iterator InsertPt = std::next(D->getIterator());
309       Value *ConvertVal = convertToOptType(D, InsertPt);
310       assert(ConvertVal);
311       ValMap[D] = ConvertVal;
312     }
313   }
314 
315   // Construct new-typed PHI nodes.
316   for (PHINode *Phi : PhiNodes) {
317     ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
318                                   Phi->getNumIncomingValues(),
319                                   Phi->getName() + ".tc", Phi->getIterator());
320   }
321 
322   // Connect all the PHI nodes with their new incoming values.
323   for (PHINode *Phi : PhiNodes) {
324     PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
325     bool MissingIncVal = false;
326     for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
327       Value *IncVal = Phi->getIncomingValue(I);
328       if (isa<ConstantAggregateZero>(IncVal)) {
329         Type *NewType = calculateConvertType(Phi->getType());
330         NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
331                             Phi->getIncomingBlock(I));
332       } else if (Value *Val = ValMap.lookup(IncVal))
333         NewPhi->addIncoming(Val, Phi->getIncomingBlock(I));
334       else
335         MissingIncVal = true;
336     }
337     if (MissingIncVal) {
338       Value *DeadVal = ValMap[Phi];
339       // The coercion chain of the PHI is broken. Delete the Phi
340       // from the ValMap and any connected / user Phis.
341       SmallVector<Value *, 4> PHIWorklist;
342       SmallPtrSet<Value *, 4> VisitedPhis;
343       PHIWorklist.push_back(DeadVal);
344       while (!PHIWorklist.empty()) {
345         Value *NextDeadValue = PHIWorklist.pop_back_val();
346         VisitedPhis.insert(NextDeadValue);
347         auto OriginalPhi =
348             std::find_if(PhiNodes.begin(), PhiNodes.end(),
349                          [this, &NextDeadValue](PHINode *CandPhi) {
350                            return ValMap[CandPhi] == NextDeadValue;
351                          });
352         // This PHI may have already been removed from maps when
353         // unwinding a previous Phi
354         if (OriginalPhi != PhiNodes.end())
355           ValMap.erase(*OriginalPhi);
356 
357         DeadInsts.emplace_back(cast<Instruction>(NextDeadValue));
358 
359         for (User *U : NextDeadValue->users()) {
360           if (!VisitedPhis.contains(cast<PHINode>(U)))
361             PHIWorklist.push_back(U);
362         }
363       }
364     } else {
365       DeadInsts.emplace_back(cast<Instruction>(Phi));
366     }
367   }
368   // Coerce back to the original type and replace the uses.
369   for (Instruction *U : Uses) {
370     // Replace all converted operands for a use.
371     for (auto [OpIdx, Op] : enumerate(U->operands())) {
372       if (ValMap.contains(Op) && ValMap[Op]) {
373         Value *NewVal = nullptr;
374         if (BBUseValMap.contains(U->getParent()) &&
375             BBUseValMap[U->getParent()].contains(ValMap[Op]))
376           NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
377         else {
378           BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
379           // We may pick up ops that were previously converted for users in
380           // other blocks. If there is an originally typed definition of the Op
381           // already in this block, simply reuse it.
382           if (isa<Instruction>(Op) && !isa<PHINode>(Op) &&
383               U->getParent() == cast<Instruction>(Op)->getParent()) {
384             NewVal = Op;
385           } else {
386             NewVal =
387                 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
388                                    InsertPt, U->getParent());
389             BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
390           }
391         }
392         assert(NewVal);
393         U->setOperand(OpIdx, NewVal);
394       }
395     }
396   }
397 
398   return true;
399 }
400 
401 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
402   unsigned AS = LI.getPointerAddressSpace();
403   // Skip non-constant address space.
404   if (AS != AMDGPUAS::CONSTANT_ADDRESS &&
405       AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT)
406     return false;
407   // Skip non-simple loads.
408   if (!LI.isSimple())
409     return false;
410   Type *Ty = LI.getType();
411   // Skip aggregate types.
412   if (Ty->isAggregateType())
413     return false;
414   unsigned TySize = DL->getTypeStoreSize(Ty);
415   // Only handle sub-DWORD loads.
416   if (TySize >= 4)
417     return false;
418   // That load must be at least naturally aligned.
419   if (LI.getAlign() < DL->getABITypeAlign(Ty))
420     return false;
421   // It should be uniform, i.e. a scalar load.
422   return UA->isUniform(&LI);
423 }
424 
425 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
426   if (!WidenLoads)
427     return false;
428 
429   // Skip if that load is already aligned on DWORD at least as it's handled in
430   // SDAG.
431   if (LI.getAlign() >= 4)
432     return false;
433 
434   if (!canWidenScalarExtLoad(LI))
435     return false;
436 
437   int64_t Offset = 0;
438   auto *Base =
439       GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL);
440   // If that base is not DWORD aligned, it's not safe to perform the following
441   // transforms.
442   if (!isDWORDAligned(Base))
443     return false;
444 
445   int64_t Adjust = Offset & 0x3;
446   if (Adjust == 0) {
447     // With a zero adjust, the original alignment could be promoted with a
448     // better one.
449     LI.setAlignment(Align(4));
450     return true;
451   }
452 
453   IRBuilder<> IRB(&LI);
454   IRB.SetCurrentDebugLocation(LI.getDebugLoc());
455 
456   unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType());
457   auto *IntNTy = Type::getIntNTy(LI.getContext(), LdBits);
458 
459   auto *NewPtr = IRB.CreateConstGEP1_64(
460       IRB.getInt8Ty(),
461       IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()),
462       Offset - Adjust);
463 
464   LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4));
465   NewLd->copyMetadata(LI);
466   NewLd->setMetadata(LLVMContext::MD_range, nullptr);
467 
468   unsigned ShAmt = Adjust * 8;
469   auto *NewVal = IRB.CreateBitCast(
470       IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType());
471   LI.replaceAllUsesWith(NewVal);
472   DeadInsts.emplace_back(&LI);
473 
474   return true;
475 }
476 
477 PreservedAnalyses
478 AMDGPULateCodeGenPreparePass::run(Function &F, FunctionAnalysisManager &FAM) {
479   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
480 
481   AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(F);
482   UniformityInfo &UI = FAM.getResult<UniformityInfoAnalysis>(F);
483 
484   AMDGPULateCodeGenPrepare Impl(*F.getParent(), ST, &AC, &UI);
485 
486   bool Changed = Impl.run(F);
487 
488   PreservedAnalyses PA = PreservedAnalyses::none();
489   if (!Changed)
490     return PA;
491   PA.preserveSet<CFGAnalyses>();
492   return PA;
493 }
494 
495 class AMDGPULateCodeGenPrepareLegacy : public FunctionPass {
496 public:
497   static char ID;
498 
499   AMDGPULateCodeGenPrepareLegacy() : FunctionPass(ID) {}
500 
501   StringRef getPassName() const override {
502     return "AMDGPU IR late optimizations";
503   }
504 
505   void getAnalysisUsage(AnalysisUsage &AU) const override {
506     AU.addRequired<TargetPassConfig>();
507     AU.addRequired<AssumptionCacheTracker>();
508     AU.addRequired<UniformityInfoWrapperPass>();
509     AU.setPreservesAll();
510   }
511 
512   bool runOnFunction(Function &F) override;
513 };
514 
515 bool AMDGPULateCodeGenPrepareLegacy::runOnFunction(Function &F) {
516   if (skipFunction(F))
517     return false;
518 
519   const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
520   const TargetMachine &TM = TPC.getTM<TargetMachine>();
521   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
522 
523   AssumptionCache &AC =
524       getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
525   UniformityInfo &UI =
526       getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
527 
528   AMDGPULateCodeGenPrepare Impl(*F.getParent(), ST, &AC, &UI);
529 
530   return Impl.run(F);
531 }
532 
533 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE,
534                       "AMDGPU IR late optimizations", false, false)
535 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
536 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
537 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
538 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE,
539                     "AMDGPU IR late optimizations", false, false)
540 
541 char AMDGPULateCodeGenPrepareLegacy::ID = 0;
542 
543 FunctionPass *llvm::createAMDGPULateCodeGenPrepareLegacyPass() {
544   return new AMDGPULateCodeGenPrepareLegacy();
545 }
546