xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision bfe93aedcc7d393c2697e66d6569baffb701ba6f)
1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
20 #include "SIInstrInfo.h"
21 #include "SIMachineFunctionInfo.h"
22 #include "llvm/ADT/BitmaskEnum.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/CodeGen/MachineScheduler.h"
25 #include "llvm/CodeGen/TargetOpcodes.h"
26 
27 using namespace llvm;
28 
29 #define DEBUG_TYPE "igrouplp"
30 
31 namespace {
32 
33 static cl::opt<bool> EnableExactSolver(
34     "amdgpu-igrouplp-exact-solver", cl::Hidden,
35     cl::desc("Whether to use the exponential time solver to fit "
36              "the instructions to the pipeline as closely as "
37              "possible."),
38     cl::init(false));
39 
40 static cl::opt<unsigned> CutoffForExact(
41     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
42     cl::desc("The maximum number of scheduling group conflicts "
43              "which we attempt to solve with the exponential time "
44              "exact solver. Problem sizes greater than this will"
45              "be solved by the less accurate greedy algorithm. Selecting "
46              "solver by size is superseded by manually selecting "
47              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
48 
49 static cl::opt<uint64_t> MaxBranchesExplored(
50     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
51     cl::desc("The amount of branches that we are willing to explore with"
52              "the exact algorithm before giving up."));
53 
54 static cl::opt<bool> UseCostHeur(
55     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
56     cl::desc("Whether to use the cost heuristic to make choices as we "
57              "traverse the search space using the exact solver. Defaulted "
58              "to on, and if turned off, we will use the node order -- "
59              "attempting to put the later nodes in the later sched groups. "
60              "Experimentally, results are mixed, so this should be set on a "
61              "case-by-case basis."));
62 
63 // Components of the mask that determines which instruction types may be may be
64 // classified into a SchedGroup.
65 enum class SchedGroupMask {
66   NONE = 0u,
67   ALU = 1u << 0,
68   VALU = 1u << 1,
69   SALU = 1u << 2,
70   MFMA = 1u << 3,
71   VMEM = 1u << 4,
72   VMEM_READ = 1u << 5,
73   VMEM_WRITE = 1u << 6,
74   DS = 1u << 7,
75   DS_READ = 1u << 8,
76   DS_WRITE = 1u << 9,
77   TRANS = 1u << 10,
78   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79         DS_READ | DS_WRITE | TRANS,
80   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
81 };
82 
83 class SchedGroup;
84 
85 // InstructionRule class is used to enact a filter which determines whether or
86 // not an SU maps to a given SchedGroup. It contains complementary data
87 // structures (e.g Cache) to help those filters.
88 class InstructionRule {
89 protected:
90   const SIInstrInfo *TII;
91   unsigned SGID;
92   // A cache made available to the Filter to store SUnits for subsequent
93   // invocations of the Filter
94   std::optional<SmallVector<SUnit *, 4>> Cache;
95 
96 public:
97   virtual bool
98   apply(const SUnit *, const ArrayRef<SUnit *>,
99         SmallVectorImpl<SchedGroup> &) {
100     return true;
101   };
102 
103   InstructionRule(const SIInstrInfo *TII, unsigned SGID,
104                   bool NeedsCache = false)
105       : TII(TII), SGID(SGID) {
106     if (NeedsCache) {
107       Cache = SmallVector<SUnit *, 4>();
108     }
109   }
110 
111   virtual ~InstructionRule() = default;
112 };
113 
114 using SUnitsToCandidateSGsMap = DenseMap<SUnit *, SmallVector<int, 4>>;
115 
116 // Classify instructions into groups to enable fine tuned control over the
117 // scheduler. These groups may be more specific than current SchedModel
118 // instruction classes.
119 class SchedGroup {
120 private:
121   // Mask that defines which instruction types can be classified into this
122   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
123   // and SCHED_GROUP_BARRIER.
124   SchedGroupMask SGMask;
125 
126   // Maximum number of SUnits that can be added to this group.
127   std::optional<unsigned> MaxSize;
128 
129   // SchedGroups will only synchronize with other SchedGroups that have the same
130   // SyncID.
131   int SyncID = 0;
132 
133   // SGID is used to map instructions to candidate SchedGroups
134   unsigned SGID;
135 
136   // The different rules each instruction in this SchedGroup must conform to
137   SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
138 
139   // Count of the number of created SchedGroups, used to initialize SGID.
140   static unsigned NumSchedGroups;
141 
142   // Try to add and edge from SU A to SU B.
143   bool tryAddEdge(SUnit *A, SUnit *B);
144 
145   // Use SGMask to determine whether we can classify MI as a member of this
146   // SchedGroup object.
147   bool canAddMI(const MachineInstr &MI) const;
148 
149 public:
150   // Collection of SUnits that are classified as members of this group.
151   SmallVector<SUnit *, 32> Collection;
152 
153   ScheduleDAGInstrs *DAG;
154   const SIInstrInfo *TII;
155 
156   // Returns true if SU can be added to this SchedGroup.
157   bool canAddSU(SUnit &SU) const;
158 
159   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
160   // MakePred is true, SU will be a predecessor of the SUnits in this
161   // SchedGroup, otherwise SU will be a successor.
162   void link(SUnit &SU, bool MakePred = false);
163 
164   // Add DAG dependencies and track which edges are added, and the count of
165   // missed edges
166   int link(SUnit &SU, bool MakePred,
167            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
168 
169   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
170   // Use the predicate to determine whether SU should be a predecessor (P =
171   // true) or a successor (P = false) of this SchedGroup.
172   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
173 
174   // Add DAG dependencies such that SUnits in this group shall be ordered
175   // before SUnits in OtherGroup.
176   void link(SchedGroup &OtherGroup);
177 
178   // Returns true if no more instructions may be added to this group.
179   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
180 
181   // Append a constraint that SUs must meet in order to fit into this
182   // SchedGroup. Since many rules involve the relationship between a SchedGroup
183   // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
184   // time (rather than SchedGroup init time.)
185   void addRule(std::shared_ptr<InstructionRule> NewRule) {
186     Rules.push_back(NewRule);
187   }
188 
189   // Returns true if the SU matches all rules
190   bool allowedByRules(const SUnit *SU,
191                       SmallVectorImpl<SchedGroup> &SyncPipe) const {
192     for (auto &Rule : Rules) {
193       if (!Rule->apply(SU, Collection, SyncPipe))
194         return false;
195     }
196     return true;
197   }
198 
199   // Add SU to the SchedGroup.
200   void add(SUnit &SU) {
201     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
202                       << format_hex((int)SGMask, 10, true) << " adding "
203                       << *SU.getInstr());
204     Collection.push_back(&SU);
205   }
206 
207   // Remove last element in the SchedGroup
208   void pop() { Collection.pop_back(); }
209 
210   // Identify and add all relevant SUs from the DAG to this SchedGroup.
211   void initSchedGroup();
212 
213   // Add instructions to the SchedGroup bottom up starting from RIter.
214   // PipelineInstrs is a set of instructions that should not be added to the
215   // SchedGroup even when the other conditions for adding it are satisfied.
216   // RIter will be added to the SchedGroup as well, and dependencies will be
217   // added so that RIter will always be scheduled at the end of the group.
218   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
219                       SUnitsToCandidateSGsMap &SyncedInstrs);
220 
221   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
222 
223   int getSyncID() { return SyncID; }
224 
225   int getSGID() { return SGID; }
226 
227   SchedGroupMask getMask() { return SGMask; }
228 
229   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
230              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
231       : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
232     SGID = NumSchedGroups++;
233   }
234 
235   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
236              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
237       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
238     SGID = NumSchedGroups++;
239   }
240 };
241 
242 using SUToCandSGsPair = std::pair<SUnit *, SmallVector<int, 4>>;
243 using SUsToCandSGsVec = SmallVector<SUToCandSGsPair, 4>;
244 
245 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
246 // in non-trivial cases. For example, if the requested pipeline is
247 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
248 // in the DAG, then we will have an instruction that can not be trivially
249 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
250 // to find a good solution to the pipeline -- a greedy algorithm and an exact
251 // algorithm. The exact algorithm has an exponential time complexity and should
252 // only be used for small sized problems or medium sized problems where an exact
253 // solution is highly desired.
254 class PipelineSolver {
255   [[maybe_unused]] ScheduleDAGMI *DAG;
256 
257   // Instructions that can be assigned to multiple SchedGroups
258   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
259   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
260   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
261   // The current working pipeline
262   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
263   // The pipeline that has the best solution found so far
264   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
265 
266   // Whether or not we actually have any SyncedInstrs to try to solve.
267   bool NeedsSolver = false;
268 
269   // Compute an estimate of the size of search tree -- the true size is
270   // the product of each conflictedInst.Matches.size() across all SyncPipelines
271   unsigned computeProblemSize();
272 
273   // The cost penalty of not assigning a SU to a SchedGroup
274   int MissPenalty = 0;
275 
276   // Costs in terms of the number of edges we are unable to add
277   int BestCost = -1;
278   int CurrCost = 0;
279 
280   // Index pointing to the conflicting instruction that is currently being
281   // fitted
282   int CurrConflInstNo = 0;
283   // Index to the pipeline that is currently being fitted
284   int CurrSyncGroupIdx = 0;
285   // The first non trivial pipeline
286   int BeginSyncGroupIdx = 0;
287 
288   // How many branches we have explored
289   uint64_t BranchesExplored = 0;
290 
291   // The direction in which we process the candidate SchedGroups per SU
292   bool IsBottomUp = true;
293 
294   // Update indices to fit next conflicting instruction
295   void advancePosition();
296   // Recede indices to attempt to find better fit for previous conflicting
297   // instruction
298   void retreatPosition();
299 
300   // The exponential time algorithm which finds the provably best fit
301   bool solveExact();
302   // The polynomial time algorithm which attempts to find a good fit
303   bool solveGreedy();
304   // Find the best SchedGroup for the current SU using the heuristic given all
305   // current information. One step in the greedy algorithm. Templated against
306   // the SchedGroup iterator (either reverse or forward).
307   template <typename T>
308   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
309                   T E);
310   // Whether or not the current solution is optimal
311   bool checkOptimal();
312   // Populate the ready list, prioiritizing fewest missed edges first
313   // Templated against the SchedGroup iterator (either reverse or forward).
314   template <typename T>
315   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
316                          T E);
317   // Add edges corresponding to the SchedGroups as assigned by solver
318   void makePipeline();
319   // Link the SchedGroups in the best found pipeline.
320   // Tmplated against the SchedGroup iterator (either reverse or forward).
321   template <typename T> void linkSchedGroups(T I, T E);
322   // Add the edges from the SU to the other SchedGroups in pipeline, and
323   // return the number of edges missed.
324   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
325                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
326   /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
327   /// returns the cost (in terms of missed pipeline edges), and tracks the edges
328   /// added in \p AddedEdges
329   template <typename T>
330   int linkSUnit(SUnit *SU, int SGID,
331                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
332   /// Remove the edges passed via \p AddedEdges
333   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
334   // Convert the passed in maps to arrays for bidirectional iterators
335   void convertSyncMapsToArrays();
336 
337   void reset();
338 
339 public:
340   // Invoke the solver to map instructions to instruction groups. Heuristic &&
341   // command-line-option determines to use exact or greedy algorithm.
342   void solve();
343 
344   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
345                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
346                  ScheduleDAGMI *DAG, bool IsBottomUp = true)
347       : DAG(DAG), SyncedInstrs(SyncedInstrs),
348         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
349 
350     for (auto &PipelineInstrs : SyncedInstrs) {
351       if (PipelineInstrs.second.size() > 0) {
352         NeedsSolver = true;
353         break;
354       }
355     }
356 
357     if (!NeedsSolver)
358       return;
359 
360     convertSyncMapsToArrays();
361 
362     CurrPipeline = BestPipeline;
363 
364     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
365            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
366       ++BeginSyncGroupIdx;
367 
368     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
369       return;
370   }
371 };
372 
373 void PipelineSolver::reset() {
374 
375   for (auto &SyncPipeline : CurrPipeline) {
376     for (auto &SG : SyncPipeline) {
377       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
378       SG.Collection.clear();
379       auto *SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
380         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
381       });
382       if (SchedBarr != TempCollection.end())
383         SG.Collection.push_back(*SchedBarr);
384     }
385   }
386 
387   CurrSyncGroupIdx = BeginSyncGroupIdx;
388   CurrConflInstNo = 0;
389   CurrCost = 0;
390 }
391 
392 void PipelineSolver::convertSyncMapsToArrays() {
393   for (auto &SyncPipe : SyncedSchedGroups) {
394     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
395   }
396 
397   int PipelineIDx = SyncedInstrs.size() - 1;
398   PipelineInstrs.resize(SyncedInstrs.size());
399   for (auto &SyncInstrMap : SyncedInstrs) {
400     for (auto &SUsToCandSGs : SyncInstrMap.second) {
401       if (PipelineInstrs[PipelineIDx].size() == 0) {
402         PipelineInstrs[PipelineIDx].push_back(
403             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
404         continue;
405       }
406       auto *SortPosition = PipelineInstrs[PipelineIDx].begin();
407       // Insert them in sorted order -- this allows for good parsing order in
408       // the greedy algorithm
409       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
410              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
411         ++SortPosition;
412       PipelineInstrs[PipelineIDx].insert(
413           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
414     }
415     --PipelineIDx;
416   }
417 }
418 
419 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
420   for (; I != E; ++I) {
421     auto &GroupA = *I;
422     for (auto J = std::next(I); J != E; ++J) {
423       auto &GroupB = *J;
424       GroupA.link(GroupB);
425     }
426   }
427 }
428 
429 void PipelineSolver::makePipeline() {
430   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
431   for (auto &SyncPipeline : BestPipeline) {
432     LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
433     for (auto &SG : SyncPipeline) {
434       LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
435                         << " has: \n");
436       SUnit *SGBarr = nullptr;
437       for (auto &SU : SG.Collection) {
438         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
439           SGBarr = SU;
440         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
441       }
442       // Command line requested IGroupLP doesn't have SGBarr
443       if (!SGBarr)
444         continue;
445       SG.link(*SGBarr, false);
446     }
447   }
448 
449   for (auto &SyncPipeline : BestPipeline) {
450     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
451                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
452   }
453 }
454 
455 template <typename T>
456 int PipelineSolver::linkSUnit(
457     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
458     T I, T E) {
459   bool MakePred = false;
460   int AddedCost = 0;
461   for (; I < E; ++I) {
462     if (I->getSGID() == SGID) {
463       MakePred = true;
464       continue;
465     }
466     auto Group = *I;
467     AddedCost += Group.link(*SU, MakePred, AddedEdges);
468     assert(AddedCost >= 0);
469   }
470   return AddedCost;
471 }
472 
473 int PipelineSolver::addEdges(
474     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
475     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
476 
477   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
478   // instructions that are the ultimate successors in the resultant mutation.
479   // Therefore, in such a configuration, the SchedGroups occurring before the
480   // candidate SGID are successors of the candidate SchedGroup, thus the current
481   // SU should be linked as a predecessor to SUs in those SchedGroups. The
482   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
483   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
484   // IsBottomUp (in reverse).
485   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
486                                 SyncPipeline.rend())
487                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
488                                 SyncPipeline.end());
489 }
490 
491 void PipelineSolver::removeEdges(
492     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
493   // Only remove the edges that we have added when testing
494   // the fit.
495   for (auto &PredSuccPair : EdgesToRemove) {
496     SUnit *Pred = PredSuccPair.first;
497     SUnit *Succ = PredSuccPair.second;
498 
499     auto *Match = llvm::find_if(
500         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
501     if (Match != Succ->Preds.end()) {
502       assert(Match->isArtificial());
503       Succ->removePred(*Match);
504     }
505   }
506 }
507 
508 void PipelineSolver::advancePosition() {
509   ++CurrConflInstNo;
510 
511   if (static_cast<size_t>(CurrConflInstNo) >=
512       PipelineInstrs[CurrSyncGroupIdx].size()) {
513     CurrConflInstNo = 0;
514     ++CurrSyncGroupIdx;
515     // Advance to next non-trivial pipeline
516     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
517            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
518       ++CurrSyncGroupIdx;
519   }
520 }
521 
522 void PipelineSolver::retreatPosition() {
523   assert(CurrConflInstNo >= 0);
524   assert(CurrSyncGroupIdx >= 0);
525 
526   if (CurrConflInstNo > 0) {
527     --CurrConflInstNo;
528     return;
529   }
530 
531   if (CurrConflInstNo == 0) {
532     // If we return to the starting position, we have explored
533     // the entire tree
534     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
535       return;
536 
537     --CurrSyncGroupIdx;
538     // Go to previous non-trivial pipeline
539     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
540       --CurrSyncGroupIdx;
541 
542     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
543   }
544 }
545 
546 bool PipelineSolver::checkOptimal() {
547   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
548     if (BestCost == -1 || CurrCost < BestCost) {
549       BestPipeline = CurrPipeline;
550       BestCost = CurrCost;
551       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
552     }
553     assert(BestCost >= 0);
554   }
555 
556   bool DoneExploring = false;
557   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
558     DoneExploring = true;
559 
560   return (DoneExploring || BestCost == 0);
561 }
562 
563 template <typename T>
564 void PipelineSolver::populateReadyList(
565     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
566   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
567   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
568   assert(CurrSU.second.size() >= 1);
569 
570   for (; I != E; ++I) {
571     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
572     int CandSGID = *I;
573     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
574       return SG.getSGID() == CandSGID;
575     });
576     assert(Match);
577 
578     if (UseCostHeur) {
579       if (Match->isFull()) {
580         ReadyList.push_back(std::pair(*I, MissPenalty));
581         continue;
582       }
583 
584       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
585       ReadyList.push_back(std::pair(*I, TempCost));
586       removeEdges(AddedEdges);
587     } else
588       ReadyList.push_back(std::pair(*I, -1));
589   }
590 
591   if (UseCostHeur) {
592     std::sort(ReadyList.begin(), ReadyList.end(),
593               [](std::pair<int, int> A, std::pair<int, int> B) {
594                 return A.second < B.second;
595               });
596   }
597 
598   assert(ReadyList.size() == CurrSU.second.size());
599 }
600 
601 bool PipelineSolver::solveExact() {
602   if (checkOptimal())
603     return true;
604 
605   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
606     return false;
607 
608   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
609   assert(static_cast<size_t>(CurrConflInstNo) <
610          PipelineInstrs[CurrSyncGroupIdx].size());
611   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
612   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
613                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
614 
615   // SchedGroup -> Cost pairs
616   SmallVector<std::pair<int, int>, 4> ReadyList;
617   // Prioritize the candidate sched groups in terms of lowest cost first
618   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
619                                  CurrSU.second.rend())
620              : populateReadyList(ReadyList, CurrSU.second.begin(),
621                                  CurrSU.second.end());
622 
623   auto *I = ReadyList.begin();
624   auto *E = ReadyList.end();
625   for (; I != E; ++I) {
626     // If we are trying SGs in least cost order, and the current SG is cost
627     // infeasible, then all subsequent SGs will also be cost infeasible, so we
628     // can prune.
629     if (BestCost != -1 && (CurrCost + I->second > BestCost))
630       return false;
631 
632     int CandSGID = I->first;
633     int AddedCost = 0;
634     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
635     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
636     SchedGroup *Match;
637     for (auto &SG : SyncPipeline) {
638       if (SG.getSGID() == CandSGID)
639         Match = &SG;
640     }
641 
642     if (Match->isFull())
643       continue;
644 
645     if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
646       continue;
647 
648     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
649                       << (int)Match->getMask() << "and ID " << CandSGID
650                       << "\n");
651     Match->add(*CurrSU.first);
652     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
653     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
654     CurrCost += AddedCost;
655     advancePosition();
656     ++BranchesExplored;
657     bool FinishedExploring = false;
658     // If the Cost after adding edges is greater than a known solution,
659     // backtrack
660     if (CurrCost < BestCost || BestCost == -1) {
661       if (solveExact()) {
662         FinishedExploring = BestCost != 0;
663         if (!FinishedExploring)
664           return true;
665       }
666     }
667 
668     retreatPosition();
669     CurrCost -= AddedCost;
670     removeEdges(AddedEdges);
671     Match->pop();
672     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
673     if (FinishedExploring)
674       return true;
675   }
676 
677   // Try the pipeline where the current instruction is omitted
678   // Potentially if we omit a problematic instruction from the pipeline,
679   // all the other instructions can nicely fit.
680   CurrCost += MissPenalty;
681   advancePosition();
682 
683   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
684 
685   bool FinishedExploring = false;
686   if (CurrCost < BestCost || BestCost == -1) {
687     if (solveExact()) {
688       bool FinishedExploring = BestCost != 0;
689       if (!FinishedExploring)
690         return true;
691     }
692   }
693 
694   retreatPosition();
695   CurrCost -= MissPenalty;
696   return FinishedExploring;
697 }
698 
699 template <typename T>
700 void PipelineSolver::greedyFind(
701     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
702   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
703   int BestNodeCost = -1;
704   int TempCost;
705   SchedGroup *BestGroup = nullptr;
706   int BestGroupID = -1;
707   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
708   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
709                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
710 
711   // Since we have added the potential SchedGroups from bottom up, but
712   // traversed the DAG from top down, parse over the groups from last to
713   // first. If we fail to do this for the greedy algorithm, the solution will
714   // likely not be good in more complex cases.
715   for (; I != E; ++I) {
716     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
717     int CandSGID = *I;
718     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
719       return SG.getSGID() == CandSGID;
720     });
721     assert(Match);
722 
723     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
724                       << (int)Match->getMask() << "\n");
725 
726     if (Match->isFull()) {
727       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
728       continue;
729     }
730     if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
731       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
732       continue;
733     }
734     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
735     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
736     if (TempCost < BestNodeCost || BestNodeCost == -1) {
737       BestGroup = Match;
738       BestNodeCost = TempCost;
739       BestGroupID = CandSGID;
740     }
741     removeEdges(AddedEdges);
742     if (BestNodeCost == 0)
743       break;
744   }
745 
746   if (BestGroupID != -1) {
747     BestGroup->add(*CurrSU.first);
748     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
749     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
750                       << (int)BestGroup->getMask() << "\n");
751     BestCost += TempCost;
752   } else
753     BestCost += MissPenalty;
754 
755   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
756 }
757 
758 bool PipelineSolver::solveGreedy() {
759   BestCost = 0;
760   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
761 
762   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
763     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
764     IsBottomUp
765         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
766         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
767     advancePosition();
768   }
769   BestPipeline = CurrPipeline;
770   removeEdges(AddedEdges);
771   return false;
772 }
773 
774 unsigned PipelineSolver::computeProblemSize() {
775   unsigned ProblemSize = 0;
776   for (auto &PipeConflicts : PipelineInstrs) {
777     ProblemSize += PipeConflicts.size();
778   }
779 
780   return ProblemSize;
781 }
782 
783 void PipelineSolver::solve() {
784   if (!NeedsSolver)
785     return;
786 
787   unsigned ProblemSize = computeProblemSize();
788   assert(ProblemSize > 0);
789 
790   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
791   MissPenalty = (ProblemSize / 2) + 1;
792 
793   LLVM_DEBUG(DAG->dump());
794   if (EnableExactSolver || BelowCutoff) {
795     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
796     solveGreedy();
797     reset();
798     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
799     if (BestCost > 0) {
800       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
801       solveExact();
802       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
803     }
804   } else { // Use the Greedy Algorithm by default
805     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
806     solveGreedy();
807   }
808 
809   makePipeline();
810   LLVM_DEBUG(dbgs() << "After applying mutation\n");
811   LLVM_DEBUG(DAG->dump());
812 }
813 
814 enum IGLPStrategyID : int {
815   MFMASmallGemmOptID = 0,
816   MFMASmallGemmSingleWaveOptID = 1,
817   MFMAExpInterleaveID = 2,
818   MFMAExpSimpleInterleaveID = 3
819 };
820 
821 // Implement a IGLP scheduling strategy.
822 class IGLPStrategy {
823 protected:
824   ScheduleDAGInstrs *DAG;
825 
826   const SIInstrInfo *TII;
827 
828 public:
829   /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
830   virtual bool applyIGLPStrategy(
831       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
832       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
833       AMDGPU::SchedulingPhase Phase) = 0;
834 
835   // Returns true if this strategy should be applied to a ScheduleDAG.
836   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
837                                    AMDGPU::SchedulingPhase Phase) = 0;
838 
839   bool IsBottomUp = true;
840 
841   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
842       : DAG(DAG), TII(TII) {}
843 
844   virtual ~IGLPStrategy() = default;
845 };
846 
847 class MFMASmallGemmOpt final : public IGLPStrategy {
848 private:
849 public:
850   bool applyIGLPStrategy(
851       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
852       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
853       AMDGPU::SchedulingPhase Phase) override;
854 
855   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
856                            AMDGPU::SchedulingPhase Phase) override {
857     return true;
858   }
859 
860   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
861       : IGLPStrategy(DAG, TII) {
862     IsBottomUp = true;
863   }
864 };
865 
866 bool MFMASmallGemmOpt::applyIGLPStrategy(
867     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
868     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
869     AMDGPU::SchedulingPhase Phase) {
870   // Count the number of MFMA instructions.
871   unsigned MFMACount = 0;
872   for (const MachineInstr &I : *DAG)
873     if (TII->isMFMAorWMMA(I))
874       ++MFMACount;
875 
876   const unsigned PipelineSyncID = 0;
877   SchedGroup *SG = nullptr;
878   for (unsigned I = 0; I < MFMACount * 3; ++I) {
879     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
880         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
881     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
882 
883     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
884         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
885     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
886   }
887 
888   return true;
889 }
890 
891 class MFMAExpInterleaveOpt final : public IGLPStrategy {
892 private:
893   // The count of TRANS SUs involved in the interleaved pipeline
894   static unsigned TransPipeCount;
895   // The count of MFMA SUs involved in the interleaved pipeline
896   static unsigned MFMAPipeCount;
897   // The count of Add SUs involved in the interleaved pipeline
898   static unsigned AddPipeCount;
899   // The number of transitive MFMA successors for each TRANS SU
900   static unsigned MFMAEnablement;
901   // The number of transitive TRANS predecessors for each MFMA SU
902   static unsigned ExpRequirement;
903   // The count of independent "chains" of MFMA instructions in the pipeline
904   static unsigned MFMAChains;
905   // The length of each independent "chain" of MFMA instructions
906   static unsigned MFMAChainLength;
907   // Whether or not the pipeline has V_CVT instructions
908   static bool HasCvt;
909   // Whether or not there are instructions between the TRANS instruction and
910   // V_CVT
911   static bool HasChainBetweenCvt;
912   // The first occuring DS_READ which feeds an MFMA chain
913   static std::optional<unsigned> FirstPipeDSR;
914   // The MFMAPipe SUs with no MFMA predecessors
915   SmallVector<SUnit *, 4> MFMAChainSeeds;
916   // Compute the heuristics for the pipeline, returning whether or not the DAG
917   // is well formatted for the mutation
918   bool analyzeDAG(const SIInstrInfo *TII);
919 
920   /// Whether or not the instruction is a transitive predecessor of an MFMA
921   /// instruction
922   class IsPipeExp final : public InstructionRule {
923   public:
924     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
925                SmallVectorImpl<SchedGroup> &SyncPipe) override {
926 
927       auto *DAG = SyncPipe[0].DAG;
928 
929       if (Cache->empty()) {
930         auto I = DAG->SUnits.rbegin();
931         auto E = DAG->SUnits.rend();
932         for (; I != E; I++) {
933           if (TII->isMFMAorWMMA(*I->getInstr()))
934             Cache->push_back(&*I);
935         }
936         if (Cache->empty())
937           return false;
938       }
939 
940       auto Reaches = any_of(*Cache, [&SU, &DAG](SUnit *TargetSU) {
941         return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
942       });
943 
944       return Reaches;
945     }
946     IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
947         : InstructionRule(TII, SGID, NeedsCache) {}
948   };
949 
950   /// Whether or not the instruction is a transitive predecessor of the
951   /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
952   class EnablesNthMFMA final : public InstructionRule {
953   private:
954     unsigned Number = 1;
955 
956   public:
957     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
958                SmallVectorImpl<SchedGroup> &SyncPipe) override {
959       bool FoundTrans = false;
960       unsigned Counter = 1;
961       auto *DAG = SyncPipe[0].DAG;
962 
963       if (Cache->empty()) {
964         SmallVector<SUnit *, 8> Worklist;
965 
966         auto I = DAG->SUnits.begin();
967         auto E = DAG->SUnits.end();
968         for (; I != E; I++) {
969           if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
970             if (Counter == Number) {
971               Cache->push_back(&*I);
972               break;
973             }
974             ++Counter;
975           }
976           if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
977             FoundTrans = true;
978         }
979         if (Cache->empty())
980           return false;
981       }
982 
983       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
984     }
985 
986     EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
987                    bool NeedsCache = false)
988         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
989   };
990 
991   /// Whether or not the instruction enables the exact MFMA that is the \p
992   /// Number th MFMA in the chain starting with \p ChainSeed
993   class EnablesNthMFMAInChain final : public InstructionRule {
994   private:
995     unsigned Number = 1;
996     SUnit *ChainSeed;
997 
998   public:
999     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1000                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1001       auto *DAG = SyncPipe[0].DAG;
1002 
1003       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1004         return false;
1005 
1006       if (Cache->empty()) {
1007         auto *TempSU = ChainSeed;
1008         auto Depth = Number;
1009         while (Depth > 0) {
1010           --Depth;
1011           bool Found = false;
1012           for (auto &Succ : TempSU->Succs) {
1013             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1014               TempSU = Succ.getSUnit();
1015               Found = true;
1016               break;
1017             }
1018           }
1019           if (!Found)
1020             return false;
1021         }
1022 
1023         Cache->push_back(TempSU);
1024       }
1025       // If we failed to find the instruction to be placed into the cache, we
1026       // would have already exited.
1027       assert(!Cache->empty());
1028 
1029       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1030     }
1031 
1032     EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
1033                           const SIInstrInfo *TII, unsigned SGID,
1034                           bool NeedsCache = false)
1035         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1036           ChainSeed(ChainSeed) {}
1037   };
1038 
1039   /// Whether or not the instruction has less than \p Size immediate successors.
1040   /// If \p HasIntermediary is true, this tests also whether all successors of
1041   /// the SUnit have less than \p Size successors.
1042   class LessThanNSuccs final : public InstructionRule {
1043   private:
1044     unsigned Size = 1;
1045     bool HasIntermediary = false;
1046 
1047   public:
1048     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1049                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1050       if (!SyncPipe.size())
1051         return false;
1052 
1053       auto SuccSize = std::count_if(
1054           SU->Succs.begin(), SU->Succs.end(),
1055           [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1056       if (SuccSize >= Size)
1057         return false;
1058 
1059       if (HasIntermediary) {
1060         for (auto Succ : SU->Succs) {
1061           auto SuccSize = std::count_if(
1062               Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1063               [](const SDep &SuccSucc) {
1064                 return SuccSucc.getKind() == SDep::Data;
1065               });
1066           if (SuccSize >= Size)
1067             return false;
1068         }
1069       }
1070 
1071       return true;
1072     }
1073     LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
1074                    bool HasIntermediary = false, bool NeedsCache = false)
1075         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1076           HasIntermediary(HasIntermediary) {}
1077   };
1078 
1079   /// Whether or not the instruction has greater than or equal to \p Size
1080   /// immediate successors. If \p HasIntermediary is true, this tests also
1081   /// whether all successors of the SUnit have greater than or equal to \p Size
1082   /// successors.
1083   class GreaterThanOrEqualToNSuccs final : public InstructionRule {
1084   private:
1085     unsigned Size = 1;
1086     bool HasIntermediary = false;
1087 
1088   public:
1089     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1090                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1091       if (!SyncPipe.size())
1092         return false;
1093 
1094       auto SuccSize = std::count_if(
1095           SU->Succs.begin(), SU->Succs.end(),
1096           [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1097       if (SuccSize >= Size)
1098         return true;
1099 
1100       if (HasIntermediary) {
1101         for (auto Succ : SU->Succs) {
1102           auto SuccSize = std::count_if(
1103               Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1104               [](const SDep &SuccSucc) {
1105                 return SuccSucc.getKind() == SDep::Data;
1106               });
1107           if (SuccSize >= Size)
1108             return true;
1109         }
1110       }
1111 
1112       return false;
1113     }
1114     GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
1115                                unsigned SGID, bool HasIntermediary = false,
1116                                bool NeedsCache = false)
1117         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1118           HasIntermediary(HasIntermediary) {}
1119   };
1120 
1121   // Whether or not the instruction is a relevant V_CVT instruction.
1122   class IsCvt final : public InstructionRule {
1123   public:
1124     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1125                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1126       auto Opc = SU->getInstr()->getOpcode();
1127       return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
1128              Opc == AMDGPU::V_CVT_I32_F32_e32;
1129     }
1130     IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1131         : InstructionRule(TII, SGID, NeedsCache) {}
1132   };
1133 
1134   // Whether or not the instruction is FMA_F32.
1135   class IsFMA final : public InstructionRule {
1136   public:
1137     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1138                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1139       return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
1140              SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
1141     }
1142     IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1143         : InstructionRule(TII, SGID, NeedsCache) {}
1144   };
1145 
1146   // Whether or not the instruction is a V_ADD_F32 instruction.
1147   class IsPipeAdd final : public InstructionRule {
1148   public:
1149     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1150                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1151       return SU->getInstr()->getOpcode() == AMDGPU::V_ADD_F32_e32;
1152     }
1153     IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1154         : InstructionRule(TII, SGID, NeedsCache) {}
1155   };
1156 
1157   /// Whether or not the instruction is an immediate RAW successor
1158   /// of the SchedGroup \p Distance steps before.
1159   class IsSuccOfPrevNthGroup final : public InstructionRule {
1160   private:
1161     unsigned Distance = 1;
1162 
1163   public:
1164     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1165                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1166       SchedGroup *OtherGroup = nullptr;
1167       if (!SyncPipe.size())
1168         return false;
1169 
1170       for (auto &PipeSG : SyncPipe) {
1171         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1172           OtherGroup = &PipeSG;
1173       }
1174 
1175       if (!OtherGroup)
1176         return false;
1177       if (!OtherGroup->Collection.size())
1178         return true;
1179 
1180       for (auto &OtherEle : OtherGroup->Collection) {
1181         for (auto &Succ : OtherEle->Succs) {
1182           if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
1183             return true;
1184         }
1185       }
1186 
1187       return false;
1188     }
1189     IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1190                          unsigned SGID, bool NeedsCache = false)
1191         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1192   };
1193 
1194   /// Whether or not the instruction is a transitive successor of any
1195   /// instruction the the SchedGroup \p Distance steps before.
1196   class IsReachableFromPrevNthGroup final : public InstructionRule {
1197   private:
1198     unsigned Distance = 1;
1199 
1200   public:
1201     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1202                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1203       SchedGroup *OtherGroup = nullptr;
1204       if (!SyncPipe.size())
1205         return false;
1206 
1207       for (auto &PipeSG : SyncPipe) {
1208         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1209           OtherGroup = &PipeSG;
1210       }
1211 
1212       if (!OtherGroup)
1213         return false;
1214       if (!OtherGroup->Collection.size())
1215         return true;
1216 
1217       auto *DAG = SyncPipe[0].DAG;
1218 
1219       for (auto &OtherEle : OtherGroup->Collection)
1220         if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
1221           return true;
1222 
1223       return false;
1224     }
1225     IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1226                                 unsigned SGID, bool NeedsCache = false)
1227         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1228   };
1229 
1230   /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
1231   class OccursAtOrAfterNode final : public InstructionRule {
1232   private:
1233     unsigned Number = 1;
1234 
1235   public:
1236     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1237                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1238 
1239       return SU->NodeNum >= Number;
1240     }
1241     OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1242                         bool NeedsCache = false)
1243         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1244   };
1245 
1246   /// Whether or not the SU is exactly the \p Number th MFMA in the chain
1247   /// starting with \p ChainSeed
1248   class IsExactMFMA final : public InstructionRule {
1249   private:
1250     unsigned Number = 1;
1251     SUnit *ChainSeed;
1252 
1253   public:
1254     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1255                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1256       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1257         return false;
1258 
1259       if (Cache->empty()) {
1260         auto *TempSU = ChainSeed;
1261         auto Depth = Number;
1262         while (Depth > 0) {
1263           --Depth;
1264           bool Found = false;
1265           for (auto &Succ : TempSU->Succs) {
1266             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1267               TempSU = Succ.getSUnit();
1268               Found = true;
1269               break;
1270             }
1271           }
1272           if (!Found) {
1273             return false;
1274           }
1275         }
1276         Cache->push_back(TempSU);
1277       }
1278       // If we failed to find the instruction to be placed into the cache, we
1279       // would have already exited.
1280       assert(!Cache->empty());
1281 
1282       return (*Cache)[0] == SU;
1283     }
1284 
1285     IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
1286                 unsigned SGID, bool NeedsCache = false)
1287         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1288           ChainSeed(ChainSeed) {}
1289   };
1290 
1291   // Whether the instruction occurs after the first TRANS instruction. This
1292   // implies the instruction can not be a predecessor of the first TRANS
1293   // insruction
1294   class OccursAfterExp final : public InstructionRule {
1295   public:
1296     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1297                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1298 
1299       SmallVector<SUnit *, 12> Worklist;
1300       auto *DAG = SyncPipe[0].DAG;
1301       if (Cache->empty()) {
1302         for (auto &SU : DAG->SUnits)
1303           if (TII->isTRANS(SU.getInstr()->getOpcode())) {
1304             Cache->push_back(&SU);
1305             break;
1306           }
1307         if (Cache->empty())
1308           return false;
1309       }
1310 
1311       return SU->NodeNum > (*Cache)[0]->NodeNum;
1312     }
1313 
1314     OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
1315                    bool NeedsCache = false)
1316         : InstructionRule(TII, SGID, NeedsCache) {}
1317   };
1318 
1319 public:
1320   bool applyIGLPStrategy(
1321       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1322       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1323       AMDGPU::SchedulingPhase Phase) override;
1324 
1325   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1326                            AMDGPU::SchedulingPhase Phase) override;
1327 
1328   MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1329       : IGLPStrategy(DAG, TII) {
1330     IsBottomUp = false;
1331   }
1332 };
1333 
1334 unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
1335 unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
1336 unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;
1337 unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
1338 unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
1339 unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
1340 unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
1341 bool MFMAExpInterleaveOpt::HasCvt = false;
1342 bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
1343 std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
1344 
1345 bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
1346   SmallVector<SUnit *, 10> ExpPipeCands;
1347   SmallVector<SUnit *, 10> MFMAPipeCands;
1348   SmallVector<SUnit *, 10> MFMAPipeSUs;
1349   SmallVector<SUnit *, 10> PackSUs;
1350   SmallVector<SUnit *, 10> CvtSUs;
1351 
1352   auto isBitPack = [](unsigned Opc) {
1353     return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
1354   };
1355 
1356   auto isCvt = [](unsigned Opc) {
1357     return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
1358   };
1359 
1360   auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };
1361 
1362   AddPipeCount = 0;
1363   for (SUnit &SU : DAG->SUnits) {
1364     auto Opc = SU.getInstr()->getOpcode();
1365     if (TII->isTRANS(Opc)) {
1366       // Avoid counting a potential bonus V_EXP which all the MFMA depend on
1367       if (SU.Succs.size() >= 7)
1368         continue;
1369       for (auto &Succ : SU.Succs) {
1370         if (Succ.getSUnit()->Succs.size() >= 7)
1371           continue;
1372       }
1373       ExpPipeCands.push_back(&SU);
1374     }
1375 
1376     if (TII->isMFMAorWMMA(*SU.getInstr()))
1377       MFMAPipeCands.push_back(&SU);
1378 
1379     if (isBitPack(Opc))
1380       PackSUs.push_back(&SU);
1381 
1382     if (isCvt(Opc))
1383       CvtSUs.push_back(&SU);
1384 
1385     if (isAdd(Opc))
1386       ++AddPipeCount;
1387   }
1388 
1389   if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
1390     return false;
1391 
1392   TransPipeCount = 0;
1393 
1394   std::optional<SUnit *> TempMFMA;
1395   std::optional<SUnit *> TempExp;
1396   // Count the number of EXPs that reach an MFMA
1397   for (auto &PredSU : ExpPipeCands) {
1398     for (auto &SuccSU : MFMAPipeCands) {
1399       if (DAG->IsReachable(SuccSU, PredSU)) {
1400         if (!TempExp) {
1401           TempExp = PredSU;
1402           TempMFMA = SuccSU;
1403         }
1404         MFMAPipeSUs.push_back(SuccSU);
1405         ++TransPipeCount;
1406         break;
1407       }
1408     }
1409   }
1410 
1411   if (!(TempExp && TempMFMA))
1412     return false;
1413 
1414   HasChainBetweenCvt = none_of((*TempExp)->Succs, [&isCvt](SDep &Succ) {
1415     return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
1416   });
1417 
1418   // Count the number of MFMAs that are reached by an EXP
1419   for (auto &SuccSU : MFMAPipeCands) {
1420     if (MFMAPipeSUs.size() &&
1421         any_of(MFMAPipeSUs, [&SuccSU](SUnit *PotentialMatch) {
1422           return PotentialMatch->NodeNum == SuccSU->NodeNum;
1423         }))
1424       continue;
1425 
1426     for (auto &PredSU : ExpPipeCands) {
1427       if (DAG->IsReachable(SuccSU, PredSU)) {
1428         MFMAPipeSUs.push_back(SuccSU);
1429         break;
1430       }
1431     }
1432   }
1433 
1434   MFMAPipeCount = MFMAPipeSUs.size();
1435 
1436   assert(TempExp && TempMFMA);
1437   assert(MFMAPipeCount > 0);
1438 
1439   std::optional<SUnit *> TempCvt;
1440   for (auto &SuccSU : CvtSUs) {
1441     if (DAG->IsReachable(SuccSU, *TempExp)) {
1442       TempCvt = SuccSU;
1443       break;
1444     }
1445   }
1446 
1447   HasCvt = false;
1448   if (TempCvt.has_value()) {
1449     for (auto &SuccSU : MFMAPipeSUs) {
1450       if (DAG->IsReachable(SuccSU, *TempCvt)) {
1451         HasCvt = true;
1452         break;
1453       }
1454     }
1455   }
1456 
1457   MFMAChains = 0;
1458   for (auto &MFMAPipeSU : MFMAPipeSUs) {
1459     if (is_contained(MFMAChainSeeds, MFMAPipeSU))
1460       continue;
1461     if (none_of(MFMAPipeSU->Preds, [&TII](SDep &Succ) {
1462           return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1463         })) {
1464       MFMAChainSeeds.push_back(MFMAPipeSU);
1465       ++MFMAChains;
1466     }
1467   }
1468 
1469   if (!MFMAChains)
1470     return false;
1471 
1472   for (auto Pred : MFMAChainSeeds[0]->Preds) {
1473     if (TII->isDS(Pred.getSUnit()->getInstr()->getOpcode()) &&
1474         Pred.getSUnit()->getInstr()->mayLoad())
1475       FirstPipeDSR = Pred.getSUnit()->NodeNum;
1476   }
1477 
1478   MFMAChainLength = MFMAPipeCount / MFMAChains;
1479 
1480   // The number of bit pack operations that depend on a single V_EXP
1481   unsigned PackSuccCount = std::count_if(
1482       PackSUs.begin(), PackSUs.end(), [this, &TempExp](SUnit *VPack) {
1483         return DAG->IsReachable(VPack, *TempExp);
1484       });
1485 
1486   // The number of bit pack operations an MFMA depends on
1487   unsigned PackPredCount =
1488       std::count_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1489                     [&isBitPack](SDep &Pred) {
1490                       auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1491                       return isBitPack(Opc);
1492                     });
1493 
1494   auto *PackPred =
1495       std::find_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1496                    [&isBitPack](SDep &Pred) {
1497                      auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1498                      return isBitPack(Opc);
1499                    });
1500 
1501   if (PackPred == (*TempMFMA)->Preds.end())
1502     return false;
1503 
1504   MFMAEnablement = 0;
1505   ExpRequirement = 0;
1506   // How many MFMAs depend on a single bit pack operation
1507   MFMAEnablement =
1508       std::count_if(PackPred->getSUnit()->Succs.begin(),
1509                     PackPred->getSUnit()->Succs.end(), [&TII](SDep &Succ) {
1510                       return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1511                     });
1512 
1513   // The number of MFMAs that depend on a single V_EXP
1514   MFMAEnablement *= PackSuccCount;
1515 
1516   // The number of V_EXPs required to resolve all dependencies for an MFMA
1517   ExpRequirement =
1518       std::count_if(ExpPipeCands.begin(), ExpPipeCands.end(),
1519                     [this, &PackPred](SUnit *ExpBase) {
1520                       return DAG->IsReachable(PackPred->getSUnit(), ExpBase);
1521                     });
1522 
1523   ExpRequirement *= PackPredCount;
1524   return true;
1525 }
1526 
1527 bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1528                                                AMDGPU::SchedulingPhase Phase) {
1529   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1530   const SIInstrInfo *TII = ST.getInstrInfo();
1531 
1532   if (Phase != AMDGPU::SchedulingPhase::PostRA)
1533     MFMAChainSeeds.clear();
1534   if (Phase != AMDGPU::SchedulingPhase::PostRA && !analyzeDAG(TII))
1535     return false;
1536 
1537   return true;
1538 }
1539 
1540 bool MFMAExpInterleaveOpt::applyIGLPStrategy(
1541     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1542     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1543     AMDGPU::SchedulingPhase Phase) {
1544 
1545   bool IsSmallKernelType =
1546       MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;
1547   bool IsLargeKernelType =
1548       MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;
1549 
1550   if (!(IsSmallKernelType || IsLargeKernelType))
1551     return false;
1552 
1553   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1554   const SIInstrInfo *TII = ST.getInstrInfo();
1555 
1556   unsigned PipelineSyncID = 0;
1557   SchedGroup *SG = nullptr;
1558 
1559   unsigned MFMAChain = 0;
1560   unsigned PositionInChain = 0;
1561   unsigned CurrMFMAForTransPosition = 0;
1562 
1563   auto incrementTransPosition = [&MFMAChain, &PositionInChain,
1564                                  &CurrMFMAForTransPosition]() {
1565     CurrMFMAForTransPosition += MFMAEnablement;
1566     PositionInChain = (CurrMFMAForTransPosition / MFMAChains);
1567     MFMAChain = CurrMFMAForTransPosition % MFMAChains;
1568   };
1569 
1570   auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {
1571     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1572     return (TempMFMAForTrans / MFMAChains);
1573   };
1574 
1575   auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {
1576     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1577     return TempMFMAForTrans % MFMAChains;
1578   };
1579 
1580   unsigned CurrMFMAPosition = 0;
1581   unsigned MFMAChainForMFMA = 0;
1582   unsigned PositionInChainForMFMA = 0;
1583 
1584   auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,
1585                                 &PositionInChainForMFMA]() {
1586     ++CurrMFMAPosition;
1587     MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;
1588     PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;
1589   };
1590 
1591   bool IsPostRA = Phase == AMDGPU::SchedulingPhase::PostRA;
1592   assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);
1593 
1594   bool UsesFMA = IsSmallKernelType || !IsPostRA;
1595   bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;
1596   bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);
1597   bool UsesVALU = IsSmallKernelType;
1598 
1599   // PHASE 1: "Prefetch"
1600   if (UsesFMA) {
1601     // First Round FMA
1602     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1603         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1604     if (!IsPostRA && MFMAChains) {
1605       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1606           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1607           true));
1608     } else
1609       SG->addRule(
1610           std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1611     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1612     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1613 
1614     // Second Round FMA
1615     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1616         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1617     if (!IsPostRA && MFMAChains) {
1618       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1619           getNextTransPositionInChain(),
1620           MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1621     } else
1622       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1623                                                    SG->getSGID(), true));
1624     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1625     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1626   }
1627 
1628   if (UsesDSRead) {
1629     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1630         SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1631     SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1632                                                       SG->getSGID()));
1633     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1634   }
1635 
1636   // First Round EXP
1637   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1638       SchedGroupMask::TRANS, ExpRequirement, PipelineSyncID, DAG, TII);
1639   if (!IsPostRA && MFMAChains)
1640     SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1641         PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(), true));
1642   else
1643     SG->addRule(std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1644   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1645   SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1646                                                HasChainBetweenCvt));
1647   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1648 
1649   incrementTransPosition();
1650 
1651   // First Round CVT, Third Round FMA, Second Round EXP; interleaved
1652   for (unsigned I = 0; I < ExpRequirement; I++) {
1653     // First Round CVT
1654     if (UsesCvt) {
1655       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1656           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1657       SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1658       if (HasChainBetweenCvt)
1659         SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1660             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1661       else
1662         SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(
1663             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1664       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1665     }
1666 
1667     // Third Round FMA
1668     if (UsesFMA) {
1669       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1670           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1671       if (!IsPostRA && MFMAChains) {
1672         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1673             getNextTransPositionInChain(),
1674             MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1675       } else
1676         SG->addRule(std::make_shared<EnablesNthMFMA>(2 * MFMAEnablement + 1,
1677                                                      TII, SG->getSGID(), true));
1678       SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1679       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1680     }
1681 
1682     // Second Round EXP
1683     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1684         SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1685     if (!IsPostRA && MFMAChains)
1686       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1687           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1688           true));
1689     else
1690       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1691                                                    SG->getSGID(), true));
1692     SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1693     SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1694                                                  HasChainBetweenCvt));
1695     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1696   }
1697 
1698   // The "extra" EXP which enables all MFMA
1699   // TODO: UsesExtraExp
1700   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1701       SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1702   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1703   SG->addRule(std::make_shared<GreaterThanOrEqualToNSuccs>(
1704       8, TII, SG->getSGID(), HasChainBetweenCvt));
1705   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1706 
1707   // PHASE 2: Main Interleave Loop
1708 
1709   // The number of MFMAs per iteration
1710   unsigned MFMARatio =
1711       MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;
1712   // The number of Exps per iteration
1713   unsigned ExpRatio =
1714       MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;
1715   // The reamaining Exps
1716   unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)
1717                               ? TransPipeCount - (2 * ExpRequirement)
1718                               : 0;
1719   unsigned ExpLoopCount = RemainingExp / ExpRatio;
1720   // In loop MFMAs
1721   unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)
1722                             ? MFMAPipeCount - (MFMAEnablement * 2)
1723                             : 0;
1724   unsigned MFMALoopCount = MFMAInLoop / MFMARatio;
1725   unsigned VALUOps =
1726       AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;
1727   unsigned LoopSize = std::min(ExpLoopCount, MFMALoopCount);
1728 
1729   for (unsigned I = 0; I < LoopSize; I++) {
1730     if (!(I * ExpRatio % ExpRequirement))
1731       incrementTransPosition();
1732 
1733     // Round N MFMA
1734     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1735         SchedGroupMask::MFMA, MFMARatio, PipelineSyncID, DAG, TII);
1736     if (!IsPostRA && MFMAChains)
1737       SG->addRule(std::make_shared<IsExactMFMA>(
1738           PositionInChainForMFMA, MFMAChainSeeds[MFMAChainForMFMA], TII,
1739           SG->getSGID(), true));
1740     else
1741       SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1742     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1743     incrementMFMAPosition();
1744 
1745     if (UsesVALU) {
1746       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1747           SchedGroupMask::VALU, VALUOps, PipelineSyncID, DAG, TII);
1748       SG->addRule(std::make_shared<IsPipeAdd>(TII, SG->getSGID()));
1749       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1750     }
1751 
1752     if (UsesDSRead && !(I % 4)) {
1753       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1754           SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1755       SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1756                                                         SG->getSGID()));
1757       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1758     }
1759 
1760     // CVT, EXP, FMA Interleaving
1761     for (unsigned J = 0; J < ExpRatio; J++) {
1762       auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);
1763       auto MaxMFMAOffset =
1764           (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;
1765 
1766       // Round N + 1 CVT
1767       if (UsesCvt) {
1768         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1769             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1770         SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1771         auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;
1772         auto DSROffset = I / 4 + 1;
1773         auto MaxDSROffset = MaxMFMAOffset / 4;
1774         // TODO: UsesExtraExp
1775         auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;
1776         auto CurrentOffset = UsesDSRead * std::min(MaxDSROffset, DSROffset) +
1777                              std::min(MaxMFMAOffset, MFMAOffset) + BaseDiff +
1778                              ExpOffset;
1779         if (HasChainBetweenCvt)
1780           SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1781               CurrentOffset, TII, SG->getSGID()));
1782         else
1783           SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(CurrentOffset, TII,
1784                                                              SG->getSGID()));
1785         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1786       }
1787 
1788       // Round N + 3 FMA
1789       if (UsesFMA) {
1790         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1791             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1792         if (!IsPostRA && MFMAChains)
1793           SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1794               getNextTransPositionInChain(),
1795               MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(),
1796               true));
1797         else
1798           SG->addRule(std::make_shared<EnablesNthMFMA>(
1799               (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,
1800               TII, SG->getSGID(), true));
1801         SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1802         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1803       }
1804 
1805       // Round N + 2 Exp
1806       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1807           SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1808       if (!IsPostRA && MFMAChains)
1809         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1810             PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1811             true));
1812       else
1813         SG->addRule(std::make_shared<EnablesNthMFMA>(
1814             (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,
1815             TII, SG->getSGID(), true));
1816       SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1817       SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1818                                                    HasChainBetweenCvt));
1819       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1820     }
1821   }
1822 
1823   // PHASE 3: Remaining MFMAs
1824   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1825       SchedGroupMask::MFMA, MFMAEnablement * 2, PipelineSyncID, DAG, TII);
1826   SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1827   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1828   return true;
1829 }
1830 
1831 class MFMAExpSimpleInterleaveOpt final : public IGLPStrategy {
1832 public:
1833   bool applyIGLPStrategy(
1834       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1835       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1836       AMDGPU::SchedulingPhase Phase) override;
1837 
1838   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1839                            AMDGPU::SchedulingPhase Phase) override {
1840     return true;
1841   }
1842 
1843   MFMAExpSimpleInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1844       : IGLPStrategy(DAG, TII) {
1845     IsBottomUp = true;
1846   }
1847 };
1848 
1849 bool MFMAExpSimpleInterleaveOpt::applyIGLPStrategy(
1850     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1851     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1852     AMDGPU::SchedulingPhase Phase) {
1853   // Count the number of MFMA instructions.
1854   unsigned MFMACount = 0;
1855   for (const MachineInstr &I : *DAG)
1856     if (TII->isMFMAorWMMA(I))
1857       ++MFMACount;
1858 
1859   const unsigned PipelineSyncID = 0;
1860   for (unsigned I = 0; I < MFMACount * 3; ++I) {
1861     SchedGroup *SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1862         SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1863     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1864 
1865     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1866         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1867     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1868   }
1869 
1870   return true;
1871 }
1872 
1873 class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1874 private:
1875   // Whether the DS_READ is a predecessor of first four MFMA in region
1876   class EnablesInitialMFMA final : public InstructionRule {
1877   public:
1878     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1879                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1880       if (!SyncPipe.size())
1881         return false;
1882       int MFMAsFound = 0;
1883       if (!Cache->size()) {
1884         for (auto &Elt : SyncPipe[0].DAG->SUnits) {
1885           if (TII->isMFMAorWMMA(*Elt.getInstr())) {
1886             ++MFMAsFound;
1887             if (MFMAsFound > 4)
1888               break;
1889             Cache->push_back(&Elt);
1890           }
1891         }
1892       }
1893 
1894       assert(Cache->size());
1895       auto *DAG = SyncPipe[0].DAG;
1896       for (auto &Elt : *Cache) {
1897         if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
1898           return true;
1899       }
1900       return false;
1901     }
1902 
1903     EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
1904                        bool NeedsCache = false)
1905         : InstructionRule(TII, SGID, NeedsCache) {}
1906   };
1907 
1908   // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
1909   class IsPermForDSW final : public InstructionRule {
1910   public:
1911     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1912                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1913       auto *MI = SU->getInstr();
1914       if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
1915         return false;
1916 
1917       bool FitsInGroup = false;
1918       // Does the VALU have a DS_WRITE successor
1919       if (!Collection.size()) {
1920         for (auto &Succ : SU->Succs) {
1921           SUnit *SuccUnit = Succ.getSUnit();
1922           if (TII->isDS(*SuccUnit->getInstr()) &&
1923               SuccUnit->getInstr()->mayStore()) {
1924             Cache->push_back(SuccUnit);
1925             FitsInGroup = true;
1926           }
1927         }
1928         return FitsInGroup;
1929       }
1930 
1931       assert(Cache->size());
1932 
1933       // Does the VALU have a DS_WRITE successor that is the same as other
1934       // VALU already in the group. The V_PERMs will all share 1 DS_W succ
1935       return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
1936         return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
1937           return ThisSucc.getSUnit() == Elt;
1938         });
1939       });
1940     }
1941 
1942     IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1943         : InstructionRule(TII, SGID, NeedsCache) {}
1944   };
1945 
1946   // Whether the SU is a successor of any element in previous SchedGroup
1947   class IsSuccOfPrevGroup final : public InstructionRule {
1948   public:
1949     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1950                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1951       SchedGroup *OtherGroup = nullptr;
1952       for (auto &PipeSG : SyncPipe) {
1953         if ((unsigned)PipeSG.getSGID() == SGID - 1) {
1954           OtherGroup = &PipeSG;
1955         }
1956       }
1957 
1958       if (!OtherGroup)
1959         return false;
1960       if (!OtherGroup->Collection.size())
1961         return true;
1962 
1963       // Does the previous VALU have this DS_Write as a successor
1964       return any_of(OtherGroup->Collection, [&SU](SUnit *Elt) {
1965         return any_of(Elt->Succs,
1966                       [&SU](SDep &Succ) { return Succ.getSUnit() == SU; });
1967       });
1968     }
1969     IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1970                       bool NeedsCache = false)
1971         : InstructionRule(TII, SGID, NeedsCache) {}
1972   };
1973 
1974   // Whether the combined load width of group is 128 bits
1975   class VMEMSize final : public InstructionRule {
1976   public:
1977     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1978                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1979       auto *MI = SU->getInstr();
1980       if (MI->getOpcode() == TargetOpcode::BUNDLE)
1981         return false;
1982       if (!Collection.size())
1983         return true;
1984 
1985       int NumBits = 0;
1986 
1987       auto TRI = TII->getRegisterInfo();
1988       auto &MRI = MI->getParent()->getParent()->getRegInfo();
1989       for (auto &Elt : Collection) {
1990         auto Op = Elt->getInstr()->getOperand(0);
1991         auto Size =
1992             TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1993         NumBits += Size;
1994       }
1995 
1996       if (NumBits < 128) {
1997         assert(TII->isVMEM(*MI) && MI->mayLoad());
1998         if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1999                           MRI, MI->getOperand(0))) <=
2000             128)
2001           return true;
2002       }
2003 
2004       return false;
2005     }
2006 
2007     VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
2008         : InstructionRule(TII, SGID, NeedsCache) {}
2009   };
2010 
2011   /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
2012   /// that is \p Distance steps away
2013   class SharesPredWithPrevNthGroup final : public InstructionRule {
2014   private:
2015     unsigned Distance = 1;
2016 
2017   public:
2018     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2019                SmallVectorImpl<SchedGroup> &SyncPipe) override {
2020       SchedGroup *OtherGroup = nullptr;
2021       if (!SyncPipe.size())
2022         return false;
2023 
2024       if (!Cache->size()) {
2025 
2026         for (auto &PipeSG : SyncPipe) {
2027           if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
2028             OtherGroup = &PipeSG;
2029           }
2030         }
2031 
2032         if (!OtherGroup)
2033           return false;
2034         if (!OtherGroup->Collection.size())
2035           return true;
2036 
2037         for (auto &OtherEle : OtherGroup->Collection) {
2038           for (auto &Pred : OtherEle->Preds) {
2039             if (Pred.getSUnit()->getInstr()->getOpcode() ==
2040                 AMDGPU::V_PERM_B32_e64)
2041               Cache->push_back(Pred.getSUnit());
2042           }
2043         }
2044 
2045         // If the other group has no PERM preds, then this group won't share any
2046         if (!Cache->size())
2047           return false;
2048       }
2049 
2050       auto *DAG = SyncPipe[0].DAG;
2051       // Does the previous DS_WRITE share a V_PERM predecessor with this
2052       // VMEM_READ
2053       return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
2054         return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
2055       });
2056     }
2057     SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
2058                                unsigned SGID, bool NeedsCache = false)
2059         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
2060   };
2061 
2062 public:
2063   bool applyIGLPStrategy(
2064       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2065       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2066       AMDGPU::SchedulingPhase Phase) override;
2067 
2068   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
2069                            AMDGPU::SchedulingPhase Phase) override {
2070     return true;
2071   }
2072 
2073   MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
2074       : IGLPStrategy(DAG, TII) {
2075     IsBottomUp = false;
2076   }
2077 };
2078 
2079 static unsigned DSWCount = 0;
2080 static unsigned DSWWithPermCount = 0;
2081 static unsigned DSWWithSharedVMEMCount = 0;
2082 
2083 bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
2084     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2085     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2086     AMDGPU::SchedulingPhase Phase) {
2087   unsigned MFMACount = 0;
2088   unsigned DSRCount = 0;
2089 
2090   bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;
2091 
2092   assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&
2093                          DSWWithSharedVMEMCount == 0)) &&
2094          "DSWCounters should be zero in pre-RA scheduling!");
2095   SmallVector<SUnit *, 6> DSWithPerms;
2096   for (auto &SU : DAG->SUnits) {
2097     auto *I = SU.getInstr();
2098     if (TII->isMFMAorWMMA(*I))
2099       ++MFMACount;
2100     else if (TII->isDS(*I)) {
2101       if (I->mayLoad())
2102         ++DSRCount;
2103       else if (I->mayStore() && IsInitial) {
2104         ++DSWCount;
2105         for (auto Pred : SU.Preds) {
2106           if (Pred.getSUnit()->getInstr()->getOpcode() ==
2107               AMDGPU::V_PERM_B32_e64) {
2108             DSWithPerms.push_back(&SU);
2109             break;
2110           }
2111         }
2112       }
2113     }
2114   }
2115 
2116   if (IsInitial) {
2117     DSWWithPermCount = DSWithPerms.size();
2118     auto *I = DSWithPerms.begin();
2119     auto *E = DSWithPerms.end();
2120 
2121     // Get the count of DS_WRITES with V_PERM predecessors which
2122     // have loop carried dependencies (WAR) on the same VMEM_READs.
2123     // We consider partial overlap as a miss -- in other words,
2124     // for a given DS_W, we only consider another DS_W as matching
2125     // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
2126     // for every V_PERM pred of this DS_W.
2127     DenseMap<MachineInstr *, SUnit *> VMEMLookup;
2128     SmallVector<SUnit *, 6> Counted;
2129     for (; I != E; I++) {
2130       SUnit *Cand = nullptr;
2131       bool MissedAny = false;
2132       for (auto &Pred : (*I)->Preds) {
2133         if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
2134           continue;
2135 
2136         if (Cand && llvm::is_contained(Counted, Cand))
2137           break;
2138 
2139         for (auto &Succ : Pred.getSUnit()->Succs) {
2140           auto *MI = Succ.getSUnit()->getInstr();
2141           if (!TII->isVMEM(*MI) || !MI->mayLoad())
2142             continue;
2143 
2144           if (MissedAny || !VMEMLookup.size()) {
2145             MissedAny = true;
2146             VMEMLookup[MI] = *I;
2147             continue;
2148           }
2149 
2150           auto [It, Inserted] = VMEMLookup.try_emplace(MI, *I);
2151           if (Inserted) {
2152             MissedAny = true;
2153             continue;
2154           }
2155 
2156           Cand = It->second;
2157           if (llvm::is_contained(Counted, Cand)) {
2158             MissedAny = true;
2159             break;
2160           }
2161         }
2162       }
2163       if (!MissedAny && Cand) {
2164         DSWWithSharedVMEMCount += 2;
2165         Counted.push_back(Cand);
2166         Counted.push_back(*I);
2167       }
2168     }
2169   }
2170 
2171   assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
2172   SchedGroup *SG;
2173   unsigned PipelineSyncID = 0;
2174   // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
2175   if (DSWWithPermCount) {
2176     for (unsigned I = 0; I < MFMACount; I++) {
2177       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2178           SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2179       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2180 
2181       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2182           SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
2183       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2184     }
2185   }
2186 
2187   PipelineSyncID = 1;
2188   // Phase 1: Break up DS_READ and MFMA clusters.
2189   // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
2190   // prefetch
2191 
2192   // Make ready initial MFMA
2193   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2194       SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
2195   SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
2196   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2197 
2198   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2199       SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2200   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2201 
2202   // Interleave MFMA with DS_READ prefetch
2203   for (unsigned I = 0; I < DSRCount - 4; ++I) {
2204     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2205         SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
2206     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2207 
2208     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2209         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2210     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2211   }
2212 
2213   // Phase 2a: Loop carried dependency with V_PERM
2214   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2215   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2216   for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
2217     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2218         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2219     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2220     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2221 
2222     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2223         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2224     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2225     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2226 
2227     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2228         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2229     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2230         1, TII, SG->getSGID(), true));
2231     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2232     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2233 
2234     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2235         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2236     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2237 
2238     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2239         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2240     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2241         3, TII, SG->getSGID(), true));
2242     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2243     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2244 
2245     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2246         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2247     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2248   }
2249 
2250   // Phase 2b: Loop carried dependency without V_PERM
2251   // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
2252   // Interleave MFMA to keep XDL unit busy throughout.
2253   for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
2254     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2255         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2256     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2257 
2258     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2259         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2260     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2261     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2262 
2263     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2264         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2265     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2266   }
2267 
2268   // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
2269   // ultimately used by two DS_WRITE
2270   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2271   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2272 
2273   for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
2274     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2275         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2276     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2277     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2278 
2279     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2280         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2281     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2282     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2283 
2284     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2285         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2286     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2287 
2288     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2289         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2290     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2291     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2292 
2293     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2294         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2295     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2296     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2297 
2298     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2299         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2300     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2301 
2302     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2303         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2304     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2305         2, TII, SG->getSGID(), true));
2306     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2307     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2308 
2309     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2310         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2311     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2312 
2313     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2314         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2315     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2316         4, TII, SG->getSGID(), true));
2317     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2318     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2319 
2320     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2321         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2322     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2323   }
2324 
2325   return true;
2326 }
2327 
2328 static std::unique_ptr<IGLPStrategy>
2329 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
2330                    const SIInstrInfo *TII) {
2331   switch (ID) {
2332   case MFMASmallGemmOptID:
2333     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
2334   case MFMASmallGemmSingleWaveOptID:
2335     return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
2336   case MFMAExpInterleaveID:
2337     return std::make_unique<MFMAExpInterleaveOpt>(DAG, TII);
2338   case MFMAExpSimpleInterleaveID:
2339     return std::make_unique<MFMAExpSimpleInterleaveOpt>(DAG, TII);
2340   }
2341 
2342   llvm_unreachable("Unknown IGLPStrategyID");
2343 }
2344 
2345 class IGroupLPDAGMutation : public ScheduleDAGMutation {
2346 private:
2347   const SIInstrInfo *TII;
2348 
2349   ScheduleDAGMI *DAG;
2350 
2351   // Organize lists of SchedGroups by their SyncID. SchedGroups /
2352   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
2353   // between then.
2354   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
2355 
2356   // Used to track instructions that can be mapped to multiple sched groups
2357   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
2358 
2359   // Add DAG edges that enforce SCHED_BARRIER ordering.
2360   void addSchedBarrierEdges(SUnit &SU);
2361 
2362   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
2363   // not be reordered accross the SCHED_BARRIER. This is used for the base
2364   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
2365   // SCHED_BARRIER will always block all instructions that can be classified
2366   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
2367   // and may only synchronize with some SchedGroups. Returns the inverse of
2368   // Mask. SCHED_BARRIER's mask describes which instruction types should be
2369   // allowed to be scheduled across it. Invert the mask to get the
2370   // SchedGroupMask of instructions that should be barred.
2371   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
2372 
2373   // Create SchedGroups for a SCHED_GROUP_BARRIER.
2374   void initSchedGroupBarrierPipelineStage(
2375       std::vector<SUnit>::reverse_iterator RIter);
2376 
2377   bool initIGLPOpt(SUnit &SU);
2378 
2379 public:
2380   void apply(ScheduleDAGInstrs *DAGInstrs) override;
2381 
2382   // The order in which the PipelineSolver should process the candidate
2383   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
2384   // created SchedGroup first, and will consider that as the ultimate
2385   // predecessor group when linking. TOP_DOWN instead links and processes the
2386   // first created SchedGroup first.
2387   bool IsBottomUp = true;
2388 
2389   // The scheduling phase this application of IGLP corresponds with.
2390   AMDGPU::SchedulingPhase Phase = AMDGPU::SchedulingPhase::Initial;
2391 
2392   IGroupLPDAGMutation() = default;
2393   IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) : Phase(Phase) {}
2394 };
2395 
2396 unsigned SchedGroup::NumSchedGroups = 0;
2397 
2398 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2399   if (A != B && DAG->canAddEdge(B, A)) {
2400     DAG->addEdge(B, SDep(A, SDep::Artificial));
2401     return true;
2402   }
2403   return false;
2404 }
2405 
2406 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
2407   bool Result = false;
2408   if (MI.isMetaInstruction())
2409     Result = false;
2410 
2411   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
2412            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||
2413             TII->isTRANS(MI)))
2414     Result = true;
2415 
2416   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
2417            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI) && !TII->isTRANS(MI))
2418     Result = true;
2419 
2420   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
2421            TII->isSALU(MI))
2422     Result = true;
2423 
2424   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
2425            TII->isMFMAorWMMA(MI))
2426     Result = true;
2427 
2428   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
2429            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2430     Result = true;
2431 
2432   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
2433            MI.mayLoad() &&
2434            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2435     Result = true;
2436 
2437   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
2438            MI.mayStore() &&
2439            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2440     Result = true;
2441 
2442   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
2443            TII->isDS(MI))
2444     Result = true;
2445 
2446   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
2447            MI.mayLoad() && TII->isDS(MI))
2448     Result = true;
2449 
2450   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
2451            MI.mayStore() && TII->isDS(MI))
2452     Result = true;
2453 
2454   else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
2455            TII->isTRANS(MI))
2456     Result = true;
2457 
2458   LLVM_DEBUG(
2459       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
2460              << (Result ? " could classify " : " unable to classify ") << MI);
2461 
2462   return Result;
2463 }
2464 
2465 int SchedGroup::link(SUnit &SU, bool MakePred,
2466                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2467   int MissedEdges = 0;
2468   for (auto *A : Collection) {
2469     SUnit *B = &SU;
2470     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2471       continue;
2472     if (MakePred)
2473       std::swap(A, B);
2474 
2475     if (DAG->IsReachable(B, A))
2476       continue;
2477 
2478     // tryAddEdge returns false if there is a dependency that makes adding
2479     // the A->B edge impossible, otherwise it returns true;
2480     bool Added = tryAddEdge(A, B);
2481     if (Added)
2482       AddedEdges.emplace_back(A, B);
2483     else
2484       ++MissedEdges;
2485   }
2486 
2487   return MissedEdges;
2488 }
2489 
2490 void SchedGroup::link(SUnit &SU, bool MakePred) {
2491   for (auto *A : Collection) {
2492     SUnit *B = &SU;
2493     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2494       continue;
2495     if (MakePred)
2496       std::swap(A, B);
2497 
2498     tryAddEdge(A, B);
2499   }
2500 }
2501 
2502 void SchedGroup::link(SUnit &SU,
2503                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
2504   for (auto *A : Collection) {
2505     SUnit *B = &SU;
2506     if (P(A, B))
2507       std::swap(A, B);
2508 
2509     tryAddEdge(A, B);
2510   }
2511 }
2512 
2513 void SchedGroup::link(SchedGroup &OtherGroup) {
2514   for (auto *B : OtherGroup.Collection)
2515     link(*B);
2516 }
2517 
2518 bool SchedGroup::canAddSU(SUnit &SU) const {
2519   MachineInstr &MI = *SU.getInstr();
2520   if (MI.getOpcode() != TargetOpcode::BUNDLE)
2521     return canAddMI(MI);
2522 
2523   // Special case for bundled MIs.
2524   const MachineBasicBlock *MBB = MI.getParent();
2525   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
2526   while (E != MBB->end() && E->isBundledWithPred())
2527     ++E;
2528 
2529   // Return true if all of the bundled MIs can be added to this group.
2530   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
2531 }
2532 
2533 void SchedGroup::initSchedGroup() {
2534   for (auto &SU : DAG->SUnits) {
2535     if (isFull())
2536       break;
2537 
2538     if (canAddSU(SU))
2539       add(SU);
2540   }
2541 }
2542 
2543 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
2544                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
2545   SUnit &InitSU = *RIter;
2546   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
2547     auto &SU = *RIter;
2548     if (isFull())
2549       break;
2550 
2551     if (canAddSU(SU))
2552       SyncedInstrs[&SU].push_back(SGID);
2553   }
2554 
2555   add(InitSU);
2556   assert(MaxSize);
2557   (*MaxSize)++;
2558 }
2559 
2560 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
2561   auto I = DAG->SUnits.rbegin();
2562   auto E = DAG->SUnits.rend();
2563   for (; I != E; ++I) {
2564     auto &SU = *I;
2565     if (isFull())
2566       break;
2567     if (canAddSU(SU))
2568       SyncedInstrs[&SU].push_back(SGID);
2569   }
2570 }
2571 
2572 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
2573   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
2574   if (!TSchedModel || DAGInstrs->SUnits.empty())
2575     return;
2576 
2577   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
2578   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
2579   TII = ST.getInstrInfo();
2580   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
2581   SyncedSchedGroups.clear();
2582   SyncedInstrs.clear();
2583   bool FoundSB = false;
2584   bool FoundIGLP = false;
2585   bool ShouldApplyIGLP = false;
2586   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
2587     unsigned Opc = R->getInstr()->getOpcode();
2588     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
2589     if (Opc == AMDGPU::SCHED_BARRIER) {
2590       addSchedBarrierEdges(*R);
2591       FoundSB = true;
2592     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
2593       initSchedGroupBarrierPipelineStage(R);
2594       FoundSB = true;
2595     } else if (Opc == AMDGPU::IGLP_OPT) {
2596       if (!FoundSB && !FoundIGLP) {
2597         FoundIGLP = true;
2598         ShouldApplyIGLP = initIGLPOpt(*R);
2599       }
2600     }
2601   }
2602 
2603   if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2604     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
2605     // PipelineSolver performs the mutation by adding the edges it
2606     // determined as the best
2607     PS.solve();
2608     return;
2609   }
2610 }
2611 
2612 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
2613   MachineInstr &MI = *SchedBarrier.getInstr();
2614   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
2615   // Remove all existing edges from the SCHED_BARRIER that were added due to the
2616   // instruction having side effects.
2617   LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
2618                     << MI.getOperand(0).getImm() << "\n");
2619   auto InvertedMask =
2620       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
2621   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
2622   SG.initSchedGroup();
2623 
2624   // Preserve original instruction ordering relative to the SCHED_BARRIER.
2625   SG.link(
2626       SchedBarrier,
2627       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
2628           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
2629 }
2630 
2631 SchedGroupMask
2632 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
2633   // Invert mask and erase bits for types of instructions that are implied to be
2634   // allowed past the SCHED_BARRIER.
2635   SchedGroupMask InvertedMask = ~Mask;
2636 
2637   // ALU implies VALU, SALU, MFMA, TRANS.
2638   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
2639     InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
2640                     ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
2641   // VALU, SALU, MFMA, TRANS implies ALU.
2642   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
2643            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
2644            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
2645            (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
2646     InvertedMask &= ~SchedGroupMask::ALU;
2647 
2648   // VMEM implies VMEM_READ, VMEM_WRITE.
2649   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
2650     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
2651   // VMEM_READ, VMEM_WRITE implies VMEM.
2652   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
2653            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
2654     InvertedMask &= ~SchedGroupMask::VMEM;
2655 
2656   // DS implies DS_READ, DS_WRITE.
2657   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
2658     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
2659   // DS_READ, DS_WRITE implies DS.
2660   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
2661            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
2662     InvertedMask &= ~SchedGroupMask::DS;
2663 
2664   LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
2665                     << "\n");
2666 
2667   return InvertedMask;
2668 }
2669 
2670 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
2671     std::vector<SUnit>::reverse_iterator RIter) {
2672   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
2673   // to the instruction having side effects.
2674   MachineInstr &SGB = *RIter->getInstr();
2675   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
2676   int32_t SGMask = SGB.getOperand(0).getImm();
2677   int32_t Size = SGB.getOperand(1).getImm();
2678   int32_t SyncID = SGB.getOperand(2).getImm();
2679 
2680   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
2681                                                     Size, SyncID, DAG, TII);
2682 
2683   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
2684 }
2685 
2686 bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
2687   IGLPStrategyID StrategyID =
2688       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
2689   auto S = createIGLPStrategy(StrategyID, DAG, TII);
2690   if (!S->shouldApplyStrategy(DAG, Phase))
2691     return false;
2692 
2693   IsBottomUp = S->IsBottomUp;
2694   return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);
2695 }
2696 
2697 } // namespace
2698 
2699 namespace llvm {
2700 
2701 /// \p Phase specifes whether or not this is a reentry into the
2702 /// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
2703 /// same scheduling region (e.g. pre and post-RA scheduling / multiple
2704 /// scheduling "phases"), we can reenter this mutation framework more than once
2705 /// for a given region.
2706 std::unique_ptr<ScheduleDAGMutation>
2707 createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) {
2708   return std::make_unique<IGroupLPDAGMutation>(Phase);
2709 }
2710 
2711 } // end namespace llvm
2712