1 //===- GCNRegPressure.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file implements the GCNRegPressure class. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "GCNRegPressure.h" 15 #include "AMDGPU.h" 16 #include "llvm/CodeGen/RegisterPressure.h" 17 18 using namespace llvm; 19 20 #define DEBUG_TYPE "machine-scheduler" 21 22 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, 23 const GCNRPTracker::LiveRegSet &S2) { 24 if (S1.size() != S2.size()) 25 return false; 26 27 for (const auto &P : S1) { 28 auto I = S2.find(P.first); 29 if (I == S2.end() || I->second != P.second) 30 return false; 31 } 32 return true; 33 } 34 35 /////////////////////////////////////////////////////////////////////////////// 36 // GCNRegPressure 37 38 unsigned GCNRegPressure::getRegKind(Register Reg, 39 const MachineRegisterInfo &MRI) { 40 assert(Reg.isVirtual()); 41 const auto *const RC = MRI.getRegClass(Reg); 42 const auto *STI = 43 static_cast<const SIRegisterInfo *>(MRI.getTargetRegisterInfo()); 44 return STI->isSGPRClass(RC) 45 ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) 46 : STI->isAGPRClass(RC) 47 ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) 48 : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE); 49 } 50 51 void GCNRegPressure::inc(unsigned Reg, 52 LaneBitmask PrevMask, 53 LaneBitmask NewMask, 54 const MachineRegisterInfo &MRI) { 55 if (SIRegisterInfo::getNumCoveredRegs(NewMask) == 56 SIRegisterInfo::getNumCoveredRegs(PrevMask)) 57 return; 58 59 int Sign = 1; 60 if (NewMask < PrevMask) { 61 std::swap(NewMask, PrevMask); 62 Sign = -1; 63 } 64 65 switch (auto Kind = getRegKind(Reg, MRI)) { 66 case SGPR32: 67 case VGPR32: 68 case AGPR32: 69 Value[Kind] += Sign; 70 break; 71 72 case SGPR_TUPLE: 73 case VGPR_TUPLE: 74 case AGPR_TUPLE: 75 assert(PrevMask < NewMask); 76 77 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += 78 Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); 79 80 if (PrevMask.none()) { 81 assert(NewMask.any()); 82 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 83 Value[Kind] += 84 Sign * TRI->getRegClassWeight(MRI.getRegClass(Reg)).RegWeight; 85 } 86 break; 87 88 default: llvm_unreachable("Unknown register kind"); 89 } 90 } 91 92 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, 93 unsigned MaxOccupancy) const { 94 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); 95 96 const auto SGPROcc = std::min(MaxOccupancy, 97 ST.getOccupancyWithNumSGPRs(getSGPRNum())); 98 const auto VGPROcc = 99 std::min(MaxOccupancy, 100 ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts()))); 101 const auto OtherSGPROcc = std::min(MaxOccupancy, 102 ST.getOccupancyWithNumSGPRs(O.getSGPRNum())); 103 const auto OtherVGPROcc = 104 std::min(MaxOccupancy, 105 ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts()))); 106 107 const auto Occ = std::min(SGPROcc, VGPROcc); 108 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc); 109 110 // Give first precedence to the better occupancy. 111 if (Occ != OtherOcc) 112 return Occ > OtherOcc; 113 114 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); 115 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF); 116 117 // SGPR excess pressure conditions 118 unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0); 119 unsigned OtherExcessSGPR = 120 std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0); 121 122 auto WaveSize = ST.getWavefrontSize(); 123 // The number of virtual VGPRs required to handle excess SGPR 124 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize; 125 unsigned OtherVGPRForSGPRSpills = 126 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize; 127 128 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); 129 130 // Unified excess pressure conditions, accounting for VGPRs used for SGPR 131 // spills 132 unsigned ExcessVGPR = 133 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) + 134 VGPRForSGPRSpills - MaxVGPRs), 135 0); 136 unsigned OtherExcessVGPR = 137 std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) + 138 OtherVGPRForSGPRSpills - MaxVGPRs), 139 0); 140 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR 141 // spills 142 unsigned ExcessArchVGPR = std::max( 143 static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs), 144 0); 145 unsigned OtherExcessArchVGPR = 146 std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills - 147 MaxArchVGPRs), 148 0); 149 // AGPR excess pressure conditions 150 unsigned ExcessAGPR = std::max( 151 static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs) 152 : (getAGPRNum() - MaxVGPRs)), 153 0); 154 unsigned OtherExcessAGPR = std::max( 155 static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs) 156 : (O.getAGPRNum() - MaxVGPRs)), 157 0); 158 159 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR; 160 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR || 161 OtherExcessArchVGPR || OtherExcessAGPR; 162 163 // Give second precedence to the reduced number of spills to hold the register 164 // pressure. 165 if (ExcessRP || OtherExcessRP) { 166 // The difference in excess VGPR pressure, after including VGPRs used for 167 // SGPR spills 168 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) - 169 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR)); 170 171 int SGPRDiff = OtherExcessSGPR - ExcessSGPR; 172 173 if (VGPRDiff != 0) 174 return VGPRDiff > 0; 175 if (SGPRDiff != 0) { 176 unsigned PureExcessVGPR = 177 std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 178 0) + 179 std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0); 180 unsigned OtherPureExcessVGPR = 181 std::max( 182 static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs), 183 0) + 184 std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0); 185 186 // If we have a special case where there is a tie in excess VGPR, but one 187 // of the pressures has VGPR usage from SGPR spills, prefer the pressure 188 // with SGPR spills. 189 if (PureExcessVGPR != OtherPureExcessVGPR) 190 return SGPRDiff < 0; 191 // If both pressures have the same excess pressure before and after 192 // accounting for SGPR spills, prefer fewer SGPR spills. 193 return SGPRDiff > 0; 194 } 195 } 196 197 bool SGPRImportant = SGPROcc < VGPROcc; 198 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; 199 200 // If both pressures disagree on what is more important compare vgprs. 201 if (SGPRImportant != OtherSGPRImportant) { 202 SGPRImportant = false; 203 } 204 205 // Give third precedence to lower register tuple pressure. 206 bool SGPRFirst = SGPRImportant; 207 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { 208 if (SGPRFirst) { 209 auto SW = getSGPRTuplesWeight(); 210 auto OtherSW = O.getSGPRTuplesWeight(); 211 if (SW != OtherSW) 212 return SW < OtherSW; 213 } else { 214 auto VW = getVGPRTuplesWeight(); 215 auto OtherVW = O.getVGPRTuplesWeight(); 216 if (VW != OtherVW) 217 return VW < OtherVW; 218 } 219 } 220 221 // Give final precedence to lower general RP. 222 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): 223 (getVGPRNum(ST.hasGFX90AInsts()) < 224 O.getVGPRNum(ST.hasGFX90AInsts())); 225 } 226 227 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) { 228 return Printable([&RP, ST](raw_ostream &OS) { 229 OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' ' 230 << "AGPRs: " << RP.getAGPRNum(); 231 if (ST) 232 OS << "(O" 233 << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts())) 234 << ')'; 235 OS << ", SGPRs: " << RP.getSGPRNum(); 236 if (ST) 237 OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')'; 238 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() 239 << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); 240 if (ST) 241 OS << " -> Occ: " << RP.getOccupancy(*ST); 242 OS << '\n'; 243 }); 244 } 245 246 static LaneBitmask getDefRegMask(const MachineOperand &MO, 247 const MachineRegisterInfo &MRI) { 248 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); 249 250 // We don't rely on read-undef flag because in case of tentative schedule 251 // tracking it isn't set correctly yet. This works correctly however since 252 // use mask has been tracked before using LIS. 253 return MO.getSubReg() == 0 ? 254 MRI.getMaxLaneMaskForVReg(MO.getReg()) : 255 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg()); 256 } 257 258 static void 259 collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs, 260 const MachineInstr &MI, const LiveIntervals &LIS, 261 const MachineRegisterInfo &MRI) { 262 263 auto &TRI = *MRI.getTargetRegisterInfo(); 264 for (const auto &MO : MI.operands()) { 265 if (!MO.isReg() || !MO.getReg().isVirtual()) 266 continue; 267 if (!MO.isUse() || !MO.readsReg()) 268 continue; 269 270 Register Reg = MO.getReg(); 271 auto I = llvm::find_if(RegMaskPairs, [Reg](const RegisterMaskPair &RM) { 272 return RM.RegUnit == Reg; 273 }); 274 275 auto &P = I == RegMaskPairs.end() 276 ? RegMaskPairs.emplace_back(Reg, LaneBitmask::getNone()) 277 : *I; 278 279 P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg()) 280 : MRI.getMaxLaneMaskForVReg(Reg); 281 } 282 283 SlotIndex InstrSI; 284 for (auto &P : RegMaskPairs) { 285 auto &LI = LIS.getInterval(P.RegUnit); 286 if (!LI.hasSubRanges()) 287 continue; 288 289 // For a tentative schedule LIS isn't updated yet but livemask should 290 // remain the same on any schedule. Subreg defs can be reordered but they 291 // all must dominate uses anyway. 292 if (!InstrSI) 293 InstrSI = LIS.getInstructionIndex(MI).getBaseIndex(); 294 295 P.LaneMask = getLiveLaneMask(LI, InstrSI, MRI, P.LaneMask); 296 } 297 } 298 299 /////////////////////////////////////////////////////////////////////////////// 300 // GCNRPTracker 301 302 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, 303 const LiveIntervals &LIS, 304 const MachineRegisterInfo &MRI, 305 LaneBitmask LaneMaskFilter) { 306 return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI, LaneMaskFilter); 307 } 308 309 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, 310 const MachineRegisterInfo &MRI, 311 LaneBitmask LaneMaskFilter) { 312 LaneBitmask LiveMask; 313 if (LI.hasSubRanges()) { 314 for (const auto &S : LI.subranges()) 315 if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) { 316 LiveMask |= S.LaneMask; 317 assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); 318 } 319 } else if (LI.liveAt(SI)) { 320 LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg()); 321 } 322 LiveMask &= LaneMaskFilter; 323 return LiveMask; 324 } 325 326 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, 327 const LiveIntervals &LIS, 328 const MachineRegisterInfo &MRI) { 329 GCNRPTracker::LiveRegSet LiveRegs; 330 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 331 auto Reg = Register::index2VirtReg(I); 332 if (!LIS.hasInterval(Reg)) 333 continue; 334 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); 335 if (LiveMask.any()) 336 LiveRegs[Reg] = LiveMask; 337 } 338 return LiveRegs; 339 } 340 341 void GCNRPTracker::reset(const MachineInstr &MI, 342 const LiveRegSet *LiveRegsCopy, 343 bool After) { 344 const MachineFunction &MF = *MI.getMF(); 345 MRI = &MF.getRegInfo(); 346 if (LiveRegsCopy) { 347 if (&LiveRegs != LiveRegsCopy) 348 LiveRegs = *LiveRegsCopy; 349 } else { 350 LiveRegs = After ? getLiveRegsAfter(MI, LIS) 351 : getLiveRegsBefore(MI, LIS); 352 } 353 354 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs); 355 } 356 357 //////////////////////////////////////////////////////////////////////////////// 358 // GCNUpwardRPTracker 359 360 void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_, 361 const LiveRegSet &LiveRegs_) { 362 MRI = &MRI_; 363 LiveRegs = LiveRegs_; 364 LastTrackedMI = nullptr; 365 MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_); 366 } 367 368 void GCNUpwardRPTracker::recede(const MachineInstr &MI) { 369 assert(MRI && "call reset first"); 370 371 LastTrackedMI = &MI; 372 373 if (MI.isDebugInstr()) 374 return; 375 376 // Kill all defs. 377 GCNRegPressure DefPressure, ECDefPressure; 378 bool HasECDefs = false; 379 for (const MachineOperand &MO : MI.all_defs()) { 380 if (!MO.getReg().isVirtual()) 381 continue; 382 383 Register Reg = MO.getReg(); 384 LaneBitmask DefMask = getDefRegMask(MO, *MRI); 385 386 // Treat a def as fully live at the moment of definition: keep a record. 387 if (MO.isEarlyClobber()) { 388 ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 389 HasECDefs = true; 390 } else 391 DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI); 392 393 auto I = LiveRegs.find(Reg); 394 if (I == LiveRegs.end()) 395 continue; 396 397 LaneBitmask &LiveMask = I->second; 398 LaneBitmask PrevMask = LiveMask; 399 LiveMask &= ~DefMask; 400 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 401 if (LiveMask.none()) 402 LiveRegs.erase(I); 403 } 404 405 // Update MaxPressure with defs pressure. 406 DefPressure += CurPressure; 407 if (HasECDefs) 408 DefPressure += ECDefPressure; 409 MaxPressure = max(DefPressure, MaxPressure); 410 411 // Make uses alive. 412 SmallVector<RegisterMaskPair, 8> RegUses; 413 collectVirtualRegUses(RegUses, MI, LIS, *MRI); 414 for (const RegisterMaskPair &U : RegUses) { 415 LaneBitmask &LiveMask = LiveRegs[U.RegUnit]; 416 LaneBitmask PrevMask = LiveMask; 417 LiveMask |= U.LaneMask; 418 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI); 419 } 420 421 // Update MaxPressure with uses plus early-clobber defs pressure. 422 MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure) 423 : max(CurPressure, MaxPressure); 424 425 assert(CurPressure == getRegPressure(*MRI, LiveRegs)); 426 } 427 428 //////////////////////////////////////////////////////////////////////////////// 429 // GCNDownwardRPTracker 430 431 bool GCNDownwardRPTracker::reset(const MachineInstr &MI, 432 const LiveRegSet *LiveRegsCopy) { 433 MRI = &MI.getParent()->getParent()->getRegInfo(); 434 LastTrackedMI = nullptr; 435 MBBEnd = MI.getParent()->end(); 436 NextMI = &MI; 437 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 438 if (NextMI == MBBEnd) 439 return false; 440 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false); 441 return true; 442 } 443 444 bool GCNDownwardRPTracker::advanceBeforeNext() { 445 assert(MRI && "call reset first"); 446 if (!LastTrackedMI) 447 return NextMI == MBBEnd; 448 449 assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); 450 451 SlotIndex SI = NextMI == MBBEnd 452 ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot() 453 : LIS.getInstructionIndex(*NextMI).getBaseIndex(); 454 assert(SI.isValid()); 455 456 // Remove dead registers or mask bits. 457 SmallSet<Register, 8> SeenRegs; 458 for (auto &MO : LastTrackedMI->operands()) { 459 if (!MO.isReg() || !MO.getReg().isVirtual()) 460 continue; 461 if (MO.isUse() && !MO.readsReg()) 462 continue; 463 if (!SeenRegs.insert(MO.getReg()).second) 464 continue; 465 const LiveInterval &LI = LIS.getInterval(MO.getReg()); 466 if (LI.hasSubRanges()) { 467 auto It = LiveRegs.end(); 468 for (const auto &S : LI.subranges()) { 469 if (!S.liveAt(SI)) { 470 if (It == LiveRegs.end()) { 471 It = LiveRegs.find(MO.getReg()); 472 if (It == LiveRegs.end()) 473 llvm_unreachable("register isn't live"); 474 } 475 auto PrevMask = It->second; 476 It->second &= ~S.LaneMask; 477 CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI); 478 } 479 } 480 if (It != LiveRegs.end() && It->second.none()) 481 LiveRegs.erase(It); 482 } else if (!LI.liveAt(SI)) { 483 auto It = LiveRegs.find(MO.getReg()); 484 if (It == LiveRegs.end()) 485 llvm_unreachable("register isn't live"); 486 CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI); 487 LiveRegs.erase(It); 488 } 489 } 490 491 MaxPressure = max(MaxPressure, CurPressure); 492 493 LastTrackedMI = nullptr; 494 495 return NextMI == MBBEnd; 496 } 497 498 void GCNDownwardRPTracker::advanceToNext() { 499 LastTrackedMI = &*NextMI++; 500 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd); 501 502 // Add new registers or mask bits. 503 for (const auto &MO : LastTrackedMI->all_defs()) { 504 Register Reg = MO.getReg(); 505 if (!Reg.isVirtual()) 506 continue; 507 auto &LiveMask = LiveRegs[Reg]; 508 auto PrevMask = LiveMask; 509 LiveMask |= getDefRegMask(MO, *MRI); 510 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI); 511 } 512 513 MaxPressure = max(MaxPressure, CurPressure); 514 } 515 516 bool GCNDownwardRPTracker::advance() { 517 if (NextMI == MBBEnd) 518 return false; 519 advanceBeforeNext(); 520 advanceToNext(); 521 return true; 522 } 523 524 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { 525 while (NextMI != End) 526 if (!advance()) return false; 527 return true; 528 } 529 530 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, 531 MachineBasicBlock::const_iterator End, 532 const LiveRegSet *LiveRegsCopy) { 533 reset(*Begin, LiveRegsCopy); 534 return advance(End); 535 } 536 537 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, 538 const GCNRPTracker::LiveRegSet &TrackedLR, 539 const TargetRegisterInfo *TRI, StringRef Pfx) { 540 return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { 541 for (auto const &P : TrackedLR) { 542 auto I = LISLR.find(P.first); 543 if (I == LISLR.end()) { 544 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 545 << " isn't found in LIS reported set\n"; 546 } else if (I->second != P.second) { 547 OS << Pfx << printReg(P.first, TRI) 548 << " masks doesn't match: LIS reported " << PrintLaneMask(I->second) 549 << ", tracked " << PrintLaneMask(P.second) << '\n'; 550 } 551 } 552 for (auto const &P : LISLR) { 553 auto I = TrackedLR.find(P.first); 554 if (I == TrackedLR.end()) { 555 OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second) 556 << " isn't found in tracked set\n"; 557 } 558 } 559 }); 560 } 561 562 bool GCNUpwardRPTracker::isValid() const { 563 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex(); 564 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI); 565 const auto &TrackedLR = LiveRegs; 566 567 if (!isEqual(LISLR, TrackedLR)) { 568 dbgs() << "\nGCNUpwardRPTracker error: Tracked and" 569 " LIS reported livesets mismatch:\n" 570 << print(LISLR, *MRI); 571 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo()); 572 return false; 573 } 574 575 auto LISPressure = getRegPressure(*MRI, LISLR); 576 if (LISPressure != CurPressure) { 577 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " 578 << print(CurPressure) << "LIS rpt: " << print(LISPressure); 579 return false; 580 } 581 return true; 582 } 583 584 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, 585 const MachineRegisterInfo &MRI) { 586 return Printable([&LiveRegs, &MRI](raw_ostream &OS) { 587 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 588 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 589 Register Reg = Register::index2VirtReg(I); 590 auto It = LiveRegs.find(Reg); 591 if (It != LiveRegs.end() && It->second.any()) 592 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':' 593 << PrintLaneMask(It->second); 594 } 595 OS << '\n'; 596 }); 597 } 598 599 void GCNRegPressure::dump() const { dbgs() << print(*this); } 600 601 static cl::opt<bool> UseDownwardTracker( 602 "amdgpu-print-rp-downward", 603 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"), 604 cl::init(false), cl::Hidden); 605 606 char llvm::GCNRegPressurePrinter::ID = 0; 607 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; 608 609 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true) 610 611 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and 612 // are fully covered by Mask. 613 static LaneBitmask 614 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, 615 Register Reg, SlotIndex Begin, SlotIndex End, 616 LaneBitmask Mask = LaneBitmask::getAll()) { 617 618 auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { 619 auto *Segment = LR.getSegmentContaining(Begin); 620 return Segment && Segment->contains(End); 621 }; 622 623 LaneBitmask LiveThroughMask; 624 const LiveInterval &LI = LIS.getInterval(Reg); 625 if (LI.hasSubRanges()) { 626 for (auto &SR : LI.subranges()) { 627 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) 628 LiveThroughMask |= SR.LaneMask; 629 } 630 } else { 631 LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); 632 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) 633 LiveThroughMask = RegMask; 634 } 635 636 return LiveThroughMask; 637 } 638 639 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { 640 const MachineRegisterInfo &MRI = MF.getRegInfo(); 641 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); 642 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); 643 644 auto &OS = dbgs(); 645 646 // Leading spaces are important for YAML syntax. 647 #define PFX " " 648 649 OS << "---\nname: " << MF.getName() << "\nbody: |\n"; 650 651 auto printRP = [](const GCNRegPressure &RP) { 652 return Printable([&RP](raw_ostream &OS) { 653 OS << format(PFX " %-5d", RP.getSGPRNum()) 654 << format(" %-5d", RP.getVGPRNum(false)); 655 }); 656 }; 657 658 auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, 659 const GCNRPTracker::LiveRegSet &LISLR) { 660 if (LISLR != TrackedLR) { 661 OS << PFX " mis LIS: " << llvm::print(LISLR, MRI) 662 << reportMismatch(LISLR, TrackedLR, TRI, PFX " "); 663 } 664 }; 665 666 // Register pressure before and at an instruction (in program order). 667 SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; 668 669 for (auto &MBB : MF) { 670 RP.clear(); 671 RP.reserve(MBB.size()); 672 673 OS << PFX; 674 MBB.printName(OS); 675 OS << ":\n"; 676 677 SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB); 678 SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB); 679 680 GCNRPTracker::LiveRegSet LiveIn, LiveOut; 681 GCNRegPressure RPAtMBBEnd; 682 683 if (UseDownwardTracker) { 684 if (MBB.empty()) { 685 LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI); 686 RPAtMBBEnd = getRegPressure(MRI, LiveIn); 687 } else { 688 GCNDownwardRPTracker RPT(LIS); 689 RPT.reset(MBB.front()); 690 691 LiveIn = RPT.getLiveRegs(); 692 693 while (!RPT.advanceBeforeNext()) { 694 GCNRegPressure RPBeforeMI = RPT.getPressure(); 695 RPT.advanceToNext(); 696 RP.emplace_back(RPBeforeMI, RPT.getPressure()); 697 } 698 699 LiveOut = RPT.getLiveRegs(); 700 RPAtMBBEnd = RPT.getPressure(); 701 } 702 } else { 703 GCNUpwardRPTracker RPT(LIS); 704 RPT.reset(MRI, MBBEndSlot); 705 706 LiveOut = RPT.getLiveRegs(); 707 RPAtMBBEnd = RPT.getPressure(); 708 709 for (auto &MI : reverse(MBB)) { 710 RPT.resetMaxPressure(); 711 RPT.recede(MI); 712 if (!MI.isDebugInstr()) 713 RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure()); 714 } 715 716 LiveIn = RPT.getLiveRegs(); 717 } 718 719 OS << PFX " Live-in: " << llvm::print(LiveIn, MRI); 720 if (!UseDownwardTracker) 721 ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI)); 722 723 OS << PFX " SGPR VGPR\n"; 724 int I = 0; 725 for (auto &MI : MBB) { 726 if (!MI.isDebugInstr()) { 727 auto &[RPBeforeInstr, RPAtInstr] = 728 RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; 729 ++I; 730 OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " "; 731 } else 732 OS << PFX " "; 733 MI.print(OS); 734 } 735 OS << printRP(RPAtMBBEnd) << '\n'; 736 737 OS << PFX " Live-out:" << llvm::print(LiveOut, MRI); 738 if (UseDownwardTracker) 739 ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI)); 740 741 GCNRPTracker::LiveRegSet LiveThrough; 742 for (auto [Reg, Mask] : LiveIn) { 743 LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg); 744 if (MaskIntersection.any()) { 745 LaneBitmask LTMask = getRegLiveThroughMask( 746 MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection); 747 if (LTMask.any()) 748 LiveThrough[Reg] = LTMask; 749 } 750 } 751 OS << PFX " Live-thr:" << llvm::print(LiveThrough, MRI); 752 OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n'; 753 } 754 OS << "...\n"; 755 return false; 756 757 #undef PFX 758 }