1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction 11 /// selection. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "AMDGPU.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "llvm/Analysis/AssumptionCache.h" 18 #include "llvm/Analysis/UniformityAnalysis.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/InstVisitor.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/KnownBits.h" 26 #include "llvm/Transforms/Utils/Local.h" 27 28 #define DEBUG_TYPE "amdgpu-late-codegenprepare" 29 30 using namespace llvm; 31 32 // Scalar load widening needs running after load-store-vectorizer as that pass 33 // doesn't handle overlapping cases. In addition, this pass enhances the 34 // widening to handle cases where scalar sub-dword loads are naturally aligned 35 // only but not dword aligned. 36 static cl::opt<bool> 37 WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads", 38 cl::desc("Widen sub-dword constant address space loads in " 39 "AMDGPULateCodeGenPrepare"), 40 cl::ReallyHidden, cl::init(true)); 41 42 namespace { 43 44 class AMDGPULateCodeGenPrepare 45 : public InstVisitor<AMDGPULateCodeGenPrepare, bool> { 46 Module *Mod = nullptr; 47 const DataLayout *DL = nullptr; 48 const GCNSubtarget &ST; 49 50 AssumptionCache *AC = nullptr; 51 UniformityInfo *UA = nullptr; 52 53 SmallVector<WeakTrackingVH, 8> DeadInsts; 54 55 public: 56 AMDGPULateCodeGenPrepare(Module &M, const GCNSubtarget &ST, 57 AssumptionCache *AC, UniformityInfo *UA) 58 : Mod(&M), DL(&M.getDataLayout()), ST(ST), AC(AC), UA(UA) {} 59 bool run(Function &F); 60 bool visitInstruction(Instruction &) { return false; } 61 62 // Check if the specified value is at least DWORD aligned. 63 bool isDWORDAligned(const Value *V) const { 64 KnownBits Known = computeKnownBits(V, *DL, 0, AC); 65 return Known.countMinTrailingZeros() >= 2; 66 } 67 68 bool canWidenScalarExtLoad(LoadInst &LI) const; 69 bool visitLoadInst(LoadInst &LI); 70 }; 71 72 using ValueToValueMap = DenseMap<const Value *, Value *>; 73 74 class LiveRegOptimizer { 75 private: 76 Module *Mod = nullptr; 77 const DataLayout *DL = nullptr; 78 const GCNSubtarget *ST; 79 /// The scalar type to convert to 80 Type *ConvertToScalar; 81 /// The set of visited Instructions 82 SmallPtrSet<Instruction *, 4> Visited; 83 /// Map of Value -> Converted Value 84 ValueToValueMap ValMap; 85 /// Map of containing conversions from Optimal Type -> Original Type per BB. 86 DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap; 87 88 public: 89 /// Calculate the and \p return the type to convert to given a problematic \p 90 /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32). 91 Type *calculateConvertType(Type *OriginalType); 92 /// Convert the virtual register defined by \p V to the compatible vector of 93 /// legal type 94 Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt); 95 /// Convert the virtual register defined by \p V back to the original type \p 96 /// ConvertType, stripping away the MSBs in cases where there was an imperfect 97 /// fit (e.g. v2i32 -> v7i8) 98 Value *convertFromOptType(Type *ConvertType, Instruction *V, 99 BasicBlock::iterator &InstPt, 100 BasicBlock *InsertBlock); 101 /// Check for problematic PHI nodes or cross-bb values based on the value 102 /// defined by \p I, and coerce to legal types if necessary. For problematic 103 /// PHI node, we coerce all incoming values in a single invocation. 104 bool optimizeLiveType(Instruction *I, 105 SmallVectorImpl<WeakTrackingVH> &DeadInsts); 106 107 // Whether or not the type should be replaced to avoid inefficient 108 // legalization code 109 bool shouldReplace(Type *ITy) { 110 FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy); 111 if (!VTy) 112 return false; 113 114 const auto *TLI = ST->getTargetLowering(); 115 116 Type *EltTy = VTy->getElementType(); 117 // If the element size is not less than the convert to scalar size, then we 118 // can't do any bit packing 119 if (!EltTy->isIntegerTy() || 120 EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits()) 121 return false; 122 123 // Only coerce illegal types 124 TargetLoweringBase::LegalizeKind LK = 125 TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false)); 126 return LK.first != TargetLoweringBase::TypeLegal; 127 } 128 129 LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) { 130 DL = &Mod->getDataLayout(); 131 ConvertToScalar = Type::getInt32Ty(Mod->getContext()); 132 } 133 }; 134 135 } // end anonymous namespace 136 137 bool AMDGPULateCodeGenPrepare::run(Function &F) { 138 // "Optimize" the virtual regs that cross basic block boundaries. When 139 // building the SelectionDAG, vectors of illegal types that cross basic blocks 140 // will be scalarized and widened, with each scalar living in its 141 // own register. To work around this, this optimization converts the 142 // vectors to equivalent vectors of legal type (which are converted back 143 // before uses in subsequent blocks), to pack the bits into fewer physical 144 // registers (used in CopyToReg/CopyFromReg pairs). 145 LiveRegOptimizer LRO(Mod, &ST); 146 147 bool Changed = false; 148 149 bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads(); 150 151 for (auto &BB : reverse(F)) 152 for (Instruction &I : make_early_inc_range(reverse(BB))) { 153 Changed |= !HasScalarSubwordLoads && visit(I); 154 Changed |= LRO.optimizeLiveType(&I, DeadInsts); 155 } 156 157 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts); 158 return Changed; 159 } 160 161 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) { 162 assert(OriginalType->getScalarSizeInBits() <= 163 ConvertToScalar->getScalarSizeInBits()); 164 165 FixedVectorType *VTy = cast<FixedVectorType>(OriginalType); 166 167 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 168 TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar); 169 unsigned ConvertEltCount = 170 (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize; 171 172 if (OriginalSize <= ConvertScalarSize) 173 return IntegerType::get(Mod->getContext(), ConvertScalarSize); 174 175 return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize), 176 ConvertEltCount, false); 177 } 178 179 Value *LiveRegOptimizer::convertToOptType(Instruction *V, 180 BasicBlock::iterator &InsertPt) { 181 FixedVectorType *VTy = cast<FixedVectorType>(V->getType()); 182 Type *NewTy = calculateConvertType(V->getType()); 183 184 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 185 TypeSize NewSize = DL->getTypeSizeInBits(NewTy); 186 187 IRBuilder<> Builder(V->getParent(), InsertPt); 188 // If there is a bitsize match, we can fit the old vector into a new vector of 189 // desired type. 190 if (OriginalSize == NewSize) 191 return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"); 192 193 // If there is a bitsize mismatch, we must use a wider vector. 194 assert(NewSize > OriginalSize); 195 uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits(); 196 197 SmallVector<int, 8> ShuffleMask; 198 uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue(); 199 for (unsigned I = 0; I < OriginalElementCount; I++) 200 ShuffleMask.push_back(I); 201 202 for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++) 203 ShuffleMask.push_back(OriginalElementCount); 204 205 Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask); 206 return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"); 207 } 208 209 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V, 210 BasicBlock::iterator &InsertPt, 211 BasicBlock *InsertBB) { 212 FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType); 213 214 TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType()); 215 TypeSize NewSize = DL->getTypeSizeInBits(NewVTy); 216 217 IRBuilder<> Builder(InsertBB, InsertPt); 218 // If there is a bitsize match, we simply convert back to the original type. 219 if (OriginalSize == NewSize) 220 return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"); 221 222 // If there is a bitsize mismatch, then we must have used a wider value to 223 // hold the bits. 224 assert(OriginalSize > NewSize); 225 // For wide scalars, we can just truncate the value. 226 if (!V->getType()->isVectorTy()) { 227 Instruction *Trunc = cast<Instruction>( 228 Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize))); 229 return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy)); 230 } 231 232 // For wider vectors, we must strip the MSBs to convert back to the original 233 // type. 234 VectorType *ExpandedVT = VectorType::get( 235 Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()), 236 (OriginalSize / NewVTy->getScalarSizeInBits()), false); 237 Instruction *Converted = 238 cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT)); 239 240 unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue(); 241 SmallVector<int, 8> ShuffleMask(NarrowElementCount); 242 std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0); 243 244 return Builder.CreateShuffleVector(Converted, ShuffleMask); 245 } 246 247 bool LiveRegOptimizer::optimizeLiveType( 248 Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) { 249 SmallVector<Instruction *, 4> Worklist; 250 SmallPtrSet<PHINode *, 4> PhiNodes; 251 SmallPtrSet<Instruction *, 4> Defs; 252 SmallPtrSet<Instruction *, 4> Uses; 253 254 Worklist.push_back(cast<Instruction>(I)); 255 while (!Worklist.empty()) { 256 Instruction *II = Worklist.pop_back_val(); 257 258 if (!Visited.insert(II).second) 259 continue; 260 261 if (!shouldReplace(II->getType())) 262 continue; 263 264 if (PHINode *Phi = dyn_cast<PHINode>(II)) { 265 PhiNodes.insert(Phi); 266 // Collect all the incoming values of problematic PHI nodes. 267 for (Value *V : Phi->incoming_values()) { 268 // Repeat the collection process for newly found PHI nodes. 269 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 270 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 271 Worklist.push_back(OpPhi); 272 continue; 273 } 274 275 Instruction *IncInst = dyn_cast<Instruction>(V); 276 // Other incoming value types (e.g. vector literals) are unhandled 277 if (!IncInst && !isa<ConstantAggregateZero>(V)) 278 return false; 279 280 // Collect all other incoming values for coercion. 281 if (IncInst) 282 Defs.insert(IncInst); 283 } 284 } 285 286 // Collect all relevant uses. 287 for (User *V : II->users()) { 288 // Repeat the collection process for problematic PHI nodes. 289 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 290 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 291 Worklist.push_back(OpPhi); 292 continue; 293 } 294 295 Instruction *UseInst = cast<Instruction>(V); 296 // Collect all uses of PHINodes and any use the crosses BB boundaries. 297 if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) { 298 Uses.insert(UseInst); 299 if (!isa<PHINode>(II)) 300 Defs.insert(II); 301 } 302 } 303 } 304 305 // Coerce and track the defs. 306 for (Instruction *D : Defs) { 307 if (!ValMap.contains(D)) { 308 BasicBlock::iterator InsertPt = std::next(D->getIterator()); 309 Value *ConvertVal = convertToOptType(D, InsertPt); 310 assert(ConvertVal); 311 ValMap[D] = ConvertVal; 312 } 313 } 314 315 // Construct new-typed PHI nodes. 316 for (PHINode *Phi : PhiNodes) { 317 ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()), 318 Phi->getNumIncomingValues(), 319 Phi->getName() + ".tc", Phi->getIterator()); 320 } 321 322 // Connect all the PHI nodes with their new incoming values. 323 for (PHINode *Phi : PhiNodes) { 324 PHINode *NewPhi = cast<PHINode>(ValMap[Phi]); 325 bool MissingIncVal = false; 326 for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) { 327 Value *IncVal = Phi->getIncomingValue(I); 328 if (isa<ConstantAggregateZero>(IncVal)) { 329 Type *NewType = calculateConvertType(Phi->getType()); 330 NewPhi->addIncoming(ConstantInt::get(NewType, 0, false), 331 Phi->getIncomingBlock(I)); 332 } else if (Value *Val = ValMap.lookup(IncVal)) 333 NewPhi->addIncoming(Val, Phi->getIncomingBlock(I)); 334 else 335 MissingIncVal = true; 336 } 337 if (MissingIncVal) { 338 Value *DeadVal = ValMap[Phi]; 339 // The coercion chain of the PHI is broken. Delete the Phi 340 // from the ValMap and any connected / user Phis. 341 SmallVector<Value *, 4> PHIWorklist; 342 SmallPtrSet<Value *, 4> VisitedPhis; 343 PHIWorklist.push_back(DeadVal); 344 while (!PHIWorklist.empty()) { 345 Value *NextDeadValue = PHIWorklist.pop_back_val(); 346 VisitedPhis.insert(NextDeadValue); 347 auto OriginalPhi = 348 std::find_if(PhiNodes.begin(), PhiNodes.end(), 349 [this, &NextDeadValue](PHINode *CandPhi) { 350 return ValMap[CandPhi] == NextDeadValue; 351 }); 352 // This PHI may have already been removed from maps when 353 // unwinding a previous Phi 354 if (OriginalPhi != PhiNodes.end()) 355 ValMap.erase(*OriginalPhi); 356 357 DeadInsts.emplace_back(cast<Instruction>(NextDeadValue)); 358 359 for (User *U : NextDeadValue->users()) { 360 if (!VisitedPhis.contains(cast<PHINode>(U))) 361 PHIWorklist.push_back(U); 362 } 363 } 364 } else { 365 DeadInsts.emplace_back(cast<Instruction>(Phi)); 366 } 367 } 368 // Coerce back to the original type and replace the uses. 369 for (Instruction *U : Uses) { 370 // Replace all converted operands for a use. 371 for (auto [OpIdx, Op] : enumerate(U->operands())) { 372 if (ValMap.contains(Op) && ValMap[Op]) { 373 Value *NewVal = nullptr; 374 if (BBUseValMap.contains(U->getParent()) && 375 BBUseValMap[U->getParent()].contains(ValMap[Op])) 376 NewVal = BBUseValMap[U->getParent()][ValMap[Op]]; 377 else { 378 BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt(); 379 // We may pick up ops that were previously converted for users in 380 // other blocks. If there is an originally typed definition of the Op 381 // already in this block, simply reuse it. 382 if (isa<Instruction>(Op) && !isa<PHINode>(Op) && 383 U->getParent() == cast<Instruction>(Op)->getParent()) { 384 NewVal = Op; 385 } else { 386 NewVal = 387 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]), 388 InsertPt, U->getParent()); 389 BBUseValMap[U->getParent()][ValMap[Op]] = NewVal; 390 } 391 } 392 assert(NewVal); 393 U->setOperand(OpIdx, NewVal); 394 } 395 } 396 } 397 398 return true; 399 } 400 401 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const { 402 unsigned AS = LI.getPointerAddressSpace(); 403 // Skip non-constant address space. 404 if (AS != AMDGPUAS::CONSTANT_ADDRESS && 405 AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT) 406 return false; 407 // Skip non-simple loads. 408 if (!LI.isSimple()) 409 return false; 410 Type *Ty = LI.getType(); 411 // Skip aggregate types. 412 if (Ty->isAggregateType()) 413 return false; 414 unsigned TySize = DL->getTypeStoreSize(Ty); 415 // Only handle sub-DWORD loads. 416 if (TySize >= 4) 417 return false; 418 // That load must be at least naturally aligned. 419 if (LI.getAlign() < DL->getABITypeAlign(Ty)) 420 return false; 421 // It should be uniform, i.e. a scalar load. 422 return UA->isUniform(&LI); 423 } 424 425 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) { 426 if (!WidenLoads) 427 return false; 428 429 // Skip if that load is already aligned on DWORD at least as it's handled in 430 // SDAG. 431 if (LI.getAlign() >= 4) 432 return false; 433 434 if (!canWidenScalarExtLoad(LI)) 435 return false; 436 437 int64_t Offset = 0; 438 auto *Base = 439 GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL); 440 // If that base is not DWORD aligned, it's not safe to perform the following 441 // transforms. 442 if (!isDWORDAligned(Base)) 443 return false; 444 445 int64_t Adjust = Offset & 0x3; 446 if (Adjust == 0) { 447 // With a zero adjust, the original alignment could be promoted with a 448 // better one. 449 LI.setAlignment(Align(4)); 450 return true; 451 } 452 453 IRBuilder<> IRB(&LI); 454 IRB.SetCurrentDebugLocation(LI.getDebugLoc()); 455 456 unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType()); 457 auto *IntNTy = Type::getIntNTy(LI.getContext(), LdBits); 458 459 auto *NewPtr = IRB.CreateConstGEP1_64( 460 IRB.getInt8Ty(), 461 IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()), 462 Offset - Adjust); 463 464 LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4)); 465 NewLd->copyMetadata(LI); 466 NewLd->setMetadata(LLVMContext::MD_range, nullptr); 467 468 unsigned ShAmt = Adjust * 8; 469 auto *NewVal = IRB.CreateBitCast( 470 IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType()); 471 LI.replaceAllUsesWith(NewVal); 472 DeadInsts.emplace_back(&LI); 473 474 return true; 475 } 476 477 PreservedAnalyses 478 AMDGPULateCodeGenPreparePass::run(Function &F, FunctionAnalysisManager &FAM) { 479 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 480 481 AssumptionCache &AC = FAM.getResult<AssumptionAnalysis>(F); 482 UniformityInfo &UI = FAM.getResult<UniformityInfoAnalysis>(F); 483 484 AMDGPULateCodeGenPrepare Impl(*F.getParent(), ST, &AC, &UI); 485 486 bool Changed = Impl.run(F); 487 488 PreservedAnalyses PA = PreservedAnalyses::none(); 489 if (!Changed) 490 return PA; 491 PA.preserveSet<CFGAnalyses>(); 492 return PA; 493 } 494 495 class AMDGPULateCodeGenPrepareLegacy : public FunctionPass { 496 public: 497 static char ID; 498 499 AMDGPULateCodeGenPrepareLegacy() : FunctionPass(ID) {} 500 501 StringRef getPassName() const override { 502 return "AMDGPU IR late optimizations"; 503 } 504 505 void getAnalysisUsage(AnalysisUsage &AU) const override { 506 AU.addRequired<TargetPassConfig>(); 507 AU.addRequired<AssumptionCacheTracker>(); 508 AU.addRequired<UniformityInfoWrapperPass>(); 509 AU.setPreservesAll(); 510 } 511 512 bool runOnFunction(Function &F) override; 513 }; 514 515 bool AMDGPULateCodeGenPrepareLegacy::runOnFunction(Function &F) { 516 if (skipFunction(F)) 517 return false; 518 519 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); 520 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 521 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 522 523 AssumptionCache &AC = 524 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 525 UniformityInfo &UI = 526 getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); 527 528 AMDGPULateCodeGenPrepare Impl(*F.getParent(), ST, &AC, &UI); 529 530 return Impl.run(F); 531 } 532 533 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE, 534 "AMDGPU IR late optimizations", false, false) 535 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 536 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 537 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) 538 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepareLegacy, DEBUG_TYPE, 539 "AMDGPU IR late optimizations", false, false) 540 541 char AMDGPULateCodeGenPrepareLegacy::ID = 0; 542 543 FunctionPass *llvm::createAMDGPULateCodeGenPrepareLegacyPass() { 544 return new AMDGPULateCodeGenPrepareLegacy(); 545 } 546