xref: /llvm-project/llvm/lib/Target/AMDGPU/GCNRegPressure.h (revision 9e6494c0fb29dfb5d4d2b7bf3ed7af261efee034)
1 //===- GCNRegPressure.h -----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the GCNRegPressure class, which tracks registry pressure
11 /// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It
12 /// also implements a compare function, which compares different register
13 /// pressures, and declares one with max occupancy as winner.
14 ///
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
18 #define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
19 
20 #include "GCNSubtarget.h"
21 #include "llvm/CodeGen/LiveIntervals.h"
22 #include "llvm/CodeGen/RegisterPressure.h"
23 #include <algorithm>
24 
25 namespace llvm {
26 
27 class MachineRegisterInfo;
28 class raw_ostream;
29 class SlotIndex;
30 
31 struct GCNRegPressure {
32   enum RegKind {
33     SGPR32,
34     SGPR_TUPLE,
35     VGPR32,
36     VGPR_TUPLE,
37     AGPR32,
38     AGPR_TUPLE,
39     TOTAL_KINDS
40   };
41 
42   GCNRegPressure() {
43     clear();
44   }
45 
46   bool empty() const { return getSGPRNum() == 0 && getVGPRNum(false) == 0; }
47 
48   void clear() { std::fill(&Value[0], &Value[TOTAL_KINDS], 0); }
49 
50   /// \returns the SGPR32 pressure
51   unsigned getSGPRNum() const { return Value[SGPR32]; }
52   /// \returns the aggregated ArchVGPR32, AccVGPR32 pressure dependent upon \p
53   /// UnifiedVGPRFile
54   unsigned getVGPRNum(bool UnifiedVGPRFile) const {
55     if (UnifiedVGPRFile) {
56       return Value[AGPR32] ? alignTo(Value[VGPR32], 4) + Value[AGPR32]
57                            : Value[VGPR32] + Value[AGPR32];
58     }
59     return std::max(Value[VGPR32], Value[AGPR32]);
60   }
61   /// \returns the ArchVGPR32 pressure
62   unsigned getArchVGPRNum() const { return Value[VGPR32]; }
63   /// \returns the AccVGPR32 pressure
64   unsigned getAGPRNum() const { return Value[AGPR32]; }
65 
66   unsigned getVGPRTuplesWeight() const { return std::max(Value[VGPR_TUPLE],
67                                                          Value[AGPR_TUPLE]); }
68   unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
69 
70   unsigned getOccupancy(const GCNSubtarget &ST) const {
71     return std::min(ST.getOccupancyWithNumSGPRs(getSGPRNum()),
72              ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
73   }
74 
75   void inc(unsigned Reg,
76            LaneBitmask PrevMask,
77            LaneBitmask NewMask,
78            const MachineRegisterInfo &MRI);
79 
80   bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const {
81     return getOccupancy(ST) > O.getOccupancy(ST);
82   }
83 
84   /// Compares \p this GCNRegpressure to \p O, returning true if \p this is
85   /// less. Since GCNRegpressure contains different types of pressures, and due
86   /// to target-specific pecularities (e.g. we care about occupancy rather than
87   /// raw register usage), we determine if \p this GCNRegPressure is less than
88   /// \p O based on the following tiered comparisons (in order order of
89   /// precedence):
90   /// 1. Better occupancy
91   /// 2. Less spilling (first preference to VGPR spills, then to SGPR spills)
92   /// 3. Less tuple register pressure (first preference to VGPR tuples if we
93   /// determine that SGPR pressure is not important)
94   /// 4. Less raw register pressure (first preference to VGPR tuples if we
95   /// determine that SGPR pressure is not important)
96   bool less(const MachineFunction &MF, const GCNRegPressure &O,
97             unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
98 
99   bool operator==(const GCNRegPressure &O) const {
100     return std::equal(&Value[0], &Value[TOTAL_KINDS], O.Value);
101   }
102 
103   bool operator!=(const GCNRegPressure &O) const {
104     return !(*this == O);
105   }
106 
107   GCNRegPressure &operator+=(const GCNRegPressure &RHS) {
108     for (unsigned I = 0; I < TOTAL_KINDS; ++I)
109       Value[I] += RHS.Value[I];
110     return *this;
111   }
112 
113   GCNRegPressure &operator-=(const GCNRegPressure &RHS) {
114     for (unsigned I = 0; I < TOTAL_KINDS; ++I)
115       Value[I] -= RHS.Value[I];
116     return *this;
117   }
118 
119   void dump() const;
120 
121 private:
122   unsigned Value[TOTAL_KINDS];
123 
124   static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
125 
126   friend GCNRegPressure max(const GCNRegPressure &P1,
127                             const GCNRegPressure &P2);
128 
129   friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST);
130 };
131 
132 inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
133   GCNRegPressure Res;
134   for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
135     Res.Value[I] = std::max(P1.Value[I], P2.Value[I]);
136   return Res;
137 }
138 
139 inline GCNRegPressure operator+(const GCNRegPressure &P1,
140                                 const GCNRegPressure &P2) {
141   GCNRegPressure Sum = P1;
142   Sum += P2;
143   return Sum;
144 }
145 
146 inline GCNRegPressure operator-(const GCNRegPressure &P1,
147                                 const GCNRegPressure &P2) {
148   GCNRegPressure Diff = P1;
149   Diff -= P2;
150   return Diff;
151 }
152 
153 ///////////////////////////////////////////////////////////////////////////////
154 // GCNRPTracker
155 
156 class GCNRPTracker {
157 public:
158   using LiveRegSet = DenseMap<unsigned, LaneBitmask>;
159 
160 protected:
161   const LiveIntervals &LIS;
162   LiveRegSet LiveRegs;
163   GCNRegPressure CurPressure, MaxPressure;
164   const MachineInstr *LastTrackedMI = nullptr;
165   mutable const MachineRegisterInfo *MRI = nullptr;
166 
167   GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {}
168 
169   void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy,
170              bool After);
171 
172   /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
173   void bumpDeadDefs(ArrayRef<VRegMaskOrUnit> DeadDefs);
174 
175   LaneBitmask getLastUsedLanes(Register RegUnit, SlotIndex Pos) const;
176 
177 public:
178   // reset tracker and set live register set to the specified value.
179   void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
180   // live regs for the current state
181   const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
182   const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
183 
184   void clearMaxPressure() { MaxPressure.clear(); }
185 
186   GCNRegPressure getPressure() const { return CurPressure; }
187 
188   decltype(LiveRegs) moveLiveRegs() {
189     return std::move(LiveRegs);
190   }
191 };
192 
193 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
194                                      const MachineRegisterInfo &MRI);
195 
196 ////////////////////////////////////////////////////////////////////////////////
197 // GCNUpwardRPTracker
198 
199 class GCNUpwardRPTracker : public GCNRPTracker {
200 public:
201   GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
202 
203   using GCNRPTracker::reset;
204 
205   /// reset tracker at the specified slot index \p SI.
206   void reset(const MachineRegisterInfo &MRI, SlotIndex SI) {
207     GCNRPTracker::reset(MRI, llvm::getLiveRegs(SI, LIS, MRI));
208   }
209 
210   /// reset tracker to the end of the \p MBB.
211   void reset(const MachineBasicBlock &MBB) {
212     reset(MBB.getParent()->getRegInfo(),
213           LIS.getSlotIndexes()->getMBBEndIdx(&MBB));
214   }
215 
216   /// reset tracker to the point just after \p MI (in program order).
217   void reset(const MachineInstr &MI) {
218     reset(MI.getMF()->getRegInfo(), LIS.getInstructionIndex(MI).getDeadSlot());
219   }
220 
221   /// Move to the state of RP just before the \p MI . If \p UseInternalIterator
222   /// is set, also update the internal iterators. Setting \p UseInternalIterator
223   /// to false allows for an externally managed iterator / program order.
224   void recede(const MachineInstr &MI);
225 
226   /// \p returns whether the tracker's state after receding MI corresponds
227   /// to reported by LIS.
228   bool isValid() const;
229 
230   const GCNRegPressure &getMaxPressure() const { return MaxPressure; }
231 
232   void resetMaxPressure() { MaxPressure = CurPressure; }
233 
234   GCNRegPressure getMaxPressureAndReset() {
235     GCNRegPressure RP = MaxPressure;
236     resetMaxPressure();
237     return RP;
238   }
239 };
240 
241 ////////////////////////////////////////////////////////////////////////////////
242 // GCNDownwardRPTracker
243 
244 class GCNDownwardRPTracker : public GCNRPTracker {
245   // Last position of reset or advanceBeforeNext
246   MachineBasicBlock::const_iterator NextMI;
247 
248   MachineBasicBlock::const_iterator MBBEnd;
249 
250 public:
251   GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
252 
253   using GCNRPTracker::reset;
254 
255   MachineBasicBlock::const_iterator getNext() const { return NextMI; }
256 
257   /// \p return MaxPressure and clear it.
258   GCNRegPressure moveMaxPressure() {
259     auto Res = MaxPressure;
260     MaxPressure.clear();
261     return Res;
262   }
263 
264   /// Reset tracker to the point before the \p MI
265   /// filling \p LiveRegs upon this point using LIS.
266   /// \p returns false if block is empty except debug values.
267   bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
268 
269   /// Move to the state right before the next MI or after the end of MBB.
270   /// \p returns false if reached end of the block.
271   /// If \p UseInternalIterator is true, then internal iterators are used and
272   /// set to process in program order. If \p UseInternalIterator is false, then
273   /// it is assumed that the tracker is using an externally managed iterator,
274   /// and advance* calls will not update the state of the iterator. In such
275   /// cases, the tracker will move to the state right before the provided \p MI
276   /// and use LIS for RP calculations.
277   bool advanceBeforeNext(MachineInstr *MI = nullptr,
278                          bool UseInternalIterator = true);
279 
280   /// Move to the state at the MI, advanceBeforeNext has to be called first.
281   /// If \p UseInternalIterator is true, then internal iterators are used and
282   /// set to process in program order. If \p UseInternalIterator is false, then
283   /// it is assumed that the tracker is using an externally managed iterator,
284   /// and advance* calls will not update the state of the iterator. In such
285   /// cases, the tracker will move to the state at the provided \p MI .
286   void advanceToNext(MachineInstr *MI = nullptr,
287                      bool UseInternalIterator = true);
288 
289   /// Move to the state at the next MI. \p returns false if reached end of
290   /// block. If \p UseInternalIterator is true, then internal iterators are used
291   /// and set to process in program order. If \p UseInternalIterator is false,
292   /// then it is assumed that the tracker is using an externally managed
293   /// iterator, and advance* calls will not update the state of the iterator. In
294   /// such cases, the tracker will move to the state right before the provided
295   /// \p MI and use LIS for RP calculations.
296   bool advance(MachineInstr *MI = nullptr, bool UseInternalIterator = true);
297 
298   /// Advance instructions until before \p End.
299   bool advance(MachineBasicBlock::const_iterator End);
300 
301   /// Reset to \p Begin and advance to \p End.
302   bool advance(MachineBasicBlock::const_iterator Begin,
303                MachineBasicBlock::const_iterator End,
304                const LiveRegSet *LiveRegsCopy = nullptr);
305 
306   /// Mostly copy/paste from CodeGen/RegisterPressure.cpp
307   /// Calculate the impact \p MI will have on CurPressure and \return the
308   /// speculated pressure. In order to support RP Speculation, this does not
309   /// rely on the implicit program ordering in the LiveIntervals.
310   GCNRegPressure bumpDownwardPressure(const MachineInstr *MI,
311                                       const SIRegisterInfo *TRI) const;
312 };
313 
314 /// \returns the LaneMask of live lanes of \p Reg at position \p SI. Only the
315 /// active lanes of \p LaneMaskFilter will be set in the return value. This is
316 /// used, for example, to limit the live lanes to a specific subreg when
317 /// calculating use masks.
318 LaneBitmask getLiveLaneMask(unsigned Reg, SlotIndex SI,
319                             const LiveIntervals &LIS,
320                             const MachineRegisterInfo &MRI,
321                             LaneBitmask LaneMaskFilter = LaneBitmask::getAll());
322 
323 LaneBitmask getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
324                             const MachineRegisterInfo &MRI,
325                             LaneBitmask LaneMaskFilter = LaneBitmask::getAll());
326 
327 GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
328                                      const MachineRegisterInfo &MRI);
329 
330 /// creates a map MachineInstr -> LiveRegSet
331 /// R - range of iterators on instructions
332 /// After - upon entry or exit of every instruction
333 /// Note: there is no entry in the map for instructions with empty live reg set
334 /// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
335 template <typename Range>
336 DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
337 getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
338   std::vector<SlotIndex> Indexes;
339   Indexes.reserve(std::distance(R.begin(), R.end()));
340   auto &SII = *LIS.getSlotIndexes();
341   for (MachineInstr *I : R) {
342     auto SI = SII.getInstructionIndex(*I);
343     Indexes.push_back(After ? SI.getDeadSlot() : SI.getBaseIndex());
344   }
345   llvm::sort(Indexes);
346 
347   auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
348   DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
349   SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
350   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
351     auto Reg = Register::index2VirtReg(I);
352     if (!LIS.hasInterval(Reg))
353       continue;
354     auto &LI = LIS.getInterval(Reg);
355     LiveIdxs.clear();
356     if (!LI.findIndexesLiveAt(Indexes, std::back_inserter(LiveIdxs)))
357       continue;
358     if (!LI.hasSubRanges()) {
359       for (auto SI : LiveIdxs)
360         LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] =
361           MRI.getMaxLaneMaskForVReg(Reg);
362     } else
363       for (const auto &S : LI.subranges()) {
364         // constrain search for subranges by indexes live at main range
365         SRLiveIdxs.clear();
366         S.findIndexesLiveAt(LiveIdxs, std::back_inserter(SRLiveIdxs));
367         for (auto SI : SRLiveIdxs)
368           LiveRegMap[SII.getInstructionFromIndex(SI)][Reg] |= S.LaneMask;
369       }
370   }
371   return LiveRegMap;
372 }
373 
374 inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
375                                                  const LiveIntervals &LIS) {
376   return getLiveRegs(LIS.getInstructionIndex(MI).getDeadSlot(), LIS,
377                      MI.getParent()->getParent()->getRegInfo());
378 }
379 
380 inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI,
381                                                   const LiveIntervals &LIS) {
382   return getLiveRegs(LIS.getInstructionIndex(MI).getBaseIndex(), LIS,
383                      MI.getParent()->getParent()->getRegInfo());
384 }
385 
386 template <typename Range>
387 GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI,
388                               Range &&LiveRegs) {
389   GCNRegPressure Res;
390   for (const auto &RM : LiveRegs)
391     Res.inc(RM.first, LaneBitmask::getNone(), RM.second, MRI);
392   return Res;
393 }
394 
395 bool isEqual(const GCNRPTracker::LiveRegSet &S1,
396              const GCNRPTracker::LiveRegSet &S2);
397 
398 Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr);
399 
400 Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,
401                 const MachineRegisterInfo &MRI);
402 
403 Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
404                          const GCNRPTracker::LiveRegSet &TrackedL,
405                          const TargetRegisterInfo *TRI, StringRef Pfx = "  ");
406 
407 struct GCNRegPressurePrinter : public MachineFunctionPass {
408   static char ID;
409 
410 public:
411   GCNRegPressurePrinter() : MachineFunctionPass(ID) {}
412 
413   bool runOnMachineFunction(MachineFunction &MF) override;
414 
415   void getAnalysisUsage(AnalysisUsage &AU) const override {
416     AU.addRequired<LiveIntervalsWrapperPass>();
417     AU.setPreservesAll();
418     MachineFunctionPass::getAnalysisUsage(AU);
419   }
420 };
421 
422 } // end namespace llvm
423 
424 #endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
425