1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "AMDGPU.h" 16 #include "llvm/CodeGen/RegisterPressure.h" 17 18 using namespace llvm; 19 20 #define DEBUG_TYPE "machine-scheduler" 21 22 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 23 const GCNRPTracker::LiveRegSet &S2) { 24 if (S1.size() != S2.size()) 25 return false; 26 27 for (const auto &P : S1) { 28 auto I = S2.find(P.first); 29 if (I == S2.end() || I->second != P.second) 30 return false; 31 } 32 return true; 33 } 34 35 /////////////////////////////////////////////////////////////////////////////// 36 // GCNRegPressure 37 38 unsigned GCNRegPressure::getRegKind(Register Reg, 39 const MachineRegisterInfo &MRI) { 40 assert(Reg.isVirtual()); 41 const auto *const RC = MRI.getRegClass(Reg); 42 const auto *STI = 43 static_cast<const SIRegisterInfo *>(MRI.getTargetRegisterInfo()); 44 return STI->isSGPRClass(RC) 45 ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) 46 : STI->isAGPRClass(RC) 47 ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) 48 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 49 } 50 51 void GCNRegPressure::inc(unsigned Reg, 52 LaneBitmask PrevMask, 53 LaneBitmask NewMask, 54 const MachineRegisterInfo &MRI) { 55 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 56 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 57 return; 58 59 int Sign = 1; 60 if (NewMask < PrevMask) { 61 std::swap(NewMask, PrevMask); 62 Sign = -1; 63 } 64 65 switch (auto Kind = getRegKind(Reg, MRI)) { 66 case SGPR32: 67 case VGPR32: 68 case AGPR32: 69 Value[Kind] += Sign; 70 break; 71 72 case SGPR_TUPLE: 73 case VGPR_TUPLE: 74 case AGPR_TUPLE: 75 assert(PrevMask < NewMask); 76 77 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 78 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 79 80 if (PrevMask.none()) { 81 assert(NewMask.any()); 82 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 83 Value[Kind] += 84 Sign * TRI->getRegClassWeight(MRI.getRegClass(Reg)).RegWeight; 85 } 86 break; 87 88 default: llvm_unreachable("Unknown register kind"); 89 } 90 } 91 92 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, 93 unsigned MaxOccupancy) const { 94 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 95 96 const auto SGPROcc = std::min(MaxOccupancy, 97 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 98 const auto VGPROcc = 99 std::min(MaxOccupancy, 100 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()))); 101 const auto OtherSGPROcc = std::min(MaxOccupancy, 102 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 103 const auto OtherVGPROcc = 104 std::min(MaxOccupancy, 105 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()))); 106 107 const auto Occ = std::min(SGPROcc, VGPROcc); 108 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 109 110 // Give first precedence to the better occupancy. 111 if (Occ != OtherOcc) 112 return Occ > OtherOcc; 113 114 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); 115 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF); 116 117 // SGPR excess pressure conditions 118 unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0); 119 unsigned OtherExcessSGPR = 120 std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0); 121 122 auto WaveSize = ST.getWavefrontSize(); 123 // The number of virtual VGPRs required to handle excess SGPR 124 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize; 125 unsigned OtherVGPRForSGPRSpills = 126 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize; 127 128 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); 129 130 // Unified excess pressure conditions, accounting for VGPRs used for SGPR 131 // spills 132 unsigned ExcessVGPR = 133 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) + 134 VGPRForSGPRSpills - MaxVGPRs), 135 0); 136 unsigned OtherExcessVGPR = 137 std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) + 138 OtherVGPRForSGPRSpills - MaxVGPRs), 139 0); 140 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR 141 // spills 142 unsigned ExcessArchVGPR = std::max( 143 static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs), 144 0); 145 unsigned OtherExcessArchVGPR = 146 std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills - 147 MaxArchVGPRs), 148 0); 149 // AGPR excess pressure conditions 150 unsigned ExcessAGPR = std::max( 151 static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs) 152 : (getAGPRNum() - MaxVGPRs)), 153 0); 154 unsigned OtherExcessAGPR = std::max( 155 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs) 156 : (O.getAGPRNum() - MaxVGPRs)), 157 0); 158 159 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR; 160 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR || 161 OtherExcessArchVGPR || OtherExcessAGPR; 162 163 // Give second precedence to the reduced number of spills to hold the register 164 // pressure. 165 if (ExcessRP || OtherExcessRP) { 166 // The difference in excess VGPR pressure, after including VGPRs used for 167 // SGPR spills 168 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) - 169 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR)); 170 171 int SGPRDiff = OtherExcessSGPR - ExcessSGPR; 172 173 if (VGPRDiff != 0) 174 return VGPRDiff > 0; 175 if (SGPRDiff != 0) { 176 unsigned PureExcessVGPR = 177 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 178 0) + 179 std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0); 180 unsigned OtherPureExcessVGPR = 181 std::max( 182 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 183 0) + 184 std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0); 185 186 // If we have a special case where there is a tie in excess VGPR, but one 187 // of the pressures has VGPR usage from SGPR spills, prefer the pressure 188 // with SGPR spills. 189 if (PureExcessVGPR != OtherPureExcessVGPR) 190 return SGPRDiff < 0; 191 // If both pressures have the same excess pressure before and after 192 // accounting for SGPR spills, prefer fewer SGPR spills. 193 return SGPRDiff > 0; 194 } 195 } 196 197 bool SGPRImportant = SGPROcc < VGPROcc; 198 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 199 200 // If both pressures disagree on what is more important compare vgprs. 201 if (SGPRImportant != OtherSGPRImportant) { 202 SGPRImportant = false; 203 } 204 205 // Give third precedence to lower register tuple pressure. 206 bool SGPRFirst = SGPRImportant; 207 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 208 if (SGPRFirst) { 209 auto SW = getSGPRTuplesWeight(); 210 auto OtherSW = O.getSGPRTuplesWeight(); 211 if (SW != OtherSW) 212 return SW < OtherSW; 213 } else { 214 auto VW = getVGPRTuplesWeight(); 215 auto OtherVW = O.getVGPRTuplesWeight(); 216 if (VW != OtherVW) 217 return VW < OtherVW; 218 } 219 } 220 221 // Give final precedence to lower general RP. 222 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 223 (getVGPRNum(ST.hasGFX90AInsts()) < 224 O.getVGPRNum(ST.hasGFX90AInsts())); 225 } 226 227 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) { 228 return Printable([&RP, ST](raw_ostream &OS) { 229 OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' ' 230 << "AGPRs: " << RP.getAGPRNum(); 231 if (ST) 232 OS << "(O" 233 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts())) 234 << ')'; 235 OS << ", SGPRs: " << RP.getSGPRNum(); 236 if (ST) 237 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')'; 238 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() 239 << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); 240 if (ST) 241 OS << " -> Occ: " << RP.getOccupancy(*ST); 242 OS << '\n'; 243 }); 244 } 245 246 static LaneBitmask getDefRegMask(const MachineOperand &MO, 247 const MachineRegisterInfo &MRI) { 248 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 249 250 // We don't rely on read-undef flag because in case of tentative schedule 251 // tracking it isn't set correctly yet. This works correctly however since 252 // use mask has been tracked before using LIS. 253 return MO.getSubReg() == 0 ? 254 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 255 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 256 } 257 258 static void 259 collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits, 260 const MachineInstr &MI, const LiveIntervals &LIS, 261 const MachineRegisterInfo &MRI) { 262 263 auto &TRI = *MRI.getTargetRegisterInfo(); 264 for (const auto &MO : MI.operands()) { 265 if (!MO.isReg() || !MO.getReg().isVirtual()) 266 continue; 267 if (!MO.isUse() || !MO.readsReg()) 268 continue; 269 270 Register Reg = MO.getReg(); 271 auto I = llvm::find_if(VRegMaskOrUnits, [Reg](const VRegMaskOrUnit &RM) { 272 return RM.RegUnit == Reg; 273 }); 274 275 auto &P = I == VRegMaskOrUnits.end() 276 ? VRegMaskOrUnits.emplace_back(Reg, LaneBitmask::getNone()) 277 : *I; 278 279 P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg()) 280 : MRI.getMaxLaneMaskForVReg(Reg); 281 } 282 283 SlotIndex InstrSI; 284 for (auto &P : VRegMaskOrUnits) { 285 auto &LI = LIS.getInterval(P.RegUnit); 286 if (!LI.hasSubRanges()) 287 continue; 288 289 // For a tentative schedule LIS isn't updated yet but livemask should 290 // remain the same on any schedule. Subreg defs can be reordered but they 291 // all must dominate uses anyway. 292 if (!InstrSI) 293 InstrSI = LIS.getInstructionIndex(MI).getBaseIndex(); 294 295 P.LaneMask = getLiveLaneMask(LI, InstrSI, MRI, P.LaneMask); 296 } 297 } 298 299 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp 300 static LaneBitmask getLanesWithProperty( 301 const LiveIntervals &LIS, const MachineRegisterInfo &MRI, 302 bool TrackLaneMasks, Register RegUnit, SlotIndex Pos, 303 LaneBitmask SafeDefault, 304 function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) { 305 if (RegUnit.isVirtual()) { 306 const LiveInterval &LI = LIS.getInterval(RegUnit); 307 LaneBitmask Result; 308 if (TrackLaneMasks && LI.hasSubRanges()) { 309 for (const LiveInterval::SubRange &SR : LI.subranges()) { 310 if (Property(SR, Pos)) 311 Result |= SR.LaneMask; 312 } 313 } else if (Property(LI, Pos)) { 314 Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit) 315 : LaneBitmask::getAll(); 316 } 317 318 return Result; 319 } 320 321 const LiveRange *LR = LIS.getCachedRegUnit(RegUnit); 322 if (LR == nullptr) 323 return SafeDefault; 324 return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone(); 325 } 326 327 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp 328 /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}. 329 /// The query starts with a lane bitmask which gets lanes/bits removed for every 330 /// use we find. 331 static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, 332 SlotIndex PriorUseIdx, SlotIndex NextUseIdx, 333 const MachineRegisterInfo &MRI, 334 const SIRegisterInfo *TRI, 335 const LiveIntervals *LIS, 336 bool Upward = false) { 337 for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) { 338 if (MO.isUndef()) 339 continue; 340 const MachineInstr *MI = MO.getParent(); 341 SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot(); 342 bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx) 343 : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx); 344 if (!InRange) 345 continue; 346 347 unsigned SubRegIdx = MO.getSubReg(); 348 LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx); 349 LastUseMask &= ~UseMask; 350 if (LastUseMask.none()) 351 return LaneBitmask::getNone(); 352 } 353 return LastUseMask; 354 } 355 356 /////////////////////////////////////////////////////////////////////////////// 357 // GCNRPTracker 358 359 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, 360 const LiveIntervals &LIS, 361 const MachineRegisterInfo &MRI, 362 LaneBitmask LaneMaskFilter) { 363 return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI, LaneMaskFilter); 364 } 365 366 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, 367 const MachineRegisterInfo &MRI, 368 LaneBitmask LaneMaskFilter) { 369 LaneBitmask LiveMask; 370 if (LI.hasSubRanges()) { 371 for (const auto &S : LI.subranges()) 372 if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) { 373 LiveMask |= S.LaneMask; 374 assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); 375 } 376 } else if (LI.liveAt(SI)) { 377 LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg()); 378 } 379 LiveMask &= LaneMaskFilter; 380 return LiveMask; 381 } 382 383 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 384 const LiveIntervals &LIS, 385 const MachineRegisterInfo &MRI) { 386 GCNRPTracker::LiveRegSet LiveRegs; 387 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 388 auto Reg = Register::index2VirtReg(I); 389 if (!LIS.hasInterval(Reg)) 390 continue; 391 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 392 if (LiveMask.any()) 393 LiveRegs[Reg] = LiveMask; 394 } 395 return LiveRegs; 396 } 397 398 void GCNRPTracker::reset(const MachineInstr &MI, 399 const LiveRegSet *LiveRegsCopy, 400 bool After) { 401 const MachineFunction &MF = *MI.getMF(); 402 MRI = &MF.getRegInfo(); 403 if (LiveRegsCopy) { 404 if (&LiveRegs != LiveRegsCopy) 405 LiveRegs = *LiveRegsCopy; 406 } else { 407 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 408 : getLiveRegsBefore(MI, LIS); 409 } 410 411 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 412 } 413 414 void GCNRPTracker::reset(const MachineRegisterInfo &MRI_, 415 const LiveRegSet &LiveRegs_) { 416 MRI = &MRI_; 417 LiveRegs = LiveRegs_; 418 LastTrackedMI = nullptr; 419 MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_); 420 } 421 422 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp 423 LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit, 424 SlotIndex Pos) const { 425 return getLanesWithProperty( 426 LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(), 427 [](const LiveRange &LR, SlotIndex Pos) { 428 const LiveRange::Segment *S = LR.getSegmentContaining(Pos); 429 return S != nullptr && S->end == Pos.getRegSlot(); 430 }); 431 } 432 433 //////////////////////////////////////////////////////////////////////////////// 434 // GCNUpwardRPTracker 435 436 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 437 assert(MRI && "call reset first"); 438 439 LastTrackedMI = &MI; 440 441 if (MI.isDebugInstr()) 442 return; 443 444 // Kill all defs. 445 GCNRegPressure DefPressure, ECDefPressure; 446 bool HasECDefs = false; 447 for (const MachineOperand &MO : MI.all_defs()) { 448 if (!MO.getReg().isVirtual()) 449 continue; 450 451 Register Reg = MO.getReg(); 452 LaneBitmask DefMask = getDefRegMask(MO, *MRI); 453 454 // Treat a def as fully live at the moment of definition: keep a record. 455 if (MO.isEarlyClobber()) { 456 ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 457 HasECDefs = true; 458 } else 459 DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 460 461 auto I = LiveRegs.find(Reg); 462 if (I == LiveRegs.end()) 463 continue; 464 465 LaneBitmask &LiveMask = I->second; 466 LaneBitmask PrevMask = LiveMask; 467 LiveMask &= ~DefMask; 468 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 469 if (LiveMask.none()) 470 LiveRegs.erase(I); 471 } 472 473 // Update MaxPressure with defs pressure. 474 DefPressure += CurPressure; 475 if (HasECDefs) 476 DefPressure += ECDefPressure; 477 MaxPressure = max(DefPressure, MaxPressure); 478 479 // Make uses alive. 480 SmallVector<VRegMaskOrUnit, 8> RegUses; 481 collectVirtualRegUses(RegUses, MI, LIS, *MRI); 482 for (const VRegMaskOrUnit &U : RegUses) { 483 LaneBitmask &LiveMask = LiveRegs[U.RegUnit]; 484 LaneBitmask PrevMask = LiveMask; 485 LiveMask |= U.LaneMask; 486 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 487 } 488 489 // Update MaxPressure with uses plus early-clobber defs pressure. 490 MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure) 491 : max(CurPressure, MaxPressure); 492 493 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 494 } 495 496 //////////////////////////////////////////////////////////////////////////////// 497 // GCNDownwardRPTracker 498 499 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 500 const LiveRegSet *LiveRegsCopy) { 501 MRI = &MI.getParent()->getParent()->getRegInfo(); 502 LastTrackedMI = nullptr; 503 MBBEnd = MI.getParent()->end(); 504 NextMI = &MI; 505 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 506 if (NextMI == MBBEnd) 507 return false; 508 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 509 return true; 510 } 511 512 bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI, 513 bool UseInternalIterator) { 514 assert(MRI && "call reset first"); 515 SlotIndex SI; 516 const MachineInstr *CurrMI; 517 if (UseInternalIterator) { 518 if (!LastTrackedMI) 519 return NextMI == MBBEnd; 520 521 assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); 522 CurrMI = LastTrackedMI; 523 524 SI = NextMI == MBBEnd 525 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot() 526 : LIS.getInstructionIndex(*NextMI).getBaseIndex(); 527 } else { //! UseInternalIterator 528 SI = LIS.getInstructionIndex(*MI).getBaseIndex(); 529 CurrMI = MI; 530 } 531 532 assert(SI.isValid()); 533 534 // Remove dead registers or mask bits. 535 SmallSet<Register, 8> SeenRegs; 536 for (auto &MO : CurrMI->operands()) { 537 if (!MO.isReg() || !MO.getReg().isVirtual()) 538 continue; 539 if (MO.isUse() && !MO.readsReg()) 540 continue; 541 if (!UseInternalIterator && MO.isDef()) 542 continue; 543 if (!SeenRegs.insert(MO.getReg()).second) 544 continue; 545 const LiveInterval &LI = LIS.getInterval(MO.getReg()); 546 if (LI.hasSubRanges()) { 547 auto It = LiveRegs.end(); 548 for (const auto &S : LI.subranges()) { 549 if (!S.liveAt(SI)) { 550 if (It == LiveRegs.end()) { 551 It = LiveRegs.find(MO.getReg()); 552 if (It == LiveRegs.end()) 553 llvm_unreachable("register isn't live"); 554 } 555 auto PrevMask = It->second; 556 It->second &= ~S.LaneMask; 557 CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI); 558 } 559 } 560 if (It != LiveRegs.end() && It->second.none()) 561 LiveRegs.erase(It); 562 } else if (!LI.liveAt(SI)) { 563 auto It = LiveRegs.find(MO.getReg()); 564 if (It == LiveRegs.end()) 565 llvm_unreachable("register isn't live"); 566 CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI); 567 LiveRegs.erase(It); 568 } 569 } 570 571 MaxPressure = max(MaxPressure, CurPressure); 572 573 LastTrackedMI = nullptr; 574 575 return UseInternalIterator && (NextMI == MBBEnd); 576 } 577 578 void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI, 579 bool UseInternalIterator) { 580 if (UseInternalIterator) { 581 LastTrackedMI = &*NextMI++; 582 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 583 } else { 584 LastTrackedMI = MI; 585 } 586 587 const MachineInstr *CurrMI = LastTrackedMI; 588 589 // Add new registers or mask bits. 590 for (const auto &MO : CurrMI->all_defs()) { 591 Register Reg = MO.getReg(); 592 if (!Reg.isVirtual()) 593 continue; 594 auto &LiveMask = LiveRegs[Reg]; 595 auto PrevMask = LiveMask; 596 LiveMask |= getDefRegMask(MO, *MRI); 597 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 598 } 599 600 MaxPressure = max(MaxPressure, CurPressure); 601 } 602 603 bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) { 604 if (UseInternalIterator && NextMI == MBBEnd) 605 return false; 606 607 advanceBeforeNext(MI, UseInternalIterator); 608 advanceToNext(MI, UseInternalIterator); 609 if (!UseInternalIterator) { 610 // We must remove any dead def lanes from the current RP 611 advanceBeforeNext(MI, true); 612 } 613 return true; 614 } 615 616 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 617 while (NextMI != End) 618 if (!advance()) return false; 619 return true; 620 } 621 622 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 623 MachineBasicBlock::const_iterator End, 624 const LiveRegSet *LiveRegsCopy) { 625 reset(*Begin, LiveRegsCopy); 626 return advance(End); 627 } 628 629 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 630 const GCNRPTracker::LiveRegSet &TrackedLR, 631 const TargetRegisterInfo *TRI, StringRef Pfx) { 632 return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { 633 for (auto const &P : TrackedLR) { 634 auto I = LISLR.find(P.first); 635 if (I == LISLR.end()) { 636 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 637 << " isn't found in LIS reported set\n"; 638 } else if (I->second != P.second) { 639 OS << Pfx << printReg(P.first, TRI) 640 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second) 641 << ", tracked " << PrintLaneMask(P.second) << '\n'; 642 } 643 } 644 for (auto const &P : LISLR) { 645 auto I = TrackedLR.find(P.first); 646 if (I == TrackedLR.end()) { 647 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 648 << " isn't found in tracked set\n"; 649 } 650 } 651 }); 652 } 653 654 GCNRegPressure 655 GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI, 656 const SIRegisterInfo *TRI) const { 657 assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction."); 658 659 SlotIndex SlotIdx; 660 SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot(); 661 662 // Account for register pressure similar to RegPressureTracker::recede(). 663 RegisterOperands RegOpers; 664 RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false); 665 RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx); 666 GCNRegPressure TempPressure = CurPressure; 667 668 for (const VRegMaskOrUnit &Use : RegOpers.Uses) { 669 Register Reg = Use.RegUnit; 670 if (!Reg.isVirtual()) 671 continue; 672 LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx); 673 if (LastUseMask.none()) 674 continue; 675 // The LastUseMask is queried from the liveness information of instruction 676 // which may be further down the schedule. Some lanes may actually not be 677 // last uses for the current position. 678 // FIXME: allow the caller to pass in the list of vreg uses that remain 679 // to be bottom-scheduled to avoid searching uses at each query. 680 SlotIndex CurrIdx; 681 const MachineBasicBlock *MBB = MI->getParent(); 682 MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward( 683 LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end()); 684 if (IdxPos == MBB->end()) { 685 CurrIdx = LIS.getMBBEndIdx(MBB); 686 } else { 687 CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot(); 688 } 689 690 LastUseMask = 691 findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS); 692 if (LastUseMask.none()) 693 continue; 694 695 LaneBitmask LiveMask = 696 LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0); 697 LaneBitmask NewMask = LiveMask & ~LastUseMask; 698 TempPressure.inc(Reg, LiveMask, NewMask, *MRI); 699 } 700 701 // Generate liveness for defs. 702 for (const VRegMaskOrUnit &Def : RegOpers.Defs) { 703 Register Reg = Def.RegUnit; 704 if (!Reg.isVirtual()) 705 continue; 706 LaneBitmask LiveMask = 707 LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0); 708 LaneBitmask NewMask = LiveMask | Def.LaneMask; 709 TempPressure.inc(Reg, LiveMask, NewMask, *MRI); 710 } 711 712 return TempPressure; 713 } 714 715 bool GCNUpwardRPTracker::isValid() const { 716 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 717 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 718 const auto &TrackedLR = LiveRegs; 719 720 if (!isEqual(LISLR, TrackedLR)) { 721 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 722 " LIS reported livesets mismatch:\n" 723 << print(LISLR, *MRI); 724 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 725 return false; 726 } 727 728 auto LISPressure = getRegPressure(*MRI, LISLR); 729 if (LISPressure != CurPressure) { 730 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " 731 << print(CurPressure) << "LIS rpt: " << print(LISPressure); 732 return false; 733 } 734 return true; 735 } 736 737 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, 738 const MachineRegisterInfo &MRI) { 739 return Printable([&LiveRegs, &MRI](raw_ostream &OS) { 740 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 741 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 742 Register Reg = Register::index2VirtReg(I); 743 auto It = LiveRegs.find(Reg); 744 if (It != LiveRegs.end() && It->second.any()) 745 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 746 << PrintLaneMask(It->second); 747 } 748 OS << '\n'; 749 }); 750 } 751 752 void GCNRegPressure::dump() const { dbgs() << print(*this); } 753 754 static cl::opt<bool> UseDownwardTracker( 755 "amdgpu-print-rp-downward", 756 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"), 757 cl::init(false), cl::Hidden); 758 759 char llvm::GCNRegPressurePrinter::ID = 0; 760 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; 761 762 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true) 763 764 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and 765 // are fully covered by Mask. 766 static LaneBitmask 767 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, 768 Register Reg, SlotIndex Begin, SlotIndex End, 769 LaneBitmask Mask = LaneBitmask::getAll()) { 770 771 auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { 772 auto *Segment = LR.getSegmentContaining(Begin); 773 return Segment && Segment->contains(End); 774 }; 775 776 LaneBitmask LiveThroughMask; 777 const LiveInterval &LI = LIS.getInterval(Reg); 778 if (LI.hasSubRanges()) { 779 for (auto &SR : LI.subranges()) { 780 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) 781 LiveThroughMask |= SR.LaneMask; 782 } 783 } else { 784 LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); 785 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) 786 LiveThroughMask = RegMask; 787 } 788 789 return LiveThroughMask; 790 } 791 792 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { 793 const MachineRegisterInfo &MRI = MF.getRegInfo(); 794 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 795 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); 796 797 auto &OS = dbgs(); 798 799 // Leading spaces are important for YAML syntax. 800 #define PFX " " 801 802 OS << "---\nname: " << MF.getName() << "\nbody: |\n"; 803 804 auto printRP = [](const GCNRegPressure &RP) { 805 return Printable([&RP](raw_ostream &OS) { 806 OS << format(PFX " %-5d", RP.getSGPRNum()) 807 << format(" %-5d", RP.getVGPRNum(false)); 808 }); 809 }; 810 811 auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, 812 const GCNRPTracker::LiveRegSet &LISLR) { 813 if (LISLR != TrackedLR) { 814 OS << PFX " mis LIS: " << llvm::print(LISLR, MRI) 815 << reportMismatch(LISLR, TrackedLR, TRI, PFX " "); 816 } 817 }; 818 819 // Register pressure before and at an instruction (in program order). 820 SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; 821 822 for (auto &MBB : MF) { 823 RP.clear(); 824 RP.reserve(MBB.size()); 825 826 OS << PFX; 827 MBB.printName(OS); 828 OS << ":\n"; 829 830 SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB); 831 SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB); 832 833 GCNRPTracker::LiveRegSet LiveIn, LiveOut; 834 GCNRegPressure RPAtMBBEnd; 835 836 if (UseDownwardTracker) { 837 if (MBB.empty()) { 838 LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI); 839 RPAtMBBEnd = getRegPressure(MRI, LiveIn); 840 } else { 841 GCNDownwardRPTracker RPT(LIS); 842 RPT.reset(MBB.front()); 843 844 LiveIn = RPT.getLiveRegs(); 845 846 while (!RPT.advanceBeforeNext()) { 847 GCNRegPressure RPBeforeMI = RPT.getPressure(); 848 RPT.advanceToNext(); 849 RP.emplace_back(RPBeforeMI, RPT.getPressure()); 850 } 851 852 LiveOut = RPT.getLiveRegs(); 853 RPAtMBBEnd = RPT.getPressure(); 854 } 855 } else { 856 GCNUpwardRPTracker RPT(LIS); 857 RPT.reset(MRI, MBBEndSlot); 858 859 LiveOut = RPT.getLiveRegs(); 860 RPAtMBBEnd = RPT.getPressure(); 861 862 for (auto &MI : reverse(MBB)) { 863 RPT.resetMaxPressure(); 864 RPT.recede(MI); 865 if (!MI.isDebugInstr()) 866 RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure()); 867 } 868 869 LiveIn = RPT.getLiveRegs(); 870 } 871 872 OS << PFX " Live-in: " << llvm::print(LiveIn, MRI); 873 if (!UseDownwardTracker) 874 ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI)); 875 876 OS << PFX " SGPR VGPR\n"; 877 int I = 0; 878 for (auto &MI : MBB) { 879 if (!MI.isDebugInstr()) { 880 auto &[RPBeforeInstr, RPAtInstr] = 881 RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; 882 ++I; 883 OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " "; 884 } else 885 OS << PFX " "; 886 MI.print(OS); 887 } 888 OS << printRP(RPAtMBBEnd) << '\n'; 889 890 OS << PFX " Live-out:" << llvm::print(LiveOut, MRI); 891 if (UseDownwardTracker) 892 ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI)); 893 894 GCNRPTracker::LiveRegSet LiveThrough; 895 for (auto [Reg, Mask] : LiveIn) { 896 LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg); 897 if (MaskIntersection.any()) { 898 LaneBitmask LTMask = getRegLiveThroughMask( 899 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection); 900 if (LTMask.any()) 901 LiveThrough[Reg] = LTMask; 902 } 903 } 904 OS << PFX " Live-thr:" << llvm::print(LiveThrough, MRI); 905 OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n'; 906 } 907 OS << "...\n"; 908 return false; 909 910 #undef PFX 911 } 912