1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction 11 /// selection. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "AMDGPU.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "llvm/Analysis/AssumptionCache.h" 18 #include "llvm/Analysis/UniformityAnalysis.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/InstVisitor.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/KnownBits.h" 26 #include "llvm/Transforms/Utils/Local.h" 27 28 #define DEBUG_TYPE "amdgpu-late-codegenprepare" 29 30 using namespace llvm; 31 32 // Scalar load widening needs running after load-store-vectorizer as that pass 33 // doesn't handle overlapping cases. In addition, this pass enhances the 34 // widening to handle cases where scalar sub-dword loads are naturally aligned 35 // only but not dword aligned. 36 static cl::opt<bool> 37 WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads", 38 cl::desc("Widen sub-dword constant address space loads in " 39 "AMDGPULateCodeGenPrepare"), 40 cl::ReallyHidden, cl::init(true)); 41 42 namespace { 43 44 class AMDGPULateCodeGenPrepare 45 : public FunctionPass, 46 public InstVisitor<AMDGPULateCodeGenPrepare, bool> { 47 Module *Mod = nullptr; 48 const DataLayout *DL = nullptr; 49 50 AssumptionCache *AC = nullptr; 51 UniformityInfo *UA = nullptr; 52 53 public: 54 static char ID; 55 56 AMDGPULateCodeGenPrepare() : FunctionPass(ID) {} 57 58 StringRef getPassName() const override { 59 return "AMDGPU IR late optimizations"; 60 } 61 62 void getAnalysisUsage(AnalysisUsage &AU) const override { 63 AU.addRequired<TargetPassConfig>(); 64 AU.addRequired<AssumptionCacheTracker>(); 65 AU.addRequired<UniformityInfoWrapperPass>(); 66 AU.setPreservesAll(); 67 } 68 69 bool doInitialization(Module &M) override; 70 bool runOnFunction(Function &F) override; 71 72 bool visitInstruction(Instruction &) { return false; } 73 74 // Check if the specified value is at least DWORD aligned. 75 bool isDWORDAligned(const Value *V) const { 76 KnownBits Known = computeKnownBits(V, *DL, 0, AC); 77 return Known.countMinTrailingZeros() >= 2; 78 } 79 80 bool canWidenScalarExtLoad(LoadInst &LI) const; 81 bool visitLoadInst(LoadInst &LI); 82 }; 83 84 using ValueToValueMap = DenseMap<const Value *, Value *>; 85 86 class LiveRegOptimizer { 87 private: 88 Module *Mod = nullptr; 89 const DataLayout *DL = nullptr; 90 const GCNSubtarget *ST; 91 /// The scalar type to convert to 92 Type *ConvertToScalar; 93 /// The set of visited Instructions 94 SmallPtrSet<Instruction *, 4> Visited; 95 /// The set of Instructions to be deleted 96 SmallPtrSet<Instruction *, 4> DeadInstrs; 97 /// Map of Value -> Converted Value 98 ValueToValueMap ValMap; 99 /// Map of containing conversions from Optimal Type -> Original Type per BB. 100 DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap; 101 102 public: 103 /// Calculate the and \p return the type to convert to given a problematic \p 104 /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32). 105 Type *calculateConvertType(Type *OriginalType); 106 /// Convert the virtual register defined by \p V to the compatible vector of 107 /// legal type 108 Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt); 109 /// Convert the virtual register defined by \p V back to the original type \p 110 /// ConvertType, stripping away the MSBs in cases where there was an imperfect 111 /// fit (e.g. v2i32 -> v7i8) 112 Value *convertFromOptType(Type *ConvertType, Instruction *V, 113 BasicBlock::iterator &InstPt, 114 BasicBlock *InsertBlock); 115 /// Check for problematic PHI nodes or cross-bb values based on the value 116 /// defined by \p I, and coerce to legal types if necessary. For problematic 117 /// PHI node, we coerce all incoming values in a single invocation. 118 bool optimizeLiveType(Instruction *I); 119 120 /// Remove all instructions that have become dead (i.e. all the re-typed PHIs) 121 void removeDeadInstrs(); 122 123 // Whether or not the type should be replaced to avoid inefficient 124 // legalization code 125 bool shouldReplace(Type *ITy) { 126 FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy); 127 if (!VTy) 128 return false; 129 130 auto TLI = ST->getTargetLowering(); 131 132 Type *EltTy = VTy->getElementType(); 133 // If the element size is not less than the convert to scalar size, then we 134 // can't do any bit packing 135 if (!EltTy->isIntegerTy() || 136 EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits()) 137 return false; 138 139 // Only coerce illegal types 140 TargetLoweringBase::LegalizeKind LK = 141 TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false)); 142 return LK.first != TargetLoweringBase::TypeLegal; 143 } 144 145 LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) { 146 DL = &Mod->getDataLayout(); 147 ConvertToScalar = Type::getInt32Ty(Mod->getContext()); 148 } 149 }; 150 151 } // end anonymous namespace 152 153 bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) { 154 Mod = &M; 155 DL = &Mod->getDataLayout(); 156 return false; 157 } 158 159 bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) { 160 if (skipFunction(F)) 161 return false; 162 163 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); 164 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 165 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 166 if (ST.hasScalarSubwordLoads()) 167 return false; 168 169 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 170 UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); 171 172 // "Optimize" the virtual regs that cross basic block boundaries. When 173 // building the SelectionDAG, vectors of illegal types that cross basic blocks 174 // will be scalarized and widened, with each scalar living in its 175 // own register. To work around this, this optimization converts the 176 // vectors to equivalent vectors of legal type (which are converted back 177 // before uses in subsequent blocks), to pack the bits into fewer physical 178 // registers (used in CopyToReg/CopyFromReg pairs). 179 LiveRegOptimizer LRO(Mod, &ST); 180 181 bool Changed = false; 182 183 for (auto &BB : F) 184 for (Instruction &I : make_early_inc_range(BB)) { 185 Changed |= visit(I); 186 Changed |= LRO.optimizeLiveType(&I); 187 } 188 189 LRO.removeDeadInstrs(); 190 return Changed; 191 } 192 193 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) { 194 assert(OriginalType->getScalarSizeInBits() <= 195 ConvertToScalar->getScalarSizeInBits()); 196 197 FixedVectorType *VTy = cast<FixedVectorType>(OriginalType); 198 199 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 200 TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar); 201 unsigned ConvertEltCount = 202 (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize; 203 204 if (OriginalSize <= ConvertScalarSize) 205 return IntegerType::get(Mod->getContext(), ConvertScalarSize); 206 207 return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize), 208 ConvertEltCount, false); 209 } 210 211 Value *LiveRegOptimizer::convertToOptType(Instruction *V, 212 BasicBlock::iterator &InsertPt) { 213 FixedVectorType *VTy = cast<FixedVectorType>(V->getType()); 214 Type *NewTy = calculateConvertType(V->getType()); 215 216 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 217 TypeSize NewSize = DL->getTypeSizeInBits(NewTy); 218 219 IRBuilder<> Builder(V->getParent(), InsertPt); 220 // If there is a bitsize match, we can fit the old vector into a new vector of 221 // desired type. 222 if (OriginalSize == NewSize) 223 return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"); 224 225 // If there is a bitsize mismatch, we must use a wider vector. 226 assert(NewSize > OriginalSize); 227 uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits(); 228 229 SmallVector<int, 8> ShuffleMask; 230 uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue(); 231 for (unsigned I = 0; I < OriginalElementCount; I++) 232 ShuffleMask.push_back(I); 233 234 for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++) 235 ShuffleMask.push_back(OriginalElementCount); 236 237 Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask); 238 return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"); 239 } 240 241 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V, 242 BasicBlock::iterator &InsertPt, 243 BasicBlock *InsertBB) { 244 FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType); 245 246 TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType()); 247 TypeSize NewSize = DL->getTypeSizeInBits(NewVTy); 248 249 IRBuilder<> Builder(InsertBB, InsertPt); 250 // If there is a bitsize match, we simply convert back to the original type. 251 if (OriginalSize == NewSize) 252 return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"); 253 254 // If there is a bitsize mismatch, then we must have used a wider value to 255 // hold the bits. 256 assert(OriginalSize > NewSize); 257 // For wide scalars, we can just truncate the value. 258 if (!V->getType()->isVectorTy()) { 259 Instruction *Trunc = cast<Instruction>( 260 Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize))); 261 return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy)); 262 } 263 264 // For wider vectors, we must strip the MSBs to convert back to the original 265 // type. 266 VectorType *ExpandedVT = VectorType::get( 267 Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()), 268 (OriginalSize / NewVTy->getScalarSizeInBits()), false); 269 Instruction *Converted = 270 cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT)); 271 272 unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue(); 273 SmallVector<int, 8> ShuffleMask(NarrowElementCount); 274 std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0); 275 276 return Builder.CreateShuffleVector(Converted, ShuffleMask); 277 } 278 279 bool LiveRegOptimizer::optimizeLiveType(Instruction *I) { 280 SmallVector<Instruction *, 4> Worklist; 281 SmallPtrSet<PHINode *, 4> PhiNodes; 282 SmallPtrSet<Instruction *, 4> Defs; 283 SmallPtrSet<Instruction *, 4> Uses; 284 285 Worklist.push_back(cast<Instruction>(I)); 286 while (!Worklist.empty()) { 287 Instruction *II = Worklist.pop_back_val(); 288 289 if (!Visited.insert(II).second) 290 continue; 291 292 if (!shouldReplace(II->getType())) 293 continue; 294 295 if (PHINode *Phi = dyn_cast<PHINode>(II)) { 296 PhiNodes.insert(Phi); 297 // Collect all the incoming values of problematic PHI nodes. 298 for (Value *V : Phi->incoming_values()) { 299 // Repeat the collection process for newly found PHI nodes. 300 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 301 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 302 Worklist.push_back(OpPhi); 303 continue; 304 } 305 306 Instruction *IncInst = dyn_cast<Instruction>(V); 307 // Other incoming value types (e.g. vector literals) are unhandled 308 if (!IncInst && !isa<ConstantAggregateZero>(V)) 309 return false; 310 311 // Collect all other incoming values for coercion. 312 if (IncInst) 313 Defs.insert(IncInst); 314 } 315 } 316 317 // Collect all relevant uses. 318 for (User *V : II->users()) { 319 // Repeat the collection process for problematic PHI nodes. 320 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 321 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 322 Worklist.push_back(OpPhi); 323 continue; 324 } 325 326 Instruction *UseInst = cast<Instruction>(V); 327 // Collect all uses of PHINodes and any use the crosses BB boundaries. 328 if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) { 329 Uses.insert(UseInst); 330 if (!Defs.count(II) && !isa<PHINode>(II)) { 331 Defs.insert(II); 332 } 333 } 334 } 335 } 336 337 // Coerce and track the defs. 338 for (Instruction *D : Defs) { 339 if (!ValMap.contains(D)) { 340 BasicBlock::iterator InsertPt = std::next(D->getIterator()); 341 Value *ConvertVal = convertToOptType(D, InsertPt); 342 assert(ConvertVal); 343 ValMap[D] = ConvertVal; 344 } 345 } 346 347 // Construct new-typed PHI nodes. 348 for (PHINode *Phi : PhiNodes) { 349 ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()), 350 Phi->getNumIncomingValues(), 351 Phi->getName() + ".tc", Phi->getIterator()); 352 } 353 354 // Connect all the PHI nodes with their new incoming values. 355 for (PHINode *Phi : PhiNodes) { 356 PHINode *NewPhi = cast<PHINode>(ValMap[Phi]); 357 bool MissingIncVal = false; 358 for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) { 359 Value *IncVal = Phi->getIncomingValue(I); 360 if (isa<ConstantAggregateZero>(IncVal)) { 361 Type *NewType = calculateConvertType(Phi->getType()); 362 NewPhi->addIncoming(ConstantInt::get(NewType, 0, false), 363 Phi->getIncomingBlock(I)); 364 } else if (ValMap.contains(IncVal)) 365 NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I)); 366 else 367 MissingIncVal = true; 368 } 369 DeadInstrs.insert(MissingIncVal ? cast<Instruction>(ValMap[Phi]) : Phi); 370 } 371 // Coerce back to the original type and replace the uses. 372 for (Instruction *U : Uses) { 373 // Replace all converted operands for a use. 374 for (auto [OpIdx, Op] : enumerate(U->operands())) { 375 if (ValMap.contains(Op)) { 376 Value *NewVal = nullptr; 377 if (BBUseValMap.contains(U->getParent()) && 378 BBUseValMap[U->getParent()].contains(ValMap[Op])) 379 NewVal = BBUseValMap[U->getParent()][ValMap[Op]]; 380 else { 381 BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt(); 382 NewVal = 383 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]), 384 InsertPt, U->getParent()); 385 BBUseValMap[U->getParent()][ValMap[Op]] = NewVal; 386 } 387 assert(NewVal); 388 U->setOperand(OpIdx, NewVal); 389 } 390 } 391 } 392 393 return true; 394 } 395 396 void LiveRegOptimizer::removeDeadInstrs() { 397 // Remove instrs that have been marked dead after type-coercion. 398 for (auto *I : DeadInstrs) { 399 I->replaceAllUsesWith(PoisonValue::get(I->getType())); 400 I->eraseFromParent(); 401 } 402 } 403 404 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const { 405 unsigned AS = LI.getPointerAddressSpace(); 406 // Skip non-constant address space. 407 if (AS != AMDGPUAS::CONSTANT_ADDRESS && 408 AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT) 409 return false; 410 // Skip non-simple loads. 411 if (!LI.isSimple()) 412 return false; 413 Type *Ty = LI.getType(); 414 // Skip aggregate types. 415 if (Ty->isAggregateType()) 416 return false; 417 unsigned TySize = DL->getTypeStoreSize(Ty); 418 // Only handle sub-DWORD loads. 419 if (TySize >= 4) 420 return false; 421 // That load must be at least naturally aligned. 422 if (LI.getAlign() < DL->getABITypeAlign(Ty)) 423 return false; 424 // It should be uniform, i.e. a scalar load. 425 return UA->isUniform(&LI); 426 } 427 428 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) { 429 if (!WidenLoads) 430 return false; 431 432 // Skip if that load is already aligned on DWORD at least as it's handled in 433 // SDAG. 434 if (LI.getAlign() >= 4) 435 return false; 436 437 if (!canWidenScalarExtLoad(LI)) 438 return false; 439 440 int64_t Offset = 0; 441 auto *Base = 442 GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL); 443 // If that base is not DWORD aligned, it's not safe to perform the following 444 // transforms. 445 if (!isDWORDAligned(Base)) 446 return false; 447 448 int64_t Adjust = Offset & 0x3; 449 if (Adjust == 0) { 450 // With a zero adjust, the original alignment could be promoted with a 451 // better one. 452 LI.setAlignment(Align(4)); 453 return true; 454 } 455 456 IRBuilder<> IRB(&LI); 457 IRB.SetCurrentDebugLocation(LI.getDebugLoc()); 458 459 unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType()); 460 auto IntNTy = Type::getIntNTy(LI.getContext(), LdBits); 461 462 auto *NewPtr = IRB.CreateConstGEP1_64( 463 IRB.getInt8Ty(), 464 IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()), 465 Offset - Adjust); 466 467 LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4)); 468 NewLd->copyMetadata(LI); 469 NewLd->setMetadata(LLVMContext::MD_range, nullptr); 470 471 unsigned ShAmt = Adjust * 8; 472 auto *NewVal = IRB.CreateBitCast( 473 IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType()); 474 LI.replaceAllUsesWith(NewVal); 475 RecursivelyDeleteTriviallyDeadInstructions(&LI); 476 477 return true; 478 } 479 480 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE, 481 "AMDGPU IR late optimizations", false, false) 482 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 483 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 484 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) 485 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE, 486 "AMDGPU IR late optimizations", false, false) 487 488 char AMDGPULateCodeGenPrepare::ID = 0; 489 490 FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() { 491 return new AMDGPULateCodeGenPrepare(); 492 } 493