xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp (revision 1612e4a3510982692f22e3f8190fc7c977185cbe)
1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction
11 /// selection.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/UniformityAnalysis.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstVisitor.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/KnownBits.h"
26 #include "llvm/Transforms/Utils/Local.h"
27 
28 #define DEBUG_TYPE "amdgpu-late-codegenprepare"
29 
30 using namespace llvm;
31 
32 // Scalar load widening needs running after load-store-vectorizer as that pass
33 // doesn't handle overlapping cases. In addition, this pass enhances the
34 // widening to handle cases where scalar sub-dword loads are naturally aligned
35 // only but not dword aligned.
36 static cl::opt<bool>
37     WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
38                cl::desc("Widen sub-dword constant address space loads in "
39                         "AMDGPULateCodeGenPrepare"),
40                cl::ReallyHidden, cl::init(true));
41 
42 namespace {
43 
44 class AMDGPULateCodeGenPrepare
45     : public FunctionPass,
46       public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
47   Module *Mod = nullptr;
48   const DataLayout *DL = nullptr;
49 
50   AssumptionCache *AC = nullptr;
51   UniformityInfo *UA = nullptr;
52 
53   SmallVector<WeakTrackingVH, 8> DeadInsts;
54 
55 public:
56   static char ID;
57 
58   AMDGPULateCodeGenPrepare() : FunctionPass(ID) {}
59 
60   StringRef getPassName() const override {
61     return "AMDGPU IR late optimizations";
62   }
63 
64   void getAnalysisUsage(AnalysisUsage &AU) const override {
65     AU.addRequired<TargetPassConfig>();
66     AU.addRequired<AssumptionCacheTracker>();
67     AU.addRequired<UniformityInfoWrapperPass>();
68     AU.setPreservesAll();
69   }
70 
71   bool doInitialization(Module &M) override;
72   bool runOnFunction(Function &F) override;
73 
74   bool visitInstruction(Instruction &) { return false; }
75 
76   // Check if the specified value is at least DWORD aligned.
77   bool isDWORDAligned(const Value *V) const {
78     KnownBits Known = computeKnownBits(V, *DL, 0, AC);
79     return Known.countMinTrailingZeros() >= 2;
80   }
81 
82   bool canWidenScalarExtLoad(LoadInst &LI) const;
83   bool visitLoadInst(LoadInst &LI);
84 };
85 
86 using ValueToValueMap = DenseMap<const Value *, Value *>;
87 
88 class LiveRegOptimizer {
89 private:
90   Module *Mod = nullptr;
91   const DataLayout *DL = nullptr;
92   const GCNSubtarget *ST;
93   /// The scalar type to convert to
94   Type *ConvertToScalar;
95   /// The set of visited Instructions
96   SmallPtrSet<Instruction *, 4> Visited;
97   /// Map of Value -> Converted Value
98   ValueToValueMap ValMap;
99   /// Map of containing conversions from Optimal Type -> Original Type per BB.
100   DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101 
102 public:
103   /// Calculate the and \p return  the type to convert to given a problematic \p
104   /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105   Type *calculateConvertType(Type *OriginalType);
106   /// Convert the virtual register defined by \p V to the compatible vector of
107   /// legal type
108   Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
109   /// Convert the virtual register defined by \p V back to the original type \p
110   /// ConvertType, stripping away the MSBs in cases where there was an imperfect
111   /// fit (e.g. v2i32 -> v7i8)
112   Value *convertFromOptType(Type *ConvertType, Instruction *V,
113                             BasicBlock::iterator &InstPt,
114                             BasicBlock *InsertBlock);
115   /// Check for problematic PHI nodes or cross-bb values based on the value
116   /// defined by \p I, and coerce to legal types if necessary. For problematic
117   /// PHI node, we coerce all incoming values in a single invocation.
118   bool optimizeLiveType(Instruction *I,
119                         SmallVectorImpl<WeakTrackingVH> &DeadInsts);
120 
121   // Whether or not the type should be replaced to avoid inefficient
122   // legalization code
123   bool shouldReplace(Type *ITy) {
124     FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
125     if (!VTy)
126       return false;
127 
128     auto TLI = ST->getTargetLowering();
129 
130     Type *EltTy = VTy->getElementType();
131     // If the element size is not less than the convert to scalar size, then we
132     // can't do any bit packing
133     if (!EltTy->isIntegerTy() ||
134         EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
135       return false;
136 
137     // Only coerce illegal types
138     TargetLoweringBase::LegalizeKind LK =
139         TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
140     return LK.first != TargetLoweringBase::TypeLegal;
141   }
142 
143   LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
144     DL = &Mod->getDataLayout();
145     ConvertToScalar = Type::getInt32Ty(Mod->getContext());
146   }
147 };
148 
149 } // end anonymous namespace
150 
151 bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) {
152   Mod = &M;
153   DL = &Mod->getDataLayout();
154   return false;
155 }
156 
157 bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
158   if (skipFunction(F))
159     return false;
160 
161   const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
162   const TargetMachine &TM = TPC.getTM<TargetMachine>();
163   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
164 
165   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
166   UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
167 
168   // "Optimize" the virtual regs that cross basic block boundaries. When
169   // building the SelectionDAG, vectors of illegal types that cross basic blocks
170   // will be scalarized and widened, with each scalar living in its
171   // own register. To work around this, this optimization converts the
172   // vectors to equivalent vectors of legal type (which are converted back
173   // before uses in subsequent blocks), to pack the bits into fewer physical
174   // registers (used in CopyToReg/CopyFromReg pairs).
175   LiveRegOptimizer LRO(Mod, &ST);
176 
177   bool Changed = false;
178 
179   bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads();
180 
181   for (auto &BB : reverse(F))
182     for (Instruction &I : make_early_inc_range(reverse(BB))) {
183       Changed |= !HasScalarSubwordLoads && visit(I);
184       Changed |= LRO.optimizeLiveType(&I, DeadInsts);
185     }
186 
187   RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts);
188   return Changed;
189 }
190 
191 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
192   assert(OriginalType->getScalarSizeInBits() <=
193          ConvertToScalar->getScalarSizeInBits());
194 
195   FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
196 
197   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
198   TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
199   unsigned ConvertEltCount =
200       (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
201 
202   if (OriginalSize <= ConvertScalarSize)
203     return IntegerType::get(Mod->getContext(), ConvertScalarSize);
204 
205   return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
206                          ConvertEltCount, false);
207 }
208 
209 Value *LiveRegOptimizer::convertToOptType(Instruction *V,
210                                           BasicBlock::iterator &InsertPt) {
211   FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
212   Type *NewTy = calculateConvertType(V->getType());
213 
214   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
215   TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
216 
217   IRBuilder<> Builder(V->getParent(), InsertPt);
218   // If there is a bitsize match, we can fit the old vector into a new vector of
219   // desired type.
220   if (OriginalSize == NewSize)
221     return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
222 
223   // If there is a bitsize mismatch, we must use a wider vector.
224   assert(NewSize > OriginalSize);
225   uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
226 
227   SmallVector<int, 8> ShuffleMask;
228   uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
229   for (unsigned I = 0; I < OriginalElementCount; I++)
230     ShuffleMask.push_back(I);
231 
232   for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
233     ShuffleMask.push_back(OriginalElementCount);
234 
235   Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
236   return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
237 }
238 
239 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
240                                             BasicBlock::iterator &InsertPt,
241                                             BasicBlock *InsertBB) {
242   FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
243 
244   TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
245   TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
246 
247   IRBuilder<> Builder(InsertBB, InsertPt);
248   // If there is a bitsize match, we simply convert back to the original type.
249   if (OriginalSize == NewSize)
250     return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
251 
252   // If there is a bitsize mismatch, then we must have used a wider value to
253   // hold the bits.
254   assert(OriginalSize > NewSize);
255   // For wide scalars, we can just truncate the value.
256   if (!V->getType()->isVectorTy()) {
257     Instruction *Trunc = cast<Instruction>(
258         Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
259     return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
260   }
261 
262   // For wider vectors, we must strip the MSBs to convert back to the original
263   // type.
264   VectorType *ExpandedVT = VectorType::get(
265       Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
266       (OriginalSize / NewVTy->getScalarSizeInBits()), false);
267   Instruction *Converted =
268       cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
269 
270   unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
271   SmallVector<int, 8> ShuffleMask(NarrowElementCount);
272   std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
273 
274   return Builder.CreateShuffleVector(Converted, ShuffleMask);
275 }
276 
277 bool LiveRegOptimizer::optimizeLiveType(
278     Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
279   SmallVector<Instruction *, 4> Worklist;
280   SmallPtrSet<PHINode *, 4> PhiNodes;
281   SmallPtrSet<Instruction *, 4> Defs;
282   SmallPtrSet<Instruction *, 4> Uses;
283 
284   Worklist.push_back(cast<Instruction>(I));
285   while (!Worklist.empty()) {
286     Instruction *II = Worklist.pop_back_val();
287 
288     if (!Visited.insert(II).second)
289       continue;
290 
291     if (!shouldReplace(II->getType()))
292       continue;
293 
294     if (PHINode *Phi = dyn_cast<PHINode>(II)) {
295       PhiNodes.insert(Phi);
296       // Collect all the incoming values of problematic PHI nodes.
297       for (Value *V : Phi->incoming_values()) {
298         // Repeat the collection process for newly found PHI nodes.
299         if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
300           if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
301             Worklist.push_back(OpPhi);
302           continue;
303         }
304 
305         Instruction *IncInst = dyn_cast<Instruction>(V);
306         // Other incoming value types (e.g. vector literals) are unhandled
307         if (!IncInst && !isa<ConstantAggregateZero>(V))
308           return false;
309 
310         // Collect all other incoming values for coercion.
311         if (IncInst)
312           Defs.insert(IncInst);
313       }
314     }
315 
316     // Collect all relevant uses.
317     for (User *V : II->users()) {
318       // Repeat the collection process for problematic PHI nodes.
319       if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
320         if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
321           Worklist.push_back(OpPhi);
322         continue;
323       }
324 
325       Instruction *UseInst = cast<Instruction>(V);
326       // Collect all uses of PHINodes and any use the crosses BB boundaries.
327       if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
328         Uses.insert(UseInst);
329         if (!Defs.count(II) && !isa<PHINode>(II)) {
330           Defs.insert(II);
331         }
332       }
333     }
334   }
335 
336   // Coerce and track the defs.
337   for (Instruction *D : Defs) {
338     if (!ValMap.contains(D)) {
339       BasicBlock::iterator InsertPt = std::next(D->getIterator());
340       Value *ConvertVal = convertToOptType(D, InsertPt);
341       assert(ConvertVal);
342       ValMap[D] = ConvertVal;
343     }
344   }
345 
346   // Construct new-typed PHI nodes.
347   for (PHINode *Phi : PhiNodes) {
348     ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
349                                   Phi->getNumIncomingValues(),
350                                   Phi->getName() + ".tc", Phi->getIterator());
351   }
352 
353   // Connect all the PHI nodes with their new incoming values.
354   for (PHINode *Phi : PhiNodes) {
355     PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
356     bool MissingIncVal = false;
357     for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
358       Value *IncVal = Phi->getIncomingValue(I);
359       if (isa<ConstantAggregateZero>(IncVal)) {
360         Type *NewType = calculateConvertType(Phi->getType());
361         NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
362                             Phi->getIncomingBlock(I));
363       } else if (ValMap.contains(IncVal))
364         NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
365       else
366         MissingIncVal = true;
367     }
368     Instruction *DeadInst = Phi;
369     if (MissingIncVal) {
370       DeadInst = cast<Instruction>(ValMap[Phi]);
371       // Do not use the dead phi
372       ValMap[Phi] = Phi;
373     }
374     DeadInsts.emplace_back(DeadInst);
375   }
376   // Coerce back to the original type and replace the uses.
377   for (Instruction *U : Uses) {
378     // Replace all converted operands for a use.
379     for (auto [OpIdx, Op] : enumerate(U->operands())) {
380       if (ValMap.contains(Op)) {
381         Value *NewVal = nullptr;
382         if (BBUseValMap.contains(U->getParent()) &&
383             BBUseValMap[U->getParent()].contains(ValMap[Op]))
384           NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
385         else {
386           BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
387           NewVal =
388               convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
389                                  InsertPt, U->getParent());
390           BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
391         }
392         assert(NewVal);
393         U->setOperand(OpIdx, NewVal);
394       }
395     }
396   }
397 
398   return true;
399 }
400 
401 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
402   unsigned AS = LI.getPointerAddressSpace();
403   // Skip non-constant address space.
404   if (AS != AMDGPUAS::CONSTANT_ADDRESS &&
405       AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT)
406     return false;
407   // Skip non-simple loads.
408   if (!LI.isSimple())
409     return false;
410   Type *Ty = LI.getType();
411   // Skip aggregate types.
412   if (Ty->isAggregateType())
413     return false;
414   unsigned TySize = DL->getTypeStoreSize(Ty);
415   // Only handle sub-DWORD loads.
416   if (TySize >= 4)
417     return false;
418   // That load must be at least naturally aligned.
419   if (LI.getAlign() < DL->getABITypeAlign(Ty))
420     return false;
421   // It should be uniform, i.e. a scalar load.
422   return UA->isUniform(&LI);
423 }
424 
425 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
426   if (!WidenLoads)
427     return false;
428 
429   // Skip if that load is already aligned on DWORD at least as it's handled in
430   // SDAG.
431   if (LI.getAlign() >= 4)
432     return false;
433 
434   if (!canWidenScalarExtLoad(LI))
435     return false;
436 
437   int64_t Offset = 0;
438   auto *Base =
439       GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL);
440   // If that base is not DWORD aligned, it's not safe to perform the following
441   // transforms.
442   if (!isDWORDAligned(Base))
443     return false;
444 
445   int64_t Adjust = Offset & 0x3;
446   if (Adjust == 0) {
447     // With a zero adjust, the original alignment could be promoted with a
448     // better one.
449     LI.setAlignment(Align(4));
450     return true;
451   }
452 
453   IRBuilder<> IRB(&LI);
454   IRB.SetCurrentDebugLocation(LI.getDebugLoc());
455 
456   unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType());
457   auto IntNTy = Type::getIntNTy(LI.getContext(), LdBits);
458 
459   auto *NewPtr = IRB.CreateConstGEP1_64(
460       IRB.getInt8Ty(),
461       IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()),
462       Offset - Adjust);
463 
464   LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4));
465   NewLd->copyMetadata(LI);
466   NewLd->setMetadata(LLVMContext::MD_range, nullptr);
467 
468   unsigned ShAmt = Adjust * 8;
469   auto *NewVal = IRB.CreateBitCast(
470       IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType());
471   LI.replaceAllUsesWith(NewVal);
472   DeadInsts.emplace_back(&LI);
473 
474   return true;
475 }
476 
477 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
478                       "AMDGPU IR late optimizations", false, false)
479 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
480 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
481 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
482 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
483                     "AMDGPU IR late optimizations", false, false)
484 
485 char AMDGPULateCodeGenPrepare::ID = 0;
486 
487 FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() {
488   return new AMDGPULateCodeGenPrepare();
489 }
490