xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp (revision 06c3fb2749bda94cb5201f81ffdb8fa6c3161b2e)
181ad6265SDimitry Andric //===--- AMDGPUIGroupLP.cpp - AMDGPU IGroupLP  ------------===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // \file This file defines a set of schedule DAG mutations that can be used to
1081ad6265SDimitry Andric // override default scheduler behavior to enforce specific scheduling patterns.
1181ad6265SDimitry Andric // They should be used in cases where runtime performance considerations such as
1281ad6265SDimitry Andric // inter-wavefront interactions, mean that compile-time heuristics cannot
1381ad6265SDimitry Andric // predict the optimal instruction ordering, or in kernels where optimum
1481ad6265SDimitry Andric // instruction scheduling is important enough to warrant manual intervention.
1581ad6265SDimitry Andric //
1681ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1781ad6265SDimitry Andric 
1881ad6265SDimitry Andric #include "AMDGPUIGroupLP.h"
1981ad6265SDimitry Andric #include "AMDGPUTargetMachine.h"
2081ad6265SDimitry Andric #include "MCTargetDesc/AMDGPUMCTargetDesc.h"
2181ad6265SDimitry Andric #include "SIInstrInfo.h"
2281ad6265SDimitry Andric #include "SIMachineFunctionInfo.h"
2381ad6265SDimitry Andric #include "llvm/ADT/BitmaskEnum.h"
24bdd1243dSDimitry Andric #include "llvm/ADT/DenseMap.h"
2581ad6265SDimitry Andric #include "llvm/CodeGen/MachineScheduler.h"
2681ad6265SDimitry Andric #include "llvm/CodeGen/TargetOpcodes.h"
2781ad6265SDimitry Andric 
2881ad6265SDimitry Andric using namespace llvm;
2981ad6265SDimitry Andric 
30bdd1243dSDimitry Andric #define DEBUG_TYPE "igrouplp"
3181ad6265SDimitry Andric 
3281ad6265SDimitry Andric namespace {
3381ad6265SDimitry Andric 
34bdd1243dSDimitry Andric static cl::opt<bool> EnableExactSolver(
35bdd1243dSDimitry Andric     "amdgpu-igrouplp-exact-solver", cl::Hidden,
36bdd1243dSDimitry Andric     cl::desc("Whether to use the exponential time solver to fit "
37bdd1243dSDimitry Andric              "the instructions to the pipeline as closely as "
38bdd1243dSDimitry Andric              "possible."),
3981ad6265SDimitry Andric     cl::init(false));
4081ad6265SDimitry Andric 
41bdd1243dSDimitry Andric static cl::opt<unsigned> CutoffForExact(
42bdd1243dSDimitry Andric     "amdgpu-igrouplp-exact-solver-cutoff", cl::init(0), cl::Hidden,
43bdd1243dSDimitry Andric     cl::desc("The maximum number of scheduling group conflicts "
44bdd1243dSDimitry Andric              "which we attempt to solve with the exponential time "
45bdd1243dSDimitry Andric              "exact solver. Problem sizes greater than this will"
46bdd1243dSDimitry Andric              "be solved by the less accurate greedy algorithm. Selecting "
47bdd1243dSDimitry Andric              "solver by size is superseded by manually selecting "
48bdd1243dSDimitry Andric              "the solver (e.g. by amdgpu-igrouplp-exact-solver"));
4981ad6265SDimitry Andric 
50bdd1243dSDimitry Andric static cl::opt<uint64_t> MaxBranchesExplored(
51bdd1243dSDimitry Andric     "amdgpu-igrouplp-exact-solver-max-branches", cl::init(0), cl::Hidden,
52bdd1243dSDimitry Andric     cl::desc("The amount of branches that we are willing to explore with"
53bdd1243dSDimitry Andric              "the exact algorithm before giving up."));
5481ad6265SDimitry Andric 
55bdd1243dSDimitry Andric static cl::opt<bool> UseCostHeur(
56bdd1243dSDimitry Andric     "amdgpu-igrouplp-exact-solver-cost-heur", cl::init(true), cl::Hidden,
57bdd1243dSDimitry Andric     cl::desc("Whether to use the cost heuristic to make choices as we "
58bdd1243dSDimitry Andric              "traverse the search space using the exact solver. Defaulted "
59bdd1243dSDimitry Andric              "to on, and if turned off, we will use the node order -- "
60bdd1243dSDimitry Andric              "attempting to put the later nodes in the later sched groups. "
61bdd1243dSDimitry Andric              "Experimentally, results are mixed, so this should be set on a "
62bdd1243dSDimitry Andric              "case-by-case basis."));
6381ad6265SDimitry Andric 
64bdd1243dSDimitry Andric // Components of the mask that determines which instruction types may be may be
65bdd1243dSDimitry Andric // classified into a SchedGroup.
66bdd1243dSDimitry Andric enum class SchedGroupMask {
6781ad6265SDimitry Andric   NONE = 0u,
6881ad6265SDimitry Andric   ALU = 1u << 0,
6981ad6265SDimitry Andric   VALU = 1u << 1,
7081ad6265SDimitry Andric   SALU = 1u << 2,
7181ad6265SDimitry Andric   MFMA = 1u << 3,
7281ad6265SDimitry Andric   VMEM = 1u << 4,
7381ad6265SDimitry Andric   VMEM_READ = 1u << 5,
7481ad6265SDimitry Andric   VMEM_WRITE = 1u << 6,
7581ad6265SDimitry Andric   DS = 1u << 7,
7681ad6265SDimitry Andric   DS_READ = 1u << 8,
7781ad6265SDimitry Andric   DS_WRITE = 1u << 9,
78bdd1243dSDimitry Andric   ALL = ALU | VALU | SALU | MFMA | VMEM | VMEM_READ | VMEM_WRITE | DS |
79bdd1243dSDimitry Andric         DS_READ | DS_WRITE,
80bdd1243dSDimitry Andric   LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL)
8181ad6265SDimitry Andric };
8281ad6265SDimitry Andric 
83*06c3fb27SDimitry Andric class SchedGroup;
84*06c3fb27SDimitry Andric 
85*06c3fb27SDimitry Andric // InstructionRule class is used to enact a filter which determines whether or
86*06c3fb27SDimitry Andric // not an SU maps to a given SchedGroup. It contains complementary data
87*06c3fb27SDimitry Andric // structures (e.g Cache) to help those filters.
88*06c3fb27SDimitry Andric class InstructionRule {
89*06c3fb27SDimitry Andric protected:
90*06c3fb27SDimitry Andric   const SIInstrInfo *TII;
91*06c3fb27SDimitry Andric   unsigned SGID;
92*06c3fb27SDimitry Andric   // A cache made available to the Filter to store SUnits for subsequent
93*06c3fb27SDimitry Andric   // invocations of the Filter
94*06c3fb27SDimitry Andric   std::optional<SmallVector<SUnit *, 4>> Cache;
95*06c3fb27SDimitry Andric 
96*06c3fb27SDimitry Andric public:
97*06c3fb27SDimitry Andric   virtual bool
98*06c3fb27SDimitry Andric   apply(const SUnit *, const ArrayRef<SUnit *>,
99*06c3fb27SDimitry Andric         SmallVectorImpl<SchedGroup> &) {
100*06c3fb27SDimitry Andric     return true;
101*06c3fb27SDimitry Andric   };
102*06c3fb27SDimitry Andric 
103*06c3fb27SDimitry Andric   InstructionRule(const SIInstrInfo *TII, unsigned SGID,
104*06c3fb27SDimitry Andric                   bool NeedsCache = false)
105*06c3fb27SDimitry Andric       : TII(TII), SGID(SGID) {
106*06c3fb27SDimitry Andric     if (NeedsCache) {
107*06c3fb27SDimitry Andric       Cache = SmallVector<SUnit *, 4>();
108*06c3fb27SDimitry Andric     }
109*06c3fb27SDimitry Andric   }
110*06c3fb27SDimitry Andric 
111*06c3fb27SDimitry Andric   virtual ~InstructionRule() = default;
112*06c3fb27SDimitry Andric };
113*06c3fb27SDimitry Andric 
114bdd1243dSDimitry Andric typedef DenseMap<SUnit *, SmallVector<int, 4>> SUnitsToCandidateSGsMap;
11581ad6265SDimitry Andric 
116bdd1243dSDimitry Andric // Classify instructions into groups to enable fine tuned control over the
117bdd1243dSDimitry Andric // scheduler. These groups may be more specific than current SchedModel
118bdd1243dSDimitry Andric // instruction classes.
119bdd1243dSDimitry Andric class SchedGroup {
120bdd1243dSDimitry Andric private:
121bdd1243dSDimitry Andric   // Mask that defines which instruction types can be classified into this
122bdd1243dSDimitry Andric   // SchedGroup. The instruction types correspond to the mask from SCHED_BARRIER
123bdd1243dSDimitry Andric   // and SCHED_GROUP_BARRIER.
124bdd1243dSDimitry Andric   SchedGroupMask SGMask;
12581ad6265SDimitry Andric 
126bdd1243dSDimitry Andric   // Maximum number of SUnits that can be added to this group.
127bdd1243dSDimitry Andric   std::optional<unsigned> MaxSize;
12881ad6265SDimitry Andric 
129bdd1243dSDimitry Andric   // SchedGroups will only synchronize with other SchedGroups that have the same
130bdd1243dSDimitry Andric   // SyncID.
131bdd1243dSDimitry Andric   int SyncID = 0;
13281ad6265SDimitry Andric 
133bdd1243dSDimitry Andric   // SGID is used to map instructions to candidate SchedGroups
134bdd1243dSDimitry Andric   unsigned SGID;
135bdd1243dSDimitry Andric 
136*06c3fb27SDimitry Andric   // The different rules each instruction in this SchedGroup must conform to
137*06c3fb27SDimitry Andric   SmallVector<std::shared_ptr<InstructionRule>, 4> Rules;
138*06c3fb27SDimitry Andric 
139bdd1243dSDimitry Andric   // Count of the number of created SchedGroups, used to initialize SGID.
140bdd1243dSDimitry Andric   static unsigned NumSchedGroups;
141bdd1243dSDimitry Andric 
142bdd1243dSDimitry Andric   const SIInstrInfo *TII;
143bdd1243dSDimitry Andric 
144bdd1243dSDimitry Andric   // Try to add and edge from SU A to SU B.
145bdd1243dSDimitry Andric   bool tryAddEdge(SUnit *A, SUnit *B);
146bdd1243dSDimitry Andric 
147bdd1243dSDimitry Andric   // Use SGMask to determine whether we can classify MI as a member of this
148bdd1243dSDimitry Andric   // SchedGroup object.
149bdd1243dSDimitry Andric   bool canAddMI(const MachineInstr &MI) const;
15081ad6265SDimitry Andric 
15181ad6265SDimitry Andric public:
152bdd1243dSDimitry Andric   // Collection of SUnits that are classified as members of this group.
153bdd1243dSDimitry Andric   SmallVector<SUnit *, 32> Collection;
15481ad6265SDimitry Andric 
155*06c3fb27SDimitry Andric   ScheduleDAGInstrs *DAG;
156*06c3fb27SDimitry Andric 
157bdd1243dSDimitry Andric   // Returns true if SU can be added to this SchedGroup.
158bdd1243dSDimitry Andric   bool canAddSU(SUnit &SU) const;
15981ad6265SDimitry Andric 
160bdd1243dSDimitry Andric   // Add DAG dependencies from all SUnits in this SchedGroup and this SU. If
161bdd1243dSDimitry Andric   // MakePred is true, SU will be a predecessor of the SUnits in this
162bdd1243dSDimitry Andric   // SchedGroup, otherwise SU will be a successor.
163bdd1243dSDimitry Andric   void link(SUnit &SU, bool MakePred = false);
16481ad6265SDimitry Andric 
165bdd1243dSDimitry Andric   // Add DAG dependencies and track which edges are added, and the count of
166bdd1243dSDimitry Andric   // missed edges
167bdd1243dSDimitry Andric   int link(SUnit &SU, bool MakePred,
168bdd1243dSDimitry Andric            std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
16981ad6265SDimitry Andric 
170bdd1243dSDimitry Andric   // Add DAG dependencies from all SUnits in this SchedGroup and this SU.
171bdd1243dSDimitry Andric   // Use the predicate to determine whether SU should be a predecessor (P =
172bdd1243dSDimitry Andric   // true) or a successor (P = false) of this SchedGroup.
173bdd1243dSDimitry Andric   void link(SUnit &SU, function_ref<bool(const SUnit *A, const SUnit *B)> P);
17481ad6265SDimitry Andric 
175bdd1243dSDimitry Andric   // Add DAG dependencies such that SUnits in this group shall be ordered
176bdd1243dSDimitry Andric   // before SUnits in OtherGroup.
177bdd1243dSDimitry Andric   void link(SchedGroup &OtherGroup);
178bdd1243dSDimitry Andric 
179bdd1243dSDimitry Andric   // Returns true if no more instructions may be added to this group.
180bdd1243dSDimitry Andric   bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; }
181bdd1243dSDimitry Andric 
182*06c3fb27SDimitry Andric   // Append a constraint that SUs must meet in order to fit into this
183*06c3fb27SDimitry Andric   // SchedGroup. Since many rules involve the relationship between a SchedGroup
184*06c3fb27SDimitry Andric   // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve
185*06c3fb27SDimitry Andric   // time (rather than SchedGroup init time.)
186*06c3fb27SDimitry Andric   void addRule(std::shared_ptr<InstructionRule> NewRule) {
187*06c3fb27SDimitry Andric     Rules.push_back(NewRule);
188*06c3fb27SDimitry Andric   }
189*06c3fb27SDimitry Andric 
190*06c3fb27SDimitry Andric   // Returns true if the SU matches all rules
191*06c3fb27SDimitry Andric   bool allowedByRules(const SUnit *SU,
192*06c3fb27SDimitry Andric                       SmallVectorImpl<SchedGroup> &SyncPipe) const {
193*06c3fb27SDimitry Andric     if (Rules.empty())
194*06c3fb27SDimitry Andric       return true;
195*06c3fb27SDimitry Andric     for (size_t I = 0; I < Rules.size(); I++) {
196*06c3fb27SDimitry Andric       auto TheRule = Rules[I].get();
197*06c3fb27SDimitry Andric       if (!TheRule->apply(SU, Collection, SyncPipe)) {
198*06c3fb27SDimitry Andric         return false;
199*06c3fb27SDimitry Andric       }
200*06c3fb27SDimitry Andric     }
201*06c3fb27SDimitry Andric     return true;
202*06c3fb27SDimitry Andric   }
203*06c3fb27SDimitry Andric 
204bdd1243dSDimitry Andric   // Add SU to the SchedGroup.
205bdd1243dSDimitry Andric   void add(SUnit &SU) {
206bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "For SchedGroup with mask "
207bdd1243dSDimitry Andric                       << format_hex((int)SGMask, 10, true) << " adding "
208bdd1243dSDimitry Andric                       << *SU.getInstr());
209bdd1243dSDimitry Andric     Collection.push_back(&SU);
21081ad6265SDimitry Andric   }
21181ad6265SDimitry Andric 
212bdd1243dSDimitry Andric   // Remove last element in the SchedGroup
213bdd1243dSDimitry Andric   void pop() { Collection.pop_back(); }
214bdd1243dSDimitry Andric 
215bdd1243dSDimitry Andric   // Identify and add all relevant SUs from the DAG to this SchedGroup.
216bdd1243dSDimitry Andric   void initSchedGroup();
217bdd1243dSDimitry Andric 
218bdd1243dSDimitry Andric   // Add instructions to the SchedGroup bottom up starting from RIter.
219bdd1243dSDimitry Andric   // PipelineInstrs is a set of instructions that should not be added to the
220bdd1243dSDimitry Andric   // SchedGroup even when the other conditions for adding it are satisfied.
221bdd1243dSDimitry Andric   // RIter will be added to the SchedGroup as well, and dependencies will be
222bdd1243dSDimitry Andric   // added so that RIter will always be scheduled at the end of the group.
223bdd1243dSDimitry Andric   void initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
224bdd1243dSDimitry Andric                       SUnitsToCandidateSGsMap &SyncedInstrs);
225bdd1243dSDimitry Andric 
226bdd1243dSDimitry Andric   void initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs);
227bdd1243dSDimitry Andric 
228bdd1243dSDimitry Andric   int getSyncID() { return SyncID; }
229bdd1243dSDimitry Andric 
230bdd1243dSDimitry Andric   int getSGID() { return SGID; }
231bdd1243dSDimitry Andric 
232bdd1243dSDimitry Andric   SchedGroupMask getMask() { return SGMask; }
233bdd1243dSDimitry Andric 
234bdd1243dSDimitry Andric   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize,
235bdd1243dSDimitry Andric              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
236*06c3fb27SDimitry Andric       : SGMask(SGMask), MaxSize(MaxSize), TII(TII), DAG(DAG) {
237bdd1243dSDimitry Andric     SGID = NumSchedGroups++;
238bdd1243dSDimitry Andric   }
239bdd1243dSDimitry Andric 
240bdd1243dSDimitry Andric   SchedGroup(SchedGroupMask SGMask, std::optional<unsigned> MaxSize, int SyncID,
241bdd1243dSDimitry Andric              ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
242*06c3fb27SDimitry Andric       : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), TII(TII), DAG(DAG) {
243bdd1243dSDimitry Andric     SGID = NumSchedGroups++;
244bdd1243dSDimitry Andric   }
245bdd1243dSDimitry Andric };
246bdd1243dSDimitry Andric 
247bdd1243dSDimitry Andric // Remove all existing edges from a SCHED_BARRIER or SCHED_GROUP_BARRIER.
248bdd1243dSDimitry Andric static void resetEdges(SUnit &SU, ScheduleDAGInstrs *DAG) {
249bdd1243dSDimitry Andric   assert(SU.getInstr()->getOpcode() == AMDGPU::SCHED_BARRIER ||
250bdd1243dSDimitry Andric          SU.getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER ||
251bdd1243dSDimitry Andric          SU.getInstr()->getOpcode() == AMDGPU::IGLP_OPT);
252bdd1243dSDimitry Andric 
253bdd1243dSDimitry Andric   while (!SU.Preds.empty())
254bdd1243dSDimitry Andric     for (auto &P : SU.Preds)
255bdd1243dSDimitry Andric       SU.removePred(P);
256bdd1243dSDimitry Andric 
257bdd1243dSDimitry Andric   while (!SU.Succs.empty())
258bdd1243dSDimitry Andric     for (auto &S : SU.Succs)
259bdd1243dSDimitry Andric       for (auto &SP : S.getSUnit()->Preds)
260bdd1243dSDimitry Andric         if (SP.getSUnit() == &SU)
261bdd1243dSDimitry Andric           S.getSUnit()->removePred(SP);
262bdd1243dSDimitry Andric }
263bdd1243dSDimitry Andric 
264bdd1243dSDimitry Andric typedef std::pair<SUnit *, SmallVector<int, 4>> SUToCandSGsPair;
265bdd1243dSDimitry Andric typedef SmallVector<SUToCandSGsPair, 4> SUsToCandSGsVec;
266bdd1243dSDimitry Andric 
267bdd1243dSDimitry Andric // The PipelineSolver is used to assign SUnits to SchedGroups in a pipeline
268bdd1243dSDimitry Andric // in non-trivial cases. For example, if the requested pipeline is
269bdd1243dSDimitry Andric // {VMEM_READ, VALU, MFMA, VMEM_READ} and we encounter a VMEM_READ instruction
270bdd1243dSDimitry Andric // in the DAG, then we will have an instruction that can not be trivially
271bdd1243dSDimitry Andric // assigned to a SchedGroup. The PipelineSolver class implements two algorithms
272bdd1243dSDimitry Andric // to find a good solution to the pipeline -- a greedy algorithm and an exact
273bdd1243dSDimitry Andric // algorithm. The exact algorithm has an exponential time complexity and should
274bdd1243dSDimitry Andric // only be used for small sized problems or medium sized problems where an exact
275bdd1243dSDimitry Andric // solution is highly desired.
276bdd1243dSDimitry Andric class PipelineSolver {
277bdd1243dSDimitry Andric   ScheduleDAGMI *DAG;
278bdd1243dSDimitry Andric 
279bdd1243dSDimitry Andric   // Instructions that can be assigned to multiple SchedGroups
280bdd1243dSDimitry Andric   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
281bdd1243dSDimitry Andric   SmallVector<SUsToCandSGsVec, 4> PipelineInstrs;
282bdd1243dSDimitry Andric   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
283bdd1243dSDimitry Andric   // The current working pipeline
284bdd1243dSDimitry Andric   SmallVector<SmallVector<SchedGroup, 4>, 4> CurrPipeline;
285bdd1243dSDimitry Andric   // The pipeline that has the best solution found so far
286bdd1243dSDimitry Andric   SmallVector<SmallVector<SchedGroup, 4>, 4> BestPipeline;
287bdd1243dSDimitry Andric 
288bdd1243dSDimitry Andric   // Whether or not we actually have any SyncedInstrs to try to solve.
289bdd1243dSDimitry Andric   bool NeedsSolver = false;
290bdd1243dSDimitry Andric 
291bdd1243dSDimitry Andric   // Compute an estimate of the size of search tree -- the true size is
292bdd1243dSDimitry Andric   // the product of each conflictedInst.Matches.size() across all SyncPipelines
293bdd1243dSDimitry Andric   unsigned computeProblemSize();
294bdd1243dSDimitry Andric 
295bdd1243dSDimitry Andric   // The cost penalty of not assigning a SU to a SchedGroup
296bdd1243dSDimitry Andric   int MissPenalty = 0;
297bdd1243dSDimitry Andric 
298bdd1243dSDimitry Andric   // Costs in terms of the number of edges we are unable to add
299bdd1243dSDimitry Andric   int BestCost = -1;
300bdd1243dSDimitry Andric   int CurrCost = 0;
301bdd1243dSDimitry Andric 
302bdd1243dSDimitry Andric   // Index pointing to the conflicting instruction that is currently being
303bdd1243dSDimitry Andric   // fitted
304bdd1243dSDimitry Andric   int CurrConflInstNo = 0;
305bdd1243dSDimitry Andric   // Index to the pipeline that is currently being fitted
306bdd1243dSDimitry Andric   int CurrSyncGroupIdx = 0;
307bdd1243dSDimitry Andric   // The first non trivial pipeline
308bdd1243dSDimitry Andric   int BeginSyncGroupIdx = 0;
309bdd1243dSDimitry Andric 
310bdd1243dSDimitry Andric   // How many branches we have explored
311bdd1243dSDimitry Andric   uint64_t BranchesExplored = 0;
312bdd1243dSDimitry Andric 
313*06c3fb27SDimitry Andric   // The direction in which we process the candidate SchedGroups per SU
314*06c3fb27SDimitry Andric   bool IsBottomUp = 1;
315*06c3fb27SDimitry Andric 
316bdd1243dSDimitry Andric   // Update indices to fit next conflicting instruction
317bdd1243dSDimitry Andric   void advancePosition();
318bdd1243dSDimitry Andric   // Recede indices to attempt to find better fit for previous conflicting
319bdd1243dSDimitry Andric   // instruction
320bdd1243dSDimitry Andric   void retreatPosition();
321bdd1243dSDimitry Andric 
322bdd1243dSDimitry Andric   // The exponential time algorithm which finds the provably best fit
323bdd1243dSDimitry Andric   bool solveExact();
324bdd1243dSDimitry Andric   // The polynomial time algorithm which attempts to find a good fit
325bdd1243dSDimitry Andric   bool solveGreedy();
326*06c3fb27SDimitry Andric   // Find the best SchedGroup for the current SU using the heuristic given all
327*06c3fb27SDimitry Andric   // current information. One step in the greedy algorithm. Templated against
328*06c3fb27SDimitry Andric   // the SchedGroup iterator (either reverse or forward).
329*06c3fb27SDimitry Andric   template <typename T>
330*06c3fb27SDimitry Andric   void greedyFind(std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I,
331*06c3fb27SDimitry Andric                   T E);
332bdd1243dSDimitry Andric   // Whether or not the current solution is optimal
333bdd1243dSDimitry Andric   bool checkOptimal();
334bdd1243dSDimitry Andric   // Populate the ready list, prioiritizing fewest missed edges first
335*06c3fb27SDimitry Andric   // Templated against the SchedGroup iterator (either reverse or forward).
336*06c3fb27SDimitry Andric   template <typename T>
337*06c3fb27SDimitry Andric   void populateReadyList(SmallVectorImpl<std::pair<int, int>> &ReadyList, T I,
338*06c3fb27SDimitry Andric                          T E);
339bdd1243dSDimitry Andric   // Add edges corresponding to the SchedGroups as assigned by solver
340bdd1243dSDimitry Andric   void makePipeline();
341*06c3fb27SDimitry Andric   // Link the SchedGroups in the best found pipeline.
342*06c3fb27SDimitry Andric   // Tmplated against the SchedGroup iterator (either reverse or forward).
343*06c3fb27SDimitry Andric   template <typename T> void linkSchedGroups(T I, T E);
344bdd1243dSDimitry Andric   // Add the edges from the SU to the other SchedGroups in pipeline, and
345bdd1243dSDimitry Andric   // return the number of edges missed.
346bdd1243dSDimitry Andric   int addEdges(SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
347bdd1243dSDimitry Andric                std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
348*06c3fb27SDimitry Andric   // Link the pipeline as if \p SU was in the SchedGroup with ID \p SGID. It
349*06c3fb27SDimitry Andric   // returns the cost (in terms of missed pipeline edges), and tracks the edges
350*06c3fb27SDimitry Andric   // added in \p AddedEdges
351*06c3fb27SDimitry Andric   template <typename T>
352*06c3fb27SDimitry Andric   int linkSUnit(SUnit *SU, int SGID,
353*06c3fb27SDimitry Andric                 std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E);
354*06c3fb27SDimitry Andric   // Remove the edges passed via \p AddedEdges
355bdd1243dSDimitry Andric   void removeEdges(const std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges);
356bdd1243dSDimitry Andric   // Convert the passed in maps to arrays for bidirectional iterators
357bdd1243dSDimitry Andric   void convertSyncMapsToArrays();
358bdd1243dSDimitry Andric 
359bdd1243dSDimitry Andric   void reset();
360bdd1243dSDimitry Andric 
361bdd1243dSDimitry Andric public:
362bdd1243dSDimitry Andric   // Invoke the solver to map instructions to instruction groups. Heuristic &&
363bdd1243dSDimitry Andric   // command-line-option determines to use exact or greedy algorithm.
364bdd1243dSDimitry Andric   void solve();
365bdd1243dSDimitry Andric 
366bdd1243dSDimitry Andric   PipelineSolver(DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups,
367bdd1243dSDimitry Andric                  DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
368*06c3fb27SDimitry Andric                  ScheduleDAGMI *DAG, bool IsBottomUp = 1)
369bdd1243dSDimitry Andric       : DAG(DAG), SyncedInstrs(SyncedInstrs),
370*06c3fb27SDimitry Andric         SyncedSchedGroups(SyncedSchedGroups), IsBottomUp(IsBottomUp) {
371bdd1243dSDimitry Andric 
372bdd1243dSDimitry Andric     for (auto &PipelineInstrs : SyncedInstrs) {
373bdd1243dSDimitry Andric       if (PipelineInstrs.second.size() > 0) {
374bdd1243dSDimitry Andric         NeedsSolver = true;
375bdd1243dSDimitry Andric         break;
376bdd1243dSDimitry Andric       }
377bdd1243dSDimitry Andric     }
378bdd1243dSDimitry Andric 
379bdd1243dSDimitry Andric     if (!NeedsSolver)
380bdd1243dSDimitry Andric       return;
381bdd1243dSDimitry Andric 
382bdd1243dSDimitry Andric     convertSyncMapsToArrays();
383bdd1243dSDimitry Andric 
384bdd1243dSDimitry Andric     CurrPipeline = BestPipeline;
385bdd1243dSDimitry Andric 
386bdd1243dSDimitry Andric     while (static_cast<size_t>(BeginSyncGroupIdx) < PipelineInstrs.size() &&
387bdd1243dSDimitry Andric            PipelineInstrs[BeginSyncGroupIdx].size() == 0)
388bdd1243dSDimitry Andric       ++BeginSyncGroupIdx;
389bdd1243dSDimitry Andric 
390bdd1243dSDimitry Andric     if (static_cast<size_t>(BeginSyncGroupIdx) >= PipelineInstrs.size())
391bdd1243dSDimitry Andric       return;
392bdd1243dSDimitry Andric   }
393bdd1243dSDimitry Andric };
394bdd1243dSDimitry Andric 
395bdd1243dSDimitry Andric void PipelineSolver::reset() {
396bdd1243dSDimitry Andric 
397bdd1243dSDimitry Andric   for (auto &SyncPipeline : CurrPipeline) {
398bdd1243dSDimitry Andric     for (auto &SG : SyncPipeline) {
399bdd1243dSDimitry Andric       SmallVector<SUnit *, 32> TempCollection = SG.Collection;
400bdd1243dSDimitry Andric       SG.Collection.clear();
401bdd1243dSDimitry Andric       auto SchedBarr = llvm::find_if(TempCollection, [](SUnit *SU) {
402bdd1243dSDimitry Andric         return SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER;
403bdd1243dSDimitry Andric       });
404bdd1243dSDimitry Andric       if (SchedBarr != TempCollection.end())
405bdd1243dSDimitry Andric         SG.Collection.push_back(*SchedBarr);
406bdd1243dSDimitry Andric     }
407bdd1243dSDimitry Andric   }
408bdd1243dSDimitry Andric 
409bdd1243dSDimitry Andric   CurrSyncGroupIdx = BeginSyncGroupIdx;
410bdd1243dSDimitry Andric   CurrConflInstNo = 0;
411bdd1243dSDimitry Andric   CurrCost = 0;
412bdd1243dSDimitry Andric }
413bdd1243dSDimitry Andric 
414bdd1243dSDimitry Andric void PipelineSolver::convertSyncMapsToArrays() {
415bdd1243dSDimitry Andric   for (auto &SyncPipe : SyncedSchedGroups) {
416bdd1243dSDimitry Andric     BestPipeline.insert(BestPipeline.begin(), SyncPipe.second);
417bdd1243dSDimitry Andric   }
418bdd1243dSDimitry Andric 
419bdd1243dSDimitry Andric   int PipelineIDx = SyncedInstrs.size() - 1;
420bdd1243dSDimitry Andric   PipelineInstrs.resize(SyncedInstrs.size());
421bdd1243dSDimitry Andric   for (auto &SyncInstrMap : SyncedInstrs) {
422bdd1243dSDimitry Andric     for (auto &SUsToCandSGs : SyncInstrMap.second) {
423bdd1243dSDimitry Andric       if (PipelineInstrs[PipelineIDx].size() == 0) {
424bdd1243dSDimitry Andric         PipelineInstrs[PipelineIDx].push_back(
425bdd1243dSDimitry Andric             std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
426bdd1243dSDimitry Andric         continue;
427bdd1243dSDimitry Andric       }
428bdd1243dSDimitry Andric       auto SortPosition = PipelineInstrs[PipelineIDx].begin();
429bdd1243dSDimitry Andric       // Insert them in sorted order -- this allows for good parsing order in
430bdd1243dSDimitry Andric       // the greedy algorithm
431bdd1243dSDimitry Andric       while (SortPosition != PipelineInstrs[PipelineIDx].end() &&
432bdd1243dSDimitry Andric              SUsToCandSGs.first->NodeNum > SortPosition->first->NodeNum)
433bdd1243dSDimitry Andric         ++SortPosition;
434bdd1243dSDimitry Andric       PipelineInstrs[PipelineIDx].insert(
435bdd1243dSDimitry Andric           SortPosition, std::pair(SUsToCandSGs.first, SUsToCandSGs.second));
436bdd1243dSDimitry Andric     }
437bdd1243dSDimitry Andric     --PipelineIDx;
438bdd1243dSDimitry Andric   }
439bdd1243dSDimitry Andric }
440bdd1243dSDimitry Andric 
441*06c3fb27SDimitry Andric template <typename T> void PipelineSolver::linkSchedGroups(T I, T E) {
442*06c3fb27SDimitry Andric   for (; I != E; ++I) {
443*06c3fb27SDimitry Andric     auto &GroupA = *I;
444*06c3fb27SDimitry Andric     for (auto J = std::next(I); J != E; ++J) {
445*06c3fb27SDimitry Andric       auto &GroupB = *J;
446*06c3fb27SDimitry Andric       GroupA.link(GroupB);
447*06c3fb27SDimitry Andric     }
448*06c3fb27SDimitry Andric   }
449*06c3fb27SDimitry Andric }
450*06c3fb27SDimitry Andric 
451bdd1243dSDimitry Andric void PipelineSolver::makePipeline() {
452bdd1243dSDimitry Andric   // Preserve the order of barrier for subsequent SchedGroupBarrier mutations
453bdd1243dSDimitry Andric   for (auto &SyncPipeline : BestPipeline) {
454*06c3fb27SDimitry Andric     LLVM_DEBUG(dbgs() << "Printing SchedGroups\n");
455bdd1243dSDimitry Andric     for (auto &SG : SyncPipeline) {
456*06c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID()
457*06c3fb27SDimitry Andric                         << " has: \n");
458bdd1243dSDimitry Andric       SUnit *SGBarr = nullptr;
459bdd1243dSDimitry Andric       for (auto &SU : SG.Collection) {
460bdd1243dSDimitry Andric         if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
461bdd1243dSDimitry Andric           SGBarr = SU;
462*06c3fb27SDimitry Andric         LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n");
463bdd1243dSDimitry Andric       }
464bdd1243dSDimitry Andric       // Command line requested IGroupLP doesn't have SGBarr
465bdd1243dSDimitry Andric       if (!SGBarr)
466bdd1243dSDimitry Andric         continue;
467bdd1243dSDimitry Andric       resetEdges(*SGBarr, DAG);
468bdd1243dSDimitry Andric       SG.link(*SGBarr, false);
469bdd1243dSDimitry Andric     }
470bdd1243dSDimitry Andric   }
471bdd1243dSDimitry Andric 
472bdd1243dSDimitry Andric   for (auto &SyncPipeline : BestPipeline) {
473*06c3fb27SDimitry Andric     IsBottomUp ? linkSchedGroups(SyncPipeline.rbegin(), SyncPipeline.rend())
474*06c3fb27SDimitry Andric                : linkSchedGroups(SyncPipeline.begin(), SyncPipeline.end());
47581ad6265SDimitry Andric   }
47681ad6265SDimitry Andric }
477*06c3fb27SDimitry Andric 
478*06c3fb27SDimitry Andric template <typename T>
479*06c3fb27SDimitry Andric int PipelineSolver::linkSUnit(
480*06c3fb27SDimitry Andric     SUnit *SU, int SGID, std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges,
481*06c3fb27SDimitry Andric     T I, T E) {
482*06c3fb27SDimitry Andric   bool MakePred = false;
483*06c3fb27SDimitry Andric   int AddedCost = 0;
484*06c3fb27SDimitry Andric   for (; I < E; ++I) {
485*06c3fb27SDimitry Andric     if (I->getSGID() == SGID) {
486*06c3fb27SDimitry Andric       MakePred = true;
487*06c3fb27SDimitry Andric       continue;
48881ad6265SDimitry Andric     }
489*06c3fb27SDimitry Andric     auto Group = *I;
490*06c3fb27SDimitry Andric     AddedCost += Group.link(*SU, MakePred, AddedEdges);
491*06c3fb27SDimitry Andric     assert(AddedCost >= 0);
492*06c3fb27SDimitry Andric   }
493*06c3fb27SDimitry Andric   return AddedCost;
494bdd1243dSDimitry Andric }
49581ad6265SDimitry Andric 
496bdd1243dSDimitry Andric int PipelineSolver::addEdges(
497bdd1243dSDimitry Andric     SmallVectorImpl<SchedGroup> &SyncPipeline, SUnit *SU, int SGID,
498bdd1243dSDimitry Andric     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
499bdd1243dSDimitry Andric 
500*06c3fb27SDimitry Andric   // For IsBottomUp, the first SchedGroup in SyncPipeline contains the
501*06c3fb27SDimitry Andric   // instructions that are the ultimate successors in the resultant mutation.
502*06c3fb27SDimitry Andric   // Therefore, in such a configuration, the SchedGroups occurring before the
503*06c3fb27SDimitry Andric   // candidate SGID are successors of the candidate SchedGroup, thus the current
504*06c3fb27SDimitry Andric   // SU should be linked as a predecessor to SUs in those SchedGroups. The
505*06c3fb27SDimitry Andric   // opposite is true if !IsBottomUp. IsBottomUp occurs in the case of multiple
506*06c3fb27SDimitry Andric   // SCHED_GROUP_BARRIERS, or if a user specifies IGLP_OPT SchedGroups using
507*06c3fb27SDimitry Andric   // IsBottomUp (in reverse).
508*06c3fb27SDimitry Andric   return IsBottomUp ? linkSUnit(SU, SGID, AddedEdges, SyncPipeline.rbegin(),
509*06c3fb27SDimitry Andric                                 SyncPipeline.rend())
510*06c3fb27SDimitry Andric                     : linkSUnit(SU, SGID, AddedEdges, SyncPipeline.begin(),
511*06c3fb27SDimitry Andric                                 SyncPipeline.end());
512bdd1243dSDimitry Andric }
513bdd1243dSDimitry Andric 
514bdd1243dSDimitry Andric void PipelineSolver::removeEdges(
515bdd1243dSDimitry Andric     const std::vector<std::pair<SUnit *, SUnit *>> &EdgesToRemove) {
516bdd1243dSDimitry Andric   // Only remove the edges that we have added when testing
517bdd1243dSDimitry Andric   // the fit.
518bdd1243dSDimitry Andric   for (auto &PredSuccPair : EdgesToRemove) {
519bdd1243dSDimitry Andric     SUnit *Pred = PredSuccPair.first;
520bdd1243dSDimitry Andric     SUnit *Succ = PredSuccPair.second;
521bdd1243dSDimitry Andric 
522bdd1243dSDimitry Andric     auto Match = llvm::find_if(
523bdd1243dSDimitry Andric         Succ->Preds, [&Pred](SDep &P) { return P.getSUnit() == Pred; });
524bdd1243dSDimitry Andric     if (Match != Succ->Preds.end()) {
525bdd1243dSDimitry Andric       assert(Match->isArtificial());
526bdd1243dSDimitry Andric       Succ->removePred(*Match);
527bdd1243dSDimitry Andric     }
528bdd1243dSDimitry Andric   }
529bdd1243dSDimitry Andric }
530bdd1243dSDimitry Andric 
531bdd1243dSDimitry Andric void PipelineSolver::advancePosition() {
532bdd1243dSDimitry Andric   ++CurrConflInstNo;
533bdd1243dSDimitry Andric 
534bdd1243dSDimitry Andric   if (static_cast<size_t>(CurrConflInstNo) >=
535bdd1243dSDimitry Andric       PipelineInstrs[CurrSyncGroupIdx].size()) {
536bdd1243dSDimitry Andric     CurrConflInstNo = 0;
537bdd1243dSDimitry Andric     ++CurrSyncGroupIdx;
538bdd1243dSDimitry Andric     // Advance to next non-trivial pipeline
539bdd1243dSDimitry Andric     while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size() &&
540bdd1243dSDimitry Andric            PipelineInstrs[CurrSyncGroupIdx].size() == 0)
541bdd1243dSDimitry Andric       ++CurrSyncGroupIdx;
542bdd1243dSDimitry Andric   }
543bdd1243dSDimitry Andric }
544bdd1243dSDimitry Andric 
545bdd1243dSDimitry Andric void PipelineSolver::retreatPosition() {
546bdd1243dSDimitry Andric   assert(CurrConflInstNo >= 0);
547bdd1243dSDimitry Andric   assert(CurrSyncGroupIdx >= 0);
548bdd1243dSDimitry Andric 
549bdd1243dSDimitry Andric   if (CurrConflInstNo > 0) {
550bdd1243dSDimitry Andric     --CurrConflInstNo;
551bdd1243dSDimitry Andric     return;
552bdd1243dSDimitry Andric   }
553bdd1243dSDimitry Andric 
554bdd1243dSDimitry Andric   if (CurrConflInstNo == 0) {
555bdd1243dSDimitry Andric     // If we return to the starting position, we have explored
556bdd1243dSDimitry Andric     // the entire tree
557bdd1243dSDimitry Andric     if (CurrSyncGroupIdx == BeginSyncGroupIdx)
558bdd1243dSDimitry Andric       return;
559bdd1243dSDimitry Andric 
560bdd1243dSDimitry Andric     --CurrSyncGroupIdx;
561bdd1243dSDimitry Andric     // Go to previous non-trivial pipeline
562bdd1243dSDimitry Andric     while (PipelineInstrs[CurrSyncGroupIdx].size() == 0)
563bdd1243dSDimitry Andric       --CurrSyncGroupIdx;
564bdd1243dSDimitry Andric 
565bdd1243dSDimitry Andric     CurrConflInstNo = PipelineInstrs[CurrSyncGroupIdx].size() - 1;
566bdd1243dSDimitry Andric   }
567bdd1243dSDimitry Andric }
568bdd1243dSDimitry Andric 
569bdd1243dSDimitry Andric bool PipelineSolver::checkOptimal() {
570bdd1243dSDimitry Andric   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size()) {
571bdd1243dSDimitry Andric     if (BestCost == -1 || CurrCost < BestCost) {
572bdd1243dSDimitry Andric       BestPipeline = CurrPipeline;
573bdd1243dSDimitry Andric       BestCost = CurrCost;
574bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << "Found Fit with cost " << BestCost << "\n");
575bdd1243dSDimitry Andric     }
576bdd1243dSDimitry Andric     assert(BestCost >= 0);
577bdd1243dSDimitry Andric   }
578bdd1243dSDimitry Andric 
579bdd1243dSDimitry Andric   bool DoneExploring = false;
580bdd1243dSDimitry Andric   if (MaxBranchesExplored > 0 && BranchesExplored >= MaxBranchesExplored)
581bdd1243dSDimitry Andric     DoneExploring = true;
582bdd1243dSDimitry Andric 
583bdd1243dSDimitry Andric   return (DoneExploring || BestCost == 0);
584bdd1243dSDimitry Andric }
585bdd1243dSDimitry Andric 
586*06c3fb27SDimitry Andric template <typename T>
587bdd1243dSDimitry Andric void PipelineSolver::populateReadyList(
588*06c3fb27SDimitry Andric     SmallVectorImpl<std::pair<int, int>> &ReadyList, T I, T E) {
589*06c3fb27SDimitry Andric   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
590*06c3fb27SDimitry Andric   auto SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
591bdd1243dSDimitry Andric   assert(CurrSU.second.size() >= 1);
592*06c3fb27SDimitry Andric 
593bdd1243dSDimitry Andric   for (; I != E; ++I) {
594bdd1243dSDimitry Andric     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
595bdd1243dSDimitry Andric     int CandSGID = *I;
596bdd1243dSDimitry Andric     SchedGroup *Match;
597bdd1243dSDimitry Andric     for (auto &SG : SyncPipeline) {
598bdd1243dSDimitry Andric       if (SG.getSGID() == CandSGID)
599bdd1243dSDimitry Andric         Match = &SG;
600bdd1243dSDimitry Andric     }
601bdd1243dSDimitry Andric 
602bdd1243dSDimitry Andric     if (UseCostHeur) {
603bdd1243dSDimitry Andric       if (Match->isFull()) {
604bdd1243dSDimitry Andric         ReadyList.push_back(std::pair(*I, MissPenalty));
605bdd1243dSDimitry Andric         continue;
606bdd1243dSDimitry Andric       }
607bdd1243dSDimitry Andric 
608bdd1243dSDimitry Andric       int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
609bdd1243dSDimitry Andric       ReadyList.push_back(std::pair(*I, TempCost));
610bdd1243dSDimitry Andric       removeEdges(AddedEdges);
611bdd1243dSDimitry Andric     } else
612bdd1243dSDimitry Andric       ReadyList.push_back(std::pair(*I, -1));
613bdd1243dSDimitry Andric   }
614bdd1243dSDimitry Andric 
615bdd1243dSDimitry Andric   if (UseCostHeur) {
616bdd1243dSDimitry Andric     std::sort(ReadyList.begin(), ReadyList.end(),
617bdd1243dSDimitry Andric               [](std::pair<int, int> A, std::pair<int, int> B) {
618bdd1243dSDimitry Andric                 return A.second < B.second;
619bdd1243dSDimitry Andric               });
620bdd1243dSDimitry Andric   }
621bdd1243dSDimitry Andric 
622bdd1243dSDimitry Andric   assert(ReadyList.size() == CurrSU.second.size());
623bdd1243dSDimitry Andric }
624bdd1243dSDimitry Andric 
625bdd1243dSDimitry Andric bool PipelineSolver::solveExact() {
626bdd1243dSDimitry Andric   if (checkOptimal())
627bdd1243dSDimitry Andric     return true;
628bdd1243dSDimitry Andric 
629bdd1243dSDimitry Andric   if (static_cast<size_t>(CurrSyncGroupIdx) == PipelineInstrs.size())
630bdd1243dSDimitry Andric     return false;
631bdd1243dSDimitry Andric 
632bdd1243dSDimitry Andric   assert(static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size());
633bdd1243dSDimitry Andric   assert(static_cast<size_t>(CurrConflInstNo) <
634bdd1243dSDimitry Andric          PipelineInstrs[CurrSyncGroupIdx].size());
635bdd1243dSDimitry Andric   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
636bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
637bdd1243dSDimitry Andric                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
638bdd1243dSDimitry Andric 
639bdd1243dSDimitry Andric   // SchedGroup -> Cost pairs
640bdd1243dSDimitry Andric   SmallVector<std::pair<int, int>, 4> ReadyList;
641bdd1243dSDimitry Andric   // Prioritize the candidate sched groups in terms of lowest cost first
642*06c3fb27SDimitry Andric   IsBottomUp ? populateReadyList(ReadyList, CurrSU.second.rbegin(),
643*06c3fb27SDimitry Andric                                  CurrSU.second.rend())
644*06c3fb27SDimitry Andric              : populateReadyList(ReadyList, CurrSU.second.begin(),
645*06c3fb27SDimitry Andric                                  CurrSU.second.end());
646bdd1243dSDimitry Andric 
647bdd1243dSDimitry Andric   auto I = ReadyList.begin();
648bdd1243dSDimitry Andric   auto E = ReadyList.end();
649bdd1243dSDimitry Andric   for (; I != E; ++I) {
650bdd1243dSDimitry Andric     // If we are trying SGs in least cost order, and the current SG is cost
651bdd1243dSDimitry Andric     // infeasible, then all subsequent SGs will also be cost infeasible, so we
652bdd1243dSDimitry Andric     // can prune.
653bdd1243dSDimitry Andric     if (BestCost != -1 && (CurrCost + I->second > BestCost))
654bdd1243dSDimitry Andric       return false;
655bdd1243dSDimitry Andric 
656bdd1243dSDimitry Andric     int CandSGID = I->first;
657bdd1243dSDimitry Andric     int AddedCost = 0;
658bdd1243dSDimitry Andric     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
659bdd1243dSDimitry Andric     auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
660bdd1243dSDimitry Andric     SchedGroup *Match;
661bdd1243dSDimitry Andric     for (auto &SG : SyncPipeline) {
662bdd1243dSDimitry Andric       if (SG.getSGID() == CandSGID)
663bdd1243dSDimitry Andric         Match = &SG;
664bdd1243dSDimitry Andric     }
665bdd1243dSDimitry Andric 
666bdd1243dSDimitry Andric     if (Match->isFull())
667bdd1243dSDimitry Andric       continue;
668bdd1243dSDimitry Andric 
669*06c3fb27SDimitry Andric     if (!Match->allowedByRules(CurrSU.first, SyncPipeline))
670*06c3fb27SDimitry Andric       continue;
671*06c3fb27SDimitry Andric 
672bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask "
673bdd1243dSDimitry Andric                       << (int)Match->getMask() << "and ID " << CandSGID
674bdd1243dSDimitry Andric                       << "\n");
675bdd1243dSDimitry Andric     Match->add(*CurrSU.first);
676bdd1243dSDimitry Andric     AddedCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
677bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Cost of Assignment: " << AddedCost << "\n");
678bdd1243dSDimitry Andric     CurrCost += AddedCost;
679bdd1243dSDimitry Andric     advancePosition();
680bdd1243dSDimitry Andric     ++BranchesExplored;
681bdd1243dSDimitry Andric     bool FinishedExploring = false;
682bdd1243dSDimitry Andric     // If the Cost after adding edges is greater than a known solution,
683bdd1243dSDimitry Andric     // backtrack
684bdd1243dSDimitry Andric     if (CurrCost < BestCost || BestCost == -1) {
685bdd1243dSDimitry Andric       if (solveExact()) {
686bdd1243dSDimitry Andric         FinishedExploring = BestCost != 0;
687bdd1243dSDimitry Andric         if (!FinishedExploring)
688bdd1243dSDimitry Andric           return true;
689bdd1243dSDimitry Andric       }
690bdd1243dSDimitry Andric     }
691bdd1243dSDimitry Andric 
692bdd1243dSDimitry Andric     retreatPosition();
693bdd1243dSDimitry Andric     CurrCost -= AddedCost;
694bdd1243dSDimitry Andric     removeEdges(AddedEdges);
695bdd1243dSDimitry Andric     Match->pop();
696bdd1243dSDimitry Andric     CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
697bdd1243dSDimitry Andric     if (FinishedExploring)
698bdd1243dSDimitry Andric       return true;
699bdd1243dSDimitry Andric   }
700bdd1243dSDimitry Andric 
701bdd1243dSDimitry Andric   // Try the pipeline where the current instruction is omitted
702bdd1243dSDimitry Andric   // Potentially if we omit a problematic instruction from the pipeline,
703bdd1243dSDimitry Andric   // all the other instructions can nicely fit.
704bdd1243dSDimitry Andric   CurrCost += MissPenalty;
705bdd1243dSDimitry Andric   advancePosition();
706bdd1243dSDimitry Andric 
707bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "NOT Assigned (" << CurrSU.first->NodeNum << ")\n");
708bdd1243dSDimitry Andric 
709bdd1243dSDimitry Andric   bool FinishedExploring = false;
710bdd1243dSDimitry Andric   if (CurrCost < BestCost || BestCost == -1) {
711bdd1243dSDimitry Andric     if (solveExact()) {
712bdd1243dSDimitry Andric       bool FinishedExploring = BestCost != 0;
713bdd1243dSDimitry Andric       if (!FinishedExploring)
714bdd1243dSDimitry Andric         return true;
715bdd1243dSDimitry Andric     }
716bdd1243dSDimitry Andric   }
717bdd1243dSDimitry Andric 
718bdd1243dSDimitry Andric   retreatPosition();
719bdd1243dSDimitry Andric   CurrCost -= MissPenalty;
720bdd1243dSDimitry Andric   return FinishedExploring;
721bdd1243dSDimitry Andric }
722bdd1243dSDimitry Andric 
723*06c3fb27SDimitry Andric template <typename T>
724*06c3fb27SDimitry Andric void PipelineSolver::greedyFind(
725*06c3fb27SDimitry Andric     std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges, T I, T E) {
726bdd1243dSDimitry Andric   SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
727bdd1243dSDimitry Andric   int BestNodeCost = -1;
728bdd1243dSDimitry Andric   int TempCost;
729bdd1243dSDimitry Andric   SchedGroup *BestGroup = nullptr;
730bdd1243dSDimitry Andric   int BestGroupID = -1;
731bdd1243dSDimitry Andric   auto &SyncPipeline = CurrPipeline[CurrSyncGroupIdx];
732bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum
733bdd1243dSDimitry Andric                     << ") in Pipeline # " << CurrSyncGroupIdx << "\n");
734bdd1243dSDimitry Andric 
735bdd1243dSDimitry Andric   // Since we have added the potential SchedGroups from bottom up, but
736bdd1243dSDimitry Andric   // traversed the DAG from top down, parse over the groups from last to
737bdd1243dSDimitry Andric   // first. If we fail to do this for the greedy algorithm, the solution will
738bdd1243dSDimitry Andric   // likely not be good in more complex cases.
739bdd1243dSDimitry Andric   for (; I != E; ++I) {
740bdd1243dSDimitry Andric     std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
741bdd1243dSDimitry Andric     int CandSGID = *I;
742bdd1243dSDimitry Andric     SchedGroup *Match;
743bdd1243dSDimitry Andric     for (auto &SG : SyncPipeline) {
744bdd1243dSDimitry Andric       if (SG.getSGID() == CandSGID)
745bdd1243dSDimitry Andric         Match = &SG;
746bdd1243dSDimitry Andric     }
747bdd1243dSDimitry Andric 
748bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Trying SGID # " << CandSGID << " with Mask "
749bdd1243dSDimitry Andric                       << (int)Match->getMask() << "\n");
750bdd1243dSDimitry Andric 
751bdd1243dSDimitry Andric     if (Match->isFull()) {
752bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n");
753bdd1243dSDimitry Andric       continue;
754bdd1243dSDimitry Andric     }
755*06c3fb27SDimitry Andric     if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) {
756*06c3fb27SDimitry Andric       LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n");
757*06c3fb27SDimitry Andric       continue;
758*06c3fb27SDimitry Andric     }
759bdd1243dSDimitry Andric     TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges);
760bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n");
761bdd1243dSDimitry Andric     if (TempCost < BestNodeCost || BestNodeCost == -1) {
762bdd1243dSDimitry Andric       BestGroup = Match;
763bdd1243dSDimitry Andric       BestNodeCost = TempCost;
764bdd1243dSDimitry Andric       BestGroupID = CandSGID;
765bdd1243dSDimitry Andric     }
766bdd1243dSDimitry Andric     removeEdges(AddedEdges);
767bdd1243dSDimitry Andric     if (BestNodeCost == 0)
768bdd1243dSDimitry Andric       break;
769bdd1243dSDimitry Andric   }
770bdd1243dSDimitry Andric 
771bdd1243dSDimitry Andric   if (BestGroupID != -1) {
772bdd1243dSDimitry Andric     BestGroup->add(*CurrSU.first);
773bdd1243dSDimitry Andric     addEdges(SyncPipeline, CurrSU.first, BestGroupID, AddedEdges);
774bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Best Group has ID: " << BestGroupID << " and Mask"
775bdd1243dSDimitry Andric                       << (int)BestGroup->getMask() << "\n");
776bdd1243dSDimitry Andric     BestCost += TempCost;
777bdd1243dSDimitry Andric   } else
778bdd1243dSDimitry Andric     BestCost += MissPenalty;
779bdd1243dSDimitry Andric 
780bdd1243dSDimitry Andric   CurrPipeline[CurrSyncGroupIdx] = SyncPipeline;
781*06c3fb27SDimitry Andric }
782*06c3fb27SDimitry Andric 
783*06c3fb27SDimitry Andric bool PipelineSolver::solveGreedy() {
784*06c3fb27SDimitry Andric   BestCost = 0;
785*06c3fb27SDimitry Andric   std::vector<std::pair<SUnit *, SUnit *>> AddedEdges;
786*06c3fb27SDimitry Andric 
787*06c3fb27SDimitry Andric   while (static_cast<size_t>(CurrSyncGroupIdx) < PipelineInstrs.size()) {
788*06c3fb27SDimitry Andric     SUToCandSGsPair CurrSU = PipelineInstrs[CurrSyncGroupIdx][CurrConflInstNo];
789*06c3fb27SDimitry Andric     IsBottomUp
790*06c3fb27SDimitry Andric         ? greedyFind(AddedEdges, CurrSU.second.rbegin(), CurrSU.second.rend())
791*06c3fb27SDimitry Andric         : greedyFind(AddedEdges, CurrSU.second.begin(), CurrSU.second.end());
792bdd1243dSDimitry Andric     advancePosition();
793bdd1243dSDimitry Andric   }
794bdd1243dSDimitry Andric   BestPipeline = CurrPipeline;
795bdd1243dSDimitry Andric   removeEdges(AddedEdges);
796bdd1243dSDimitry Andric   return false;
797bdd1243dSDimitry Andric }
798bdd1243dSDimitry Andric 
799bdd1243dSDimitry Andric unsigned PipelineSolver::computeProblemSize() {
800bdd1243dSDimitry Andric   unsigned ProblemSize = 0;
801bdd1243dSDimitry Andric   for (auto &PipeConflicts : PipelineInstrs) {
802bdd1243dSDimitry Andric     ProblemSize += PipeConflicts.size();
803bdd1243dSDimitry Andric   }
804bdd1243dSDimitry Andric 
805bdd1243dSDimitry Andric   return ProblemSize;
806bdd1243dSDimitry Andric }
807bdd1243dSDimitry Andric 
808bdd1243dSDimitry Andric void PipelineSolver::solve() {
809bdd1243dSDimitry Andric   if (!NeedsSolver)
810bdd1243dSDimitry Andric     return;
811bdd1243dSDimitry Andric 
812bdd1243dSDimitry Andric   unsigned ProblemSize = computeProblemSize();
813bdd1243dSDimitry Andric   assert(ProblemSize > 0);
814bdd1243dSDimitry Andric 
815bdd1243dSDimitry Andric   bool BelowCutoff = (CutoffForExact > 0) && ProblemSize <= CutoffForExact;
816bdd1243dSDimitry Andric   MissPenalty = (ProblemSize / 2) + 1;
817bdd1243dSDimitry Andric 
818bdd1243dSDimitry Andric   LLVM_DEBUG(DAG->dump());
819bdd1243dSDimitry Andric   if (EnableExactSolver || BelowCutoff) {
820bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Starting Greedy pipeline solver\n");
821bdd1243dSDimitry Andric     solveGreedy();
822bdd1243dSDimitry Andric     reset();
823bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Greedy produced best cost of " << BestCost << "\n");
824bdd1243dSDimitry Andric     if (BestCost > 0) {
825bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << "Starting EXACT pipeline solver\n");
826bdd1243dSDimitry Andric       solveExact();
827bdd1243dSDimitry Andric       LLVM_DEBUG(dbgs() << "Exact produced best cost of " << BestCost << "\n");
828bdd1243dSDimitry Andric     }
829bdd1243dSDimitry Andric   } else { // Use the Greedy Algorithm by default
830bdd1243dSDimitry Andric     LLVM_DEBUG(dbgs() << "Starting GREEDY pipeline solver\n");
831bdd1243dSDimitry Andric     solveGreedy();
832bdd1243dSDimitry Andric   }
833bdd1243dSDimitry Andric 
834bdd1243dSDimitry Andric   makePipeline();
835*06c3fb27SDimitry Andric   LLVM_DEBUG(dbgs() << "After applying mutation\n");
836*06c3fb27SDimitry Andric   LLVM_DEBUG(DAG->dump());
837bdd1243dSDimitry Andric }
838bdd1243dSDimitry Andric 
839*06c3fb27SDimitry Andric enum IGLPStrategyID : int {
840*06c3fb27SDimitry Andric   MFMASmallGemmOptID = 0,
841*06c3fb27SDimitry Andric   MFMASmallGemmSingleWaveOptID = 1,
842*06c3fb27SDimitry Andric };
843bdd1243dSDimitry Andric 
844bdd1243dSDimitry Andric // Implement a IGLP scheduling strategy.
845bdd1243dSDimitry Andric class IGLPStrategy {
846bdd1243dSDimitry Andric protected:
847bdd1243dSDimitry Andric   ScheduleDAGInstrs *DAG;
848bdd1243dSDimitry Andric 
849bdd1243dSDimitry Andric   const SIInstrInfo *TII;
850bdd1243dSDimitry Andric 
851bdd1243dSDimitry Andric public:
852bdd1243dSDimitry Andric   // Add SchedGroups to \p Pipeline to implement this Strategy.
853bdd1243dSDimitry Andric   virtual void applyIGLPStrategy(
854bdd1243dSDimitry Andric       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
855bdd1243dSDimitry Andric       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) = 0;
856bdd1243dSDimitry Andric 
857bdd1243dSDimitry Andric   // Returns true if this strategy should be applied to a ScheduleDAG.
858bdd1243dSDimitry Andric   virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) = 0;
859bdd1243dSDimitry Andric 
860*06c3fb27SDimitry Andric   bool IsBottomUp = 1;
861*06c3fb27SDimitry Andric 
862bdd1243dSDimitry Andric   IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
863bdd1243dSDimitry Andric       : DAG(DAG), TII(TII) {}
864bdd1243dSDimitry Andric 
865bdd1243dSDimitry Andric   virtual ~IGLPStrategy() = default;
866bdd1243dSDimitry Andric };
867bdd1243dSDimitry Andric 
868bdd1243dSDimitry Andric class MFMASmallGemmOpt final : public IGLPStrategy {
869*06c3fb27SDimitry Andric private:
870bdd1243dSDimitry Andric public:
871bdd1243dSDimitry Andric   void applyIGLPStrategy(
872bdd1243dSDimitry Andric       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
873bdd1243dSDimitry Andric       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) override;
874bdd1243dSDimitry Andric 
875bdd1243dSDimitry Andric   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
876bdd1243dSDimitry Andric 
877bdd1243dSDimitry Andric   MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
878*06c3fb27SDimitry Andric       : IGLPStrategy(DAG, TII) {
879*06c3fb27SDimitry Andric     IsBottomUp = 1;
880*06c3fb27SDimitry Andric   }
881bdd1243dSDimitry Andric };
882bdd1243dSDimitry Andric 
883bdd1243dSDimitry Andric void MFMASmallGemmOpt::applyIGLPStrategy(
884bdd1243dSDimitry Andric     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
885bdd1243dSDimitry Andric     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) {
886bdd1243dSDimitry Andric   // Count the number of MFMA instructions.
887bdd1243dSDimitry Andric   unsigned MFMACount = 0;
888bdd1243dSDimitry Andric   for (const MachineInstr &I : *DAG)
889bdd1243dSDimitry Andric     if (TII->isMFMAorWMMA(I))
890bdd1243dSDimitry Andric       ++MFMACount;
891bdd1243dSDimitry Andric 
892bdd1243dSDimitry Andric   const unsigned PipelineSyncID = 0;
893bdd1243dSDimitry Andric   SchedGroup *SG = nullptr;
894bdd1243dSDimitry Andric   for (unsigned I = 0; I < MFMACount * 3; ++I) {
895bdd1243dSDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
896bdd1243dSDimitry Andric         SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII);
897bdd1243dSDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
898bdd1243dSDimitry Andric 
899bdd1243dSDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
900bdd1243dSDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
901bdd1243dSDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
902bdd1243dSDimitry Andric   }
903bdd1243dSDimitry Andric }
904bdd1243dSDimitry Andric 
905*06c3fb27SDimitry Andric class MFMASmallGemmSingleWaveOpt final : public IGLPStrategy {
906*06c3fb27SDimitry Andric private:
907*06c3fb27SDimitry Andric   // Whether the DS_READ is a predecessor of first four MFMA in region
908*06c3fb27SDimitry Andric   class EnablesInitialMFMA final : public InstructionRule {
909*06c3fb27SDimitry Andric   public:
910*06c3fb27SDimitry Andric     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
911*06c3fb27SDimitry Andric                SmallVectorImpl<SchedGroup> &SyncPipe) override {
912*06c3fb27SDimitry Andric       if (!SyncPipe.size())
913*06c3fb27SDimitry Andric         return false;
914*06c3fb27SDimitry Andric       int MFMAsFound = 0;
915*06c3fb27SDimitry Andric       if (!Cache->size()) {
916*06c3fb27SDimitry Andric         for (auto &Elt : SyncPipe[0].DAG->SUnits) {
917*06c3fb27SDimitry Andric           if (TII->isMFMAorWMMA(*Elt.getInstr())) {
918*06c3fb27SDimitry Andric             ++MFMAsFound;
919*06c3fb27SDimitry Andric             if (MFMAsFound > 4)
920*06c3fb27SDimitry Andric               break;
921*06c3fb27SDimitry Andric             Cache->push_back(&Elt);
922*06c3fb27SDimitry Andric           }
923*06c3fb27SDimitry Andric         }
924*06c3fb27SDimitry Andric       }
925*06c3fb27SDimitry Andric 
926*06c3fb27SDimitry Andric       assert(Cache->size());
927*06c3fb27SDimitry Andric       auto DAG = SyncPipe[0].DAG;
928*06c3fb27SDimitry Andric       for (auto &Elt : *Cache) {
929*06c3fb27SDimitry Andric         if (DAG->IsReachable(Elt, const_cast<SUnit *>(SU)))
930*06c3fb27SDimitry Andric           return true;
931*06c3fb27SDimitry Andric       }
932*06c3fb27SDimitry Andric       return false;
933*06c3fb27SDimitry Andric     }
934*06c3fb27SDimitry Andric 
935*06c3fb27SDimitry Andric     EnablesInitialMFMA(const SIInstrInfo *TII, unsigned SGID,
936*06c3fb27SDimitry Andric                        bool NeedsCache = false)
937*06c3fb27SDimitry Andric         : InstructionRule(TII, SGID, NeedsCache) {}
938*06c3fb27SDimitry Andric   };
939*06c3fb27SDimitry Andric 
940*06c3fb27SDimitry Andric   // Whether the MI is a V_PERM and is a predecessor of a common DS_WRITE
941*06c3fb27SDimitry Andric   class IsPermForDSW final : public InstructionRule {
942*06c3fb27SDimitry Andric   public:
943*06c3fb27SDimitry Andric     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
944*06c3fb27SDimitry Andric                SmallVectorImpl<SchedGroup> &SyncPipe) override {
945*06c3fb27SDimitry Andric       auto MI = SU->getInstr();
946*06c3fb27SDimitry Andric       if (MI->getOpcode() != AMDGPU::V_PERM_B32_e64)
947*06c3fb27SDimitry Andric         return false;
948*06c3fb27SDimitry Andric 
949*06c3fb27SDimitry Andric       bool FitsInGroup = false;
950*06c3fb27SDimitry Andric       // Does the VALU have a DS_WRITE successor
951*06c3fb27SDimitry Andric       if (!Collection.size()) {
952*06c3fb27SDimitry Andric         for (auto &Succ : SU->Succs) {
953*06c3fb27SDimitry Andric           SUnit *SuccUnit = Succ.getSUnit();
954*06c3fb27SDimitry Andric           if (TII->isDS(*SuccUnit->getInstr()) &&
955*06c3fb27SDimitry Andric               SuccUnit->getInstr()->mayStore()) {
956*06c3fb27SDimitry Andric             Cache->push_back(SuccUnit);
957*06c3fb27SDimitry Andric             FitsInGroup = true;
958*06c3fb27SDimitry Andric           }
959*06c3fb27SDimitry Andric         }
960*06c3fb27SDimitry Andric         return FitsInGroup;
961*06c3fb27SDimitry Andric       }
962*06c3fb27SDimitry Andric 
963*06c3fb27SDimitry Andric       assert(Cache->size());
964*06c3fb27SDimitry Andric 
965*06c3fb27SDimitry Andric       // Does the VALU have a DS_WRITE successor that is the same as other
966*06c3fb27SDimitry Andric       // VALU already in the group. The V_PERMs will all share 1 DS_W succ
967*06c3fb27SDimitry Andric       return std::any_of(Cache->begin(), Cache->end(), [&SU](SUnit *Elt) {
968*06c3fb27SDimitry Andric         return std::any_of(SU->Succs.begin(), SU->Succs.end(),
969*06c3fb27SDimitry Andric                            [&Elt](const SDep &ThisSucc) {
970*06c3fb27SDimitry Andric                              return ThisSucc.getSUnit() == Elt;
971*06c3fb27SDimitry Andric                            });
972*06c3fb27SDimitry Andric       });
973*06c3fb27SDimitry Andric     }
974*06c3fb27SDimitry Andric 
975*06c3fb27SDimitry Andric     IsPermForDSW(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
976*06c3fb27SDimitry Andric         : InstructionRule(TII, SGID, NeedsCache) {}
977*06c3fb27SDimitry Andric   };
978*06c3fb27SDimitry Andric 
979*06c3fb27SDimitry Andric   // Whether the SU is a successor of any element in previous SchedGroup
980*06c3fb27SDimitry Andric   class IsSuccOfPrevGroup final : public InstructionRule {
981*06c3fb27SDimitry Andric   public:
982*06c3fb27SDimitry Andric     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
983*06c3fb27SDimitry Andric                SmallVectorImpl<SchedGroup> &SyncPipe) override {
984*06c3fb27SDimitry Andric       SchedGroup *OtherGroup = nullptr;
985*06c3fb27SDimitry Andric       for (auto &PipeSG : SyncPipe) {
986*06c3fb27SDimitry Andric         if ((unsigned)PipeSG.getSGID() == SGID - 1) {
987*06c3fb27SDimitry Andric           OtherGroup = &PipeSG;
988*06c3fb27SDimitry Andric         }
989*06c3fb27SDimitry Andric       }
990*06c3fb27SDimitry Andric 
991*06c3fb27SDimitry Andric       if (!OtherGroup)
992*06c3fb27SDimitry Andric         return false;
993*06c3fb27SDimitry Andric       if (!OtherGroup->Collection.size())
994*06c3fb27SDimitry Andric         return true;
995*06c3fb27SDimitry Andric 
996*06c3fb27SDimitry Andric       // Does the previous VALU have this DS_Write as a successor
997*06c3fb27SDimitry Andric       return (std::any_of(OtherGroup->Collection.begin(),
998*06c3fb27SDimitry Andric                           OtherGroup->Collection.end(), [&SU](SUnit *Elt) {
999*06c3fb27SDimitry Andric                             return std::any_of(Elt->Succs.begin(),
1000*06c3fb27SDimitry Andric                                                Elt->Succs.end(),
1001*06c3fb27SDimitry Andric                                                [&SU](SDep &Succ) {
1002*06c3fb27SDimitry Andric                                                  return Succ.getSUnit() == SU;
1003*06c3fb27SDimitry Andric                                                });
1004*06c3fb27SDimitry Andric                           }));
1005*06c3fb27SDimitry Andric     }
1006*06c3fb27SDimitry Andric     IsSuccOfPrevGroup(const SIInstrInfo *TII, unsigned SGID,
1007*06c3fb27SDimitry Andric                       bool NeedsCache = false)
1008*06c3fb27SDimitry Andric         : InstructionRule(TII, SGID, NeedsCache) {}
1009*06c3fb27SDimitry Andric   };
1010*06c3fb27SDimitry Andric 
1011*06c3fb27SDimitry Andric   // Whether the combined load width of group is 128 bits
1012*06c3fb27SDimitry Andric   class VMEMSize final : public InstructionRule {
1013*06c3fb27SDimitry Andric   public:
1014*06c3fb27SDimitry Andric     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1015*06c3fb27SDimitry Andric                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1016*06c3fb27SDimitry Andric       auto MI = SU->getInstr();
1017*06c3fb27SDimitry Andric       if (MI->getOpcode() == TargetOpcode::BUNDLE)
1018*06c3fb27SDimitry Andric         return false;
1019*06c3fb27SDimitry Andric       if (!Collection.size())
1020*06c3fb27SDimitry Andric         return true;
1021*06c3fb27SDimitry Andric 
1022*06c3fb27SDimitry Andric       int NumBits = 0;
1023*06c3fb27SDimitry Andric 
1024*06c3fb27SDimitry Andric       auto TRI = TII->getRegisterInfo();
1025*06c3fb27SDimitry Andric       auto &MRI = MI->getParent()->getParent()->getRegInfo();
1026*06c3fb27SDimitry Andric       for (auto &Elt : Collection) {
1027*06c3fb27SDimitry Andric         auto Op = Elt->getInstr()->getOperand(0);
1028*06c3fb27SDimitry Andric         auto Size =
1029*06c3fb27SDimitry Andric             TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(MRI, Op));
1030*06c3fb27SDimitry Andric         NumBits += Size;
1031*06c3fb27SDimitry Andric       }
1032*06c3fb27SDimitry Andric 
1033*06c3fb27SDimitry Andric       if (NumBits < 128) {
1034*06c3fb27SDimitry Andric         assert(TII->isVMEM(*MI) && MI->mayLoad());
1035*06c3fb27SDimitry Andric         if (NumBits + TRI.getRegSizeInBits(*TRI.getRegClassForOperandReg(
1036*06c3fb27SDimitry Andric                           MRI, MI->getOperand(0))) <=
1037*06c3fb27SDimitry Andric             128)
1038*06c3fb27SDimitry Andric           return true;
1039*06c3fb27SDimitry Andric       }
1040*06c3fb27SDimitry Andric 
1041*06c3fb27SDimitry Andric       return false;
1042*06c3fb27SDimitry Andric     }
1043*06c3fb27SDimitry Andric 
1044*06c3fb27SDimitry Andric     VMEMSize(const SIInstrInfo *TII, unsigned SGID, bool NeedsCache = false)
1045*06c3fb27SDimitry Andric         : InstructionRule(TII, SGID, NeedsCache) {}
1046*06c3fb27SDimitry Andric   };
1047*06c3fb27SDimitry Andric 
1048*06c3fb27SDimitry Andric   // Whether the SU shares a V_PERM predecessor with any SU in the SchedGroup
1049*06c3fb27SDimitry Andric   // that is /p Distance steps away
1050*06c3fb27SDimitry Andric   class SharesPredWithPrevNthGroup final : public InstructionRule {
1051*06c3fb27SDimitry Andric   private:
1052*06c3fb27SDimitry Andric     unsigned Distance = 1;
1053*06c3fb27SDimitry Andric 
1054*06c3fb27SDimitry Andric   public:
1055*06c3fb27SDimitry Andric     bool apply(const SUnit *SU, const ArrayRef<SUnit *> Collection,
1056*06c3fb27SDimitry Andric                SmallVectorImpl<SchedGroup> &SyncPipe) override {
1057*06c3fb27SDimitry Andric       SchedGroup *OtherGroup = nullptr;
1058*06c3fb27SDimitry Andric       if (!SyncPipe.size())
1059*06c3fb27SDimitry Andric         return false;
1060*06c3fb27SDimitry Andric 
1061*06c3fb27SDimitry Andric       if (!Cache->size()) {
1062*06c3fb27SDimitry Andric 
1063*06c3fb27SDimitry Andric         for (auto &PipeSG : SyncPipe) {
1064*06c3fb27SDimitry Andric           if ((unsigned)PipeSG.getSGID() == SGID - Distance) {
1065*06c3fb27SDimitry Andric             OtherGroup = &PipeSG;
1066*06c3fb27SDimitry Andric           }
1067*06c3fb27SDimitry Andric         }
1068*06c3fb27SDimitry Andric 
1069*06c3fb27SDimitry Andric         if (!OtherGroup)
1070*06c3fb27SDimitry Andric           return false;
1071*06c3fb27SDimitry Andric         if (!OtherGroup->Collection.size())
1072*06c3fb27SDimitry Andric           return true;
1073*06c3fb27SDimitry Andric 
1074*06c3fb27SDimitry Andric         for (auto &OtherEle : OtherGroup->Collection) {
1075*06c3fb27SDimitry Andric           for (auto &Pred : OtherEle->Preds) {
1076*06c3fb27SDimitry Andric             if (Pred.getSUnit()->getInstr()->getOpcode() ==
1077*06c3fb27SDimitry Andric                 AMDGPU::V_PERM_B32_e64)
1078*06c3fb27SDimitry Andric               Cache->push_back(Pred.getSUnit());
1079*06c3fb27SDimitry Andric           }
1080*06c3fb27SDimitry Andric         }
1081*06c3fb27SDimitry Andric       }
1082*06c3fb27SDimitry Andric 
1083*06c3fb27SDimitry Andric       assert(Cache->size());
1084*06c3fb27SDimitry Andric       auto DAG = SyncPipe[0].DAG;
1085*06c3fb27SDimitry Andric       // Does the previous DS_WRITE share a V_PERM predecessor with this
1086*06c3fb27SDimitry Andric       // VMEM_READ
1087*06c3fb27SDimitry Andric       return (
1088*06c3fb27SDimitry Andric           std::any_of(Cache->begin(), Cache->end(), [&SU, &DAG](SUnit *Elt) {
1089*06c3fb27SDimitry Andric             return DAG->IsReachable(const_cast<SUnit *>(SU), Elt);
1090*06c3fb27SDimitry Andric           }));
1091*06c3fb27SDimitry Andric     }
1092*06c3fb27SDimitry Andric     SharesPredWithPrevNthGroup(unsigned Distance, const SIInstrInfo *TII,
1093*06c3fb27SDimitry Andric                                unsigned SGID, bool NeedsCache = false)
1094*06c3fb27SDimitry Andric         : InstructionRule(TII, SGID, NeedsCache), Distance(Distance) {}
1095*06c3fb27SDimitry Andric   };
1096*06c3fb27SDimitry Andric 
1097*06c3fb27SDimitry Andric public:
1098*06c3fb27SDimitry Andric   void applyIGLPStrategy(
1099*06c3fb27SDimitry Andric       DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1100*06c3fb27SDimitry Andric       DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) override;
1101*06c3fb27SDimitry Andric 
1102*06c3fb27SDimitry Andric   bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; }
1103*06c3fb27SDimitry Andric 
1104*06c3fb27SDimitry Andric   MFMASmallGemmSingleWaveOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII)
1105*06c3fb27SDimitry Andric       : IGLPStrategy(DAG, TII) {
1106*06c3fb27SDimitry Andric     IsBottomUp = 0;
1107*06c3fb27SDimitry Andric   }
1108*06c3fb27SDimitry Andric };
1109*06c3fb27SDimitry Andric 
1110*06c3fb27SDimitry Andric void MFMASmallGemmSingleWaveOpt::applyIGLPStrategy(
1111*06c3fb27SDimitry Andric     DenseMap<int, SUnitsToCandidateSGsMap> &SyncedInstrs,
1112*06c3fb27SDimitry Andric     DenseMap<int, SmallVector<SchedGroup, 4>> &SyncedSchedGroups) {
1113*06c3fb27SDimitry Andric   unsigned MFMACount = 0;
1114*06c3fb27SDimitry Andric   unsigned DSWCount = 0;
1115*06c3fb27SDimitry Andric   unsigned DSWWithPermCount = 0;
1116*06c3fb27SDimitry Andric   unsigned DSWWithSharedVMEMCount = 0;
1117*06c3fb27SDimitry Andric   unsigned DSRCount = 0;
1118*06c3fb27SDimitry Andric   SmallVector<SUnit *, 6> DSWithPerms;
1119*06c3fb27SDimitry Andric   for (auto &SU : DAG->SUnits) {
1120*06c3fb27SDimitry Andric     auto I = SU.getInstr();
1121*06c3fb27SDimitry Andric     if (TII->isMFMAorWMMA(*I))
1122*06c3fb27SDimitry Andric       ++MFMACount;
1123*06c3fb27SDimitry Andric     else if (TII->isDS(*I)) {
1124*06c3fb27SDimitry Andric       if (I->mayLoad())
1125*06c3fb27SDimitry Andric         ++DSRCount;
1126*06c3fb27SDimitry Andric       else if (I->mayStore()) {
1127*06c3fb27SDimitry Andric         ++DSWCount;
1128*06c3fb27SDimitry Andric         for (auto Pred : SU.Preds) {
1129*06c3fb27SDimitry Andric           if (Pred.getSUnit()->getInstr()->getOpcode() ==
1130*06c3fb27SDimitry Andric               AMDGPU::V_PERM_B32_e64) {
1131*06c3fb27SDimitry Andric             DSWithPerms.push_back(&SU);
1132*06c3fb27SDimitry Andric             break;
1133*06c3fb27SDimitry Andric           }
1134*06c3fb27SDimitry Andric         }
1135*06c3fb27SDimitry Andric       }
1136*06c3fb27SDimitry Andric     }
1137*06c3fb27SDimitry Andric   }
1138*06c3fb27SDimitry Andric   DSWWithPermCount = DSWithPerms.size();
1139*06c3fb27SDimitry Andric   auto I = DSWithPerms.begin();
1140*06c3fb27SDimitry Andric   auto E = DSWithPerms.end();
1141*06c3fb27SDimitry Andric 
1142*06c3fb27SDimitry Andric   // Get the count of DS_WRITES with V_PERM predecessors which
1143*06c3fb27SDimitry Andric   // have loop carried dependencies (WAR) on the same VMEM_READs.
1144*06c3fb27SDimitry Andric   // We consider partial overlap as a miss -- in other words,
1145*06c3fb27SDimitry Andric   // for a given DS_W, we only consider another DS_W as matching
1146*06c3fb27SDimitry Andric   // if there is a corresponding (in terms of the VMEM_R it uses) V_PERM pred
1147*06c3fb27SDimitry Andric   // for every V_PERM pred of this DS_W.
1148*06c3fb27SDimitry Andric   DenseMap<MachineInstr *, SUnit *> VMEMLookup;
1149*06c3fb27SDimitry Andric   SmallVector<SUnit *, 6> Counted;
1150*06c3fb27SDimitry Andric   for (; I != E; I++) {
1151*06c3fb27SDimitry Andric     SUnit *Cand = nullptr;
1152*06c3fb27SDimitry Andric     bool MissedAny = false;
1153*06c3fb27SDimitry Andric     for (auto &Pred : (*I)->Preds) {
1154*06c3fb27SDimitry Andric       if (Pred.getSUnit()->getInstr()->getOpcode() != AMDGPU::V_PERM_B32_e64)
1155*06c3fb27SDimitry Andric         continue;
1156*06c3fb27SDimitry Andric 
1157*06c3fb27SDimitry Andric       if (Cand &&
1158*06c3fb27SDimitry Andric           std::find(Counted.begin(), Counted.end(), Cand) != Counted.end())
1159*06c3fb27SDimitry Andric         break;
1160*06c3fb27SDimitry Andric 
1161*06c3fb27SDimitry Andric       for (auto &Succ : Pred.getSUnit()->Succs) {
1162*06c3fb27SDimitry Andric         auto MI = Succ.getSUnit()->getInstr();
1163*06c3fb27SDimitry Andric         if (!TII->isVMEM(*MI) || !MI->mayLoad())
1164*06c3fb27SDimitry Andric           continue;
1165*06c3fb27SDimitry Andric 
1166*06c3fb27SDimitry Andric         if (MissedAny || !VMEMLookup.size()) {
1167*06c3fb27SDimitry Andric           MissedAny = true;
1168*06c3fb27SDimitry Andric           VMEMLookup[MI] = *I;
1169*06c3fb27SDimitry Andric           continue;
1170*06c3fb27SDimitry Andric         }
1171*06c3fb27SDimitry Andric 
1172*06c3fb27SDimitry Andric         if (!VMEMLookup.contains(MI)) {
1173*06c3fb27SDimitry Andric           MissedAny = true;
1174*06c3fb27SDimitry Andric           VMEMLookup[MI] = *I;
1175*06c3fb27SDimitry Andric           continue;
1176*06c3fb27SDimitry Andric         }
1177*06c3fb27SDimitry Andric 
1178*06c3fb27SDimitry Andric         Cand = VMEMLookup[MI];
1179*06c3fb27SDimitry Andric         if (std::find(Counted.begin(), Counted.end(), Cand) != Counted.end()) {
1180*06c3fb27SDimitry Andric           MissedAny = true;
1181*06c3fb27SDimitry Andric           break;
1182*06c3fb27SDimitry Andric         }
1183*06c3fb27SDimitry Andric       }
1184*06c3fb27SDimitry Andric     }
1185*06c3fb27SDimitry Andric     if (!MissedAny && Cand) {
1186*06c3fb27SDimitry Andric       DSWWithSharedVMEMCount += 2;
1187*06c3fb27SDimitry Andric       Counted.push_back(Cand);
1188*06c3fb27SDimitry Andric       Counted.push_back(*I);
1189*06c3fb27SDimitry Andric     }
1190*06c3fb27SDimitry Andric   }
1191*06c3fb27SDimitry Andric 
1192*06c3fb27SDimitry Andric   assert(DSWWithSharedVMEMCount <= DSWWithPermCount);
1193*06c3fb27SDimitry Andric   SchedGroup *SG;
1194*06c3fb27SDimitry Andric   unsigned PipelineSyncID = 0;
1195*06c3fb27SDimitry Andric   // For kernels with V_PERM, there are enough VALU to mix in between MFMAs
1196*06c3fb27SDimitry Andric   if (DSWWithPermCount) {
1197*06c3fb27SDimitry Andric     for (unsigned I = 0; I < MFMACount; I++) {
1198*06c3fb27SDimitry Andric       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1199*06c3fb27SDimitry Andric           SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1200*06c3fb27SDimitry Andric       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1201*06c3fb27SDimitry Andric 
1202*06c3fb27SDimitry Andric       SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1203*06c3fb27SDimitry Andric           SchedGroupMask::VALU, 2, PipelineSyncID, DAG, TII);
1204*06c3fb27SDimitry Andric       SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1205*06c3fb27SDimitry Andric     }
1206*06c3fb27SDimitry Andric   }
1207*06c3fb27SDimitry Andric 
1208*06c3fb27SDimitry Andric   PipelineSyncID = 1;
1209*06c3fb27SDimitry Andric   // Phase 1: Break up DS_READ and MFMA clusters.
1210*06c3fb27SDimitry Andric   // First DS_READ to make ready initial MFMA, then interleave MFMA with DS_READ
1211*06c3fb27SDimitry Andric   // prefetch
1212*06c3fb27SDimitry Andric 
1213*06c3fb27SDimitry Andric   // Make ready initial MFMA
1214*06c3fb27SDimitry Andric   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1215*06c3fb27SDimitry Andric       SchedGroupMask::DS_READ, 4, PipelineSyncID, DAG, TII);
1216*06c3fb27SDimitry Andric   SG->addRule(std::make_shared<EnablesInitialMFMA>(TII, SG->getSGID(), true));
1217*06c3fb27SDimitry Andric   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1218*06c3fb27SDimitry Andric 
1219*06c3fb27SDimitry Andric   SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1220*06c3fb27SDimitry Andric       SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1221*06c3fb27SDimitry Andric   SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1222*06c3fb27SDimitry Andric 
1223*06c3fb27SDimitry Andric   // Interleave MFMA with DS_READ prefetch
1224*06c3fb27SDimitry Andric   for (unsigned I = 0; I < DSRCount - 4; ++I) {
1225*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1226*06c3fb27SDimitry Andric         SchedGroupMask::DS_READ, 1, PipelineSyncID, DAG, TII);
1227*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1228*06c3fb27SDimitry Andric 
1229*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1230*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1231*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1232*06c3fb27SDimitry Andric   }
1233*06c3fb27SDimitry Andric 
1234*06c3fb27SDimitry Andric   // Phase 2a: Loop carried dependency with V_PERM
1235*06c3fb27SDimitry Andric   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
1236*06c3fb27SDimitry Andric   // depend on. Interleave MFMA to keep XDL unit busy throughout.
1237*06c3fb27SDimitry Andric   for (unsigned I = 0; I < DSWWithPermCount - DSWWithSharedVMEMCount; ++I) {
1238*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1239*06c3fb27SDimitry Andric         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1240*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1241*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1242*06c3fb27SDimitry Andric 
1243*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1244*06c3fb27SDimitry Andric         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1245*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1246*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1247*06c3fb27SDimitry Andric 
1248*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1249*06c3fb27SDimitry Andric         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1250*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1251*06c3fb27SDimitry Andric         1, TII, SG->getSGID(), true));
1252*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1253*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1254*06c3fb27SDimitry Andric 
1255*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1256*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1257*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1258*06c3fb27SDimitry Andric 
1259*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1260*06c3fb27SDimitry Andric         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1261*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1262*06c3fb27SDimitry Andric         3, TII, SG->getSGID(), true));
1263*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1264*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1265*06c3fb27SDimitry Andric 
1266*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1267*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1268*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1269*06c3fb27SDimitry Andric   }
1270*06c3fb27SDimitry Andric 
1271*06c3fb27SDimitry Andric   // Phase 2b: Loop carried dependency without V_PERM
1272*06c3fb27SDimitry Andric   // Schedule DS_WRITE as closely as possible to the VMEM_READ they depend on.
1273*06c3fb27SDimitry Andric   // Interleave MFMA to keep XDL unit busy throughout.
1274*06c3fb27SDimitry Andric   for (unsigned I = 0; I < DSWCount - DSWWithPermCount; I++) {
1275*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1276*06c3fb27SDimitry Andric         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1277*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1278*06c3fb27SDimitry Andric 
1279*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1280*06c3fb27SDimitry Andric         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1281*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1282*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1283*06c3fb27SDimitry Andric 
1284*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1285*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1286*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1287*06c3fb27SDimitry Andric   }
1288*06c3fb27SDimitry Andric 
1289*06c3fb27SDimitry Andric   // Phase 2c: Loop carried dependency with V_PERM, VMEM_READs are
1290*06c3fb27SDimitry Andric   // ultimately used by two DS_WRITE
1291*06c3fb27SDimitry Andric   // Schedule VPerm & DS_WRITE as closely as possible to the VMEM_READ they
1292*06c3fb27SDimitry Andric   // depend on. Interleave MFMA to keep XDL unit busy throughout.
1293*06c3fb27SDimitry Andric 
1294*06c3fb27SDimitry Andric   for (unsigned I = 0; I < DSWWithSharedVMEMCount; ++I) {
1295*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1296*06c3fb27SDimitry Andric         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1297*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1298*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1299*06c3fb27SDimitry Andric 
1300*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1301*06c3fb27SDimitry Andric         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1302*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1303*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1304*06c3fb27SDimitry Andric 
1305*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1306*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1307*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1308*06c3fb27SDimitry Andric 
1309*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1310*06c3fb27SDimitry Andric         SchedGroupMask::VALU, 4, PipelineSyncID, DAG, TII);
1311*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<IsPermForDSW>(TII, SG->getSGID(), true));
1312*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1313*06c3fb27SDimitry Andric 
1314*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1315*06c3fb27SDimitry Andric         SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII);
1316*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<IsSuccOfPrevGroup>(TII, SG->getSGID(), false));
1317*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1318*06c3fb27SDimitry Andric 
1319*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1320*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1321*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1322*06c3fb27SDimitry Andric 
1323*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1324*06c3fb27SDimitry Andric         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1325*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1326*06c3fb27SDimitry Andric         2, TII, SG->getSGID(), true));
1327*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1328*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1329*06c3fb27SDimitry Andric 
1330*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1331*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1332*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1333*06c3fb27SDimitry Andric 
1334*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1335*06c3fb27SDimitry Andric         SchedGroupMask::VMEM_READ, 4, PipelineSyncID, DAG, TII);
1336*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<SharesPredWithPrevNthGroup>(
1337*06c3fb27SDimitry Andric         4, TII, SG->getSGID(), true));
1338*06c3fb27SDimitry Andric     SG->addRule(std::make_shared<VMEMSize>(TII, SG->getSGID(), false));
1339*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1340*06c3fb27SDimitry Andric 
1341*06c3fb27SDimitry Andric     SG = &SyncedSchedGroups[PipelineSyncID].emplace_back(
1342*06c3fb27SDimitry Andric         SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII);
1343*06c3fb27SDimitry Andric     SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]);
1344*06c3fb27SDimitry Andric   }
1345*06c3fb27SDimitry Andric }
1346*06c3fb27SDimitry Andric 
1347bdd1243dSDimitry Andric static std::unique_ptr<IGLPStrategy>
1348bdd1243dSDimitry Andric createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG,
1349bdd1243dSDimitry Andric                    const SIInstrInfo *TII) {
1350bdd1243dSDimitry Andric   switch (ID) {
1351bdd1243dSDimitry Andric   case MFMASmallGemmOptID:
1352bdd1243dSDimitry Andric     return std::make_unique<MFMASmallGemmOpt>(DAG, TII);
1353*06c3fb27SDimitry Andric   case MFMASmallGemmSingleWaveOptID:
1354*06c3fb27SDimitry Andric     return std::make_unique<MFMASmallGemmSingleWaveOpt>(DAG, TII);
1355bdd1243dSDimitry Andric   }
1356bdd1243dSDimitry Andric 
1357bdd1243dSDimitry Andric   llvm_unreachable("Unknown IGLPStrategyID");
1358bdd1243dSDimitry Andric }
1359bdd1243dSDimitry Andric 
1360bdd1243dSDimitry Andric class IGroupLPDAGMutation : public ScheduleDAGMutation {
1361bdd1243dSDimitry Andric private:
1362bdd1243dSDimitry Andric   const SIInstrInfo *TII;
1363bdd1243dSDimitry Andric 
1364bdd1243dSDimitry Andric   ScheduleDAGMI *DAG;
1365bdd1243dSDimitry Andric 
1366bdd1243dSDimitry Andric   // Organize lists of SchedGroups by their SyncID. SchedGroups /
1367bdd1243dSDimitry Andric   // SCHED_GROUP_BARRIERs with different SyncIDs will have no edges added
1368bdd1243dSDimitry Andric   // between then.
1369bdd1243dSDimitry Andric   DenseMap<int, SmallVector<SchedGroup, 4>> SyncedSchedGroups;
1370bdd1243dSDimitry Andric 
1371bdd1243dSDimitry Andric   // Used to track instructions that can be mapped to multiple sched groups
1372bdd1243dSDimitry Andric   DenseMap<int, SUnitsToCandidateSGsMap> SyncedInstrs;
1373bdd1243dSDimitry Andric 
1374bdd1243dSDimitry Andric   // Add DAG edges that enforce SCHED_BARRIER ordering.
1375bdd1243dSDimitry Andric   void addSchedBarrierEdges(SUnit &SU);
1376bdd1243dSDimitry Andric 
1377bdd1243dSDimitry Andric   // Use a SCHED_BARRIER's mask to identify instruction SchedGroups that should
1378bdd1243dSDimitry Andric   // not be reordered accross the SCHED_BARRIER. This is used for the base
1379bdd1243dSDimitry Andric   // SCHED_BARRIER, and not SCHED_GROUP_BARRIER. The difference is that
1380bdd1243dSDimitry Andric   // SCHED_BARRIER will always block all instructions that can be classified
1381bdd1243dSDimitry Andric   // into a particular SchedClass, whereas SCHED_GROUP_BARRIER has a fixed size
1382bdd1243dSDimitry Andric   // and may only synchronize with some SchedGroups. Returns the inverse of
1383bdd1243dSDimitry Andric   // Mask. SCHED_BARRIER's mask describes which instruction types should be
1384bdd1243dSDimitry Andric   // allowed to be scheduled across it. Invert the mask to get the
1385bdd1243dSDimitry Andric   // SchedGroupMask of instructions that should be barred.
1386bdd1243dSDimitry Andric   SchedGroupMask invertSchedBarrierMask(SchedGroupMask Mask) const;
1387bdd1243dSDimitry Andric 
1388bdd1243dSDimitry Andric   // Create SchedGroups for a SCHED_GROUP_BARRIER.
1389bdd1243dSDimitry Andric   void initSchedGroupBarrierPipelineStage(
1390bdd1243dSDimitry Andric       std::vector<SUnit>::reverse_iterator RIter);
1391bdd1243dSDimitry Andric 
1392bdd1243dSDimitry Andric   void initIGLPOpt(SUnit &SU);
1393bdd1243dSDimitry Andric 
1394bdd1243dSDimitry Andric public:
1395bdd1243dSDimitry Andric   void apply(ScheduleDAGInstrs *DAGInstrs) override;
1396bdd1243dSDimitry Andric 
1397*06c3fb27SDimitry Andric   // The order in which the PipelineSolver should process the candidate
1398*06c3fb27SDimitry Andric   // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last
1399*06c3fb27SDimitry Andric   // created SchedGroup first, and will consider that as the ultimate
1400*06c3fb27SDimitry Andric   // predecessor group when linking. TOP_DOWN instead links and processes the
1401*06c3fb27SDimitry Andric   // first created SchedGroup first.
1402*06c3fb27SDimitry Andric   bool IsBottomUp = 1;
1403*06c3fb27SDimitry Andric 
1404bdd1243dSDimitry Andric   IGroupLPDAGMutation() = default;
1405bdd1243dSDimitry Andric };
1406bdd1243dSDimitry Andric 
1407bdd1243dSDimitry Andric unsigned SchedGroup::NumSchedGroups = 0;
1408bdd1243dSDimitry Andric 
1409bdd1243dSDimitry Andric bool SchedGroup::tryAddEdge(SUnit *A, SUnit *B) {
1410bdd1243dSDimitry Andric   if (A != B && DAG->canAddEdge(B, A)) {
1411bdd1243dSDimitry Andric     DAG->addEdge(B, SDep(A, SDep::Artificial));
1412bdd1243dSDimitry Andric     return true;
1413bdd1243dSDimitry Andric   }
1414bdd1243dSDimitry Andric   return false;
1415bdd1243dSDimitry Andric }
1416bdd1243dSDimitry Andric 
1417bdd1243dSDimitry Andric bool SchedGroup::canAddMI(const MachineInstr &MI) const {
1418bdd1243dSDimitry Andric   bool Result = false;
1419bdd1243dSDimitry Andric   if (MI.isMetaInstruction())
1420bdd1243dSDimitry Andric     Result = false;
1421bdd1243dSDimitry Andric 
1422bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::ALU) != SchedGroupMask::NONE) &&
1423bdd1243dSDimitry Andric            (TII->isVALU(MI) || TII->isMFMAorWMMA(MI) || TII->isSALU(MI)))
1424bdd1243dSDimitry Andric     Result = true;
1425bdd1243dSDimitry Andric 
1426bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::VALU) != SchedGroupMask::NONE) &&
1427bdd1243dSDimitry Andric            TII->isVALU(MI) && !TII->isMFMAorWMMA(MI))
1428bdd1243dSDimitry Andric     Result = true;
1429bdd1243dSDimitry Andric 
1430bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::SALU) != SchedGroupMask::NONE) &&
1431bdd1243dSDimitry Andric            TII->isSALU(MI))
1432bdd1243dSDimitry Andric     Result = true;
1433bdd1243dSDimitry Andric 
1434bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::MFMA) != SchedGroupMask::NONE) &&
1435bdd1243dSDimitry Andric            TII->isMFMAorWMMA(MI))
1436bdd1243dSDimitry Andric     Result = true;
1437bdd1243dSDimitry Andric 
1438bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::VMEM) != SchedGroupMask::NONE) &&
1439bdd1243dSDimitry Andric            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1440bdd1243dSDimitry Andric     Result = true;
1441bdd1243dSDimitry Andric 
1442bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::VMEM_READ) != SchedGroupMask::NONE) &&
1443bdd1243dSDimitry Andric            MI.mayLoad() &&
1444bdd1243dSDimitry Andric            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1445bdd1243dSDimitry Andric     Result = true;
1446bdd1243dSDimitry Andric 
1447bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::VMEM_WRITE) != SchedGroupMask::NONE) &&
1448bdd1243dSDimitry Andric            MI.mayStore() &&
1449bdd1243dSDimitry Andric            (TII->isVMEM(MI) || (TII->isFLAT(MI) && !TII->isDS(MI))))
1450bdd1243dSDimitry Andric     Result = true;
1451bdd1243dSDimitry Andric 
1452bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::DS) != SchedGroupMask::NONE) &&
1453bdd1243dSDimitry Andric            TII->isDS(MI))
1454bdd1243dSDimitry Andric     Result = true;
1455bdd1243dSDimitry Andric 
1456bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::DS_READ) != SchedGroupMask::NONE) &&
1457bdd1243dSDimitry Andric            MI.mayLoad() && TII->isDS(MI))
1458bdd1243dSDimitry Andric     Result = true;
1459bdd1243dSDimitry Andric 
1460bdd1243dSDimitry Andric   else if (((SGMask & SchedGroupMask::DS_WRITE) != SchedGroupMask::NONE) &&
1461bdd1243dSDimitry Andric            MI.mayStore() && TII->isDS(MI))
1462bdd1243dSDimitry Andric     Result = true;
1463bdd1243dSDimitry Andric 
1464bdd1243dSDimitry Andric   LLVM_DEBUG(
1465bdd1243dSDimitry Andric       dbgs() << "For SchedGroup with mask " << format_hex((int)SGMask, 10, true)
1466bdd1243dSDimitry Andric              << (Result ? " could classify " : " unable to classify ") << MI);
1467bdd1243dSDimitry Andric 
1468bdd1243dSDimitry Andric   return Result;
1469bdd1243dSDimitry Andric }
1470bdd1243dSDimitry Andric 
1471bdd1243dSDimitry Andric int SchedGroup::link(SUnit &SU, bool MakePred,
1472bdd1243dSDimitry Andric                      std::vector<std::pair<SUnit *, SUnit *>> &AddedEdges) {
1473bdd1243dSDimitry Andric   int MissedEdges = 0;
1474bdd1243dSDimitry Andric   for (auto *A : Collection) {
1475bdd1243dSDimitry Andric     SUnit *B = &SU;
1476bdd1243dSDimitry Andric     if (A == B || A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1477bdd1243dSDimitry Andric       continue;
1478bdd1243dSDimitry Andric     if (MakePred)
1479bdd1243dSDimitry Andric       std::swap(A, B);
1480bdd1243dSDimitry Andric 
1481bdd1243dSDimitry Andric     if (DAG->IsReachable(B, A))
1482bdd1243dSDimitry Andric       continue;
1483*06c3fb27SDimitry Andric 
1484bdd1243dSDimitry Andric     // tryAddEdge returns false if there is a dependency that makes adding
1485bdd1243dSDimitry Andric     // the A->B edge impossible, otherwise it returns true;
1486bdd1243dSDimitry Andric     bool Added = tryAddEdge(A, B);
1487bdd1243dSDimitry Andric     if (Added)
1488bdd1243dSDimitry Andric       AddedEdges.push_back(std::pair(A, B));
1489bdd1243dSDimitry Andric     else
1490bdd1243dSDimitry Andric       ++MissedEdges;
1491bdd1243dSDimitry Andric   }
1492bdd1243dSDimitry Andric 
1493bdd1243dSDimitry Andric   return MissedEdges;
1494bdd1243dSDimitry Andric }
1495bdd1243dSDimitry Andric 
1496bdd1243dSDimitry Andric void SchedGroup::link(SUnit &SU, bool MakePred) {
1497bdd1243dSDimitry Andric   for (auto *A : Collection) {
1498bdd1243dSDimitry Andric     SUnit *B = &SU;
1499bdd1243dSDimitry Andric     if (A->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER)
1500bdd1243dSDimitry Andric       continue;
1501bdd1243dSDimitry Andric     if (MakePred)
1502bdd1243dSDimitry Andric       std::swap(A, B);
1503bdd1243dSDimitry Andric 
1504bdd1243dSDimitry Andric     tryAddEdge(A, B);
1505bdd1243dSDimitry Andric   }
1506bdd1243dSDimitry Andric }
1507bdd1243dSDimitry Andric 
1508bdd1243dSDimitry Andric void SchedGroup::link(SUnit &SU,
1509bdd1243dSDimitry Andric                       function_ref<bool(const SUnit *A, const SUnit *B)> P) {
1510bdd1243dSDimitry Andric   for (auto *A : Collection) {
1511bdd1243dSDimitry Andric     SUnit *B = &SU;
1512bdd1243dSDimitry Andric     if (P(A, B))
1513bdd1243dSDimitry Andric       std::swap(A, B);
1514bdd1243dSDimitry Andric 
1515bdd1243dSDimitry Andric     tryAddEdge(A, B);
1516bdd1243dSDimitry Andric   }
1517bdd1243dSDimitry Andric }
1518bdd1243dSDimitry Andric 
1519bdd1243dSDimitry Andric void SchedGroup::link(SchedGroup &OtherGroup) {
1520bdd1243dSDimitry Andric   for (auto *B : OtherGroup.Collection)
1521bdd1243dSDimitry Andric     link(*B);
1522bdd1243dSDimitry Andric }
1523bdd1243dSDimitry Andric 
1524bdd1243dSDimitry Andric bool SchedGroup::canAddSU(SUnit &SU) const {
1525bdd1243dSDimitry Andric   MachineInstr &MI = *SU.getInstr();
1526bdd1243dSDimitry Andric   if (MI.getOpcode() != TargetOpcode::BUNDLE)
1527bdd1243dSDimitry Andric     return canAddMI(MI);
1528bdd1243dSDimitry Andric 
1529bdd1243dSDimitry Andric   // Special case for bundled MIs.
1530bdd1243dSDimitry Andric   const MachineBasicBlock *MBB = MI.getParent();
1531bdd1243dSDimitry Andric   MachineBasicBlock::instr_iterator B = MI.getIterator(), E = ++B;
1532bdd1243dSDimitry Andric   while (E != MBB->end() && E->isBundledWithPred())
1533bdd1243dSDimitry Andric     ++E;
1534bdd1243dSDimitry Andric 
1535bdd1243dSDimitry Andric   // Return true if all of the bundled MIs can be added to this group.
1536bdd1243dSDimitry Andric   return std::all_of(B, E, [this](MachineInstr &MI) { return canAddMI(MI); });
1537bdd1243dSDimitry Andric }
1538bdd1243dSDimitry Andric 
1539bdd1243dSDimitry Andric void SchedGroup::initSchedGroup() {
1540bdd1243dSDimitry Andric   for (auto &SU : DAG->SUnits) {
1541bdd1243dSDimitry Andric     if (isFull())
1542bdd1243dSDimitry Andric       break;
1543bdd1243dSDimitry Andric 
1544bdd1243dSDimitry Andric     if (canAddSU(SU))
1545bdd1243dSDimitry Andric       add(SU);
1546bdd1243dSDimitry Andric   }
1547bdd1243dSDimitry Andric }
1548bdd1243dSDimitry Andric 
1549bdd1243dSDimitry Andric void SchedGroup::initSchedGroup(std::vector<SUnit>::reverse_iterator RIter,
1550bdd1243dSDimitry Andric                                 SUnitsToCandidateSGsMap &SyncedInstrs) {
1551bdd1243dSDimitry Andric   SUnit &InitSU = *RIter;
1552bdd1243dSDimitry Andric   for (auto E = DAG->SUnits.rend(); RIter != E; ++RIter) {
1553bdd1243dSDimitry Andric     auto &SU = *RIter;
1554bdd1243dSDimitry Andric     if (isFull())
1555bdd1243dSDimitry Andric       break;
1556bdd1243dSDimitry Andric 
1557bdd1243dSDimitry Andric     if (canAddSU(SU))
1558bdd1243dSDimitry Andric       SyncedInstrs[&SU].push_back(SGID);
1559bdd1243dSDimitry Andric   }
1560bdd1243dSDimitry Andric 
1561bdd1243dSDimitry Andric   add(InitSU);
1562bdd1243dSDimitry Andric   assert(MaxSize);
1563bdd1243dSDimitry Andric   (*MaxSize)++;
1564bdd1243dSDimitry Andric }
1565bdd1243dSDimitry Andric 
1566bdd1243dSDimitry Andric void SchedGroup::initSchedGroup(SUnitsToCandidateSGsMap &SyncedInstrs) {
1567bdd1243dSDimitry Andric   auto I = DAG->SUnits.rbegin();
1568bdd1243dSDimitry Andric   auto E = DAG->SUnits.rend();
1569bdd1243dSDimitry Andric   for (; I != E; ++I) {
1570bdd1243dSDimitry Andric     auto &SU = *I;
1571bdd1243dSDimitry Andric     if (isFull())
1572bdd1243dSDimitry Andric       break;
1573bdd1243dSDimitry Andric 
1574bdd1243dSDimitry Andric     if (canAddSU(SU))
1575bdd1243dSDimitry Andric       SyncedInstrs[&SU].push_back(SGID);
1576bdd1243dSDimitry Andric   }
1577bdd1243dSDimitry Andric }
1578bdd1243dSDimitry Andric 
1579bdd1243dSDimitry Andric void IGroupLPDAGMutation::apply(ScheduleDAGInstrs *DAGInstrs) {
158081ad6265SDimitry Andric   const TargetSchedModel *TSchedModel = DAGInstrs->getSchedModel();
158181ad6265SDimitry Andric   if (!TSchedModel || DAGInstrs->SUnits.empty())
158281ad6265SDimitry Andric     return;
158381ad6265SDimitry Andric 
1584bdd1243dSDimitry Andric   LLVM_DEBUG(dbgs() << "Applying IGroupLPDAGMutation...\n");
158581ad6265SDimitry Andric   const GCNSubtarget &ST = DAGInstrs->MF.getSubtarget<GCNSubtarget>();
158681ad6265SDimitry Andric   TII = ST.getInstrInfo();
158781ad6265SDimitry Andric   DAG = static_cast<ScheduleDAGMI *>(DAGInstrs);
1588bdd1243dSDimitry Andric   SyncedSchedGroups.clear();
1589bdd1243dSDimitry Andric   SyncedInstrs.clear();
1590bdd1243dSDimitry Andric   bool foundSB = false;
1591bdd1243dSDimitry Andric   bool foundIGLP = false;
1592bdd1243dSDimitry Andric   for (auto R = DAG->SUnits.rbegin(), E = DAG->SUnits.rend(); R != E; ++R) {
1593bdd1243dSDimitry Andric     unsigned Opc = R->getInstr()->getOpcode();
1594bdd1243dSDimitry Andric     // SCHED_[GROUP_]BARRIER and IGLP are mutually exclusive.
1595bdd1243dSDimitry Andric     if (Opc == AMDGPU::SCHED_BARRIER) {
1596bdd1243dSDimitry Andric       addSchedBarrierEdges(*R);
1597bdd1243dSDimitry Andric       foundSB = true;
1598bdd1243dSDimitry Andric     } else if (Opc == AMDGPU::SCHED_GROUP_BARRIER) {
1599bdd1243dSDimitry Andric       initSchedGroupBarrierPipelineStage(R);
1600bdd1243dSDimitry Andric       foundSB = true;
1601bdd1243dSDimitry Andric     } else if (Opc == AMDGPU::IGLP_OPT) {
1602bdd1243dSDimitry Andric       resetEdges(*R, DAG);
1603bdd1243dSDimitry Andric       if (!foundSB && !foundIGLP)
1604bdd1243dSDimitry Andric         initIGLPOpt(*R);
1605bdd1243dSDimitry Andric       foundIGLP = true;
1606bdd1243dSDimitry Andric     }
160781ad6265SDimitry Andric   }
160881ad6265SDimitry Andric 
1609bdd1243dSDimitry Andric   if (foundSB || foundIGLP) {
1610*06c3fb27SDimitry Andric     PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, IsBottomUp);
1611bdd1243dSDimitry Andric     // PipelineSolver performs the mutation by adding the edges it
1612bdd1243dSDimitry Andric     // determined as the best
1613bdd1243dSDimitry Andric     PS.solve();
1614bdd1243dSDimitry Andric   }
1615bdd1243dSDimitry Andric }
1616bdd1243dSDimitry Andric 
1617bdd1243dSDimitry Andric void IGroupLPDAGMutation::addSchedBarrierEdges(SUnit &SchedBarrier) {
161881ad6265SDimitry Andric   MachineInstr &MI = *SchedBarrier.getInstr();
161981ad6265SDimitry Andric   assert(MI.getOpcode() == AMDGPU::SCHED_BARRIER);
162081ad6265SDimitry Andric   // Remove all existing edges from the SCHED_BARRIER that were added due to the
162181ad6265SDimitry Andric   // instruction having side effects.
1622bdd1243dSDimitry Andric   resetEdges(SchedBarrier, DAG);
1623bdd1243dSDimitry Andric   auto InvertedMask =
1624bdd1243dSDimitry Andric       invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm());
1625bdd1243dSDimitry Andric   SchedGroup SG(InvertedMask, std::nullopt, DAG, TII);
1626bdd1243dSDimitry Andric   SG.initSchedGroup();
1627bdd1243dSDimitry Andric   // Preserve original instruction ordering relative to the SCHED_BARRIER.
1628bdd1243dSDimitry Andric   SG.link(
1629bdd1243dSDimitry Andric       SchedBarrier,
1630bdd1243dSDimitry Andric       (function_ref<bool(const SUnit *A, const SUnit *B)>)[](
1631bdd1243dSDimitry Andric           const SUnit *A, const SUnit *B) { return A->NodeNum > B->NodeNum; });
163281ad6265SDimitry Andric }
163381ad6265SDimitry Andric 
1634bdd1243dSDimitry Andric SchedGroupMask
1635bdd1243dSDimitry Andric IGroupLPDAGMutation::invertSchedBarrierMask(SchedGroupMask Mask) const {
1636bdd1243dSDimitry Andric   // Invert mask and erase bits for types of instructions that are implied to be
1637bdd1243dSDimitry Andric   // allowed past the SCHED_BARRIER.
1638bdd1243dSDimitry Andric   SchedGroupMask InvertedMask = ~Mask;
1639bdd1243dSDimitry Andric 
1640bdd1243dSDimitry Andric   // ALU implies VALU, SALU, MFMA.
1641bdd1243dSDimitry Andric   if ((InvertedMask & SchedGroupMask::ALU) == SchedGroupMask::NONE)
1642bdd1243dSDimitry Andric     InvertedMask &=
1643bdd1243dSDimitry Andric         ~SchedGroupMask::VALU & ~SchedGroupMask::SALU & ~SchedGroupMask::MFMA;
1644bdd1243dSDimitry Andric   // VALU, SALU, MFMA implies ALU.
1645bdd1243dSDimitry Andric   else if ((InvertedMask & SchedGroupMask::VALU) == SchedGroupMask::NONE ||
1646bdd1243dSDimitry Andric            (InvertedMask & SchedGroupMask::SALU) == SchedGroupMask::NONE ||
1647bdd1243dSDimitry Andric            (InvertedMask & SchedGroupMask::MFMA) == SchedGroupMask::NONE)
1648bdd1243dSDimitry Andric     InvertedMask &= ~SchedGroupMask::ALU;
1649bdd1243dSDimitry Andric 
1650bdd1243dSDimitry Andric   // VMEM implies VMEM_READ, VMEM_WRITE.
1651bdd1243dSDimitry Andric   if ((InvertedMask & SchedGroupMask::VMEM) == SchedGroupMask::NONE)
1652bdd1243dSDimitry Andric     InvertedMask &= ~SchedGroupMask::VMEM_READ & ~SchedGroupMask::VMEM_WRITE;
1653bdd1243dSDimitry Andric   // VMEM_READ, VMEM_WRITE implies VMEM.
1654bdd1243dSDimitry Andric   else if ((InvertedMask & SchedGroupMask::VMEM_READ) == SchedGroupMask::NONE ||
1655bdd1243dSDimitry Andric            (InvertedMask & SchedGroupMask::VMEM_WRITE) == SchedGroupMask::NONE)
1656bdd1243dSDimitry Andric     InvertedMask &= ~SchedGroupMask::VMEM;
1657bdd1243dSDimitry Andric 
1658bdd1243dSDimitry Andric   // DS implies DS_READ, DS_WRITE.
1659bdd1243dSDimitry Andric   if ((InvertedMask & SchedGroupMask::DS) == SchedGroupMask::NONE)
1660bdd1243dSDimitry Andric     InvertedMask &= ~SchedGroupMask::DS_READ & ~SchedGroupMask::DS_WRITE;
1661bdd1243dSDimitry Andric   // DS_READ, DS_WRITE implies DS.
1662bdd1243dSDimitry Andric   else if ((InvertedMask & SchedGroupMask::DS_READ) == SchedGroupMask::NONE ||
1663bdd1243dSDimitry Andric            (InvertedMask & SchedGroupMask::DS_WRITE) == SchedGroupMask::NONE)
1664bdd1243dSDimitry Andric     InvertedMask &= ~SchedGroupMask::DS;
1665bdd1243dSDimitry Andric 
1666bdd1243dSDimitry Andric   return InvertedMask;
166781ad6265SDimitry Andric }
166881ad6265SDimitry Andric 
1669bdd1243dSDimitry Andric void IGroupLPDAGMutation::initSchedGroupBarrierPipelineStage(
1670bdd1243dSDimitry Andric     std::vector<SUnit>::reverse_iterator RIter) {
1671bdd1243dSDimitry Andric   // Remove all existing edges from the SCHED_GROUP_BARRIER that were added due
1672bdd1243dSDimitry Andric   // to the instruction having side effects.
1673bdd1243dSDimitry Andric   resetEdges(*RIter, DAG);
1674bdd1243dSDimitry Andric   MachineInstr &SGB = *RIter->getInstr();
1675bdd1243dSDimitry Andric   assert(SGB.getOpcode() == AMDGPU::SCHED_GROUP_BARRIER);
1676bdd1243dSDimitry Andric   int32_t SGMask = SGB.getOperand(0).getImm();
1677bdd1243dSDimitry Andric   int32_t Size = SGB.getOperand(1).getImm();
1678bdd1243dSDimitry Andric   int32_t SyncID = SGB.getOperand(2).getImm();
1679bdd1243dSDimitry Andric 
1680bdd1243dSDimitry Andric   auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask,
1681bdd1243dSDimitry Andric                                                     Size, SyncID, DAG, TII);
1682bdd1243dSDimitry Andric 
1683bdd1243dSDimitry Andric   SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]);
168481ad6265SDimitry Andric }
168581ad6265SDimitry Andric 
1686bdd1243dSDimitry Andric void IGroupLPDAGMutation::initIGLPOpt(SUnit &SU) {
1687bdd1243dSDimitry Andric   IGLPStrategyID StrategyID =
1688bdd1243dSDimitry Andric       (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm();
1689bdd1243dSDimitry Andric   auto S = createIGLPStrategy(StrategyID, DAG, TII);
1690*06c3fb27SDimitry Andric   if (S->shouldApplyStrategy(DAG)) {
1691*06c3fb27SDimitry Andric     IsBottomUp = S->IsBottomUp;
1692bdd1243dSDimitry Andric     S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups);
169381ad6265SDimitry Andric   }
1694*06c3fb27SDimitry Andric }
169581ad6265SDimitry Andric 
169681ad6265SDimitry Andric } // namespace
169781ad6265SDimitry Andric 
169881ad6265SDimitry Andric namespace llvm {
169981ad6265SDimitry Andric 
170081ad6265SDimitry Andric std::unique_ptr<ScheduleDAGMutation> createIGroupLPDAGMutation() {
1701bdd1243dSDimitry Andric   return std::make_unique<IGroupLPDAGMutation>();
170281ad6265SDimitry Andric }
170381ad6265SDimitry Andric 
170481ad6265SDimitry Andric } // end namespace llvm
1705