1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction 11 /// selection. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "AMDGPU.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "llvm/Analysis/AssumptionCache.h" 18 #include "llvm/Analysis/UniformityAnalysis.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/InstVisitor.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/KnownBits.h" 26 #include "llvm/Transforms/Utils/Local.h" 27 28 #define DEBUG_TYPE "amdgpu-late-codegenprepare" 29 30 using namespace llvm; 31 32 // Scalar load widening needs running after load-store-vectorizer as that pass 33 // doesn't handle overlapping cases. In addition, this pass enhances the 34 // widening to handle cases where scalar sub-dword loads are naturally aligned 35 // only but not dword aligned. 36 static cl::opt<bool> 37 WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads", 38 cl::desc("Widen sub-dword constant address space loads in " 39 "AMDGPULateCodeGenPrepare"), 40 cl::ReallyHidden, cl::init(true)); 41 42 namespace { 43 44 class AMDGPULateCodeGenPrepare 45 : public FunctionPass, 46 public InstVisitor<AMDGPULateCodeGenPrepare, bool> { 47 Module *Mod = nullptr; 48 const DataLayout *DL = nullptr; 49 50 AssumptionCache *AC = nullptr; 51 UniformityInfo *UA = nullptr; 52 53 SmallVector<WeakTrackingVH, 8> DeadInsts; 54 55 public: 56 static char ID; 57 58 AMDGPULateCodeGenPrepare() : FunctionPass(ID) {} 59 60 StringRef getPassName() const override { 61 return "AMDGPU IR late optimizations"; 62 } 63 64 void getAnalysisUsage(AnalysisUsage &AU) const override { 65 AU.addRequired<TargetPassConfig>(); 66 AU.addRequired<AssumptionCacheTracker>(); 67 AU.addRequired<UniformityInfoWrapperPass>(); 68 AU.setPreservesAll(); 69 } 70 71 bool doInitialization(Module &M) override; 72 bool runOnFunction(Function &F) override; 73 74 bool visitInstruction(Instruction &) { return false; } 75 76 // Check if the specified value is at least DWORD aligned. 77 bool isDWORDAligned(const Value *V) const { 78 KnownBits Known = computeKnownBits(V, *DL, 0, AC); 79 return Known.countMinTrailingZeros() >= 2; 80 } 81 82 bool canWidenScalarExtLoad(LoadInst &LI) const; 83 bool visitLoadInst(LoadInst &LI); 84 }; 85 86 using ValueToValueMap = DenseMap<const Value *, Value *>; 87 88 class LiveRegOptimizer { 89 private: 90 Module *Mod = nullptr; 91 const DataLayout *DL = nullptr; 92 const GCNSubtarget *ST; 93 /// The scalar type to convert to 94 Type *ConvertToScalar; 95 /// The set of visited Instructions 96 SmallPtrSet<Instruction *, 4> Visited; 97 /// Map of Value -> Converted Value 98 ValueToValueMap ValMap; 99 /// Map of containing conversions from Optimal Type -> Original Type per BB. 100 DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap; 101 102 public: 103 /// Calculate the and \p return the type to convert to given a problematic \p 104 /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32). 105 Type *calculateConvertType(Type *OriginalType); 106 /// Convert the virtual register defined by \p V to the compatible vector of 107 /// legal type 108 Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt); 109 /// Convert the virtual register defined by \p V back to the original type \p 110 /// ConvertType, stripping away the MSBs in cases where there was an imperfect 111 /// fit (e.g. v2i32 -> v7i8) 112 Value *convertFromOptType(Type *ConvertType, Instruction *V, 113 BasicBlock::iterator &InstPt, 114 BasicBlock *InsertBlock); 115 /// Check for problematic PHI nodes or cross-bb values based on the value 116 /// defined by \p I, and coerce to legal types if necessary. For problematic 117 /// PHI node, we coerce all incoming values in a single invocation. 118 bool optimizeLiveType(Instruction *I, 119 SmallVectorImpl<WeakTrackingVH> &DeadInsts); 120 121 // Whether or not the type should be replaced to avoid inefficient 122 // legalization code 123 bool shouldReplace(Type *ITy) { 124 FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy); 125 if (!VTy) 126 return false; 127 128 auto TLI = ST->getTargetLowering(); 129 130 Type *EltTy = VTy->getElementType(); 131 // If the element size is not less than the convert to scalar size, then we 132 // can't do any bit packing 133 if (!EltTy->isIntegerTy() || 134 EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits()) 135 return false; 136 137 // Only coerce illegal types 138 TargetLoweringBase::LegalizeKind LK = 139 TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false)); 140 return LK.first != TargetLoweringBase::TypeLegal; 141 } 142 143 LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) { 144 DL = &Mod->getDataLayout(); 145 ConvertToScalar = Type::getInt32Ty(Mod->getContext()); 146 } 147 }; 148 149 } // end anonymous namespace 150 151 bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) { 152 Mod = &M; 153 DL = &Mod->getDataLayout(); 154 return false; 155 } 156 157 bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) { 158 if (skipFunction(F)) 159 return false; 160 161 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); 162 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 163 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 164 165 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 166 UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); 167 168 // "Optimize" the virtual regs that cross basic block boundaries. When 169 // building the SelectionDAG, vectors of illegal types that cross basic blocks 170 // will be scalarized and widened, with each scalar living in its 171 // own register. To work around this, this optimization converts the 172 // vectors to equivalent vectors of legal type (which are converted back 173 // before uses in subsequent blocks), to pack the bits into fewer physical 174 // registers (used in CopyToReg/CopyFromReg pairs). 175 LiveRegOptimizer LRO(Mod, &ST); 176 177 bool Changed = false; 178 179 bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads(); 180 181 for (auto &BB : reverse(F)) 182 for (Instruction &I : make_early_inc_range(reverse(BB))) { 183 Changed |= !HasScalarSubwordLoads && visit(I); 184 Changed |= LRO.optimizeLiveType(&I, DeadInsts); 185 } 186 187 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts); 188 return Changed; 189 } 190 191 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) { 192 assert(OriginalType->getScalarSizeInBits() <= 193 ConvertToScalar->getScalarSizeInBits()); 194 195 FixedVectorType *VTy = cast<FixedVectorType>(OriginalType); 196 197 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 198 TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar); 199 unsigned ConvertEltCount = 200 (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize; 201 202 if (OriginalSize <= ConvertScalarSize) 203 return IntegerType::get(Mod->getContext(), ConvertScalarSize); 204 205 return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize), 206 ConvertEltCount, false); 207 } 208 209 Value *LiveRegOptimizer::convertToOptType(Instruction *V, 210 BasicBlock::iterator &InsertPt) { 211 FixedVectorType *VTy = cast<FixedVectorType>(V->getType()); 212 Type *NewTy = calculateConvertType(V->getType()); 213 214 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 215 TypeSize NewSize = DL->getTypeSizeInBits(NewTy); 216 217 IRBuilder<> Builder(V->getParent(), InsertPt); 218 // If there is a bitsize match, we can fit the old vector into a new vector of 219 // desired type. 220 if (OriginalSize == NewSize) 221 return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"); 222 223 // If there is a bitsize mismatch, we must use a wider vector. 224 assert(NewSize > OriginalSize); 225 uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits(); 226 227 SmallVector<int, 8> ShuffleMask; 228 uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue(); 229 for (unsigned I = 0; I < OriginalElementCount; I++) 230 ShuffleMask.push_back(I); 231 232 for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++) 233 ShuffleMask.push_back(OriginalElementCount); 234 235 Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask); 236 return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"); 237 } 238 239 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V, 240 BasicBlock::iterator &InsertPt, 241 BasicBlock *InsertBB) { 242 FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType); 243 244 TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType()); 245 TypeSize NewSize = DL->getTypeSizeInBits(NewVTy); 246 247 IRBuilder<> Builder(InsertBB, InsertPt); 248 // If there is a bitsize match, we simply convert back to the original type. 249 if (OriginalSize == NewSize) 250 return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"); 251 252 // If there is a bitsize mismatch, then we must have used a wider value to 253 // hold the bits. 254 assert(OriginalSize > NewSize); 255 // For wide scalars, we can just truncate the value. 256 if (!V->getType()->isVectorTy()) { 257 Instruction *Trunc = cast<Instruction>( 258 Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize))); 259 return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy)); 260 } 261 262 // For wider vectors, we must strip the MSBs to convert back to the original 263 // type. 264 VectorType *ExpandedVT = VectorType::get( 265 Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()), 266 (OriginalSize / NewVTy->getScalarSizeInBits()), false); 267 Instruction *Converted = 268 cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT)); 269 270 unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue(); 271 SmallVector<int, 8> ShuffleMask(NarrowElementCount); 272 std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0); 273 274 return Builder.CreateShuffleVector(Converted, ShuffleMask); 275 } 276 277 bool LiveRegOptimizer::optimizeLiveType( 278 Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) { 279 SmallVector<Instruction *, 4> Worklist; 280 SmallPtrSet<PHINode *, 4> PhiNodes; 281 SmallPtrSet<Instruction *, 4> Defs; 282 SmallPtrSet<Instruction *, 4> Uses; 283 284 Worklist.push_back(cast<Instruction>(I)); 285 while (!Worklist.empty()) { 286 Instruction *II = Worklist.pop_back_val(); 287 288 if (!Visited.insert(II).second) 289 continue; 290 291 if (!shouldReplace(II->getType())) 292 continue; 293 294 if (PHINode *Phi = dyn_cast<PHINode>(II)) { 295 PhiNodes.insert(Phi); 296 // Collect all the incoming values of problematic PHI nodes. 297 for (Value *V : Phi->incoming_values()) { 298 // Repeat the collection process for newly found PHI nodes. 299 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 300 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 301 Worklist.push_back(OpPhi); 302 continue; 303 } 304 305 Instruction *IncInst = dyn_cast<Instruction>(V); 306 // Other incoming value types (e.g. vector literals) are unhandled 307 if (!IncInst && !isa<ConstantAggregateZero>(V)) 308 return false; 309 310 // Collect all other incoming values for coercion. 311 if (IncInst) 312 Defs.insert(IncInst); 313 } 314 } 315 316 // Collect all relevant uses. 317 for (User *V : II->users()) { 318 // Repeat the collection process for problematic PHI nodes. 319 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 320 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 321 Worklist.push_back(OpPhi); 322 continue; 323 } 324 325 Instruction *UseInst = cast<Instruction>(V); 326 // Collect all uses of PHINodes and any use the crosses BB boundaries. 327 if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) { 328 Uses.insert(UseInst); 329 if (!Defs.count(II) && !isa<PHINode>(II)) { 330 Defs.insert(II); 331 } 332 } 333 } 334 } 335 336 // Coerce and track the defs. 337 for (Instruction *D : Defs) { 338 if (!ValMap.contains(D)) { 339 BasicBlock::iterator InsertPt = std::next(D->getIterator()); 340 Value *ConvertVal = convertToOptType(D, InsertPt); 341 assert(ConvertVal); 342 ValMap[D] = ConvertVal; 343 } 344 } 345 346 // Construct new-typed PHI nodes. 347 for (PHINode *Phi : PhiNodes) { 348 ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()), 349 Phi->getNumIncomingValues(), 350 Phi->getName() + ".tc", Phi->getIterator()); 351 } 352 353 // Connect all the PHI nodes with their new incoming values. 354 for (PHINode *Phi : PhiNodes) { 355 PHINode *NewPhi = cast<PHINode>(ValMap[Phi]); 356 bool MissingIncVal = false; 357 for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) { 358 Value *IncVal = Phi->getIncomingValue(I); 359 if (isa<ConstantAggregateZero>(IncVal)) { 360 Type *NewType = calculateConvertType(Phi->getType()); 361 NewPhi->addIncoming(ConstantInt::get(NewType, 0, false), 362 Phi->getIncomingBlock(I)); 363 } else if (ValMap.contains(IncVal)) 364 NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I)); 365 else 366 MissingIncVal = true; 367 } 368 Instruction *DeadInst = Phi; 369 if (MissingIncVal) { 370 DeadInst = cast<Instruction>(ValMap[Phi]); 371 // Do not use the dead phi 372 ValMap[Phi] = Phi; 373 } 374 DeadInsts.emplace_back(DeadInst); 375 } 376 // Coerce back to the original type and replace the uses. 377 for (Instruction *U : Uses) { 378 // Replace all converted operands for a use. 379 for (auto [OpIdx, Op] : enumerate(U->operands())) { 380 if (ValMap.contains(Op)) { 381 Value *NewVal = nullptr; 382 if (BBUseValMap.contains(U->getParent()) && 383 BBUseValMap[U->getParent()].contains(ValMap[Op])) 384 NewVal = BBUseValMap[U->getParent()][ValMap[Op]]; 385 else { 386 BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt(); 387 NewVal = 388 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]), 389 InsertPt, U->getParent()); 390 BBUseValMap[U->getParent()][ValMap[Op]] = NewVal; 391 } 392 assert(NewVal); 393 U->setOperand(OpIdx, NewVal); 394 } 395 } 396 } 397 398 return true; 399 } 400 401 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const { 402 unsigned AS = LI.getPointerAddressSpace(); 403 // Skip non-constant address space. 404 if (AS != AMDGPUAS::CONSTANT_ADDRESS && 405 AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT) 406 return false; 407 // Skip non-simple loads. 408 if (!LI.isSimple()) 409 return false; 410 Type *Ty = LI.getType(); 411 // Skip aggregate types. 412 if (Ty->isAggregateType()) 413 return false; 414 unsigned TySize = DL->getTypeStoreSize(Ty); 415 // Only handle sub-DWORD loads. 416 if (TySize >= 4) 417 return false; 418 // That load must be at least naturally aligned. 419 if (LI.getAlign() < DL->getABITypeAlign(Ty)) 420 return false; 421 // It should be uniform, i.e. a scalar load. 422 return UA->isUniform(&LI); 423 } 424 425 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) { 426 if (!WidenLoads) 427 return false; 428 429 // Skip if that load is already aligned on DWORD at least as it's handled in 430 // SDAG. 431 if (LI.getAlign() >= 4) 432 return false; 433 434 if (!canWidenScalarExtLoad(LI)) 435 return false; 436 437 int64_t Offset = 0; 438 auto *Base = 439 GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL); 440 // If that base is not DWORD aligned, it's not safe to perform the following 441 // transforms. 442 if (!isDWORDAligned(Base)) 443 return false; 444 445 int64_t Adjust = Offset & 0x3; 446 if (Adjust == 0) { 447 // With a zero adjust, the original alignment could be promoted with a 448 // better one. 449 LI.setAlignment(Align(4)); 450 return true; 451 } 452 453 IRBuilder<> IRB(&LI); 454 IRB.SetCurrentDebugLocation(LI.getDebugLoc()); 455 456 unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType()); 457 auto IntNTy = Type::getIntNTy(LI.getContext(), LdBits); 458 459 auto *NewPtr = IRB.CreateConstGEP1_64( 460 IRB.getInt8Ty(), 461 IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()), 462 Offset - Adjust); 463 464 LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4)); 465 NewLd->copyMetadata(LI); 466 NewLd->setMetadata(LLVMContext::MD_range, nullptr); 467 468 unsigned ShAmt = Adjust * 8; 469 auto *NewVal = IRB.CreateBitCast( 470 IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType()); 471 LI.replaceAllUsesWith(NewVal); 472 DeadInsts.emplace_back(&LI); 473 474 return true; 475 } 476 477 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE, 478 "AMDGPU IR late optimizations", false, false) 479 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 480 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 481 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) 482 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE, 483 "AMDGPU IR late optimizations", false, false) 484 485 char AMDGPULateCodeGenPrepare::ID = 0; 486 487 FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() { 488 return new AMDGPULateCodeGenPrepare(); 489 } 490