xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision 1721e72d6e6d0c18ac36155b1f89fd81f45994db)
1 //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file This file defines a set of schedule DAG mutations that can be used to
10 // override default scheduler behavior to enforce specific scheduling patterns.
11 // They should be used in cases where runtime performance considerations such as
12 // inter-wavefront interactions, mean that compile-time heuristics cannot
13 // predict the optimal instruction ordering, or in kernels where optimum
14 // instruction scheduling is important enough to warrant manual intervention.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "AMDGPUIGroupLP.h"
19 #include "AMDGPUTargetMachine.h"
20 #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
21 #include "SIInstrInfo.h"
22 #include "SIMachineFunctionInfo.h"
23 #include "llvm/ADT/BitmaskEnum.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/CodeGen/MachineScheduler.h"
26 #include "llvm/CodeGen/TargetOpcodes.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "igrouplp"
31 
32 namespace {
33 
34 static cl::opt<bool> EnableExactSolver(
35     "amdgpu-igrouplp-exact-solver", cl::Hidden,
36     cl::desc("Whether to use the exponential time solver to fit "
37              "the instructions to the pipeline as closely as "
38              "possible."),
39     cl::init(false));
40 
41 static cl::opt<unsigned> CutoffForExact(
42     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43     cl::desc("The maximum number of scheduling group conflicts "
44              "which we attempt to solve with the exponential time "
45              "exact solver. Problem sizes greater than this will"
46              "be solved by the less accurate greedy algorithm. Selecting "
47              "solver by size is superseded by manually selecting "
48              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
49 
50 static cl::opt<uint64_t> MaxBranchesExplored(
51     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52     cl::desc("The amount of branches that we are willing to explore with"
53              "the exact algorithm before giving up."));
54 
55 static cl::opt<bool> UseCostHeur(
56     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57     cl::desc("Whether to use the cost heuristic to make choices as we "
58              "traverse the search space using the exact solver. Defaulted "
59              "to on, and if turned off, we will use the node order -- "
60              "attempting to put the later nodes in the later sched groups. "
61              "Experimentally, results are mixed, so this should be set on a "
62              "case-by-case basis."));
63 
64 // Components of the mask that determines which instruction types may be may be
65 // classified into a SchedGroup.
66 enum class SchedGroupMask {
67   NONE = 0u,
68   ALU = 1u << 0,
69   VALU = 1u << 1,
70   SALU = 1u << 2,
71   MFMA = 1u << 3,
72   VMEM = 1u << 4,
73   VMEM_READ = 1u << 5,
74   VMEM_WRITE = 1u << 6,
75   DS = 1u << 7,
76   DS_READ = 1u << 8,
77   DS_WRITE = 1u << 9,
78   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79         DS_READ | DS_WRITE,
80   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
81 };
82 
83 typedef DenseMap<SUnit *, SmallVector<int, 4>> SUnitsToCandidateSGsMap;
84 
85 // Classify instructions into groups to enable fine tuned control over the
86 // scheduler. These groups may be more specific than current SchedModel
87 // instruction classes.
88 class SchedGroup {
89 private:
90   // Mask that defines which instruction types can be classified into this
91   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
92   // and SCHED_GROUP_BARRIER.
93   SchedGroupMask SGMask;
94 
95   // Maximum number of SUnits that can be added to this group.
96   std::optional<unsigned> MaxSize;
97 
98   // SchedGroups will only synchronize with other SchedGroups that have the same
99   // SyncID.
100   int SyncID = 0;
101 
102   // SGID is used to map instructions to candidate SchedGroups
103   unsigned SGID;
104 
105   // Count of the number of created SchedGroups, used to initialize SGID.
106   static unsigned NumSchedGroups;
107 
108   ScheduleDAGInstrs *DAG;
109 
110   const SIInstrInfo *TII;
111 
112   // Try to add and edge from SU A to SU B.
113   bool tryAddEdge(SUnit *A, SUnit *B);
114 
115   // Use SGMask to determine whether we can classify MI as a member of this
116   // SchedGroup object.
117   bool canAddMI(const MachineInstr &MI) const;
118 
119 public:
120   // Collection of SUnits that are classified as members of this group.
121   SmallVector<SUnit *, 32> Collection;
122 
123   // Returns true if SU can be added to this SchedGroup.
124   bool canAddSU(SUnit &SU) const;
125 
126   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
127   // MakePred is true, SU will be a predecessor of the SUnits in this
128   // SchedGroup, otherwise SU will be a successor.
129   void link(SUnit &SU, bool MakePred = false);
130 
131   // Add DAG dependencies and track which edges are added, and the count of
132   // missed edges
133   int link(SUnit &SU, bool MakePred,
134            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
135 
136   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
137   // Use the predicate to determine whether SU should be a predecessor (P =
138   // true) or a successor (P = false) of this SchedGroup.
139   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
140 
141   // Add DAG dependencies such that SUnits in this group shall be ordered
142   // before SUnits in OtherGroup.
143   void link(SchedGroup &OtherGroup);
144 
145   // Returns true if no more instructions may be added to this group.
146   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
147 
148   // Add SU to the SchedGroup.
149   void add(SUnit &SU) {
150     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
151                       << format_hex((int)SGMask, 10, true) << " adding "
152                       << *SU.getInstr());
153     Collection.push_back(&SU);
154   }
155 
156   // Remove last element in the SchedGroup
157   void pop() { Collection.pop_back(); }
158 
159   // Identify and add all relevant SUs from the DAG to this SchedGroup.
160   void initSchedGroup();
161 
162   // Add instructions to the SchedGroup bottom up starting from RIter.
163   // PipelineInstrs is a set of instructions that should not be added to the
164   // SchedGroup even when the other conditions for adding it are satisfied.
165   // RIter will be added to the SchedGroup as well, and dependencies will be
166   // added so that RIter will always be scheduled at the end of the group.
167   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
168                       SUnitsToCandidateSGsMap &SyncedInstrs);
169 
170   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
171 
172   int getSyncID() { return SyncID; }
173 
174   int getSGID() { return SGID; }
175 
176   SchedGroupMask getMask() { return SGMask; }
177 
178   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
179              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
180       : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) {
181     SGID = NumSchedGroups++;
182   }
183 
184   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
185              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
186       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) {
187     SGID = NumSchedGroups++;
188   }
189 };
190 
191 // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
192 static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
193   assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
194          SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
195          SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
196 
197   while (!SU.Preds.empty())
198     for (auto &P : SU.Preds)
199       SU.removePred(P);
200 
201   while (!SU.Succs.empty())
202     for (auto &S : SU.Succs)
203       for (auto &SP : S.getSUnit()->Preds)
204         if (SP.getSUnit() == &SU)
205           S.getSUnit()->removePred(SP);
206 }
207 
208 typedef std::pair<SUnit *, SmallVector<int, 4>> SUToCandSGsPair;
209 typedef SmallVector<SUToCandSGsPair, 4> SUsToCandSGsVec;
210 
211 // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
212 // in non-trivial cases. For example, if the requested pipeline is
213 // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
214 // in the DAG, then we will have an instruction that can not be trivially
215 // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
216 // to find a good solution to the pipeline -- a greedy algorithm and an exact
217 // algorithm. The exact algorithm has an exponential time complexity and should
218 // only be used for small sized problems or medium sized problems where an exact
219 // solution is highly desired.
220 class PipelineSolver {
221   ScheduleDAGMI *DAG;
222 
223   // Instructions that can be assigned to multiple SchedGroups
224   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
225   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
226   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
227   // The current working pipeline
228   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
229   // The pipeline that has the best solution found so far
230   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
231 
232   // Whether or not we actually have any SyncedInstrs to try to solve.
233   bool NeedsSolver = false;
234 
235   // Compute an estimate of the size of search tree -- the true size is
236   // the product of each conflictedInst.Matches.size() across all SyncPipelines
237   unsigned computeProblemSize();
238 
239   // The cost penalty of not assigning a SU to a SchedGroup
240   int MissPenalty = 0;
241 
242   // Costs in terms of the number of edges we are unable to add
243   int BestCost = -1;
244   int CurrCost = 0;
245 
246   // Index pointing to the conflicting instruction that is currently being
247   // fitted
248   int CurrConflInstNo = 0;
249   // Index to the pipeline that is currently being fitted
250   int CurrSyncGroupIdx = 0;
251   // The first non trivial pipeline
252   int BeginSyncGroupIdx = 0;
253 
254   // How many branches we have explored
255   uint64_t BranchesExplored = 0;
256 
257   // The direction in which we process the candidate SchedGroups per SU
258   bool IsBottomUp = 1;
259 
260   // Update indices to fit next conflicting instruction
261   void advancePosition();
262   // Recede indices to attempt to find better fit for previous conflicting
263   // instruction
264   void retreatPosition();
265 
266   // The exponential time algorithm which finds the provably best fit
267   bool solveExact();
268   // The polynomial time algorithm which attempts to find a good fit
269   bool solveGreedy();
270   // Find the best SchedGroup for the current SU using the heuristic given all
271   // current information. One step in the greedy algorithm. Templated against
272   // the SchedGroup iterator (either reverse or forward).
273   template <typename T>
274   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
275                   T E);
276   // Whether or not the current solution is optimal
277   bool checkOptimal();
278   // Populate the ready list, prioiritizing fewest missed edges first
279   // Templated against the SchedGroup iterator (either reverse or forward).
280   template <typename T>
281   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
282                          T E);
283   // Add edges corresponding to the SchedGroups as assigned by solver
284   void makePipeline();
285   // Link the SchedGroups in the best found pipeline.
286   // Tmplated against the SchedGroup iterator (either reverse or forward).
287   template <typename T> void linkSchedGroups(T I, T E);
288   // Add the edges from the SU to the other SchedGroups in pipeline, and
289   // return the number of edges missed.
290   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
291                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
292   // Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
293   // returns the cost (in terms of missed pipeline edges), and tracks the edges
294   // added in \p AddedEdges
295   template <typename T>
296   int linkSUnit(SUnit *SU, int SGID,
297                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
298   // Remove the edges passed via \p AddedEdges
299   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
300   // Convert the passed in maps to arrays for bidirectional iterators
301   void convertSyncMapsToArrays();
302 
303   void reset();
304 
305 public:
306   // Invoke the solver to map instructions to instruction groups. Heuristic &&
307   // command-line-option determines to use exact or greedy algorithm.
308   void solve();
309 
310   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
311                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
312                  ScheduleDAGMI *DAG, bool IsBottomUp = 1)
313       : DAG(DAG), SyncedInstrs(SyncedInstrs),
314         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
315 
316     for (auto &PipelineInstrs : SyncedInstrs) {
317       if (PipelineInstrs.second.size() > 0) {
318         NeedsSolver = true;
319         break;
320       }
321     }
322 
323     if (!NeedsSolver)
324       return;
325 
326     convertSyncMapsToArrays();
327 
328     CurrPipeline = BestPipeline;
329 
330     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
331            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
332       ++BeginSyncGroupIdx;
333 
334     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
335       return;
336   }
337 };
338 
339 void PipelineSolver::reset() {
340 
341   for (auto &SyncPipeline : CurrPipeline) {
342     for (auto &SG : SyncPipeline) {
343       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
344       SG.Collection.clear();
345       auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
346         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
347       });
348       if (SchedBarr != TempCollection.end())
349         SG.Collection.push_back(*SchedBarr);
350     }
351   }
352 
353   CurrSyncGroupIdx = BeginSyncGroupIdx;
354   CurrConflInstNo = 0;
355   CurrCost = 0;
356 }
357 
358 void PipelineSolver::convertSyncMapsToArrays() {
359   for (auto &SyncPipe : SyncedSchedGroups) {
360     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
361   }
362 
363   int PipelineIDx = SyncedInstrs.size() - 1;
364   PipelineInstrs.resize(SyncedInstrs.size());
365   for (auto &SyncInstrMap : SyncedInstrs) {
366     for (auto &SUsToCandSGs : SyncInstrMap.second) {
367       if (PipelineInstrs[PipelineIDx].size() == 0) {
368         PipelineInstrs[PipelineIDx].push_back(
369             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
370         continue;
371       }
372       auto SortPosition = PipelineInstrs[PipelineIDx].begin();
373       // Insert them in sorted order -- this allows for good parsing order in
374       // the greedy algorithm
375       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
376              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
377         ++SortPosition;
378       PipelineInstrs[PipelineIDx].insert(
379           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
380     }
381     --PipelineIDx;
382   }
383 }
384 
385 template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
386   for (; I != E; ++I) {
387     auto &GroupA = *I;
388     for (auto J = std::next(I); J != E; ++J) {
389       auto &GroupB = *J;
390       GroupA.link(GroupB);
391     }
392   }
393 }
394 
395 void PipelineSolver::makePipeline() {
396   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
397   for (auto &SyncPipeline : BestPipeline) {
398     for (auto &SG : SyncPipeline) {
399       LLVM_DEBUG(dbgs() << "Printing SchedGroups\nSchedGroup with SGID "
400                         << SG.getSGID() << " has: \n");
401       SUnit *SGBarr = nullptr;
402       for (auto &SU : SG.Collection) {
403         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
404           SGBarr = SU;
405         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
406       }
407       // Command line requested IGroupLP doesn't have SGBarr
408       if (!SGBarr)
409         continue;
410       resetEdges(*SGBarr, DAG);
411       SG.link(*SGBarr, false);
412     }
413   }
414 
415   for (auto &SyncPipeline : BestPipeline) {
416     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
417                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
418   }
419 }
420 
421 template <typename T>
422 int PipelineSolver::linkSUnit(
423     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
424     T I, T E) {
425   bool MakePred = false;
426   int AddedCost = 0;
427   for (; I < E; ++I) {
428     if (I->getSGID() == SGID) {
429       MakePred = true;
430       continue;
431     }
432     auto Group = *I;
433     AddedCost += Group.link(*SU, MakePred, AddedEdges);
434     assert(AddedCost >= 0);
435   }
436   return AddedCost;
437 }
438 
439 int PipelineSolver::addEdges(
440     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
441     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
442 
443   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
444   // instructions that are the ultimate successors in the resultant mutation.
445   // Therefore, in such a configuration, the SchedGroups occurring before the
446   // candidate SGID are successors of the candidate SchedGroup, thus the current
447   // SU should be linked as a predecessor to SUs in those SchedGroups. The
448   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
449   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
450   // IsBottomUp (in reverse).
451   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
452                                 SyncPipeline.rend())
453                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
454                                 SyncPipeline.end());
455 }
456 
457 void PipelineSolver::removeEdges(
458     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
459   // Only remove the edges that we have added when testing
460   // the fit.
461   for (auto &PredSuccPair : EdgesToRemove) {
462     SUnit *Pred = PredSuccPair.first;
463     SUnit *Succ = PredSuccPair.second;
464 
465     auto Match = llvm::find_if(
466         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
467     if (Match != Succ->Preds.end()) {
468       assert(Match->isArtificial());
469       Succ->removePred(*Match);
470     }
471   }
472 }
473 
474 void PipelineSolver::advancePosition() {
475   ++CurrConflInstNo;
476 
477   if (static_cast<size_t>(CurrConflInstNo) >=
478       PipelineInstrs[CurrSyncGroupIdx].size()) {
479     CurrConflInstNo = 0;
480     ++CurrSyncGroupIdx;
481     // Advance to next non-trivial pipeline
482     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
483            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
484       ++CurrSyncGroupIdx;
485   }
486 }
487 
488 void PipelineSolver::retreatPosition() {
489   assert(CurrConflInstNo >= 0);
490   assert(CurrSyncGroupIdx >= 0);
491 
492   if (CurrConflInstNo > 0) {
493     --CurrConflInstNo;
494     return;
495   }
496 
497   if (CurrConflInstNo == 0) {
498     // If we return to the starting position, we have explored
499     // the entire tree
500     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
501       return;
502 
503     --CurrSyncGroupIdx;
504     // Go to previous non-trivial pipeline
505     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
506       --CurrSyncGroupIdx;
507 
508     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
509   }
510 }
511 
512 bool PipelineSolver::checkOptimal() {
513   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
514     if (BestCost == -1 || CurrCost < BestCost) {
515       BestPipeline = CurrPipeline;
516       BestCost = CurrCost;
517       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
518     }
519     assert(BestCost >= 0);
520   }
521 
522   bool DoneExploring = false;
523   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
524     DoneExploring = true;
525 
526   return (DoneExploring || BestCost == 0);
527 }
528 
529 template <typename T>
530 void PipelineSolver::populateReadyList(
531     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
532   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
533   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
534   assert(CurrSU.second.size() >= 1);
535 
536   for (; I != E; ++I) {
537     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
538     int CandSGID = *I;
539     SchedGroup *Match;
540     for (auto &SG : SyncPipeline) {
541       if (SG.getSGID() == CandSGID)
542         Match = &SG;
543     }
544 
545     if (UseCostHeur) {
546       if (Match->isFull()) {
547         ReadyList.push_back(std::pair(*I, MissPenalty));
548         continue;
549       }
550 
551       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
552       ReadyList.push_back(std::pair(*I, TempCost));
553       removeEdges(AddedEdges);
554     } else
555       ReadyList.push_back(std::pair(*I, -1));
556   }
557 
558   if (UseCostHeur) {
559     std::sort(ReadyList.begin(), ReadyList.end(),
560               [](std::pair<int, int> A, std::pair<int, int> B) {
561                 return A.second < B.second;
562               });
563   }
564 
565   assert(ReadyList.size() == CurrSU.second.size());
566 }
567 
568 bool PipelineSolver::solveExact() {
569   if (checkOptimal())
570     return true;
571 
572   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
573     return false;
574 
575   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
576   assert(static_cast<size_t>(CurrConflInstNo) <
577          PipelineInstrs[CurrSyncGroupIdx].size());
578   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
579   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
580                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
581 
582   // SchedGroup -> Cost pairs
583   SmallVector<std::pair<int, int>, 4> ReadyList;
584   // Prioritize the candidate sched groups in terms of lowest cost first
585   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
586                                  CurrSU.second.rend())
587              : populateReadyList(ReadyList, CurrSU.second.begin(),
588                                  CurrSU.second.end());
589 
590   auto I = ReadyList.begin();
591   auto E = ReadyList.end();
592   for (; I != E; ++I) {
593     // If we are trying SGs in least cost order, and the current SG is cost
594     // infeasible, then all subsequent SGs will also be cost infeasible, so we
595     // can prune.
596     if (BestCost != -1 && (CurrCost + I->second > BestCost))
597       return false;
598 
599     int CandSGID = I->first;
600     int AddedCost = 0;
601     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
602     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
603     SchedGroup *Match;
604     for (auto &SG : SyncPipeline) {
605       if (SG.getSGID() == CandSGID)
606         Match = &SG;
607     }
608 
609     if (Match->isFull())
610       continue;
611 
612     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
613                       << (int)Match->getMask() << "and ID " << CandSGID
614                       << "\n");
615     Match->add(*CurrSU.first);
616     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
617     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
618     CurrCost += AddedCost;
619     advancePosition();
620     ++BranchesExplored;
621     bool FinishedExploring = false;
622     // If the Cost after adding edges is greater than a known solution,
623     // backtrack
624     if (CurrCost < BestCost || BestCost == -1) {
625       if (solveExact()) {
626         FinishedExploring = BestCost != 0;
627         if (!FinishedExploring)
628           return true;
629       }
630     }
631 
632     retreatPosition();
633     CurrCost -= AddedCost;
634     removeEdges(AddedEdges);
635     Match->pop();
636     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
637     if (FinishedExploring)
638       return true;
639   }
640 
641   // Try the pipeline where the current instruction is omitted
642   // Potentially if we omit a problematic instruction from the pipeline,
643   // all the other instructions can nicely fit.
644   CurrCost += MissPenalty;
645   advancePosition();
646 
647   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
648 
649   bool FinishedExploring = false;
650   if (CurrCost < BestCost || BestCost == -1) {
651     if (solveExact()) {
652       bool FinishedExploring = BestCost != 0;
653       if (!FinishedExploring)
654         return true;
655     }
656   }
657 
658   retreatPosition();
659   CurrCost -= MissPenalty;
660   return FinishedExploring;
661 }
662 
663 template <typename T>
664 void PipelineSolver::greedyFind(
665     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
666   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
667   int BestNodeCost = -1;
668   int TempCost;
669   SchedGroup *BestGroup = nullptr;
670   int BestGroupID = -1;
671   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
672   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
673                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
674 
675   // Since we have added the potential SchedGroups from bottom up, but
676   // traversed the DAG from top down, parse over the groups from last to
677   // first. If we fail to do this for the greedy algorithm, the solution will
678   // likely not be good in more complex cases.
679   for (; I != E; ++I) {
680     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
681     int CandSGID = *I;
682     SchedGroup *Match;
683     for (auto &SG : SyncPipeline) {
684       if (SG.getSGID() == CandSGID)
685         Match = &SG;
686     }
687 
688     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
689                       << (int)Match->getMask() << "\n");
690 
691     if (Match->isFull()) {
692       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
693       continue;
694     }
695     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
696     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
697     if (TempCost < BestNodeCost || BestNodeCost == -1) {
698       BestGroup = Match;
699       BestNodeCost = TempCost;
700       BestGroupID = CandSGID;
701     }
702     removeEdges(AddedEdges);
703     if (BestNodeCost == 0)
704       break;
705   }
706 
707   if (BestGroupID != -1) {
708     BestGroup->add(*CurrSU.first);
709     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
710     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
711                       << (int)BestGroup->getMask() << "\n");
712     BestCost += TempCost;
713   } else
714     BestCost += MissPenalty;
715 
716   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
717 }
718 
719 bool PipelineSolver::solveGreedy() {
720   BestCost = 0;
721   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
722 
723   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
724     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
725     IsBottomUp
726         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
727         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
728     advancePosition();
729   }
730   BestPipeline = CurrPipeline;
731   removeEdges(AddedEdges);
732   return false;
733 }
734 
735 unsigned PipelineSolver::computeProblemSize() {
736   unsigned ProblemSize = 0;
737   for (auto &PipeConflicts : PipelineInstrs) {
738     ProblemSize += PipeConflicts.size();
739   }
740 
741   return ProblemSize;
742 }
743 
744 void PipelineSolver::solve() {
745   if (!NeedsSolver)
746     return;
747 
748   unsigned ProblemSize = computeProblemSize();
749   assert(ProblemSize > 0);
750 
751   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
752   MissPenalty = (ProblemSize / 2) + 1;
753 
754   LLVM_DEBUG(DAG->dump());
755   if (EnableExactSolver || BelowCutoff) {
756     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
757     solveGreedy();
758     reset();
759     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
760     if (BestCost > 0) {
761       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
762       solveExact();
763       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
764     }
765   } else { // Use the Greedy Algorithm by default
766     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
767     solveGreedy();
768   }
769 
770   makePipeline();
771   LLVM_DEBUG(dbgs() << "After applying mutation\n");
772   LLVM_DEBUG(DAG->dump());
773 }
774 
775 enum IGLPStrategyID : int { MFMASmallGemmOptID = 0, DemoOptID = 1 };
776 
777 // Implement a IGLP scheduling strategy.
778 class IGLPStrategy {
779 protected:
780   ScheduleDAGInstrs *DAG;
781 
782   const SIInstrInfo *TII;
783 
784 public:
785   // Add SchedGroups to \p Pipeline to implement this Strategy.
786   virtual void applyIGLPStrategy(
787       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
788       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) = 0;
789 
790   // Returns true if this strategy should be applied to a ScheduleDAG.
791   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) = 0;
792 
793   bool IsBottomUp = 1;
794 
795   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
796       : DAG(DAG), TII(TII) {}
797 
798   virtual ~IGLPStrategy() = default;
799 };
800 
801 class MFMASmallGemmOpt final : public IGLPStrategy {
802 private:
803 public:
804   void applyIGLPStrategy(
805       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
806       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) override;
807 
808   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
809 
810   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
811       : IGLPStrategy(DAG, TII) {
812     IsBottomUp = 1;
813   }
814 };
815 
816 void MFMASmallGemmOpt::applyIGLPStrategy(
817     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
818     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) {
819   // Count the number of MFMA instructions.
820   unsigned MFMACount = 0;
821   for (const MachineInstr &I : *DAG)
822     if (TII->isMFMAorWMMA(I))
823       ++MFMACount;
824 
825   const unsigned PipelineSyncID = 0;
826   SchedGroup *SG = nullptr;
827   for (unsigned I = 0; I < MFMACount * 3; ++I) {
828     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
829         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
830     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
831 
832     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
833         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
834     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
835   }
836 }
837 
838 class DemoOpt final : public IGLPStrategy {
839 private:
840 public:
841   void applyIGLPStrategy(
842       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
843       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) override;
844 
845   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
846 
847   DemoOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
848       : IGLPStrategy(DAG, TII) {
849     IsBottomUp = 0;
850   }
851 };
852 
853 void DemoOpt::applyIGLPStrategy(
854     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
855     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) {
856   // Count the number of MFMA instructions.
857   unsigned MFMACount = 0;
858   for (const MachineInstr &I : *DAG)
859     if (TII->isMFMAorWMMA(I))
860       ++MFMACount;
861 
862   const unsigned PipelineSyncID = 0;
863   SchedGroup *SG = nullptr;
864   for (unsigned I = 0; I < MFMACount * 3; ++I) {
865     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
866         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
867     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
868 
869     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
870         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
871     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
872   }
873 }
874 
875 static std::unique_ptr<IGLPStrategy>
876 createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
877                    const SIInstrInfo *TII) {
878   switch (ID) {
879   case MFMASmallGemmOptID:
880     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
881   case DemoOptID:
882     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
883   }
884 
885   llvm_unreachable("Unknown IGLPStrategyID");
886 }
887 
888 class IGroupLPDAGMutation : public ScheduleDAGMutation {
889 private:
890   const SIInstrInfo *TII;
891 
892   ScheduleDAGMI *DAG;
893 
894   // Organize lists of SchedGroups by their SyncID. SchedGroups /
895   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
896   // between then.
897   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
898 
899   // Used to track instructions that can be mapped to multiple sched groups
900   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
901 
902   // Add DAG edges that enforce SCHED_BARRIER ordering.
903   void addSchedBarrierEdges(SUnit &SU);
904 
905   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
906   // not be reordered accross the SCHED_BARRIER. This is used for the base
907   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
908   // SCHED_BARRIER will always block all instructions that can be classified
909   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
910   // and may only synchronize with some SchedGroups. Returns the inverse of
911   // Mask. SCHED_BARRIER's mask describes which instruction types should be
912   // allowed to be scheduled across it. Invert the mask to get the
913   // SchedGroupMask of instructions that should be barred.
914   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
915 
916   // Create SchedGroups for a SCHED_GROUP_BARRIER.
917   void initSchedGroupBarrierPipelineStage(
918       std::vector<SUnit>::reverse_iterator RIter);
919 
920   void initIGLPOpt(SUnit &SU);
921 
922 public:
923   void apply(ScheduleDAGInstrs *DAGInstrs) override;
924 
925   // The order in which the PipelineSolver should process the candidate
926   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
927   // created SchedGroup first, and will consider that as the ultimate
928   // predecessor group when linking. TOP_DOWN instead links and processes the
929   // first created SchedGroup first.
930   bool IsBottomUp = 1;
931 
932   IGroupLPDAGMutation() = default;
933 };
934 
935 unsigned SchedGroup::NumSchedGroups = 0;
936 
937 bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
938   if (A != B && DAG->canAddEdge(B, A)) {
939     DAG->addEdge(B, SDep(A, SDep::Artificial));
940     return true;
941   }
942   return false;
943 }
944 
945 bool SchedGroup::canAddMI(const MachineInstr &MI) const {
946   bool Result = false;
947   if (MI.isMetaInstruction())
948     Result = false;
949 
950   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
951            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI)))
952     Result = true;
953 
954   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
955            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI))
956     Result = true;
957 
958   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
959            TII->isSALU(MI))
960     Result = true;
961 
962   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
963            TII->isMFMAorWMMA(MI))
964     Result = true;
965 
966   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
967            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
968     Result = true;
969 
970   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
971            MI.mayLoad() &&
972            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
973     Result = true;
974 
975   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
976            MI.mayStore() &&
977            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
978     Result = true;
979 
980   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
981            TII->isDS(MI))
982     Result = true;
983 
984   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
985            MI.mayLoad() && TII->isDS(MI))
986     Result = true;
987 
988   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
989            MI.mayStore() && TII->isDS(MI))
990     Result = true;
991 
992   LLVM_DEBUG(
993       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
994              << (Result ? " could classify " : " unable to classify ") << MI);
995 
996   return Result;
997 }
998 
999 int SchedGroup::link(SUnit &SU, bool MakePred,
1000                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
1001   int MissedEdges = 0;
1002   for (auto *A : Collection) {
1003     SUnit *B = &SU;
1004     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1005       continue;
1006     if (MakePred)
1007       std::swap(A, B);
1008 
1009     if (DAG->IsReachable(B, A))
1010       continue;
1011 
1012     // tryAddEdge returns false if there is a dependency that makes adding
1013     // the A->B edge impossible, otherwise it returns true;
1014     bool Added = tryAddEdge(A, B);
1015     if (Added)
1016       AddedEdges.push_back(std::pair(A, B));
1017     else
1018       ++MissedEdges;
1019   }
1020 
1021   return MissedEdges;
1022 }
1023 
1024 void SchedGroup::link(SUnit &SU, bool MakePred) {
1025   for (auto *A : Collection) {
1026     SUnit *B = &SU;
1027     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1028       continue;
1029     if (MakePred)
1030       std::swap(A, B);
1031 
1032     tryAddEdge(A, B);
1033   }
1034 }
1035 
1036 void SchedGroup::link(SUnit &SU,
1037                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
1038   for (auto *A : Collection) {
1039     SUnit *B = &SU;
1040     if (P(A, B))
1041       std::swap(A, B);
1042 
1043     tryAddEdge(A, B);
1044   }
1045 }
1046 
1047 void SchedGroup::link(SchedGroup &OtherGroup) {
1048   for (auto *B : OtherGroup.Collection)
1049     link(*B);
1050 }
1051 
1052 bool SchedGroup::canAddSU(SUnit &SU) const {
1053   MachineInstr &MI = *SU.getInstr();
1054   if (MI.getOpcode() != TargetOpcode::BUNDLE)
1055     return canAddMI(MI);
1056 
1057   // Special case for bundled MIs.
1058   const MachineBasicBlock *MBB = MI.getParent();
1059   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
1060   while (E != MBB->end() && E->isBundledWithPred())
1061     ++E;
1062 
1063   // Return true if all of the bundled MIs can be added to this group.
1064   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
1065 }
1066 
1067 void SchedGroup::initSchedGroup() {
1068   for (auto &SU : DAG->SUnits) {
1069     if (isFull())
1070       break;
1071 
1072     if (canAddSU(SU))
1073       add(SU);
1074   }
1075 }
1076 
1077 void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
1078                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
1079   SUnit &InitSU = *RIter;
1080   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
1081     auto &SU = *RIter;
1082     if (isFull())
1083       break;
1084 
1085     if (canAddSU(SU))
1086       SyncedInstrs[&SU].push_back(SGID);
1087   }
1088 
1089   add(InitSU);
1090   assert(MaxSize);
1091   (*MaxSize)++;
1092 }
1093 
1094 void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
1095   auto I = DAG->SUnits.rbegin();
1096   auto E = DAG->SUnits.rend();
1097   for (; I != E; ++I) {
1098     auto &SU = *I;
1099     if (isFull())
1100       break;
1101 
1102     if (canAddSU(SU))
1103       SyncedInstrs[&SU].push_back(SGID);
1104   }
1105 }
1106 
1107 void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
1108   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
1109   if (!TSchedModel || DAGInstrs->SUnits.empty())
1110     return;
1111 
1112   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
1113   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
1114   TII = ST.getInstrInfo();
1115   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
1116   SyncedSchedGroups.clear();
1117   SyncedInstrs.clear();
1118   bool foundSB = false;
1119   bool foundIGLP = false;
1120   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
1121     unsigned Opc = R->getInstr()->getOpcode();
1122     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
1123     if (Opc == AMDGPU::SCHED_BARRIER) {
1124       addSchedBarrierEdges(*R);
1125       foundSB = true;
1126     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
1127       initSchedGroupBarrierPipelineStage(R);
1128       foundSB = true;
1129     } else if (Opc == AMDGPU::IGLP_OPT) {
1130       resetEdges(*R, DAG);
1131       if (!foundSB && !foundIGLP)
1132         initIGLPOpt(*R);
1133       foundIGLP = true;
1134     }
1135   }
1136 
1137   if (foundSB || foundIGLP) {
1138     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
1139     // PipelineSolver performs the mutation by adding the edges it
1140     // determined as the best
1141     PS.solve();
1142   }
1143 }
1144 
1145 void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
1146   MachineInstr &MI = *SchedBarrier.getInstr();
1147   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
1148   // Remove all existing edges from the SCHED_BARRIER that were added due to the
1149   // instruction having side effects.
1150   resetEdges(SchedBarrier, DAG);
1151   auto InvertedMask =
1152       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
1153   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
1154   SG.initSchedGroup();
1155   // Preserve original instruction ordering relative to the SCHED_BARRIER.
1156   SG.link(
1157       SchedBarrier,
1158       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
1159           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
1160 }
1161 
1162 SchedGroupMask
1163 IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1164   // Invert mask and erase bits for types of instructions that are implied to be
1165   // allowed past the SCHED_BARRIER.
1166   SchedGroupMask InvertedMask = ~Mask;
1167 
1168   // ALU implies VALU, SALU, MFMA.
1169   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
1170     InvertedMask &=
1171         ~SchedGroupMask::VALU & ~SchedGroupMask::SALU & ~SchedGroupMask::MFMA;
1172   // VALU, SALU, MFMA implies ALU.
1173   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
1174            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
1175            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE)
1176     InvertedMask &= ~SchedGroupMask::ALU;
1177 
1178   // VMEM implies VMEM_READ, VMEM_WRITE.
1179   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
1180     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
1181   // VMEM_READ, VMEM_WRITE implies VMEM.
1182   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
1183            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
1184     InvertedMask &= ~SchedGroupMask::VMEM;
1185 
1186   // DS implies DS_READ, DS_WRITE.
1187   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
1188     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
1189   // DS_READ, DS_WRITE implies DS.
1190   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
1191            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
1192     InvertedMask &= ~SchedGroupMask::DS;
1193 
1194   return InvertedMask;
1195 }
1196 
1197 void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
1198     std::vector<SUnit>::reverse_iterator RIter) {
1199   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
1200   // to the instruction having side effects.
1201   resetEdges(*RIter, DAG);
1202   MachineInstr &SGB = *RIter->getInstr();
1203   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
1204   int32_t SGMask = SGB.getOperand(0).getImm();
1205   int32_t Size = SGB.getOperand(1).getImm();
1206   int32_t SyncID = SGB.getOperand(2).getImm();
1207 
1208   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
1209                                                     Size, SyncID, DAG, TII);
1210 
1211   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
1212 }
1213 
1214 void IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
1215   IGLPStrategyID StrategyID =
1216       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
1217   auto S = createIGLPStrategy(StrategyID, DAG, TII);
1218   if (S->shouldApplyStrategy(DAG)) {
1219     IsBottomUp = S->IsBottomUp;
1220     S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups);
1221   }
1222 }
1223 
1224 } // namespace
1225 
1226 namespace llvm {
1227 
1228 std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation() {
1229   return std::make_unique<IGroupLPDAGMutation>();
1230 }
1231 
1232 } // end namespace llvm
1233