xref: /llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.cpp (revision 9e6494c0fb29dfb5d4d2b7bf3ed7af261efee034)
1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file implements the GCNRegPressure class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "GCNRegPressure.h"
15 #include "AMDGPU.h"
16 #include "llvm/CodeGen/RegisterPressure.h"
17 
18 using namespace llvm;
19 
20 #define DEBUG_TYPE "machine-scheduler"
21 
22 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
23                    const GCNRPTracker::LiveRegSet &S2) {
24   if (S1.size() != S2.size())
25     return false;
26 
27   for (const auto &P : S1) {
28     auto I = S2.find(P.first);
29     if (I == S2.end() || I->second != P.second)
30       return false;
31   }
32   return true;
33 }
34 
35 ///////////////////////////////////////////////////////////////////////////////
36 // GCNRegPressure
37 
38 unsigned GCNRegPressure::getRegKind(Register Reg,
39                                     const MachineRegisterInfo &MRI) {
40   assert(Reg.isVirtual());
41   const auto *const RC = MRI.getRegClass(Reg);
42   const auto *STI =
43       static_cast<const SIRegisterInfo *>(MRI.getTargetRegisterInfo());
44   return STI->isSGPRClass(RC)
45              ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
46          : STI->isAGPRClass(RC)
47              ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
48              : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
49 }
50 
51 void GCNRegPressure::inc(unsigned Reg,
52                          LaneBitmask PrevMask,
53                          LaneBitmask NewMask,
54                          const MachineRegisterInfo &MRI) {
55   if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
56       SIRegisterInfo::getNumCoveredRegs(PrevMask))
57     return;
58 
59   int Sign = 1;
60   if (NewMask < PrevMask) {
61     std::swap(NewMask, PrevMask);
62     Sign = -1;
63   }
64 
65   switch (auto Kind = getRegKind(Reg, MRI)) {
66   case SGPR32:
67   case VGPR32:
68   case AGPR32:
69     Value[Kind] += Sign;
70     break;
71 
72   case SGPR_TUPLE:
73   case VGPR_TUPLE:
74   case AGPR_TUPLE:
75     assert(PrevMask < NewMask);
76 
77     Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
78       Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
79 
80     if (PrevMask.none()) {
81       assert(NewMask.any());
82       const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
83       Value[Kind] +=
84           Sign * TRI->getRegClassWeight(MRI.getRegClass(Reg)).RegWeight;
85     }
86     break;
87 
88   default: llvm_unreachable("Unknown register kind");
89   }
90 }
91 
92 bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
93                           unsigned MaxOccupancy) const {
94   const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
95 
96   const auto SGPROcc = std::min(MaxOccupancy,
97                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
98   const auto VGPROcc =
99     std::min(MaxOccupancy,
100              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
101   const auto OtherSGPROcc = std::min(MaxOccupancy,
102                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
103   const auto OtherVGPROcc =
104     std::min(MaxOccupancy,
105              ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
106 
107   const auto Occ = std::min(SGPROcc, VGPROcc);
108   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
109 
110   // Give first precedence to the better occupancy.
111   if (Occ != OtherOcc)
112     return Occ > OtherOcc;
113 
114   unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
115   unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
116 
117   // SGPR excess pressure conditions
118   unsigned ExcessSGPR = std::max(static_cast<int>(getSGPRNum() - MaxSGPRs), 0);
119   unsigned OtherExcessSGPR =
120       std::max(static_cast<int>(O.getSGPRNum() - MaxSGPRs), 0);
121 
122   auto WaveSize = ST.getWavefrontSize();
123   // The number of virtual VGPRs required to handle excess SGPR
124   unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
125   unsigned OtherVGPRForSGPRSpills =
126       (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
127 
128   unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
129 
130   // Unified excess pressure conditions, accounting for VGPRs used for SGPR
131   // spills
132   unsigned ExcessVGPR =
133       std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) +
134                                 VGPRForSGPRSpills - MaxVGPRs),
135                0);
136   unsigned OtherExcessVGPR =
137       std::max(static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) +
138                                 OtherVGPRForSGPRSpills - MaxVGPRs),
139                0);
140   // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
141   // spills
142   unsigned ExcessArchVGPR = std::max(
143       static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills - MaxArchVGPRs),
144       0);
145   unsigned OtherExcessArchVGPR =
146       std::max(static_cast<int>(O.getVGPRNum(false) + OtherVGPRForSGPRSpills -
147                                 MaxArchVGPRs),
148                0);
149   // AGPR excess pressure conditions
150   unsigned ExcessAGPR = std::max(
151       static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs)
152                                            : (getAGPRNum() - MaxVGPRs)),
153       0);
154   unsigned OtherExcessAGPR = std::max(
155       static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
156                                            : (O.getAGPRNum() - MaxVGPRs)),
157       0);
158 
159   bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
160   bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
161                        OtherExcessArchVGPR || OtherExcessAGPR;
162 
163   // Give second precedence to the reduced number of spills to hold the register
164   // pressure.
165   if (ExcessRP || OtherExcessRP) {
166     // The difference in excess VGPR pressure, after including VGPRs used for
167     // SGPR spills
168     int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
169                     (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
170 
171     int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
172 
173     if (VGPRDiff != 0)
174       return VGPRDiff > 0;
175     if (SGPRDiff != 0) {
176       unsigned PureExcessVGPR =
177           std::max(static_cast<int>(getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
178                    0) +
179           std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs), 0);
180       unsigned OtherPureExcessVGPR =
181           std::max(
182               static_cast<int>(O.getVGPRNum(ST.hasGFX90AInsts()) - MaxVGPRs),
183               0) +
184           std::max(static_cast<int>(O.getVGPRNum(false) - MaxArchVGPRs), 0);
185 
186       // If we have a special case where there is a tie in excess VGPR, but one
187       // of the pressures has VGPR usage from SGPR spills, prefer the pressure
188       // with SGPR spills.
189       if (PureExcessVGPR != OtherPureExcessVGPR)
190         return SGPRDiff < 0;
191       // If both pressures have the same excess pressure before and after
192       // accounting for SGPR spills, prefer fewer SGPR spills.
193       return SGPRDiff > 0;
194     }
195   }
196 
197   bool SGPRImportant = SGPROcc < VGPROcc;
198   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
199 
200   // If both pressures disagree on what is more important compare vgprs.
201   if (SGPRImportant != OtherSGPRImportant) {
202     SGPRImportant = false;
203   }
204 
205   // Give third precedence to lower register tuple pressure.
206   bool SGPRFirst = SGPRImportant;
207   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
208     if (SGPRFirst) {
209       auto SW = getSGPRTuplesWeight();
210       auto OtherSW = O.getSGPRTuplesWeight();
211       if (SW != OtherSW)
212         return SW < OtherSW;
213     } else {
214       auto VW = getVGPRTuplesWeight();
215       auto OtherVW = O.getVGPRTuplesWeight();
216       if (VW != OtherVW)
217         return VW < OtherVW;
218     }
219   }
220 
221   // Give final precedence to lower general RP.
222   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
223                          (getVGPRNum(ST.hasGFX90AInsts()) <
224                           O.getVGPRNum(ST.hasGFX90AInsts()));
225 }
226 
227 Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
228   return Printable([&RP, ST](raw_ostream &OS) {
229     OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
230        << "AGPRs: " << RP.getAGPRNum();
231     if (ST)
232       OS << "(O"
233          << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
234          << ')';
235     OS << ", SGPRs: " << RP.getSGPRNum();
236     if (ST)
237       OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
238     OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
239        << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
240     if (ST)
241       OS << " -> Occ: " << RP.getOccupancy(*ST);
242     OS << '\n';
243   });
244 }
245 
246 static LaneBitmask getDefRegMask(const MachineOperand &MO,
247                                  const MachineRegisterInfo &MRI) {
248   assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
249 
250   // We don't rely on read-undef flag because in case of tentative schedule
251   // tracking it isn't set correctly yet. This works correctly however since
252   // use mask has been tracked before using LIS.
253   return MO.getSubReg() == 0 ?
254     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
255     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
256 }
257 
258 static void
259 collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
260                       const MachineInstr &MI, const LiveIntervals &LIS,
261                       const MachineRegisterInfo &MRI) {
262 
263   auto &TRI = *MRI.getTargetRegisterInfo();
264   for (const auto &MO : MI.operands()) {
265     if (!MO.isReg() || !MO.getReg().isVirtual())
266       continue;
267     if (!MO.isUse() || !MO.readsReg())
268       continue;
269 
270     Register Reg = MO.getReg();
271     auto I = llvm::find_if(VRegMaskOrUnits, [Reg](const VRegMaskOrUnit &RM) {
272       return RM.RegUnit == Reg;
273     });
274 
275     auto &P = I == VRegMaskOrUnits.end()
276                   ? VRegMaskOrUnits.emplace_back(Reg, LaneBitmask::getNone())
277                   : *I;
278 
279     P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(MO.getSubReg())
280                                  : MRI.getMaxLaneMaskForVReg(Reg);
281   }
282 
283   SlotIndex InstrSI;
284   for (auto &P : VRegMaskOrUnits) {
285     auto &LI = LIS.getInterval(P.RegUnit);
286     if (!LI.hasSubRanges())
287       continue;
288 
289     // For a tentative schedule LIS isn't updated yet but livemask should
290     // remain the same on any schedule. Subreg defs can be reordered but they
291     // all must dominate uses anyway.
292     if (!InstrSI)
293       InstrSI = LIS.getInstructionIndex(MI).getBaseIndex();
294 
295     P.LaneMask = getLiveLaneMask(LI, InstrSI, MRI, P.LaneMask);
296   }
297 }
298 
299 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
300 static LaneBitmask getLanesWithProperty(
301     const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
302     bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
303     LaneBitmask SafeDefault,
304     function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
305   if (RegUnit.isVirtual()) {
306     const LiveInterval &LI = LIS.getInterval(RegUnit);
307     LaneBitmask Result;
308     if (TrackLaneMasks && LI.hasSubRanges()) {
309       for (const LiveInterval::SubRange &SR : LI.subranges()) {
310         if (Property(SR, Pos))
311           Result |= SR.LaneMask;
312       }
313     } else if (Property(LI, Pos)) {
314       Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
315                               : LaneBitmask::getAll();
316     }
317 
318     return Result;
319   }
320 
321   const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
322   if (LR == nullptr)
323     return SafeDefault;
324   return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
325 }
326 
327 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
328 /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
329 /// The query starts with a lane bitmask which gets lanes/bits removed for every
330 /// use we find.
331 static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
332                                   SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
333                                   const MachineRegisterInfo &MRI,
334                                   const SIRegisterInfo *TRI,
335                                   const LiveIntervals *LIS,
336                                   bool Upward = false) {
337   for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
338     if (MO.isUndef())
339       continue;
340     const MachineInstr *MI = MO.getParent();
341     SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
342     bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
343                           : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
344     if (!InRange)
345       continue;
346 
347     unsigned SubRegIdx = MO.getSubReg();
348     LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
349     LastUseMask &= ~UseMask;
350     if (LastUseMask.none())
351       return LaneBitmask::getNone();
352   }
353   return LastUseMask;
354 }
355 
356 ///////////////////////////////////////////////////////////////////////////////
357 // GCNRPTracker
358 
359 LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI,
360                                   const LiveIntervals &LIS,
361                                   const MachineRegisterInfo &MRI,
362                                   LaneBitmask LaneMaskFilter) {
363   return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI, LaneMaskFilter);
364 }
365 
366 LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
367                                   const MachineRegisterInfo &MRI,
368                                   LaneBitmask LaneMaskFilter) {
369   LaneBitmask LiveMask;
370   if (LI.hasSubRanges()) {
371     for (const auto &S : LI.subranges())
372       if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(SI)) {
373         LiveMask |= S.LaneMask;
374         assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg())));
375       }
376   } else if (LI.liveAt(SI)) {
377     LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg());
378   }
379   LiveMask &= LaneMaskFilter;
380   return LiveMask;
381 }
382 
383 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
384                                            const LiveIntervals &LIS,
385                                            const MachineRegisterInfo &MRI) {
386   GCNRPTracker::LiveRegSet LiveRegs;
387   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
388     auto Reg = Register::index2VirtReg(I);
389     if (!LIS.hasInterval(Reg))
390       continue;
391     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
392     if (LiveMask.any())
393       LiveRegs[Reg] = LiveMask;
394   }
395   return LiveRegs;
396 }
397 
398 void GCNRPTracker::reset(const MachineInstr &MI,
399                          const LiveRegSet *LiveRegsCopy,
400                          bool After) {
401   const MachineFunction &MF = *MI.getMF();
402   MRI = &MF.getRegInfo();
403   if (LiveRegsCopy) {
404     if (&LiveRegs != LiveRegsCopy)
405       LiveRegs = *LiveRegsCopy;
406   } else {
407     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
408                      : getLiveRegsBefore(MI, LIS);
409   }
410 
411   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
412 }
413 
414 void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
415                          const LiveRegSet &LiveRegs_) {
416   MRI = &MRI_;
417   LiveRegs = LiveRegs_;
418   LastTrackedMI = nullptr;
419   MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
420 }
421 
422 /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
423 LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
424                                            SlotIndex Pos) const {
425   return getLanesWithProperty(
426       LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
427       [](const LiveRange &LR, SlotIndex Pos) {
428         const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
429         return S != nullptr && S->end == Pos.getRegSlot();
430       });
431 }
432 
433 ////////////////////////////////////////////////////////////////////////////////
434 // GCNUpwardRPTracker
435 
436 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
437   assert(MRI && "call reset first");
438 
439   LastTrackedMI = &MI;
440 
441   if (MI.isDebugInstr())
442     return;
443 
444   // Kill all defs.
445   GCNRegPressure DefPressure, ECDefPressure;
446   bool HasECDefs = false;
447   for (const MachineOperand &MO : MI.all_defs()) {
448     if (!MO.getReg().isVirtual())
449       continue;
450 
451     Register Reg = MO.getReg();
452     LaneBitmask DefMask = getDefRegMask(MO, *MRI);
453 
454     // Treat a def as fully live at the moment of definition: keep a record.
455     if (MO.isEarlyClobber()) {
456       ECDefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
457       HasECDefs = true;
458     } else
459       DefPressure.inc(Reg, LaneBitmask::getNone(), DefMask, *MRI);
460 
461     auto I = LiveRegs.find(Reg);
462     if (I == LiveRegs.end())
463       continue;
464 
465     LaneBitmask &LiveMask = I->second;
466     LaneBitmask PrevMask = LiveMask;
467     LiveMask &= ~DefMask;
468     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
469     if (LiveMask.none())
470       LiveRegs.erase(I);
471   }
472 
473   // Update MaxPressure with defs pressure.
474   DefPressure += CurPressure;
475   if (HasECDefs)
476     DefPressure += ECDefPressure;
477   MaxPressure = max(DefPressure, MaxPressure);
478 
479   // Make uses alive.
480   SmallVector<VRegMaskOrUnit, 8> RegUses;
481   collectVirtualRegUses(RegUses, MI, LIS, *MRI);
482   for (const VRegMaskOrUnit &U : RegUses) {
483     LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
484     LaneBitmask PrevMask = LiveMask;
485     LiveMask |= U.LaneMask;
486     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
487   }
488 
489   // Update MaxPressure with uses plus early-clobber defs pressure.
490   MaxPressure = HasECDefs ? max(CurPressure + ECDefPressure, MaxPressure)
491                           : max(CurPressure, MaxPressure);
492 
493   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
494 }
495 
496 ////////////////////////////////////////////////////////////////////////////////
497 // GCNDownwardRPTracker
498 
499 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
500                                  const LiveRegSet *LiveRegsCopy) {
501   MRI = &MI.getParent()->getParent()->getRegInfo();
502   LastTrackedMI = nullptr;
503   MBBEnd = MI.getParent()->end();
504   NextMI = &MI;
505   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
506   if (NextMI == MBBEnd)
507     return false;
508   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
509   return true;
510 }
511 
512 bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
513                                              bool UseInternalIterator) {
514   assert(MRI && "call reset first");
515   SlotIndex SI;
516   const MachineInstr *CurrMI;
517   if (UseInternalIterator) {
518     if (!LastTrackedMI)
519       return NextMI == MBBEnd;
520 
521     assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
522     CurrMI = LastTrackedMI;
523 
524     SI = NextMI == MBBEnd
525              ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
526              : LIS.getInstructionIndex(*NextMI).getBaseIndex();
527   } else { //! UseInternalIterator
528     SI = LIS.getInstructionIndex(*MI).getBaseIndex();
529     CurrMI = MI;
530   }
531 
532   assert(SI.isValid());
533 
534   // Remove dead registers or mask bits.
535   SmallSet<Register, 8> SeenRegs;
536   for (auto &MO : CurrMI->operands()) {
537     if (!MO.isReg() || !MO.getReg().isVirtual())
538       continue;
539     if (MO.isUse() && !MO.readsReg())
540       continue;
541     if (!UseInternalIterator && MO.isDef())
542       continue;
543     if (!SeenRegs.insert(MO.getReg()).second)
544       continue;
545     const LiveInterval &LI = LIS.getInterval(MO.getReg());
546     if (LI.hasSubRanges()) {
547       auto It = LiveRegs.end();
548       for (const auto &S : LI.subranges()) {
549         if (!S.liveAt(SI)) {
550           if (It == LiveRegs.end()) {
551             It = LiveRegs.find(MO.getReg());
552             if (It == LiveRegs.end())
553               llvm_unreachable("register isn't live");
554           }
555           auto PrevMask = It->second;
556           It->second &= ~S.LaneMask;
557           CurPressure.inc(MO.getReg(), PrevMask, It->second, *MRI);
558         }
559       }
560       if (It != LiveRegs.end() && It->second.none())
561         LiveRegs.erase(It);
562     } else if (!LI.liveAt(SI)) {
563       auto It = LiveRegs.find(MO.getReg());
564       if (It == LiveRegs.end())
565         llvm_unreachable("register isn't live");
566       CurPressure.inc(MO.getReg(), It->second, LaneBitmask::getNone(), *MRI);
567       LiveRegs.erase(It);
568     }
569   }
570 
571   MaxPressure = max(MaxPressure, CurPressure);
572 
573   LastTrackedMI = nullptr;
574 
575   return UseInternalIterator && (NextMI == MBBEnd);
576 }
577 
578 void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
579                                          bool UseInternalIterator) {
580   if (UseInternalIterator) {
581     LastTrackedMI = &*NextMI++;
582     NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
583   } else {
584     LastTrackedMI = MI;
585   }
586 
587   const MachineInstr *CurrMI = LastTrackedMI;
588 
589   // Add new registers or mask bits.
590   for (const auto &MO : CurrMI->all_defs()) {
591     Register Reg = MO.getReg();
592     if (!Reg.isVirtual())
593       continue;
594     auto &LiveMask = LiveRegs[Reg];
595     auto PrevMask = LiveMask;
596     LiveMask |= getDefRegMask(MO, *MRI);
597     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
598   }
599 
600   MaxPressure = max(MaxPressure, CurPressure);
601 }
602 
603 bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
604   if (UseInternalIterator && NextMI == MBBEnd)
605     return false;
606 
607   advanceBeforeNext(MI, UseInternalIterator);
608   advanceToNext(MI, UseInternalIterator);
609   if (!UseInternalIterator) {
610     // We must remove any dead def lanes from the current RP
611     advanceBeforeNext(MI, true);
612   }
613   return true;
614 }
615 
616 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
617   while (NextMI != End)
618     if (!advance()) return false;
619   return true;
620 }
621 
622 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
623                                    MachineBasicBlock::const_iterator End,
624                                    const LiveRegSet *LiveRegsCopy) {
625   reset(*Begin, LiveRegsCopy);
626   return advance(End);
627 }
628 
629 Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
630                                const GCNRPTracker::LiveRegSet &TrackedLR,
631                                const TargetRegisterInfo *TRI, StringRef Pfx) {
632   return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) {
633     for (auto const &P : TrackedLR) {
634       auto I = LISLR.find(P.first);
635       if (I == LISLR.end()) {
636         OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
637            << " isn't found in LIS reported set\n";
638       } else if (I->second != P.second) {
639         OS << Pfx << printReg(P.first, TRI)
640            << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
641            << ", tracked " << PrintLaneMask(P.second) << '\n';
642       }
643     }
644     for (auto const &P : LISLR) {
645       auto I = TrackedLR.find(P.first);
646       if (I == TrackedLR.end()) {
647         OS << Pfx << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
648            << " isn't found in tracked set\n";
649       }
650     }
651   });
652 }
653 
654 GCNRegPressure
655 GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
656                                            const SIRegisterInfo *TRI) const {
657   assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
658 
659   SlotIndex SlotIdx;
660   SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();
661 
662   // Account for register pressure similar to RegPressureTracker::recede().
663   RegisterOperands RegOpers;
664   RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
665   RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
666   GCNRegPressure TempPressure = CurPressure;
667 
668   for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
669     Register Reg = Use.RegUnit;
670     if (!Reg.isVirtual())
671       continue;
672     LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
673     if (LastUseMask.none())
674       continue;
675     // The LastUseMask is queried from the liveness information of instruction
676     // which may be further down the schedule. Some lanes may actually not be
677     // last uses for the current position.
678     // FIXME: allow the caller to pass in the list of vreg uses that remain
679     // to be bottom-scheduled to avoid searching uses at each query.
680     SlotIndex CurrIdx;
681     const MachineBasicBlock *MBB = MI->getParent();
682     MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
683         LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end());
684     if (IdxPos == MBB->end()) {
685       CurrIdx = LIS.getMBBEndIdx(MBB);
686     } else {
687       CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
688     }
689 
690     LastUseMask =
691         findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
692     if (LastUseMask.none())
693       continue;
694 
695     LaneBitmask LiveMask =
696         LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
697     LaneBitmask NewMask = LiveMask & ~LastUseMask;
698     TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
699   }
700 
701   // Generate liveness for defs.
702   for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
703     Register Reg = Def.RegUnit;
704     if (!Reg.isVirtual())
705       continue;
706     LaneBitmask LiveMask =
707         LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
708     LaneBitmask NewMask = LiveMask | Def.LaneMask;
709     TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
710   }
711 
712   return TempPressure;
713 }
714 
715 bool GCNUpwardRPTracker::isValid() const {
716   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
717   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
718   const auto &TrackedLR = LiveRegs;
719 
720   if (!isEqual(LISLR, TrackedLR)) {
721     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
722               " LIS reported livesets mismatch:\n"
723            << print(LISLR, *MRI);
724     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
725     return false;
726   }
727 
728   auto LISPressure = getRegPressure(*MRI, LISLR);
729   if (LISPressure != CurPressure) {
730     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
731            << print(CurPressure) << "LIS rpt: " << print(LISPressure);
732     return false;
733   }
734   return true;
735 }
736 
737 Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
738                       const MachineRegisterInfo &MRI) {
739   return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
740     const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
741     for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
742       Register Reg = Register::index2VirtReg(I);
743       auto It = LiveRegs.find(Reg);
744       if (It != LiveRegs.end() && It->second.any())
745         OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
746            << PrintLaneMask(It->second);
747     }
748     OS << '\n';
749   });
750 }
751 
752 void GCNRegPressure::dump() const { dbgs() << print(*this); }
753 
754 static cl::opt<bool> UseDownwardTracker(
755     "amdgpu-print-rp-downward",
756     cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
757     cl::init(false), cl::Hidden);
758 
759 char llvm::GCNRegPressurePrinter::ID = 0;
760 char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID;
761 
762 INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)
763 
764 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and
765 // are fully covered by Mask.
766 static LaneBitmask
767 getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS,
768                       Register Reg, SlotIndex Begin, SlotIndex End,
769                       LaneBitmask Mask = LaneBitmask::getAll()) {
770 
771   auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool {
772     auto *Segment = LR.getSegmentContaining(Begin);
773     return Segment && Segment->contains(End);
774   };
775 
776   LaneBitmask LiveThroughMask;
777   const LiveInterval &LI = LIS.getInterval(Reg);
778   if (LI.hasSubRanges()) {
779     for (auto &SR : LI.subranges()) {
780       if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
781         LiveThroughMask |= SR.LaneMask;
782     }
783   } else {
784     LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg);
785     if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
786       LiveThroughMask = RegMask;
787   }
788 
789   return LiveThroughMask;
790 }
791 
792 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
793   const MachineRegisterInfo &MRI = MF.getRegInfo();
794   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
795   const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
796 
797   auto &OS = dbgs();
798 
799 // Leading spaces are important for YAML syntax.
800 #define PFX "  "
801 
802   OS << "---\nname: " << MF.getName() << "\nbody:             |\n";
803 
804   auto printRP = [](const GCNRegPressure &RP) {
805     return Printable([&RP](raw_ostream &OS) {
806       OS << format(PFX "  %-5d", RP.getSGPRNum())
807          << format(" %-5d", RP.getVGPRNum(false));
808     });
809   };
810 
811   auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
812                                     const GCNRPTracker::LiveRegSet &LISLR) {
813     if (LISLR != TrackedLR) {
814       OS << PFX "  mis LIS: " << llvm::print(LISLR, MRI)
815          << reportMismatch(LISLR, TrackedLR, TRI, PFX "    ");
816     }
817   };
818 
819   // Register pressure before and at an instruction (in program order).
820   SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP;
821 
822   for (auto &MBB : MF) {
823     RP.clear();
824     RP.reserve(MBB.size());
825 
826     OS << PFX;
827     MBB.printName(OS);
828     OS << ":\n";
829 
830     SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(&MBB);
831     SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(&MBB);
832 
833     GCNRPTracker::LiveRegSet LiveIn, LiveOut;
834     GCNRegPressure RPAtMBBEnd;
835 
836     if (UseDownwardTracker) {
837       if (MBB.empty()) {
838         LiveIn = LiveOut = getLiveRegs(MBBStartSlot, LIS, MRI);
839         RPAtMBBEnd = getRegPressure(MRI, LiveIn);
840       } else {
841         GCNDownwardRPTracker RPT(LIS);
842         RPT.reset(MBB.front());
843 
844         LiveIn = RPT.getLiveRegs();
845 
846         while (!RPT.advanceBeforeNext()) {
847           GCNRegPressure RPBeforeMI = RPT.getPressure();
848           RPT.advanceToNext();
849           RP.emplace_back(RPBeforeMI, RPT.getPressure());
850         }
851 
852         LiveOut = RPT.getLiveRegs();
853         RPAtMBBEnd = RPT.getPressure();
854       }
855     } else {
856       GCNUpwardRPTracker RPT(LIS);
857       RPT.reset(MRI, MBBEndSlot);
858 
859       LiveOut = RPT.getLiveRegs();
860       RPAtMBBEnd = RPT.getPressure();
861 
862       for (auto &MI : reverse(MBB)) {
863         RPT.resetMaxPressure();
864         RPT.recede(MI);
865         if (!MI.isDebugInstr())
866           RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure());
867       }
868 
869       LiveIn = RPT.getLiveRegs();
870     }
871 
872     OS << PFX "  Live-in: " << llvm::print(LiveIn, MRI);
873     if (!UseDownwardTracker)
874       ReportLISMismatchIfAny(LiveIn, getLiveRegs(MBBStartSlot, LIS, MRI));
875 
876     OS << PFX "  SGPR  VGPR\n";
877     int I = 0;
878     for (auto &MI : MBB) {
879       if (!MI.isDebugInstr()) {
880         auto &[RPBeforeInstr, RPAtInstr] =
881             RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
882         ++I;
883         OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << "  ";
884       } else
885         OS << PFX "               ";
886       MI.print(OS);
887     }
888     OS << printRP(RPAtMBBEnd) << '\n';
889 
890     OS << PFX "  Live-out:" << llvm::print(LiveOut, MRI);
891     if (UseDownwardTracker)
892       ReportLISMismatchIfAny(LiveOut, getLiveRegs(MBBEndSlot, LIS, MRI));
893 
894     GCNRPTracker::LiveRegSet LiveThrough;
895     for (auto [Reg, Mask] : LiveIn) {
896       LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Reg);
897       if (MaskIntersection.any()) {
898         LaneBitmask LTMask = getRegLiveThroughMask(
899             MRI, LIS, Reg, MBBStartSlot, MBBEndSlot, MaskIntersection);
900         if (LTMask.any())
901           LiveThrough[Reg] = LTMask;
902       }
903     }
904     OS << PFX "  Live-thr:" << llvm::print(LiveThrough, MRI);
905     OS << printRP(getRegPressure(MRI, LiveThrough)) << '\n';
906   }
907   OS << "...\n";
908   return false;
909 
910 #undef PFX
911 }
912