xref: /llvm-project/mlir/lib/Analysis/TopologicalSortUtils.cpp (revision b00e0c167186d69e1e6bceda57c09b272bd6acfc)
1*b00e0c16SChristian Ulmann //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2*b00e0c16SChristian Ulmann //
3*b00e0c16SChristian Ulmann // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*b00e0c16SChristian Ulmann // See https://llvm.org/LICENSE.txt for license information.
5*b00e0c16SChristian Ulmann // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*b00e0c16SChristian Ulmann //
7*b00e0c16SChristian Ulmann //===----------------------------------------------------------------------===//
8*b00e0c16SChristian Ulmann 
9*b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
10*b00e0c16SChristian Ulmann #include "mlir/IR/Block.h"
11*b00e0c16SChristian Ulmann #include "mlir/IR/OpDefinition.h"
12*b00e0c16SChristian Ulmann #include "mlir/IR/RegionGraphTraits.h"
13*b00e0c16SChristian Ulmann 
14*b00e0c16SChristian Ulmann #include "llvm/ADT/PostOrderIterator.h"
15*b00e0c16SChristian Ulmann #include "llvm/ADT/SetVector.h"
16*b00e0c16SChristian Ulmann 
17*b00e0c16SChristian Ulmann using namespace mlir;
18*b00e0c16SChristian Ulmann 
19*b00e0c16SChristian Ulmann /// Return `true` if the given operation is ready to be scheduled.
isOpReady(Operation * op,DenseSet<Operation * > & unscheduledOps,function_ref<bool (Value,Operation *)> isOperandReady)20*b00e0c16SChristian Ulmann static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21*b00e0c16SChristian Ulmann                       function_ref<bool(Value, Operation *)> isOperandReady) {
22*b00e0c16SChristian Ulmann   // An operation is ready to be scheduled if all its operands are ready. An
23*b00e0c16SChristian Ulmann   // operation is ready if:
24*b00e0c16SChristian Ulmann   const auto isReady = [&](Value value) {
25*b00e0c16SChristian Ulmann     // - the user-provided callback marks it as ready,
26*b00e0c16SChristian Ulmann     if (isOperandReady && isOperandReady(value, op))
27*b00e0c16SChristian Ulmann       return true;
28*b00e0c16SChristian Ulmann     Operation *parent = value.getDefiningOp();
29*b00e0c16SChristian Ulmann     // - it is a block argument,
30*b00e0c16SChristian Ulmann     if (!parent)
31*b00e0c16SChristian Ulmann       return true;
32*b00e0c16SChristian Ulmann     // - or it is not defined by an unscheduled op (and also not nested within
33*b00e0c16SChristian Ulmann     //   an unscheduled op).
34*b00e0c16SChristian Ulmann     do {
35*b00e0c16SChristian Ulmann       // Stop traversal when op under examination is reached.
36*b00e0c16SChristian Ulmann       if (parent == op)
37*b00e0c16SChristian Ulmann         return true;
38*b00e0c16SChristian Ulmann       if (unscheduledOps.contains(parent))
39*b00e0c16SChristian Ulmann         return false;
40*b00e0c16SChristian Ulmann     } while ((parent = parent->getParentOp()));
41*b00e0c16SChristian Ulmann     // No unscheduled op found.
42*b00e0c16SChristian Ulmann     return true;
43*b00e0c16SChristian Ulmann   };
44*b00e0c16SChristian Ulmann 
45*b00e0c16SChristian Ulmann   // An operation is recursively ready to be scheduled of it and its nested
46*b00e0c16SChristian Ulmann   // operations are ready.
47*b00e0c16SChristian Ulmann   WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48*b00e0c16SChristian Ulmann     return llvm::all_of(nestedOp->getOperands(),
49*b00e0c16SChristian Ulmann                         [&](Value operand) { return isReady(operand); })
50*b00e0c16SChristian Ulmann                ? WalkResult::advance()
51*b00e0c16SChristian Ulmann                : WalkResult::interrupt();
52*b00e0c16SChristian Ulmann   });
53*b00e0c16SChristian Ulmann   return !readyToSchedule.wasInterrupted();
54*b00e0c16SChristian Ulmann }
55*b00e0c16SChristian Ulmann 
sortTopologically(Block * block,llvm::iterator_range<Block::iterator> ops,function_ref<bool (Value,Operation *)> isOperandReady)56*b00e0c16SChristian Ulmann bool mlir::sortTopologically(
57*b00e0c16SChristian Ulmann     Block *block, llvm::iterator_range<Block::iterator> ops,
58*b00e0c16SChristian Ulmann     function_ref<bool(Value, Operation *)> isOperandReady) {
59*b00e0c16SChristian Ulmann   if (ops.empty())
60*b00e0c16SChristian Ulmann     return true;
61*b00e0c16SChristian Ulmann 
62*b00e0c16SChristian Ulmann   // The set of operations that have not yet been scheduled.
63*b00e0c16SChristian Ulmann   DenseSet<Operation *> unscheduledOps;
64*b00e0c16SChristian Ulmann   // Mark all operations as unscheduled.
65*b00e0c16SChristian Ulmann   for (Operation &op : ops)
66*b00e0c16SChristian Ulmann     unscheduledOps.insert(&op);
67*b00e0c16SChristian Ulmann 
68*b00e0c16SChristian Ulmann   Block::iterator nextScheduledOp = ops.begin();
69*b00e0c16SChristian Ulmann   Block::iterator end = ops.end();
70*b00e0c16SChristian Ulmann 
71*b00e0c16SChristian Ulmann   bool allOpsScheduled = true;
72*b00e0c16SChristian Ulmann   while (!unscheduledOps.empty()) {
73*b00e0c16SChristian Ulmann     bool scheduledAtLeastOnce = false;
74*b00e0c16SChristian Ulmann 
75*b00e0c16SChristian Ulmann     // Loop over the ops that are not sorted yet, try to find the ones "ready",
76*b00e0c16SChristian Ulmann     // i.e. the ones for which there aren't any operand produced by an op in the
77*b00e0c16SChristian Ulmann     // set, and "schedule" it (move it before the `nextScheduledOp`).
78*b00e0c16SChristian Ulmann     for (Operation &op :
79*b00e0c16SChristian Ulmann          llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80*b00e0c16SChristian Ulmann       if (!isOpReady(&op, unscheduledOps, isOperandReady))
81*b00e0c16SChristian Ulmann         continue;
82*b00e0c16SChristian Ulmann 
83*b00e0c16SChristian Ulmann       // Schedule the operation by moving it to the start.
84*b00e0c16SChristian Ulmann       unscheduledOps.erase(&op);
85*b00e0c16SChristian Ulmann       op.moveBefore(block, nextScheduledOp);
86*b00e0c16SChristian Ulmann       scheduledAtLeastOnce = true;
87*b00e0c16SChristian Ulmann       // Move the iterator forward if we schedule the operation at the front.
88*b00e0c16SChristian Ulmann       if (&op == &*nextScheduledOp)
89*b00e0c16SChristian Ulmann         ++nextScheduledOp;
90*b00e0c16SChristian Ulmann     }
91*b00e0c16SChristian Ulmann     // If no operations were scheduled, give up and advance the iterator.
92*b00e0c16SChristian Ulmann     if (!scheduledAtLeastOnce) {
93*b00e0c16SChristian Ulmann       allOpsScheduled = false;
94*b00e0c16SChristian Ulmann       unscheduledOps.erase(&*nextScheduledOp);
95*b00e0c16SChristian Ulmann       ++nextScheduledOp;
96*b00e0c16SChristian Ulmann     }
97*b00e0c16SChristian Ulmann   }
98*b00e0c16SChristian Ulmann 
99*b00e0c16SChristian Ulmann   return allOpsScheduled;
100*b00e0c16SChristian Ulmann }
101*b00e0c16SChristian Ulmann 
sortTopologically(Block * block,function_ref<bool (Value,Operation *)> isOperandReady)102*b00e0c16SChristian Ulmann bool mlir::sortTopologically(
103*b00e0c16SChristian Ulmann     Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104*b00e0c16SChristian Ulmann   if (block->empty())
105*b00e0c16SChristian Ulmann     return true;
106*b00e0c16SChristian Ulmann   if (block->back().hasTrait<OpTrait::IsTerminator>())
107*b00e0c16SChristian Ulmann     return sortTopologically(block, block->without_terminator(),
108*b00e0c16SChristian Ulmann                              isOperandReady);
109*b00e0c16SChristian Ulmann   return sortTopologically(block, *block, isOperandReady);
110*b00e0c16SChristian Ulmann }
111*b00e0c16SChristian Ulmann 
computeTopologicalSorting(MutableArrayRef<Operation * > ops,function_ref<bool (Value,Operation *)> isOperandReady)112*b00e0c16SChristian Ulmann bool mlir::computeTopologicalSorting(
113*b00e0c16SChristian Ulmann     MutableArrayRef<Operation *> ops,
114*b00e0c16SChristian Ulmann     function_ref<bool(Value, Operation *)> isOperandReady) {
115*b00e0c16SChristian Ulmann   if (ops.empty())
116*b00e0c16SChristian Ulmann     return true;
117*b00e0c16SChristian Ulmann 
118*b00e0c16SChristian Ulmann   // The set of operations that have not yet been scheduled.
119*b00e0c16SChristian Ulmann   DenseSet<Operation *> unscheduledOps;
120*b00e0c16SChristian Ulmann 
121*b00e0c16SChristian Ulmann   // Mark all operations as unscheduled.
122*b00e0c16SChristian Ulmann   for (Operation *op : ops)
123*b00e0c16SChristian Ulmann     unscheduledOps.insert(op);
124*b00e0c16SChristian Ulmann 
125*b00e0c16SChristian Ulmann   unsigned nextScheduledOp = 0;
126*b00e0c16SChristian Ulmann 
127*b00e0c16SChristian Ulmann   bool allOpsScheduled = true;
128*b00e0c16SChristian Ulmann   while (!unscheduledOps.empty()) {
129*b00e0c16SChristian Ulmann     bool scheduledAtLeastOnce = false;
130*b00e0c16SChristian Ulmann 
131*b00e0c16SChristian Ulmann     // Loop over the ops that are not sorted yet, try to find the ones "ready",
132*b00e0c16SChristian Ulmann     // i.e. the ones for which there aren't any operand produced by an op in the
133*b00e0c16SChristian Ulmann     // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
134*b00e0c16SChristian Ulmann     for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
135*b00e0c16SChristian Ulmann       if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
136*b00e0c16SChristian Ulmann         continue;
137*b00e0c16SChristian Ulmann 
138*b00e0c16SChristian Ulmann       // Schedule the operation by moving it to the start.
139*b00e0c16SChristian Ulmann       unscheduledOps.erase(ops[i]);
140*b00e0c16SChristian Ulmann       std::swap(ops[i], ops[nextScheduledOp]);
141*b00e0c16SChristian Ulmann       scheduledAtLeastOnce = true;
142*b00e0c16SChristian Ulmann       ++nextScheduledOp;
143*b00e0c16SChristian Ulmann     }
144*b00e0c16SChristian Ulmann 
145*b00e0c16SChristian Ulmann     // If no operations were scheduled, just schedule the first op and continue.
146*b00e0c16SChristian Ulmann     if (!scheduledAtLeastOnce) {
147*b00e0c16SChristian Ulmann       allOpsScheduled = false;
148*b00e0c16SChristian Ulmann       unscheduledOps.erase(ops[nextScheduledOp++]);
149*b00e0c16SChristian Ulmann     }
150*b00e0c16SChristian Ulmann   }
151*b00e0c16SChristian Ulmann 
152*b00e0c16SChristian Ulmann   return allOpsScheduled;
153*b00e0c16SChristian Ulmann }
154*b00e0c16SChristian Ulmann 
getBlocksSortedByDominance(Region & region)155*b00e0c16SChristian Ulmann SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
156*b00e0c16SChristian Ulmann   // For each block that has not been visited yet (i.e. that has no
157*b00e0c16SChristian Ulmann   // predecessors), add it to the list as well as its successors.
158*b00e0c16SChristian Ulmann   SetVector<Block *> blocks;
159*b00e0c16SChristian Ulmann   for (Block &b : region) {
160*b00e0c16SChristian Ulmann     if (blocks.count(&b) == 0) {
161*b00e0c16SChristian Ulmann       llvm::ReversePostOrderTraversal<Block *> traversal(&b);
162*b00e0c16SChristian Ulmann       blocks.insert(traversal.begin(), traversal.end());
163*b00e0c16SChristian Ulmann     }
164*b00e0c16SChristian Ulmann   }
165*b00e0c16SChristian Ulmann   assert(blocks.size() == region.getBlocks().size() &&
166*b00e0c16SChristian Ulmann          "some blocks are not sorted");
167*b00e0c16SChristian Ulmann 
168*b00e0c16SChristian Ulmann   return blocks;
169*b00e0c16SChristian Ulmann }
170*b00e0c16SChristian Ulmann 
171*b00e0c16SChristian Ulmann namespace {
172*b00e0c16SChristian Ulmann class TopoSortHelper {
173*b00e0c16SChristian Ulmann public:
TopoSortHelper(const SetVector<Operation * > & toSort)174*b00e0c16SChristian Ulmann   explicit TopoSortHelper(const SetVector<Operation *> &toSort)
175*b00e0c16SChristian Ulmann       : toSort(toSort) {}
176*b00e0c16SChristian Ulmann 
177*b00e0c16SChristian Ulmann   /// Executes the topological sort of the operations this instance was
178*b00e0c16SChristian Ulmann   /// constructed with. This function will destroy the internal state of the
179*b00e0c16SChristian Ulmann   /// instance.
sort()180*b00e0c16SChristian Ulmann   SetVector<Operation *> sort() {
181*b00e0c16SChristian Ulmann     if (toSort.size() <= 1) {
182*b00e0c16SChristian Ulmann       // Note: Creates a copy on purpose.
183*b00e0c16SChristian Ulmann       return toSort;
184*b00e0c16SChristian Ulmann     }
185*b00e0c16SChristian Ulmann 
186*b00e0c16SChristian Ulmann     // First, find the root region to start the traversal through the IR. This
187*b00e0c16SChristian Ulmann     // additionally enriches the internal caches with all relevant ancestor
188*b00e0c16SChristian Ulmann     // regions and blocks.
189*b00e0c16SChristian Ulmann     Region *rootRegion = findCommonAncestorRegion();
190*b00e0c16SChristian Ulmann     assert(rootRegion && "expected all ops to have a common ancestor");
191*b00e0c16SChristian Ulmann 
192*b00e0c16SChristian Ulmann     // Sort all elements in `toSort` by traversing the IR in the appropriate
193*b00e0c16SChristian Ulmann     // order.
194*b00e0c16SChristian Ulmann     SetVector<Operation *> result = topoSortRegion(*rootRegion);
195*b00e0c16SChristian Ulmann     assert(result.size() == toSort.size() &&
196*b00e0c16SChristian Ulmann            "expected all operations to be present in the result");
197*b00e0c16SChristian Ulmann     return result;
198*b00e0c16SChristian Ulmann   }
199*b00e0c16SChristian Ulmann 
200*b00e0c16SChristian Ulmann private:
201*b00e0c16SChristian Ulmann   /// Computes the closest common ancestor region of all operations in `toSort`.
findCommonAncestorRegion()202*b00e0c16SChristian Ulmann   Region *findCommonAncestorRegion() {
203*b00e0c16SChristian Ulmann     // Map to count the number of times a region was encountered.
204*b00e0c16SChristian Ulmann     DenseMap<Region *, size_t> regionCounts;
205*b00e0c16SChristian Ulmann     size_t expectedCount = toSort.size();
206*b00e0c16SChristian Ulmann 
207*b00e0c16SChristian Ulmann     // Walk the region tree for each operation towards the root and add to the
208*b00e0c16SChristian Ulmann     // region count.
209*b00e0c16SChristian Ulmann     Region *res = nullptr;
210*b00e0c16SChristian Ulmann     for (Operation *op : toSort) {
211*b00e0c16SChristian Ulmann       Region *current = op->getParentRegion();
212*b00e0c16SChristian Ulmann       // Store the block as an ancestor block.
213*b00e0c16SChristian Ulmann       ancestorBlocks.insert(op->getBlock());
214*b00e0c16SChristian Ulmann       while (current) {
215*b00e0c16SChristian Ulmann         // Insert or update the count and compare it.
216*b00e0c16SChristian Ulmann         if (++regionCounts[current] == expectedCount) {
217*b00e0c16SChristian Ulmann           res = current;
218*b00e0c16SChristian Ulmann           break;
219*b00e0c16SChristian Ulmann         }
220*b00e0c16SChristian Ulmann         ancestorBlocks.insert(current->getParentOp()->getBlock());
221*b00e0c16SChristian Ulmann         current = current->getParentRegion();
222*b00e0c16SChristian Ulmann       }
223*b00e0c16SChristian Ulmann     }
224*b00e0c16SChristian Ulmann     auto firstRange = llvm::make_first_range(regionCounts);
225*b00e0c16SChristian Ulmann     ancestorRegions.insert(firstRange.begin(), firstRange.end());
226*b00e0c16SChristian Ulmann     return res;
227*b00e0c16SChristian Ulmann   }
228*b00e0c16SChristian Ulmann 
229*b00e0c16SChristian Ulmann   /// Performs the dominance respecting IR walk to collect the topological order
230*b00e0c16SChristian Ulmann   /// of the operation to sort.
topoSortRegion(Region & rootRegion)231*b00e0c16SChristian Ulmann   SetVector<Operation *> topoSortRegion(Region &rootRegion) {
232*b00e0c16SChristian Ulmann     using StackT = PointerUnion<Region *, Block *, Operation *>;
233*b00e0c16SChristian Ulmann 
234*b00e0c16SChristian Ulmann     SetVector<Operation *> result;
235*b00e0c16SChristian Ulmann     // Stack that stores the different IR constructs to traverse.
236*b00e0c16SChristian Ulmann     SmallVector<StackT> stack;
237*b00e0c16SChristian Ulmann     stack.push_back(&rootRegion);
238*b00e0c16SChristian Ulmann 
239*b00e0c16SChristian Ulmann     // Traverse the IR in a dominance respecting pre-order walk.
240*b00e0c16SChristian Ulmann     while (!stack.empty()) {
241*b00e0c16SChristian Ulmann       StackT current = stack.pop_back_val();
242*b00e0c16SChristian Ulmann       if (auto *region = dyn_cast<Region *>(current)) {
243*b00e0c16SChristian Ulmann         // A region's blocks need to be traversed in dominance order.
244*b00e0c16SChristian Ulmann         SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
245*b00e0c16SChristian Ulmann         for (Block *block : llvm::reverse(sortedBlocks)) {
246*b00e0c16SChristian Ulmann           // Only add blocks to the stack that are ancestors of the operations
247*b00e0c16SChristian Ulmann           // to sort.
248*b00e0c16SChristian Ulmann           if (ancestorBlocks.contains(block))
249*b00e0c16SChristian Ulmann             stack.push_back(block);
250*b00e0c16SChristian Ulmann         }
251*b00e0c16SChristian Ulmann         continue;
252*b00e0c16SChristian Ulmann       }
253*b00e0c16SChristian Ulmann 
254*b00e0c16SChristian Ulmann       if (auto *block = dyn_cast<Block *>(current)) {
255*b00e0c16SChristian Ulmann         // Add all of the blocks operations to the stack.
256*b00e0c16SChristian Ulmann         for (Operation &op : llvm::reverse(*block))
257*b00e0c16SChristian Ulmann           stack.push_back(&op);
258*b00e0c16SChristian Ulmann         continue;
259*b00e0c16SChristian Ulmann       }
260*b00e0c16SChristian Ulmann 
261*b00e0c16SChristian Ulmann       auto *op = cast<Operation *>(current);
262*b00e0c16SChristian Ulmann       if (toSort.contains(op))
263*b00e0c16SChristian Ulmann         result.insert(op);
264*b00e0c16SChristian Ulmann 
265*b00e0c16SChristian Ulmann       // Add all the subregions that are ancestors of the operations to sort.
266*b00e0c16SChristian Ulmann       for (Region &subRegion : op->getRegions())
267*b00e0c16SChristian Ulmann         if (ancestorRegions.contains(&subRegion))
268*b00e0c16SChristian Ulmann           stack.push_back(&subRegion);
269*b00e0c16SChristian Ulmann     }
270*b00e0c16SChristian Ulmann     return result;
271*b00e0c16SChristian Ulmann   }
272*b00e0c16SChristian Ulmann 
273*b00e0c16SChristian Ulmann   /// Operations to sort.
274*b00e0c16SChristian Ulmann   const SetVector<Operation *> &toSort;
275*b00e0c16SChristian Ulmann   /// Set containing all the ancestor regions of the operations to sort.
276*b00e0c16SChristian Ulmann   DenseSet<Region *> ancestorRegions;
277*b00e0c16SChristian Ulmann   /// Set containing all the ancestor blocks of the operations to sort.
278*b00e0c16SChristian Ulmann   DenseSet<Block *> ancestorBlocks;
279*b00e0c16SChristian Ulmann };
280*b00e0c16SChristian Ulmann } // namespace
281*b00e0c16SChristian Ulmann 
282*b00e0c16SChristian Ulmann SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)283*b00e0c16SChristian Ulmann mlir::topologicalSort(const SetVector<Operation *> &toSort) {
284*b00e0c16SChristian Ulmann   return TopoSortHelper(toSort).sort();
285*b00e0c16SChristian Ulmann }
286