xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision cf1c97b2d29c51d6c2e79454f6ec3d1f8f98e672)
1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "AMDGPUTargetMachine.h"
20 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
21 #include "SIInstrInfo.h"
22 #include "SIMachineFunctionInfo.h"
23 #include "llvm/ADT/BitmaskEnum.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/CodeGen/MachineScheduler.h"
26 #include "llvm/CodeGen/TargetOpcodes.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "igrouplp"
31 
32 namespace {
33 
34 static cl::opt<bool> EnableExactSolver(
35     "amdgpu-igrouplp-exact-solver", cl::Hidden,
36     cl::desc("Whether to use the exponential time solver to fit "
37              "the instructions to the pipeline as closely as "
38              "possible."),
39     cl::init(false));
40 
41 static cl::opt<unsigned> CutoffForExact(
42     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43     cl::desc("The maximum number of scheduling group conflicts "
44              "which we attempt to solve with the exponential time "
45              "exact solver. Problem sizes greater than this will"
46              "be solved by the less accurate greedy algorithm. Selecting "
47              "solver by size is superseded by manually selecting "
48              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
49 
50 static cl::opt<uint64_t> MaxBranchesExplored(
51     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52     cl::desc("The amount of branches that we are willing to explore with"
53              "the exact algorithm before giving up."));
54 
55 static cl::opt<bool> UseCostHeur(
56     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57     cl::desc("Whether to use the cost heuristic to make choices as we "
58              "traverse the search space using the exact solver. Defaulted "
59              "to on, and if turned off, we will use the node order -- "
60              "attempting to put the later nodes in the later sched groups. "
61              "Experimentally, results are mixed, so this should be set on a "
62              "case-by-case basis."));
63 
64 // Components of the mask that determines which instruction types may be may be
65 // classified into a SchedGroup.
66 enum class SchedGroupMask {
67   NONE = 0u,
68   ALU = 1u << 0,
69   VALU = 1u << 1,
70   SALU = 1u << 2,
71   MFMA = 1u << 3,
72   VMEM = 1u << 4,
73   VMEM_READ = 1u << 5,
74   VMEM_WRITE = 1u << 6,
75   DS = 1u << 7,
76   DS_READ = 1u << 8,
77   DS_WRITE = 1u << 9,
78   TRANS = 1u << 10,
79   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
80         DS_READ | DS_WRITE | TRANS,
81   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
82 };
83 
84 class SchedGroup;
85 
86 // InstructionRule class is used to enact a filter which determines whether or
87 // not an SU maps to a given SchedGroup. It contains complementary data
88 // structures (e.g Cache) to help those filters.
89 class InstructionRule {
90 protected:
91   const SIInstrInfo *TII;
92   unsigned SGID;
93   // A cache made available to the Filter to store SUnits for subsequent
94   // invocations of the Filter
95   std::optional<SmallVector<SUnit *, 4>> Cache;
96 
97 public:
98   virtual bool
99   apply(const SUnit *, const ArrayRef<SUnit *>,
100         SmallVectorImpl<SchedGroup> &) {
101     return true;
102   };
103 
104   InstructionRule(const SIInstrInfo *TII, unsigned SGID,
105                   bool NeedsCache = false)
106       : TII(TII), SGID(SGID) {
107     if (NeedsCache) {
108       Cache = SmallVector<SUnit *, 4>();
109     }
110   }
111 
112   virtual ~InstructionRule() = default;
113 };
114 
115 typedef DenseMap<SUnit *, SmallVector<int, 4>> SUnitsToCandidateSGsMap;
116 
117 // Classify instructions into groups to enable fine tuned control over the
118 // scheduler. These groups may be more specific than current SchedModel
119 // instruction classes.
120 class SchedGroup {
121 private:
122   // Mask that defines which instruction types can be classified into this
123   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
124   // and SCHED_GROUP_BARRIER.
125   SchedGroupMask SGMask;
126 
127   // Maximum number of SUnits that can be added to this group.
128   std::optional<unsigned> MaxSize;
129 
130   // SchedGroups will only synchronize with other SchedGroups that have the same
131   // SyncID.
132   int SyncID = 0;
133 
134   // SGID is used to map instructions to candidate SchedGroups
135   unsigned SGID;
136 
137   // The different rules each instruction in this SchedGroup must conform to
138   SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
139 
140   // Count of the number of created SchedGroups, used to initialize SGID.
141   static unsigned NumSchedGroups;
142 
143   // Try to add and edge from SU A to SU B.
144   bool tryAddEdge(SUnit *A, SUnit *B);
145 
146   // Use SGMask to determine whether we can classify MI as a member of this
147   // SchedGroup object.
148   bool canAddMI(const MachineInstr &MI) const;
149 
150 public:
151   // Collection of SUnits that are classified as members of this group.
152   SmallVector<SUnit *, 32> Collection;
153 
154   ScheduleDAGInstrs *DAG;
155   const SIInstrInfo *TII;
156 
157   // Returns true if SU can be added to this SchedGroup.
158   bool canAddSU(SUnit &SU) const;
159 
160   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
161   // MakePred is true, SU will be a predecessor of the SUnits in this
162   // SchedGroup, otherwise SU will be a successor.
163   void link(SUnit &SU, bool MakePred = false);
164 
165   // Add DAG dependencies and track which edges are added, and the count of
166   // missed edges
167   int link(SUnit &SU, bool MakePred,
168            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
169 
170   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
171   // Use the predicate to determine whether SU should be a predecessor (P =
172   // true) or a successor (P = false) of this SchedGroup.
173   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
174 
175   // Add DAG dependencies such that SUnits in this group shall be ordered
176   // before SUnits in OtherGroup.
177   void link(SchedGroup &OtherGroup);
178 
179   // Returns true if no more instructions may be added to this group.
180   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
181 
182   // Append a constraint that SUs must meet in order to fit into this
183   // SchedGroup. Since many rules involve the relationship between a SchedGroup
184   // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
185   // time (rather than SchedGroup init time.)
186   void addRule(std::shared_ptr<InstructionRule> NewRule) {
187     Rules.push_back(NewRule);
188   }
189 
190   // Returns true if the SU matches all rules
191   bool allowedByRules(const SUnit *SU,
192                       SmallVectorImpl<SchedGroup> &SyncPipe) const {
193     if (Rules.empty())
194       return true;
195     for (size_t I = 0; I < Rules.size(); I++) {
196       auto TheRule = Rules[I].get();
197       if (!TheRule->apply(SU, Collection, SyncPipe)) {
198         return false;
199       }
200     }
201     return true;
202   }
203 
204   // Add SU to the SchedGroup.
205   void add(SUnit &SU) {
206     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
207                       << format_hex((int)SGMask, 10, true) << " adding "
208                       << *SU.getInstr());
209     Collection.push_back(&SU);
210   }
211 
212   // Remove last element in the SchedGroup
213   void pop() { Collection.pop_back(); }
214 
215   // Identify and add all relevant SUs from the DAG to this SchedGroup.
216   void initSchedGroup();
217 
218   // Add instructions to the SchedGroup bottom up starting from RIter.
219   // PipelineInstrs is a set of instructions that should not be added to the
220   // SchedGroup even when the other conditions for adding it are satisfied.
221   // RIter will be added to the SchedGroup as well, and dependencies will be
222   // added so that RIter will always be scheduled at the end of the group.
223   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
224                       SUnitsToCandidateSGsMap &SyncedInstrs);
225 
226   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
227 
228   int getSyncID() { return SyncID; }
229 
230   int getSGID() { return SGID; }
231 
232   SchedGroupMask getMask() { return SGMask; }
233 
234   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
235              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
236       : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
237     SGID = NumSchedGroups++;
238   }
239 
240   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
241              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
242       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
243     SGID = NumSchedGroups++;
244   }
245 };
246 
247 // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
248 static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
249   assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
250          SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
251          SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
252 
253   while (!SU.Preds.empty())
254     for (auto &P : SU.Preds)
255       SU.removePred(P);
256 
257   while (!SU.Succs.empty())
258     for (auto &S : SU.Succs)
259       for (auto &SP : S.getSUnit()->Preds)
260         if (SP.getSUnit() == &SU)
261           S.getSUnit()->removePred(SP);
262 }
263 
264 typedef std::pair<SUnit *, SmallVector<int, 4>> SUToCandSGsPair;
265 typedef SmallVector<SUToCandSGsPair, 4> SUsToCandSGsVec;
266 
267 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
268 // in non-trivial cases. For example, if the requested pipeline is
269 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
270 // in the DAG, then we will have an instruction that can not be trivially
271 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
272 // to find a good solution to the pipeline -- a greedy algorithm and an exact
273 // algorithm. The exact algorithm has an exponential time complexity and should
274 // only be used for small sized problems or medium sized problems where an exact
275 // solution is highly desired.
276 class PipelineSolver {
277   ScheduleDAGMI *DAG;
278 
279   // Instructions that can be assigned to multiple SchedGroups
280   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
281   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
282   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
283   // The current working pipeline
284   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
285   // The pipeline that has the best solution found so far
286   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
287 
288   // Whether or not we actually have any SyncedInstrs to try to solve.
289   bool NeedsSolver = false;
290 
291   // Compute an estimate of the size of search tree -- the true size is
292   // the product of each conflictedInst.Matches.size() across all SyncPipelines
293   unsigned computeProblemSize();
294 
295   // The cost penalty of not assigning a SU to a SchedGroup
296   int MissPenalty = 0;
297 
298   // Costs in terms of the number of edges we are unable to add
299   int BestCost = -1;
300   int CurrCost = 0;
301 
302   // Index pointing to the conflicting instruction that is currently being
303   // fitted
304   int CurrConflInstNo = 0;
305   // Index to the pipeline that is currently being fitted
306   int CurrSyncGroupIdx = 0;
307   // The first non trivial pipeline
308   int BeginSyncGroupIdx = 0;
309 
310   // How many branches we have explored
311   uint64_t BranchesExplored = 0;
312 
313   // The direction in which we process the candidate SchedGroups per SU
314   bool IsBottomUp = 1;
315 
316   // Update indices to fit next conflicting instruction
317   void advancePosition();
318   // Recede indices to attempt to find better fit for previous conflicting
319   // instruction
320   void retreatPosition();
321 
322   // The exponential time algorithm which finds the provably best fit
323   bool solveExact();
324   // The polynomial time algorithm which attempts to find a good fit
325   bool solveGreedy();
326   // Find the best SchedGroup for the current SU using the heuristic given all
327   // current information. One step in the greedy algorithm. Templated against
328   // the SchedGroup iterator (either reverse or forward).
329   template <typename T>
330   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
331                   T E);
332   // Whether or not the current solution is optimal
333   bool checkOptimal();
334   // Populate the ready list, prioiritizing fewest missed edges first
335   // Templated against the SchedGroup iterator (either reverse or forward).
336   template <typename T>
337   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
338                          T E);
339   // Add edges corresponding to the SchedGroups as assigned by solver
340   void makePipeline();
341   // Link the SchedGroups in the best found pipeline.
342   // Tmplated against the SchedGroup iterator (either reverse or forward).
343   template <typename T> void linkSchedGroups(T I, T E);
344   // Add the edges from the SU to the other SchedGroups in pipeline, and
345   // return the number of edges missed.
346   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
347                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
348   /// Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
349   /// returns the cost (in terms of missed pipeline edges), and tracks the edges
350   /// added in \p AddedEdges
351   template <typename T>
352   int linkSUnit(SUnit *SU, int SGID,
353                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
354   /// Remove the edges passed via \p AddedEdges
355   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
356   // Convert the passed in maps to arrays for bidirectional iterators
357   void convertSyncMapsToArrays();
358 
359   void reset();
360 
361 public:
362   // Invoke the solver to map instructions to instruction groups. Heuristic &&
363   // command-line-option determines to use exact or greedy algorithm.
364   void solve();
365 
366   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
367                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
368                  ScheduleDAGMI *DAG, bool IsBottomUp = 1)
369       : DAG(DAG), SyncedInstrs(SyncedInstrs),
370         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
371 
372     for (auto &PipelineInstrs : SyncedInstrs) {
373       if (PipelineInstrs.second.size() > 0) {
374         NeedsSolver = true;
375         break;
376       }
377     }
378 
379     if (!NeedsSolver)
380       return;
381 
382     convertSyncMapsToArrays();
383 
384     CurrPipeline = BestPipeline;
385 
386     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
387            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
388       ++BeginSyncGroupIdx;
389 
390     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
391       return;
392   }
393 };
394 
395 void PipelineSolver::reset() {
396 
397   for (auto &SyncPipeline : CurrPipeline) {
398     for (auto &SG : SyncPipeline) {
399       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
400       SG.Collection.clear();
401       auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
402         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
403       });
404       if (SchedBarr != TempCollection.end())
405         SG.Collection.push_back(*SchedBarr);
406     }
407   }
408 
409   CurrSyncGroupIdx = BeginSyncGroupIdx;
410   CurrConflInstNo = 0;
411   CurrCost = 0;
412 }
413 
414 void PipelineSolver::convertSyncMapsToArrays() {
415   for (auto &SyncPipe : SyncedSchedGroups) {
416     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
417   }
418 
419   int PipelineIDx = SyncedInstrs.size() - 1;
420   PipelineInstrs.resize(SyncedInstrs.size());
421   for (auto &SyncInstrMap : SyncedInstrs) {
422     for (auto &SUsToCandSGs : SyncInstrMap.second) {
423       if (PipelineInstrs[PipelineIDx].size() == 0) {
424         PipelineInstrs[PipelineIDx].push_back(
425             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
426         continue;
427       }
428       auto SortPosition = PipelineInstrs[PipelineIDx].begin();
429       // Insert them in sorted order -- this allows for good parsing order in
430       // the greedy algorithm
431       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
432              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
433         ++SortPosition;
434       PipelineInstrs[PipelineIDx].insert(
435           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
436     }
437     --PipelineIDx;
438   }
439 }
440 
441 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
442   for (; I != E; ++I) {
443     auto &GroupA = *I;
444     for (auto J = std::next(I); J != E; ++J) {
445       auto &GroupB = *J;
446       GroupA.link(GroupB);
447     }
448   }
449 }
450 
451 void PipelineSolver::makePipeline() {
452   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
453   for (auto &SyncPipeline : BestPipeline) {
454     LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
455     for (auto &SG : SyncPipeline) {
456       LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
457                         << " has: \n");
458       SUnit *SGBarr = nullptr;
459       for (auto &SU : SG.Collection) {
460         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
461           SGBarr = SU;
462         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
463       }
464       // Command line requested IGroupLP doesn't have SGBarr
465       if (!SGBarr)
466         continue;
467       resetEdges(*SGBarr, DAG);
468       SG.link(*SGBarr, false);
469     }
470   }
471 
472   for (auto &SyncPipeline : BestPipeline) {
473     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
474                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
475   }
476 }
477 
478 template <typename T>
479 int PipelineSolver::linkSUnit(
480     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
481     T I, T E) {
482   bool MakePred = false;
483   int AddedCost = 0;
484   for (; I < E; ++I) {
485     if (I->getSGID() == SGID) {
486       MakePred = true;
487       continue;
488     }
489     auto Group = *I;
490     AddedCost += Group.link(*SU, MakePred, AddedEdges);
491     assert(AddedCost >= 0);
492   }
493   return AddedCost;
494 }
495 
496 int PipelineSolver::addEdges(
497     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
498     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
499 
500   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
501   // instructions that are the ultimate successors in the resultant mutation.
502   // Therefore, in such a configuration, the SchedGroups occurring before the
503   // candidate SGID are successors of the candidate SchedGroup, thus the current
504   // SU should be linked as a predecessor to SUs in those SchedGroups. The
505   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
506   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
507   // IsBottomUp (in reverse).
508   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
509                                 SyncPipeline.rend())
510                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
511                                 SyncPipeline.end());
512 }
513 
514 void PipelineSolver::removeEdges(
515     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
516   // Only remove the edges that we have added when testing
517   // the fit.
518   for (auto &PredSuccPair : EdgesToRemove) {
519     SUnit *Pred = PredSuccPair.first;
520     SUnit *Succ = PredSuccPair.second;
521 
522     auto Match = llvm::find_if(
523         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
524     if (Match != Succ->Preds.end()) {
525       assert(Match->isArtificial());
526       Succ->removePred(*Match);
527     }
528   }
529 }
530 
531 void PipelineSolver::advancePosition() {
532   ++CurrConflInstNo;
533 
534   if (static_cast<size_t>(CurrConflInstNo) >=
535       PipelineInstrs[CurrSyncGroupIdx].size()) {
536     CurrConflInstNo = 0;
537     ++CurrSyncGroupIdx;
538     // Advance to next non-trivial pipeline
539     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
540            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
541       ++CurrSyncGroupIdx;
542   }
543 }
544 
545 void PipelineSolver::retreatPosition() {
546   assert(CurrConflInstNo >= 0);
547   assert(CurrSyncGroupIdx >= 0);
548 
549   if (CurrConflInstNo > 0) {
550     --CurrConflInstNo;
551     return;
552   }
553 
554   if (CurrConflInstNo == 0) {
555     // If we return to the starting position, we have explored
556     // the entire tree
557     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
558       return;
559 
560     --CurrSyncGroupIdx;
561     // Go to previous non-trivial pipeline
562     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
563       --CurrSyncGroupIdx;
564 
565     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
566   }
567 }
568 
569 bool PipelineSolver::checkOptimal() {
570   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
571     if (BestCost == -1 || CurrCost < BestCost) {
572       BestPipeline = CurrPipeline;
573       BestCost = CurrCost;
574       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
575     }
576     assert(BestCost >= 0);
577   }
578 
579   bool DoneExploring = false;
580   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
581     DoneExploring = true;
582 
583   return (DoneExploring || BestCost == 0);
584 }
585 
586 template <typename T>
587 void PipelineSolver::populateReadyList(
588     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
589   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
590   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
591   assert(CurrSU.second.size() >= 1);
592 
593   for (; I != E; ++I) {
594     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
595     int CandSGID = *I;
596     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
597       return SG.getSGID() == CandSGID;
598     });
599     assert(Match);
600 
601     if (UseCostHeur) {
602       if (Match->isFull()) {
603         ReadyList.push_back(std::pair(*I, MissPenalty));
604         continue;
605       }
606 
607       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
608       ReadyList.push_back(std::pair(*I, TempCost));
609       removeEdges(AddedEdges);
610     } else
611       ReadyList.push_back(std::pair(*I, -1));
612   }
613 
614   if (UseCostHeur) {
615     std::sort(ReadyList.begin(), ReadyList.end(),
616               [](std::pair<int, int> A, std::pair<int, int> B) {
617                 return A.second < B.second;
618               });
619   }
620 
621   assert(ReadyList.size() == CurrSU.second.size());
622 }
623 
624 bool PipelineSolver::solveExact() {
625   if (checkOptimal())
626     return true;
627 
628   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
629     return false;
630 
631   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
632   assert(static_cast<size_t>(CurrConflInstNo) <
633          PipelineInstrs[CurrSyncGroupIdx].size());
634   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
635   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
636                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
637 
638   // SchedGroup -> Cost pairs
639   SmallVector<std::pair<int, int>, 4> ReadyList;
640   // Prioritize the candidate sched groups in terms of lowest cost first
641   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
642                                  CurrSU.second.rend())
643              : populateReadyList(ReadyList, CurrSU.second.begin(),
644                                  CurrSU.second.end());
645 
646   auto I = ReadyList.begin();
647   auto E = ReadyList.end();
648   for (; I != E; ++I) {
649     // If we are trying SGs in least cost order, and the current SG is cost
650     // infeasible, then all subsequent SGs will also be cost infeasible, so we
651     // can prune.
652     if (BestCost != -1 && (CurrCost + I->second > BestCost))
653       return false;
654 
655     int CandSGID = I->first;
656     int AddedCost = 0;
657     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
658     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
659     SchedGroup *Match;
660     for (auto &SG : SyncPipeline) {
661       if (SG.getSGID() == CandSGID)
662         Match = &SG;
663     }
664 
665     if (Match->isFull())
666       continue;
667 
668     if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
669       continue;
670 
671     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
672                       << (int)Match->getMask() << "and ID " << CandSGID
673                       << "\n");
674     Match->add(*CurrSU.first);
675     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
676     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
677     CurrCost += AddedCost;
678     advancePosition();
679     ++BranchesExplored;
680     bool FinishedExploring = false;
681     // If the Cost after adding edges is greater than a known solution,
682     // backtrack
683     if (CurrCost < BestCost || BestCost == -1) {
684       if (solveExact()) {
685         FinishedExploring = BestCost != 0;
686         if (!FinishedExploring)
687           return true;
688       }
689     }
690 
691     retreatPosition();
692     CurrCost -= AddedCost;
693     removeEdges(AddedEdges);
694     Match->pop();
695     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
696     if (FinishedExploring)
697       return true;
698   }
699 
700   // Try the pipeline where the current instruction is omitted
701   // Potentially if we omit a problematic instruction from the pipeline,
702   // all the other instructions can nicely fit.
703   CurrCost += MissPenalty;
704   advancePosition();
705 
706   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
707 
708   bool FinishedExploring = false;
709   if (CurrCost < BestCost || BestCost == -1) {
710     if (solveExact()) {
711       bool FinishedExploring = BestCost != 0;
712       if (!FinishedExploring)
713         return true;
714     }
715   }
716 
717   retreatPosition();
718   CurrCost -= MissPenalty;
719   return FinishedExploring;
720 }
721 
722 template <typename T>
723 void PipelineSolver::greedyFind(
724     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
725   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
726   int BestNodeCost = -1;
727   int TempCost;
728   SchedGroup *BestGroup = nullptr;
729   int BestGroupID = -1;
730   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
731   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
732                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
733 
734   // Since we have added the potential SchedGroups from bottom up, but
735   // traversed the DAG from top down, parse over the groups from last to
736   // first. If we fail to do this for the greedy algorithm, the solution will
737   // likely not be good in more complex cases.
738   for (; I != E; ++I) {
739     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
740     int CandSGID = *I;
741     SchedGroup *Match = llvm::find_if(SyncPipeline, [CandSGID](SchedGroup &SG) {
742       return SG.getSGID() == CandSGID;
743     });
744     assert(Match);
745 
746     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
747                       << (int)Match->getMask() << "\n");
748 
749     if (Match->isFull()) {
750       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
751       continue;
752     }
753     if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
754       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
755       continue;
756     }
757     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
758     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
759     if (TempCost < BestNodeCost || BestNodeCost == -1) {
760       BestGroup = Match;
761       BestNodeCost = TempCost;
762       BestGroupID = CandSGID;
763     }
764     removeEdges(AddedEdges);
765     if (BestNodeCost == 0)
766       break;
767   }
768 
769   if (BestGroupID != -1) {
770     BestGroup->add(*CurrSU.first);
771     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
772     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
773                       << (int)BestGroup->getMask() << "\n");
774     BestCost += TempCost;
775   } else
776     BestCost += MissPenalty;
777 
778   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
779 }
780 
781 bool PipelineSolver::solveGreedy() {
782   BestCost = 0;
783   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
784 
785   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
786     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
787     IsBottomUp
788         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
789         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
790     advancePosition();
791   }
792   BestPipeline = CurrPipeline;
793   removeEdges(AddedEdges);
794   return false;
795 }
796 
797 unsigned PipelineSolver::computeProblemSize() {
798   unsigned ProblemSize = 0;
799   for (auto &PipeConflicts : PipelineInstrs) {
800     ProblemSize += PipeConflicts.size();
801   }
802 
803   return ProblemSize;
804 }
805 
806 void PipelineSolver::solve() {
807   if (!NeedsSolver)
808     return;
809 
810   unsigned ProblemSize = computeProblemSize();
811   assert(ProblemSize > 0);
812 
813   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
814   MissPenalty = (ProblemSize / 2) + 1;
815 
816   LLVM_DEBUG(DAG->dump());
817   if (EnableExactSolver || BelowCutoff) {
818     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
819     solveGreedy();
820     reset();
821     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
822     if (BestCost > 0) {
823       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
824       solveExact();
825       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
826     }
827   } else { // Use the Greedy Algorithm by default
828     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
829     solveGreedy();
830   }
831 
832   makePipeline();
833   LLVM_DEBUG(dbgs() << "After applying mutation\n");
834   LLVM_DEBUG(DAG->dump());
835 }
836 
837 enum IGLPStrategyID : int {
838   MFMASmallGemmOptID = 0,
839   MFMASmallGemmSingleWaveOptID = 1,
840   MFMAExpInterleave = 2
841 };
842 
843 // Implement a IGLP scheduling strategy.
844 class IGLPStrategy {
845 protected:
846   ScheduleDAGInstrs *DAG;
847 
848   const SIInstrInfo *TII;
849 
850 public:
851   /// Add SchedGroups to \p SyncedSchedGroups to implement this Strategy.
852   virtual bool applyIGLPStrategy(
853       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
854       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
855       AMDGPU::SchedulingPhase Phase) = 0;
856 
857   // Returns true if this strategy should be applied to a ScheduleDAG.
858   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
859                                    AMDGPU::SchedulingPhase Phase) = 0;
860 
861   bool IsBottomUp = 1;
862 
863   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
864       : DAG(DAG), TII(TII) {}
865 
866   virtual ~IGLPStrategy() = default;
867 };
868 
869 class MFMASmallGemmOpt final : public IGLPStrategy {
870 private:
871 public:
872   bool applyIGLPStrategy(
873       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
874       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
875       AMDGPU::SchedulingPhase Phase) override;
876 
877   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
878                            AMDGPU::SchedulingPhase Phase) override {
879     return true;
880   }
881 
882   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
883       : IGLPStrategy(DAG, TII) {
884     IsBottomUp = 1;
885   }
886 };
887 
888 bool MFMASmallGemmOpt::applyIGLPStrategy(
889     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
890     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
891     AMDGPU::SchedulingPhase Phase) {
892   // Count the number of MFMA instructions.
893   unsigned MFMACount = 0;
894   for (const MachineInstr &I : *DAG)
895     if (TII->isMFMAorWMMA(I))
896       ++MFMACount;
897 
898   const unsigned PipelineSyncID = 0;
899   SchedGroup *SG = nullptr;
900   for (unsigned I = 0; I < MFMACount * 3; ++I) {
901     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
902         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
903     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
904 
905     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
906         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
907     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
908   }
909 
910   return true;
911 }
912 
913 class MFMAExpInterleaveOpt final : public IGLPStrategy {
914 private:
915   // The count of TRANS SUs involved in the interleaved pipeline
916   static unsigned TransPipeCount;
917   // The count of MFMA SUs involved in the interleaved pipeline
918   static unsigned MFMAPipeCount;
919   // The count of Add SUs involved in the interleaved pipeline
920   static unsigned AddPipeCount;
921   // The number of transitive MFMA successors for each TRANS SU
922   static unsigned MFMAEnablement;
923   // The number of transitive TRANS predecessors for each MFMA SU
924   static unsigned ExpRequirement;
925   // The count of independent "chains" of MFMA instructions in the pipeline
926   static unsigned MFMAChains;
927   // The length of each independent "chain" of MFMA instructions
928   static unsigned MFMAChainLength;
929   // Whether or not the pipeline has V_CVT instructions
930   static bool HasCvt;
931   // Whether or not there are instructions between the TRANS instruction and
932   // V_CVT
933   static bool HasChainBetweenCvt;
934   // The first occuring DS_READ which feeds an MFMA chain
935   static std::optional<unsigned> FirstPipeDSR;
936   // The MFMAPipe SUs with no MFMA predecessors
937   SmallVector<SUnit *, 4> MFMAChainSeeds;
938   // Compute the heuristics for the pipeline, returning whether or not the DAG
939   // is well formatted for the mutation
940   bool analyzeDAG(const SIInstrInfo *TII);
941 
942   /// Whether or not the instruction is a transitive predecessor of an MFMA
943   /// instruction
944   class IsPipeExp final : public InstructionRule {
945   public:
946     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
947                SmallVectorImpl<SchedGroup> &SyncPipe) override {
948 
949       auto DAG = SyncPipe[0].DAG;
950 
951       if (Cache->empty()) {
952         auto I = DAG->SUnits.rbegin();
953         auto E = DAG->SUnits.rend();
954         for (; I != E; I++) {
955           if (TII->isMFMAorWMMA(*I->getInstr()))
956             Cache->push_back(&*I);
957         }
958         if (Cache->empty())
959           return false;
960       }
961 
962       auto Reaches = (std::any_of(
963           Cache->begin(), Cache->end(), [&SU, &DAG](SUnit *TargetSU) {
964             return DAG->IsReachable(TargetSU, const_cast<SUnit *>(SU));
965           }));
966 
967       return Reaches;
968     }
969     IsPipeExp(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
970         : InstructionRule(TII, SGID, NeedsCache) {}
971   };
972 
973   /// Whether or not the instruction is a transitive predecessor of the
974   /// \p Number th MFMA of the MFMAs occuring after a TRANS instruction
975   class EnablesNthMFMA final : public InstructionRule {
976   private:
977     unsigned Number = 1;
978 
979   public:
980     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
981                SmallVectorImpl<SchedGroup> &SyncPipe) override {
982       bool FoundTrans = false;
983       unsigned Counter = 1;
984       auto DAG = SyncPipe[0].DAG;
985 
986       if (Cache->empty()) {
987         SmallVector<SUnit *, 8> Worklist;
988 
989         auto I = DAG->SUnits.begin();
990         auto E = DAG->SUnits.end();
991         for (; I != E; I++) {
992           if (FoundTrans && TII->isMFMAorWMMA(*I->getInstr())) {
993             if (Counter == Number) {
994               Cache->push_back(&*I);
995               break;
996             }
997             ++Counter;
998           }
999           if (!FoundTrans && TII->isTRANS(I->getInstr()->getOpcode()))
1000             FoundTrans = true;
1001         }
1002         if (Cache->empty())
1003           return false;
1004       }
1005 
1006       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1007     }
1008 
1009     EnablesNthMFMA(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1010                    bool NeedsCache = false)
1011         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1012   };
1013 
1014   /// Whether or not the instruction enables the exact MFMA that is the \p
1015   /// Number th MFMA in the chain starting with \p ChainSeed
1016   class EnablesNthMFMAInChain final : public InstructionRule {
1017   private:
1018     unsigned Number = 1;
1019     SUnit *ChainSeed;
1020 
1021   public:
1022     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1023                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1024       auto DAG = SyncPipe[0].DAG;
1025 
1026       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1027         return false;
1028 
1029       if (Cache->empty()) {
1030         auto TempSU = ChainSeed;
1031         auto Depth = Number;
1032         while (Depth > 0) {
1033           --Depth;
1034           bool Found = false;
1035           for (auto &Succ : TempSU->Succs) {
1036             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1037               TempSU = Succ.getSUnit();
1038               Found = true;
1039               break;
1040             }
1041           }
1042           if (!Found)
1043             return false;
1044         }
1045 
1046         Cache->push_back(TempSU);
1047       }
1048       // If we failed to find the instruction to be placed into the cache, we
1049       // would have already exited.
1050       assert(!Cache->empty());
1051 
1052       return DAG->IsReachable((*Cache)[0], const_cast<SUnit *>(SU));
1053     }
1054 
1055     EnablesNthMFMAInChain(unsigned Number, SUnit *ChainSeed,
1056                           const SIInstrInfo *TII, unsigned SGID,
1057                           bool NeedsCache = false)
1058         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1059           ChainSeed(ChainSeed) {}
1060   };
1061 
1062   /// Whether or not the instruction has less than \p Size immediate successors.
1063   /// If \p HasIntermediary is true, this tests also whether all successors of
1064   /// the SUnit have less than \p Size successors.
1065   class LessThanNSuccs final : public InstructionRule {
1066   private:
1067     unsigned Size = 1;
1068     bool HasIntermediary = false;
1069 
1070   public:
1071     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1072                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1073       if (!SyncPipe.size())
1074         return false;
1075 
1076       auto SuccSize = std::count_if(
1077           SU->Succs.begin(), SU->Succs.end(),
1078           [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1079       if (SuccSize >= Size)
1080         return false;
1081 
1082       if (HasIntermediary) {
1083         for (auto Succ : SU->Succs) {
1084           auto SuccSize = std::count_if(
1085               Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1086               [](const SDep &SuccSucc) {
1087                 return SuccSucc.getKind() == SDep::Data;
1088               });
1089           if (SuccSize >= Size)
1090             return false;
1091         }
1092       }
1093 
1094       return true;
1095     }
1096     LessThanNSuccs(unsigned Size, const SIInstrInfo *TII, unsigned SGID,
1097                    bool HasIntermediary = false, bool NeedsCache = false)
1098         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1099           HasIntermediary(HasIntermediary) {}
1100   };
1101 
1102   /// Whether or not the instruction has greater than or equal to \p Size
1103   /// immediate successors. If \p HasIntermediary is true, this tests also
1104   /// whether all successors of the SUnit have greater than or equal to \p Size
1105   /// successors.
1106   class GreaterThanOrEqualToNSuccs final : public InstructionRule {
1107   private:
1108     unsigned Size = 1;
1109     bool HasIntermediary = false;
1110 
1111   public:
1112     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1113                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1114       if (!SyncPipe.size())
1115         return false;
1116 
1117       auto SuccSize = std::count_if(
1118           SU->Succs.begin(), SU->Succs.end(),
1119           [](const SDep &Succ) { return Succ.getKind() == SDep::Data; });
1120       if (SuccSize >= Size)
1121         return true;
1122 
1123       if (HasIntermediary) {
1124         for (auto Succ : SU->Succs) {
1125           auto SuccSize = std::count_if(
1126               Succ.getSUnit()->Succs.begin(), Succ.getSUnit()->Succs.end(),
1127               [](const SDep &SuccSucc) {
1128                 return SuccSucc.getKind() == SDep::Data;
1129               });
1130           if (SuccSize >= Size)
1131             return true;
1132         }
1133       }
1134 
1135       return false;
1136     }
1137     GreaterThanOrEqualToNSuccs(unsigned Size, const SIInstrInfo *TII,
1138                                unsigned SGID, bool HasIntermediary = false,
1139                                bool NeedsCache = false)
1140         : InstructionRule(TII, SGID, NeedsCache), Size(Size),
1141           HasIntermediary(HasIntermediary) {}
1142   };
1143 
1144   // Whether or not the instruction is a relevant V_CVT instruction.
1145   class IsCvt final : public InstructionRule {
1146   public:
1147     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1148                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1149       auto Opc = SU->getInstr()->getOpcode();
1150       return Opc == AMDGPU::V_CVT_F16_F32_e32 ||
1151              Opc == AMDGPU::V_CVT_I32_F32_e32;
1152     }
1153     IsCvt(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1154         : InstructionRule(TII, SGID, NeedsCache) {}
1155   };
1156 
1157   // Whether or not the instruction is FMA_F32.
1158   class IsFMA final : public InstructionRule {
1159   public:
1160     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1161                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1162       return SU->getInstr()->getOpcode() == AMDGPU::V_FMA_F32_e64 ||
1163              SU->getInstr()->getOpcode() == AMDGPU::V_PK_FMA_F32;
1164     }
1165     IsFMA(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1166         : InstructionRule(TII, SGID, NeedsCache) {}
1167   };
1168 
1169   // Whether or not the instruction is a V_ADD_F32 instruction.
1170   class IsPipeAdd final : public InstructionRule {
1171   public:
1172     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1173                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1174       return SU->getInstr()->getOpcode() == AMDGPU::V_ADD_F32_e32;
1175     }
1176     IsPipeAdd(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1177         : InstructionRule(TII, SGID, NeedsCache) {}
1178   };
1179 
1180   /// Whether or not the instruction is an immediate RAW successor
1181   /// of the SchedGroup \p Distance steps before.
1182   class IsSuccOfPrevNthGroup final : public InstructionRule {
1183   private:
1184     unsigned Distance = 1;
1185 
1186   public:
1187     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1188                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1189       SchedGroup *OtherGroup = nullptr;
1190       if (!SyncPipe.size())
1191         return false;
1192 
1193       for (auto &PipeSG : SyncPipe) {
1194         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1195           OtherGroup = &PipeSG;
1196       }
1197 
1198       if (!OtherGroup)
1199         return false;
1200       if (!OtherGroup->Collection.size())
1201         return true;
1202 
1203       for (auto &OtherEle : OtherGroup->Collection) {
1204         for (auto &Succ : OtherEle->Succs) {
1205           if (Succ.getSUnit() == SU && Succ.getKind() == SDep::Data)
1206             return true;
1207         }
1208       }
1209 
1210       return false;
1211     }
1212     IsSuccOfPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1213                          unsigned SGID, bool NeedsCache = false)
1214         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1215   };
1216 
1217   /// Whether or not the instruction is a transitive successor of any
1218   /// instruction the the SchedGroup \p Distance steps before.
1219   class IsReachableFromPrevNthGroup final : public InstructionRule {
1220   private:
1221     unsigned Distance = 1;
1222 
1223   public:
1224     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1225                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1226       SchedGroup *OtherGroup = nullptr;
1227       if (!SyncPipe.size())
1228         return false;
1229 
1230       for (auto &PipeSG : SyncPipe) {
1231         if ((unsigned)PipeSG.getSGID() == SGID - Distance)
1232           OtherGroup = &PipeSG;
1233       }
1234 
1235       if (!OtherGroup)
1236         return false;
1237       if (!OtherGroup->Collection.size())
1238         return true;
1239 
1240       auto DAG = SyncPipe[0].DAG;
1241 
1242       for (auto &OtherEle : OtherGroup->Collection)
1243         if (DAG->IsReachable(const_cast<SUnit *>(SU), OtherEle))
1244           return true;
1245 
1246       return false;
1247     }
1248     IsReachableFromPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1249                                 unsigned SGID, bool NeedsCache = false)
1250         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1251   };
1252 
1253   /// Whether or not the instruction occurs after the SU with NodeNUm \p Number
1254   class OccursAtOrAfterNode final : public InstructionRule {
1255   private:
1256     unsigned Number = 1;
1257 
1258   public:
1259     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1260                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1261 
1262       return SU->NodeNum >= Number;
1263     }
1264     OccursAtOrAfterNode(unsigned Number, const SIInstrInfo *TII, unsigned SGID,
1265                         bool NeedsCache = false)
1266         : InstructionRule(TII, SGID, NeedsCache), Number(Number) {}
1267   };
1268 
1269   /// Whether or not the SU is exactly the \p Number th MFMA in the chain
1270   /// starting with \p ChainSeed
1271   class IsExactMFMA final : public InstructionRule {
1272   private:
1273     unsigned Number = 1;
1274     SUnit *ChainSeed;
1275 
1276   public:
1277     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1278                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1279       if (!SU || !TII->isMFMAorWMMA(*ChainSeed->getInstr()))
1280         return false;
1281 
1282       if (Cache->empty()) {
1283         auto TempSU = ChainSeed;
1284         auto Depth = Number;
1285         while (Depth > 0) {
1286           --Depth;
1287           bool Found = false;
1288           for (auto &Succ : TempSU->Succs) {
1289             if (TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr())) {
1290               TempSU = Succ.getSUnit();
1291               Found = true;
1292               break;
1293             }
1294           }
1295           if (!Found) {
1296             return false;
1297           }
1298         }
1299         Cache->push_back(TempSU);
1300       }
1301       // If we failed to find the instruction to be placed into the cache, we
1302       // would have already exited.
1303       assert(!Cache->empty());
1304 
1305       return (*Cache)[0] == SU;
1306     }
1307 
1308     IsExactMFMA(unsigned Number, SUnit *ChainSeed, const SIInstrInfo *TII,
1309                 unsigned SGID, bool NeedsCache = false)
1310         : InstructionRule(TII, SGID, NeedsCache), Number(Number),
1311           ChainSeed(ChainSeed) {}
1312   };
1313 
1314   // Whether the instruction occurs after the first TRANS instruction. This
1315   // implies the instruction can not be a predecessor of the first TRANS
1316   // insruction
1317   class OccursAfterExp final : public InstructionRule {
1318   public:
1319     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1320                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1321 
1322       SmallVector<SUnit *, 12> Worklist;
1323       auto DAG = SyncPipe[0].DAG;
1324       if (Cache->empty()) {
1325         for (auto &SU : DAG->SUnits)
1326           if (TII->isTRANS(SU.getInstr()->getOpcode())) {
1327             Cache->push_back(&SU);
1328             break;
1329           }
1330         if (Cache->empty())
1331           return false;
1332       }
1333 
1334       return SU->NodeNum > (*Cache)[0]->NodeNum;
1335     }
1336 
1337     OccursAfterExp(const SIInstrInfo *TII, unsigned SGID,
1338                    bool NeedsCache = false)
1339         : InstructionRule(TII, SGID, NeedsCache) {}
1340   };
1341 
1342 public:
1343   bool applyIGLPStrategy(
1344       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1345       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1346       AMDGPU::SchedulingPhase Phase) override;
1347 
1348   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1349                            AMDGPU::SchedulingPhase Phase) override;
1350 
1351   MFMAExpInterleaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1352       : IGLPStrategy(DAG, TII) {
1353     IsBottomUp = 0;
1354   }
1355 };
1356 
1357 unsigned MFMAExpInterleaveOpt::TransPipeCount = 0;
1358 unsigned MFMAExpInterleaveOpt::MFMAPipeCount = 0;
1359 unsigned MFMAExpInterleaveOpt::AddPipeCount = 0;
1360 unsigned MFMAExpInterleaveOpt::MFMAEnablement = 0;
1361 unsigned MFMAExpInterleaveOpt::ExpRequirement = 0;
1362 unsigned MFMAExpInterleaveOpt::MFMAChains = 0;
1363 unsigned MFMAExpInterleaveOpt::MFMAChainLength = 0;
1364 bool MFMAExpInterleaveOpt::HasCvt = false;
1365 bool MFMAExpInterleaveOpt::HasChainBetweenCvt = false;
1366 std::optional<unsigned> MFMAExpInterleaveOpt::FirstPipeDSR = std::nullopt;
1367 
1368 bool MFMAExpInterleaveOpt::analyzeDAG(const SIInstrInfo *TII) {
1369   SmallVector<SUnit *, 10> ExpPipeCands;
1370   SmallVector<SUnit *, 10> MFMAPipeCands;
1371   SmallVector<SUnit *, 10> MFMAPipeSUs;
1372   SmallVector<SUnit *, 10> PackSUs;
1373   SmallVector<SUnit *, 10> CvtSUs;
1374 
1375   auto isBitPack = [](unsigned Opc) {
1376     return Opc == AMDGPU::V_PACK_B32_F16_e64 || Opc == AMDGPU::V_PERM_B32_e64;
1377   };
1378 
1379   auto isCvt = [](unsigned Opc) {
1380     return Opc == AMDGPU::V_CVT_F16_F32_e32 || Opc == AMDGPU::V_CVT_I32_F32_e32;
1381   };
1382 
1383   auto isAdd = [](unsigned Opc) { return Opc == AMDGPU::V_ADD_F32_e32; };
1384 
1385   AddPipeCount = 0;
1386   for (SUnit &SU : DAG->SUnits) {
1387     auto Opc = SU.getInstr()->getOpcode();
1388     if (TII->isTRANS(Opc)) {
1389       // Avoid counting a potential bonus V_EXP which all the MFMA depend on
1390       if (SU.Succs.size() >= 7)
1391         continue;
1392       for (auto &Succ : SU.Succs) {
1393         if (Succ.getSUnit()->Succs.size() >= 7)
1394           continue;
1395       }
1396       ExpPipeCands.push_back(&SU);
1397     }
1398 
1399     if (TII->isMFMAorWMMA(*SU.getInstr()))
1400       MFMAPipeCands.push_back(&SU);
1401 
1402     if (isBitPack(Opc))
1403       PackSUs.push_back(&SU);
1404 
1405     if (isCvt(Opc))
1406       CvtSUs.push_back(&SU);
1407 
1408     if (isAdd(Opc))
1409       ++AddPipeCount;
1410   }
1411 
1412   if (!(PackSUs.size() && MFMAPipeCands.size() && ExpPipeCands.size()))
1413     return false;
1414 
1415   TransPipeCount = 0;
1416 
1417   std::optional<SUnit *> TempMFMA;
1418   std::optional<SUnit *> TempExp;
1419   // Count the number of EXPs that reach an MFMA
1420   for (auto &PredSU : ExpPipeCands) {
1421     for (auto &SuccSU : MFMAPipeCands) {
1422       if (DAG->IsReachable(SuccSU, PredSU)) {
1423         if (!TempExp) {
1424           TempExp = PredSU;
1425           TempMFMA = SuccSU;
1426         }
1427         MFMAPipeSUs.push_back(SuccSU);
1428         ++TransPipeCount;
1429         break;
1430       }
1431     }
1432   }
1433 
1434   if (!(TempExp && TempMFMA))
1435     return false;
1436 
1437   HasChainBetweenCvt =
1438       std::find_if((*TempExp)->Succs.begin(), (*TempExp)->Succs.end(),
1439                    [&isCvt](SDep &Succ) {
1440                      return isCvt(Succ.getSUnit()->getInstr()->getOpcode());
1441                    }) == (*TempExp)->Succs.end();
1442 
1443   // Count the number of MFMAs that are reached by an EXP
1444   for (auto &SuccSU : MFMAPipeCands) {
1445     if (MFMAPipeSUs.size() &&
1446         std::find_if(MFMAPipeSUs.begin(), MFMAPipeSUs.end(),
1447                      [&SuccSU](SUnit *PotentialMatch) {
1448                        return PotentialMatch->NodeNum == SuccSU->NodeNum;
1449                      }) != MFMAPipeSUs.end())
1450       continue;
1451 
1452     for (auto &PredSU : ExpPipeCands) {
1453       if (DAG->IsReachable(SuccSU, PredSU)) {
1454         MFMAPipeSUs.push_back(SuccSU);
1455         break;
1456       }
1457     }
1458   }
1459 
1460   MFMAPipeCount = MFMAPipeSUs.size();
1461 
1462   assert(TempExp && TempMFMA);
1463   assert(MFMAPipeCount > 0);
1464 
1465   std::optional<SUnit *> TempCvt;
1466   for (auto &SuccSU : CvtSUs) {
1467     if (DAG->IsReachable(SuccSU, *TempExp)) {
1468       TempCvt = SuccSU;
1469       break;
1470     }
1471   }
1472 
1473   HasCvt = false;
1474   if (TempCvt.has_value()) {
1475     for (auto &SuccSU : MFMAPipeSUs) {
1476       if (DAG->IsReachable(SuccSU, *TempCvt)) {
1477         HasCvt = true;
1478         break;
1479       }
1480     }
1481   }
1482 
1483   MFMAChains = 0;
1484   for (auto &MFMAPipeSU : MFMAPipeSUs) {
1485     if (MFMAChainSeeds.size() &&
1486         std::find(MFMAChainSeeds.begin(), MFMAChainSeeds.end(), MFMAPipeSU) !=
1487             MFMAChainSeeds.end())
1488       continue;
1489     if (!std::any_of(MFMAPipeSU->Preds.begin(), MFMAPipeSU->Preds.end(),
1490                      [&TII](SDep &Succ) {
1491                        return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1492                      })) {
1493       MFMAChainSeeds.push_back(MFMAPipeSU);
1494       ++MFMAChains;
1495     }
1496   }
1497 
1498   if (!MFMAChains)
1499     return false;
1500 
1501   for (auto Pred : MFMAChainSeeds[0]->Preds) {
1502     if (TII->isDS(Pred.getSUnit()->getInstr()->getOpcode()) &&
1503         Pred.getSUnit()->getInstr()->mayLoad())
1504       FirstPipeDSR = Pred.getSUnit()->NodeNum;
1505   }
1506 
1507   MFMAChainLength = MFMAPipeCount / MFMAChains;
1508 
1509   // The number of bit pack operations that depend on a single V_EXP
1510   unsigned PackSuccCount = std::count_if(
1511       PackSUs.begin(), PackSUs.end(), [this, &TempExp](SUnit *VPack) {
1512         return DAG->IsReachable(VPack, *TempExp);
1513       });
1514 
1515   // The number of bit pack operations an MFMA depends on
1516   unsigned PackPredCount =
1517       std::count_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1518                     [&isBitPack](SDep &Pred) {
1519                       auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1520                       return isBitPack(Opc);
1521                     });
1522 
1523   auto PackPred =
1524       std::find_if((*TempMFMA)->Preds.begin(), (*TempMFMA)->Preds.end(),
1525                    [&isBitPack](SDep &Pred) {
1526                      auto Opc = Pred.getSUnit()->getInstr()->getOpcode();
1527                      return isBitPack(Opc);
1528                    });
1529 
1530   if (PackPred == (*TempMFMA)->Preds.end())
1531     return false;
1532 
1533   MFMAEnablement = 0;
1534   ExpRequirement = 0;
1535   // How many MFMAs depend on a single bit pack operation
1536   MFMAEnablement =
1537       std::count_if(PackPred->getSUnit()->Succs.begin(),
1538                     PackPred->getSUnit()->Succs.end(), [&TII](SDep &Succ) {
1539                       return TII->isMFMAorWMMA(*Succ.getSUnit()->getInstr());
1540                     });
1541 
1542   // The number of MFMAs that depend on a single V_EXP
1543   MFMAEnablement *= PackSuccCount;
1544 
1545   // The number of V_EXPs required to resolve all dependencies for an MFMA
1546   ExpRequirement =
1547       std::count_if(ExpPipeCands.begin(), ExpPipeCands.end(),
1548                     [this, &PackPred](SUnit *ExpBase) {
1549                       return DAG->IsReachable(PackPred->getSUnit(), ExpBase);
1550                     });
1551 
1552   ExpRequirement *= PackPredCount;
1553   return true;
1554 }
1555 
1556 bool MFMAExpInterleaveOpt::shouldApplyStrategy(ScheduleDAGInstrs *DAG,
1557                                                AMDGPU::SchedulingPhase Phase) {
1558   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1559   const SIInstrInfo *TII = ST.getInstrInfo();
1560 
1561   if (Phase != AMDGPU::SchedulingPhase::PostRA)
1562     MFMAChainSeeds.clear();
1563   if (Phase != AMDGPU::SchedulingPhase::PostRA && !analyzeDAG(TII))
1564     return false;
1565 
1566   return true;
1567 }
1568 
1569 bool MFMAExpInterleaveOpt::applyIGLPStrategy(
1570     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1571     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
1572     AMDGPU::SchedulingPhase Phase) {
1573 
1574   bool IsSmallKernelType =
1575       MFMAEnablement == 2 && ExpRequirement == 4 && TransPipeCount == 32;
1576   bool IsLargeKernelType =
1577       MFMAEnablement == 4 && ExpRequirement == 4 && TransPipeCount == 64;
1578 
1579   if (!(IsSmallKernelType || IsLargeKernelType))
1580     return false;
1581 
1582   const GCNSubtarget &ST = DAG->MF.getSubtarget<GCNSubtarget>();
1583   const SIInstrInfo *TII = ST.getInstrInfo();
1584 
1585   unsigned PipelineSyncID = 0;
1586   SchedGroup *SG = nullptr;
1587 
1588   unsigned MFMAChain = 0;
1589   unsigned PositionInChain = 0;
1590   unsigned CurrMFMAForTransPosition = 0;
1591 
1592   auto incrementTransPosition = [&MFMAChain, &PositionInChain,
1593                                  &CurrMFMAForTransPosition]() {
1594     CurrMFMAForTransPosition += MFMAEnablement;
1595     PositionInChain = (CurrMFMAForTransPosition / MFMAChains);
1596     MFMAChain = CurrMFMAForTransPosition % MFMAChains;
1597   };
1598 
1599   auto getNextTransPositionInChain = [&CurrMFMAForTransPosition]() {
1600     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1601     return (TempMFMAForTrans / MFMAChains);
1602   };
1603 
1604   auto getNextTransMFMAChain = [&CurrMFMAForTransPosition]() {
1605     auto TempMFMAForTrans = CurrMFMAForTransPosition + MFMAEnablement;
1606     return TempMFMAForTrans % MFMAChains;
1607   };
1608 
1609   unsigned CurrMFMAPosition = 0;
1610   unsigned MFMAChainForMFMA = 0;
1611   unsigned PositionInChainForMFMA = 0;
1612 
1613   auto incrementMFMAPosition = [&CurrMFMAPosition, &MFMAChainForMFMA,
1614                                 &PositionInChainForMFMA]() {
1615     ++CurrMFMAPosition;
1616     MFMAChainForMFMA = CurrMFMAPosition % MFMAChains;
1617     PositionInChainForMFMA = CurrMFMAPosition / MFMAChains;
1618   };
1619 
1620   bool IsPostRA = Phase == AMDGPU::SchedulingPhase::PostRA;
1621   assert(IsPostRA || MFMAChainSeeds.size() == MFMAChains);
1622 
1623   bool UsesFMA = IsSmallKernelType || !IsPostRA;
1624   bool UsesDSRead = IsLargeKernelType && !IsPostRA && FirstPipeDSR;
1625   bool UsesCvt = HasCvt && (IsSmallKernelType || !IsPostRA);
1626   bool UsesVALU = IsSmallKernelType;
1627 
1628   // PHASE 1: "Prefetch"
1629   if (UsesFMA) {
1630     // First Round FMA
1631     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1632         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1633     if (!IsPostRA && MFMAChains) {
1634       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1635           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1636           true));
1637     } else
1638       SG->addRule(
1639           std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1640     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1641     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1642 
1643     // Second Round FMA
1644     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1645         SchedGroupMask::VALU, ExpRequirement, PipelineSyncID, DAG, TII);
1646     if (!IsPostRA && MFMAChains) {
1647       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1648           getNextTransPositionInChain(),
1649           MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1650     } else
1651       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1652                                                    SG->getSGID(), true));
1653     SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1654     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1655   }
1656 
1657   if (UsesDSRead) {
1658     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1659         SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1660     SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1661                                                       SG->getSGID()));
1662     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1663   }
1664 
1665   // First Round EXP
1666   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1667       SchedGroupMask::TRANS, ExpRequirement, PipelineSyncID, DAG, TII);
1668   if (!IsPostRA && MFMAChains)
1669     SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1670         PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(), true));
1671   else
1672     SG->addRule(std::make_shared<EnablesNthMFMA>(1, TII, SG->getSGID(), true));
1673   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1674   SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1675                                                HasChainBetweenCvt));
1676   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1677 
1678   incrementTransPosition();
1679 
1680   // First Round CVT, Third Round FMA, Second Round EXP; interleaved
1681   for (unsigned I = 0; I < ExpRequirement; I++) {
1682     // First Round CVT
1683     if (UsesCvt) {
1684       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1685           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1686       SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1687       if (HasChainBetweenCvt)
1688         SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1689             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1690       else
1691         SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(
1692             1 + (2 + UsesFMA) * I, TII, SG->getSGID()));
1693       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1694     }
1695 
1696     // Third Round FMA
1697     if (UsesFMA) {
1698       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1699           SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1700       if (!IsPostRA && MFMAChains) {
1701         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1702             getNextTransPositionInChain(),
1703             MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(), true));
1704       } else
1705         SG->addRule(std::make_shared<EnablesNthMFMA>(2 * MFMAEnablement + 1,
1706                                                      TII, SG->getSGID(), true));
1707       SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1708       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1709     }
1710 
1711     // Second Round EXP
1712     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1713         SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1714     if (!IsPostRA && MFMAChains)
1715       SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1716           PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1717           true));
1718     else
1719       SG->addRule(std::make_shared<EnablesNthMFMA>(MFMAEnablement + 1, TII,
1720                                                    SG->getSGID(), true));
1721     SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1722     SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1723                                                  HasChainBetweenCvt));
1724     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1725   }
1726 
1727   // The "extra" EXP which enables all MFMA
1728   // TODO: UsesExtraExp
1729   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1730       SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1731   SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1732   SG->addRule(std::make_shared<GreaterThanOrEqualToNSuccs>(
1733       8, TII, SG->getSGID(), HasChainBetweenCvt));
1734   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1735 
1736   // PHASE 2: Main Interleave Loop
1737 
1738   // The number of MFMAs per iteration
1739   unsigned MFMARatio =
1740       MFMAEnablement > ExpRequirement ? MFMAEnablement / ExpRequirement : 1;
1741   // The number of Exps per iteration
1742   unsigned ExpRatio =
1743       MFMAEnablement > ExpRequirement ? 1 : ExpRequirement / MFMAEnablement;
1744   // The reamaining Exps
1745   unsigned RemainingExp = TransPipeCount > (2 * ExpRequirement)
1746                               ? TransPipeCount - (2 * ExpRequirement)
1747                               : 0;
1748   unsigned ExpLoopCount = RemainingExp / ExpRatio;
1749   // In loop MFMAs
1750   unsigned MFMAInLoop = MFMAPipeCount > (MFMAEnablement * 2)
1751                             ? MFMAPipeCount - (MFMAEnablement * 2)
1752                             : 0;
1753   unsigned MFMALoopCount = MFMAInLoop / MFMARatio;
1754   unsigned VALUOps =
1755       AddPipeCount < MFMAPipeCount ? 1 : AddPipeCount / MFMAPipeCount;
1756   unsigned LoopSize = std::min(ExpLoopCount, MFMALoopCount);
1757 
1758   for (unsigned I = 0; I < LoopSize; I++) {
1759     if (!(I * ExpRatio % ExpRequirement))
1760       incrementTransPosition();
1761 
1762     // Round N MFMA
1763     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1764         SchedGroupMask::MFMA, MFMARatio, PipelineSyncID, DAG, TII);
1765     if (!IsPostRA && MFMAChains)
1766       SG->addRule(std::make_shared<IsExactMFMA>(
1767           PositionInChainForMFMA, MFMAChainSeeds[MFMAChainForMFMA], TII,
1768           SG->getSGID(), true));
1769     else
1770       SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1771     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1772     incrementMFMAPosition();
1773 
1774     if (UsesVALU) {
1775       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1776           SchedGroupMask::VALU, VALUOps, PipelineSyncID, DAG, TII);
1777       SG->addRule(std::make_shared<IsPipeAdd>(TII, SG->getSGID()));
1778       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1779     }
1780 
1781     if (UsesDSRead && !(I % 4)) {
1782       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1783           SchedGroupMask::DS_READ, 2, PipelineSyncID, DAG, TII);
1784       SG->addRule(std::make_shared<OccursAtOrAfterNode>(*FirstPipeDSR, TII,
1785                                                         SG->getSGID()));
1786       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1787     }
1788 
1789     // CVT, EXP, FMA Interleaving
1790     for (unsigned J = 0; J < ExpRatio; J++) {
1791       auto MFMAOffset = (1 + UsesVALU) * MFMARatio * (I + 1);
1792       auto MaxMFMAOffset =
1793           (1 + UsesVALU) * ExpRequirement * MFMARatio / ExpRatio;
1794 
1795       // Round N + 1 CVT
1796       if (UsesCvt) {
1797         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1798             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1799         SG->addRule(std::make_shared<IsCvt>(TII, SG->getSGID()));
1800         auto BaseDiff = (2 + UsesFMA) * (ExpRequirement - 1) + 1;
1801         auto DSROffset = I / 4 + 1;
1802         auto MaxDSROffset = MaxMFMAOffset / 4;
1803         // TODO: UsesExtraExp
1804         auto ExpOffset = I * ExpRatio + J >= ExpRequirement ? 0 : 1;
1805         auto CurrentOffset = UsesDSRead * std::min(MaxDSROffset, DSROffset) +
1806                              std::min(MaxMFMAOffset, MFMAOffset) + BaseDiff +
1807                              ExpOffset;
1808         if (HasChainBetweenCvt)
1809           SG->addRule(std::make_shared<IsReachableFromPrevNthGroup>(
1810               CurrentOffset, TII, SG->getSGID()));
1811         else
1812           SG->addRule(std::make_shared<IsSuccOfPrevNthGroup>(CurrentOffset, TII,
1813                                                              SG->getSGID()));
1814         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1815       }
1816 
1817       // Round N + 3 FMA
1818       if (UsesFMA) {
1819         SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1820             SchedGroupMask::VALU, 1, PipelineSyncID, DAG, TII);
1821         if (!IsPostRA && MFMAChains)
1822           SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1823               getNextTransPositionInChain(),
1824               MFMAChainSeeds[getNextTransMFMAChain()], TII, SG->getSGID(),
1825               true));
1826         else
1827           SG->addRule(std::make_shared<EnablesNthMFMA>(
1828               (((I * ExpRatio + J) / ExpRequirement) + 3) * MFMAEnablement + 1,
1829               TII, SG->getSGID(), true));
1830         SG->addRule(std::make_shared<IsFMA>(TII, SG->getSGID()));
1831         SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1832       }
1833 
1834       // Round N + 2 Exp
1835       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1836           SchedGroupMask::TRANS, 1, PipelineSyncID, DAG, TII);
1837       if (!IsPostRA && MFMAChains)
1838         SG->addRule(std::make_shared<EnablesNthMFMAInChain>(
1839             PositionInChain, MFMAChainSeeds[MFMAChain], TII, SG->getSGID(),
1840             true));
1841       else
1842         SG->addRule(std::make_shared<EnablesNthMFMA>(
1843             (((I * ExpRatio + J) / ExpRequirement) + 2) * MFMAEnablement + 1,
1844             TII, SG->getSGID(), true));
1845       SG->addRule(std::make_shared<IsPipeExp>(TII, SG->getSGID(), true));
1846       SG->addRule(std::make_shared<LessThanNSuccs>(8, TII, SG->getSGID(),
1847                                                    HasChainBetweenCvt));
1848       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1849     }
1850   }
1851 
1852   // PHASE 3: Remaining MFMAs
1853   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1854       SchedGroupMask::MFMA, MFMAEnablement * 2, PipelineSyncID, DAG, TII);
1855   SG->addRule(std::make_shared<OccursAfterExp>(TII, SG->getSGID(), true));
1856   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1857   return true;
1858 }
1859 
1860 class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
1861 private:
1862   // Whether the DS_READ is a predecessor of first four MFMA in region
1863   class EnablesInitialMFMA final : public InstructionRule {
1864   public:
1865     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1866                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1867       if (!SyncPipe.size())
1868         return false;
1869       int MFMAsFound = 0;
1870       if (!Cache->size()) {
1871         for (auto &Elt : SyncPipe[0].DAG->SUnits) {
1872           if (TII->isMFMAorWMMA(*Elt.getInstr())) {
1873             ++MFMAsFound;
1874             if (MFMAsFound > 4)
1875               break;
1876             Cache->push_back(&Elt);
1877           }
1878         }
1879       }
1880 
1881       assert(Cache->size());
1882       auto DAG = SyncPipe[0].DAG;
1883       for (auto &Elt : *Cache) {
1884         if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
1885           return true;
1886       }
1887       return false;
1888     }
1889 
1890     EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
1891                        bool NeedsCache = false)
1892         : InstructionRule(TII, SGID, NeedsCache) {}
1893   };
1894 
1895   // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
1896   class IsPermForDSW final : public InstructionRule {
1897   public:
1898     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1899                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1900       auto MI = SU->getInstr();
1901       if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
1902         return false;
1903 
1904       bool FitsInGroup = false;
1905       // Does the VALU have a DS_WRITE successor
1906       if (!Collection.size()) {
1907         for (auto &Succ : SU->Succs) {
1908           SUnit *SuccUnit = Succ.getSUnit();
1909           if (TII->isDS(*SuccUnit->getInstr()) &&
1910               SuccUnit->getInstr()->mayStore()) {
1911             Cache->push_back(SuccUnit);
1912             FitsInGroup = true;
1913           }
1914         }
1915         return FitsInGroup;
1916       }
1917 
1918       assert(Cache->size());
1919 
1920       // Does the VALU have a DS_WRITE successor that is the same as other
1921       // VALU already in the group. The V_PERMs will all share 1 DS_W succ
1922       return llvm::any_of(*Cache, [&SU](SUnit *Elt) {
1923         return llvm::any_of(SU->Succs, [&Elt](const SDep &ThisSucc) {
1924           return ThisSucc.getSUnit() == Elt;
1925         });
1926       });
1927     }
1928 
1929     IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1930         : InstructionRule(TII, SGID, NeedsCache) {}
1931   };
1932 
1933   // Whether the SU is a successor of any element in previous SchedGroup
1934   class IsSuccOfPrevGroup final : public InstructionRule {
1935   public:
1936     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1937                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1938       SchedGroup *OtherGroup = nullptr;
1939       for (auto &PipeSG : SyncPipe) {
1940         if ((unsigned)PipeSG.getSGID() == SGID - 1) {
1941           OtherGroup = &PipeSG;
1942         }
1943       }
1944 
1945       if (!OtherGroup)
1946         return false;
1947       if (!OtherGroup->Collection.size())
1948         return true;
1949 
1950       // Does the previous VALU have this DS_Write as a successor
1951       return (std::any_of(OtherGroup->Collection.begin(),
1952                           OtherGroup->Collection.end(), [&SU](SUnit *Elt) {
1953                             return std::any_of(Elt->Succs.begin(),
1954                                                Elt->Succs.end(),
1955                                                [&SU](SDep &Succ) {
1956                                                  return Succ.getSUnit() == SU;
1957                                                });
1958                           }));
1959     }
1960     IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1961                       bool NeedsCache = false)
1962         : InstructionRule(TII, SGID, NeedsCache) {}
1963   };
1964 
1965   // Whether the combined load width of group is 128 bits
1966   class VMEMSize final : public InstructionRule {
1967   public:
1968     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1969                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1970       auto MI = SU->getInstr();
1971       if (MI->getOpcode() == TargetOpcode::BUNDLE)
1972         return false;
1973       if (!Collection.size())
1974         return true;
1975 
1976       int NumBits = 0;
1977 
1978       auto TRI = TII->getRegisterInfo();
1979       auto &MRI = MI->getParent()->getParent()->getRegInfo();
1980       for (auto &Elt : Collection) {
1981         auto Op = Elt->getInstr()->getOperand(0);
1982         auto Size =
1983             TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1984         NumBits += Size;
1985       }
1986 
1987       if (NumBits < 128) {
1988         assert(TII->isVMEM(*MI) && MI->mayLoad());
1989         if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1990                           MRI, MI->getOperand(0))) <=
1991             128)
1992           return true;
1993       }
1994 
1995       return false;
1996     }
1997 
1998     VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1999         : InstructionRule(TII, SGID, NeedsCache) {}
2000   };
2001 
2002   /// Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
2003   /// that is \p Distance steps away
2004   class SharesPredWithPrevNthGroup final : public InstructionRule {
2005   private:
2006     unsigned Distance = 1;
2007 
2008   public:
2009     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
2010                SmallVectorImpl<SchedGroup> &SyncPipe) override {
2011       SchedGroup *OtherGroup = nullptr;
2012       if (!SyncPipe.size())
2013         return false;
2014 
2015       if (!Cache->size()) {
2016 
2017         for (auto &PipeSG : SyncPipe) {
2018           if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
2019             OtherGroup = &PipeSG;
2020           }
2021         }
2022 
2023         if (!OtherGroup)
2024           return false;
2025         if (!OtherGroup->Collection.size())
2026           return true;
2027 
2028         for (auto &OtherEle : OtherGroup->Collection) {
2029           for (auto &Pred : OtherEle->Preds) {
2030             if (Pred.getSUnit()->getInstr()->getOpcode() ==
2031                 AMDGPU::V_PERM_B32_e64)
2032               Cache->push_back(Pred.getSUnit());
2033           }
2034         }
2035 
2036         // If the other group has no PERM preds, then this group won't share any
2037         if (!Cache->size())
2038           return false;
2039       }
2040 
2041       auto DAG = SyncPipe[0].DAG;
2042       // Does the previous DS_WRITE share a V_PERM predecessor with this
2043       // VMEM_READ
2044       return llvm::any_of(*Cache, [&SU, &DAG](SUnit *Elt) {
2045         return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
2046       });
2047     }
2048     SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
2049                                unsigned SGID, bool NeedsCache = false)
2050         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
2051   };
2052 
2053 public:
2054   bool applyIGLPStrategy(
2055       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2056       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2057       AMDGPU::SchedulingPhase Phase) override;
2058 
2059   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG,
2060                            AMDGPU::SchedulingPhase Phase) override {
2061     return true;
2062   }
2063 
2064   MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
2065       : IGLPStrategy(DAG, TII) {
2066     IsBottomUp = 0;
2067   }
2068 };
2069 
2070 static unsigned DSWCount = 0;
2071 static unsigned DSWWithPermCount = 0;
2072 static unsigned DSWWithSharedVMEMCount = 0;
2073 
2074 bool MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
2075     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
2076     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
2077     AMDGPU::SchedulingPhase Phase) {
2078   unsigned MFMACount = 0;
2079   unsigned DSRCount = 0;
2080 
2081   bool IsInitial = Phase == AMDGPU::SchedulingPhase::Initial;
2082 
2083   assert((!IsInitial || (DSWCount == 0 && DSWWithPermCount == 0 &&
2084                          DSWWithSharedVMEMCount == 0)) &&
2085          "DSWCounters should be zero in pre-RA scheduling!");
2086   SmallVector<SUnit *, 6> DSWithPerms;
2087   for (auto &SU : DAG->SUnits) {
2088     auto I = SU.getInstr();
2089     if (TII->isMFMAorWMMA(*I))
2090       ++MFMACount;
2091     else if (TII->isDS(*I)) {
2092       if (I->mayLoad())
2093         ++DSRCount;
2094       else if (I->mayStore() && IsInitial) {
2095         ++DSWCount;
2096         for (auto Pred : SU.Preds) {
2097           if (Pred.getSUnit()->getInstr()->getOpcode() ==
2098               AMDGPU::V_PERM_B32_e64) {
2099             DSWithPerms.push_back(&SU);
2100             break;
2101           }
2102         }
2103       }
2104     }
2105   }
2106 
2107   if (IsInitial) {
2108     DSWWithPermCount = DSWithPerms.size();
2109     auto I = DSWithPerms.begin();
2110     auto E = DSWithPerms.end();
2111 
2112     // Get the count of DS_WRITES with V_PERM predecessors which
2113     // have loop carried dependencies (WAR) on the same VMEM_READs.
2114     // We consider partial overlap as a miss -- in other words,
2115     // for a given DS_W, we only consider another DS_W as matching
2116     // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
2117     // for every V_PERM pred of this DS_W.
2118     DenseMap<MachineInstr *, SUnit *> VMEMLookup;
2119     SmallVector<SUnit *, 6> Counted;
2120     for (; I != E; I++) {
2121       SUnit *Cand = nullptr;
2122       bool MissedAny = false;
2123       for (auto &Pred : (*I)->Preds) {
2124         if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
2125           continue;
2126 
2127         if (Cand && llvm::is_contained(Counted, Cand))
2128           break;
2129 
2130         for (auto &Succ : Pred.getSUnit()->Succs) {
2131           auto MI = Succ.getSUnit()->getInstr();
2132           if (!TII->isVMEM(*MI) || !MI->mayLoad())
2133             continue;
2134 
2135           if (MissedAny || !VMEMLookup.size()) {
2136             MissedAny = true;
2137             VMEMLookup[MI] = *I;
2138             continue;
2139           }
2140 
2141           if (!VMEMLookup.contains(MI)) {
2142             MissedAny = true;
2143             VMEMLookup[MI] = *I;
2144             continue;
2145           }
2146 
2147           Cand = VMEMLookup[MI];
2148           if (llvm::is_contained(Counted, Cand)) {
2149             MissedAny = true;
2150             break;
2151           }
2152         }
2153       }
2154       if (!MissedAny && Cand) {
2155         DSWWithSharedVMEMCount += 2;
2156         Counted.push_back(Cand);
2157         Counted.push_back(*I);
2158       }
2159     }
2160   }
2161 
2162   assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
2163   SchedGroup *SG;
2164   unsigned PipelineSyncID = 0;
2165   // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
2166   if (DSWWithPermCount) {
2167     for (unsigned I = 0; I < MFMACount; I++) {
2168       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2169           SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2170       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2171 
2172       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2173           SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
2174       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2175     }
2176   }
2177 
2178   PipelineSyncID = 1;
2179   // Phase 1: Break up DS_READ and MFMA clusters.
2180   // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
2181   // prefetch
2182 
2183   // Make ready initial MFMA
2184   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2185       SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
2186   SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
2187   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2188 
2189   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2190       SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2191   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2192 
2193   // Interleave MFMA with DS_READ prefetch
2194   for (unsigned I = 0; I < DSRCount - 4; ++I) {
2195     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2196         SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
2197     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2198 
2199     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2200         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2201     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2202   }
2203 
2204   // Phase 2a: Loop carried dependency with V_PERM
2205   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2206   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2207   for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
2208     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2209         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2210     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2211     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2212 
2213     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2214         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2215     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2216     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2217 
2218     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2219         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2220     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2221         1, TII, SG->getSGID(), true));
2222     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2223     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2224 
2225     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2226         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2227     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2228 
2229     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2230         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2231     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2232         3, TII, SG->getSGID(), true));
2233     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2234     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2235 
2236     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2237         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2238     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2239   }
2240 
2241   // Phase 2b: Loop carried dependency without V_PERM
2242   // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
2243   // Interleave MFMA to keep XDL unit busy throughout.
2244   for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
2245     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2246         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2247     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2248 
2249     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2250         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2251     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2252     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2253 
2254     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2255         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2256     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2257   }
2258 
2259   // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
2260   // ultimately used by two DS_WRITE
2261   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
2262   // depend on. Interleave MFMA to keep XDL unit busy throughout.
2263 
2264   for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
2265     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2266         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2267     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2268     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2269 
2270     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2271         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2272     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2273     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2274 
2275     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2276         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2277     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2278 
2279     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2280         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
2281     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
2282     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2283 
2284     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2285         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
2286     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID()));
2287     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2288 
2289     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2290         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2291     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2292 
2293     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2294         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2295     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2296         2, TII, SG->getSGID(), true));
2297     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2298     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2299 
2300     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2301         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2302     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2303 
2304     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2305         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
2306     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
2307         4, TII, SG->getSGID(), true));
2308     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID()));
2309     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2310 
2311     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
2312         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
2313     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
2314   }
2315 
2316   return true;
2317 }
2318 
2319 static std::unique_ptr<IGLPStrategy>
2320 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
2321                    const SIInstrInfo *TII) {
2322   switch (ID) {
2323   case MFMASmallGemmOptID:
2324     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
2325   case MFMASmallGemmSingleWaveOptID:
2326     return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
2327   case MFMAExpInterleave:
2328     return std::make_unique<MFMAExpInterleaveOpt>(DAG, TII);
2329   }
2330 
2331   llvm_unreachable("Unknown IGLPStrategyID");
2332 }
2333 
2334 class IGroupLPDAGMutation : public ScheduleDAGMutation {
2335 private:
2336   const SIInstrInfo *TII;
2337 
2338   ScheduleDAGMI *DAG;
2339 
2340   // Organize lists of SchedGroups by their SyncID. SchedGroups /
2341   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
2342   // between then.
2343   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
2344 
2345   // Used to track instructions that can be mapped to multiple sched groups
2346   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
2347 
2348   // Add DAG edges that enforce SCHED_BARRIER ordering.
2349   void addSchedBarrierEdges(SUnit &SU);
2350 
2351   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
2352   // not be reordered accross the SCHED_BARRIER. This is used for the base
2353   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
2354   // SCHED_BARRIER will always block all instructions that can be classified
2355   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
2356   // and may only synchronize with some SchedGroups. Returns the inverse of
2357   // Mask. SCHED_BARRIER's mask describes which instruction types should be
2358   // allowed to be scheduled across it. Invert the mask to get the
2359   // SchedGroupMask of instructions that should be barred.
2360   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
2361 
2362   // Create SchedGroups for a SCHED_GROUP_BARRIER.
2363   void initSchedGroupBarrierPipelineStage(
2364       std::vector<SUnit>::reverse_iterator RIter);
2365 
2366   bool initIGLPOpt(SUnit &SU);
2367 
2368 public:
2369   void apply(ScheduleDAGInstrs *DAGInstrs) override;
2370 
2371   // The order in which the PipelineSolver should process the candidate
2372   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
2373   // created SchedGroup first, and will consider that as the ultimate
2374   // predecessor group when linking. TOP_DOWN instead links and processes the
2375   // first created SchedGroup first.
2376   bool IsBottomUp = 1;
2377 
2378   // The scheduling phase this application of IGLP corresponds with.
2379   AMDGPU::SchedulingPhase Phase = AMDGPU::SchedulingPhase::Initial;
2380 
2381   IGroupLPDAGMutation() = default;
2382   IGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) : Phase(Phase) {}
2383 };
2384 
2385 unsigned SchedGroup::NumSchedGroups = 0;
2386 
2387 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
2388   if (A != B && DAG->canAddEdge(B, A)) {
2389     DAG->addEdge(B, SDep(A, SDep::Artificial));
2390     return true;
2391   }
2392   return false;
2393 }
2394 
2395 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
2396   bool Result = false;
2397   if (MI.isMetaInstruction())
2398     Result = false;
2399 
2400   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
2401            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI) ||
2402             TII->isTRANS(MI)))
2403     Result = true;
2404 
2405   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
2406            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI) && !TII->isTRANS(MI))
2407     Result = true;
2408 
2409   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
2410            TII->isSALU(MI))
2411     Result = true;
2412 
2413   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
2414            TII->isMFMAorWMMA(MI))
2415     Result = true;
2416 
2417   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
2418            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2419     Result = true;
2420 
2421   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
2422            MI.mayLoad() &&
2423            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2424     Result = true;
2425 
2426   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
2427            MI.mayStore() &&
2428            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
2429     Result = true;
2430 
2431   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
2432            TII->isDS(MI))
2433     Result = true;
2434 
2435   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
2436            MI.mayLoad() && TII->isDS(MI))
2437     Result = true;
2438 
2439   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
2440            MI.mayStore() && TII->isDS(MI))
2441     Result = true;
2442 
2443   else if (((SGMask & SchedGroupMask::TRANS) != SchedGroupMask::NONE) &&
2444            TII->isTRANS(MI))
2445     Result = true;
2446 
2447   LLVM_DEBUG(
2448       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
2449              << (Result ? " could classify " : " unable to classify ") << MI);
2450 
2451   return Result;
2452 }
2453 
2454 int SchedGroup::link(SUnit &SU, bool MakePred,
2455                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
2456   int MissedEdges = 0;
2457   for (auto *A : Collection) {
2458     SUnit *B = &SU;
2459     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2460       continue;
2461     if (MakePred)
2462       std::swap(A, B);
2463 
2464     if (DAG->IsReachable(B, A))
2465       continue;
2466 
2467     // tryAddEdge returns false if there is a dependency that makes adding
2468     // the A->B edge impossible, otherwise it returns true;
2469     bool Added = tryAddEdge(A, B);
2470     if (Added)
2471       AddedEdges.push_back(std::pair(A, B));
2472     else
2473       ++MissedEdges;
2474   }
2475 
2476   return MissedEdges;
2477 }
2478 
2479 void SchedGroup::link(SUnit &SU, bool MakePred) {
2480   for (auto *A : Collection) {
2481     SUnit *B = &SU;
2482     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
2483       continue;
2484     if (MakePred)
2485       std::swap(A, B);
2486 
2487     tryAddEdge(A, B);
2488   }
2489 }
2490 
2491 void SchedGroup::link(SUnit &SU,
2492                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
2493   for (auto *A : Collection) {
2494     SUnit *B = &SU;
2495     if (P(A, B))
2496       std::swap(A, B);
2497 
2498     tryAddEdge(A, B);
2499   }
2500 }
2501 
2502 void SchedGroup::link(SchedGroup &OtherGroup) {
2503   for (auto *B : OtherGroup.Collection)
2504     link(*B);
2505 }
2506 
2507 bool SchedGroup::canAddSU(SUnit &SU) const {
2508   MachineInstr &MI = *SU.getInstr();
2509   if (MI.getOpcode() != TargetOpcode::BUNDLE)
2510     return canAddMI(MI);
2511 
2512   // Special case for bundled MIs.
2513   const MachineBasicBlock *MBB = MI.getParent();
2514   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
2515   while (E != MBB->end() && E->isBundledWithPred())
2516     ++E;
2517 
2518   // Return true if all of the bundled MIs can be added to this group.
2519   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
2520 }
2521 
2522 void SchedGroup::initSchedGroup() {
2523   for (auto &SU : DAG->SUnits) {
2524     if (isFull())
2525       break;
2526 
2527     if (canAddSU(SU))
2528       add(SU);
2529   }
2530 }
2531 
2532 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
2533                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
2534   SUnit &InitSU = *RIter;
2535   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
2536     auto &SU = *RIter;
2537     if (isFull())
2538       break;
2539 
2540     if (canAddSU(SU))
2541       SyncedInstrs[&SU].push_back(SGID);
2542   }
2543 
2544   add(InitSU);
2545   assert(MaxSize);
2546   (*MaxSize)++;
2547 }
2548 
2549 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
2550   auto I = DAG->SUnits.rbegin();
2551   auto E = DAG->SUnits.rend();
2552   for (; I != E; ++I) {
2553     auto &SU = *I;
2554     if (isFull())
2555       break;
2556     if (canAddSU(SU))
2557       SyncedInstrs[&SU].push_back(SGID);
2558   }
2559 }
2560 
2561 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
2562   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
2563   if (!TSchedModel || DAGInstrs->SUnits.empty())
2564     return;
2565 
2566   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
2567   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
2568   TII = ST.getInstrInfo();
2569   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
2570   SyncedSchedGroups.clear();
2571   SyncedInstrs.clear();
2572   bool FoundSB = false;
2573   bool FoundIGLP = false;
2574   bool ShouldApplyIGLP = false;
2575   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
2576     unsigned Opc = R->getInstr()->getOpcode();
2577     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
2578     if (Opc == AMDGPU::SCHED_BARRIER) {
2579       addSchedBarrierEdges(*R);
2580       FoundSB = true;
2581     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
2582       initSchedGroupBarrierPipelineStage(R);
2583       FoundSB = true;
2584     } else if (Opc == AMDGPU::IGLP_OPT) {
2585       resetEdges(*R, DAG);
2586       if (!FoundSB && !FoundIGLP) {
2587         FoundIGLP = true;
2588         ShouldApplyIGLP = initIGLPOpt(*R);
2589       }
2590     }
2591   }
2592 
2593   if (FoundSB || (FoundIGLP && ShouldApplyIGLP)) {
2594     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
2595     // PipelineSolver performs the mutation by adding the edges it
2596     // determined as the best
2597     PS.solve();
2598     return;
2599   }
2600 }
2601 
2602 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
2603   MachineInstr &MI = *SchedBarrier.getInstr();
2604   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
2605   // Remove all existing edges from the SCHED_BARRIER that were added due to the
2606   // instruction having side effects.
2607   resetEdges(SchedBarrier, DAG);
2608   LLVM_DEBUG(dbgs() << "Building SchedGroup for SchedBarrier with Mask: "
2609                     << MI.getOperand(0).getImm() << "\n");
2610   auto InvertedMask =
2611       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
2612   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
2613   SG.initSchedGroup();
2614 
2615   // Preserve original instruction ordering relative to the SCHED_BARRIER.
2616   SG.link(
2617       SchedBarrier,
2618       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
2619           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
2620 }
2621 
2622 SchedGroupMask
2623 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
2624   // Invert mask and erase bits for types of instructions that are implied to be
2625   // allowed past the SCHED_BARRIER.
2626   SchedGroupMask InvertedMask = ~Mask;
2627 
2628   // ALU implies VALU, SALU, MFMA, TRANS.
2629   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
2630     InvertedMask &= ~SchedGroupMask::VALU & ~SchedGroupMask::SALU &
2631                     ~SchedGroupMask::MFMA & ~SchedGroupMask::TRANS;
2632   // VALU, SALU, MFMA, TRANS implies ALU.
2633   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
2634            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
2635            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE ||
2636            (InvertedMask & SchedGroupMask::TRANS) == SchedGroupMask::NONE)
2637     InvertedMask &= ~SchedGroupMask::ALU;
2638 
2639   // VMEM implies VMEM_READ, VMEM_WRITE.
2640   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
2641     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
2642   // VMEM_READ, VMEM_WRITE implies VMEM.
2643   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
2644            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
2645     InvertedMask &= ~SchedGroupMask::VMEM;
2646 
2647   // DS implies DS_READ, DS_WRITE.
2648   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
2649     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
2650   // DS_READ, DS_WRITE implies DS.
2651   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
2652            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
2653     InvertedMask &= ~SchedGroupMask::DS;
2654 
2655   LLVM_DEBUG(dbgs() << "After Inverting, SchedGroup Mask: " << (int)InvertedMask
2656                     << "\n");
2657 
2658   return InvertedMask;
2659 }
2660 
2661 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
2662     std::vector<SUnit>::reverse_iterator RIter) {
2663   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
2664   // to the instruction having side effects.
2665   resetEdges(*RIter, DAG);
2666   MachineInstr &SGB = *RIter->getInstr();
2667   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
2668   int32_t SGMask = SGB.getOperand(0).getImm();
2669   int32_t Size = SGB.getOperand(1).getImm();
2670   int32_t SyncID = SGB.getOperand(2).getImm();
2671 
2672   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
2673                                                     Size, SyncID, DAG, TII);
2674 
2675   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
2676 }
2677 
2678 bool IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
2679   IGLPStrategyID StrategyID =
2680       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
2681   auto S = createIGLPStrategy(StrategyID, DAG, TII);
2682   if (!S->shouldApplyStrategy(DAG, Phase))
2683     return false;
2684 
2685   IsBottomUp = S->IsBottomUp;
2686   return S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups, Phase);
2687 }
2688 
2689 } // namespace
2690 
2691 namespace llvm {
2692 
2693 /// \p Phase specifes whether or not this is a reentry into the
2694 /// IGroupLPDAGMutation. Since there may be multiple scheduling passes on the
2695 /// same scheduling region (e.g. pre and post-RA scheduling / multiple
2696 /// scheduling "phases"), we can reenter this mutation framework more than once
2697 /// for a given region.
2698 std::unique_ptr<ScheduleDAGMutation>
2699 createIGroupLPDAGMutation(AMDGPU::SchedulingPhase Phase) {
2700   return std::make_unique<IGroupLPDAGMutation>(Phase);
2701 }
2702 
2703 } // end namespace llvm
2704