1*b00e0c16SChristian Ulmann //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2*b00e0c16SChristian Ulmann //
3*b00e0c16SChristian Ulmann // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*b00e0c16SChristian Ulmann // See https://llvm.org/LICENSE.txt for license information.
5*b00e0c16SChristian Ulmann // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*b00e0c16SChristian Ulmann //
7*b00e0c16SChristian Ulmann //===----------------------------------------------------------------------===//
8*b00e0c16SChristian Ulmann
9*b00e0c16SChristian Ulmann #include "mlir/Analysis/TopologicalSortUtils.h"
10*b00e0c16SChristian Ulmann #include "mlir/IR/Block.h"
11*b00e0c16SChristian Ulmann #include "mlir/IR/OpDefinition.h"
12*b00e0c16SChristian Ulmann #include "mlir/IR/RegionGraphTraits.h"
13*b00e0c16SChristian Ulmann
14*b00e0c16SChristian Ulmann #include "llvm/ADT/PostOrderIterator.h"
15*b00e0c16SChristian Ulmann #include "llvm/ADT/SetVector.h"
16*b00e0c16SChristian Ulmann
17*b00e0c16SChristian Ulmann using namespace mlir;
18*b00e0c16SChristian Ulmann
19*b00e0c16SChristian Ulmann /// Return `true` if the given operation is ready to be scheduled.
isOpReady(Operation * op,DenseSet<Operation * > & unscheduledOps,function_ref<bool (Value,Operation *)> isOperandReady)20*b00e0c16SChristian Ulmann static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21*b00e0c16SChristian Ulmann function_ref<bool(Value, Operation *)> isOperandReady) {
22*b00e0c16SChristian Ulmann // An operation is ready to be scheduled if all its operands are ready. An
23*b00e0c16SChristian Ulmann // operation is ready if:
24*b00e0c16SChristian Ulmann const auto isReady = [&](Value value) {
25*b00e0c16SChristian Ulmann // - the user-provided callback marks it as ready,
26*b00e0c16SChristian Ulmann if (isOperandReady && isOperandReady(value, op))
27*b00e0c16SChristian Ulmann return true;
28*b00e0c16SChristian Ulmann Operation *parent = value.getDefiningOp();
29*b00e0c16SChristian Ulmann // - it is a block argument,
30*b00e0c16SChristian Ulmann if (!parent)
31*b00e0c16SChristian Ulmann return true;
32*b00e0c16SChristian Ulmann // - or it is not defined by an unscheduled op (and also not nested within
33*b00e0c16SChristian Ulmann // an unscheduled op).
34*b00e0c16SChristian Ulmann do {
35*b00e0c16SChristian Ulmann // Stop traversal when op under examination is reached.
36*b00e0c16SChristian Ulmann if (parent == op)
37*b00e0c16SChristian Ulmann return true;
38*b00e0c16SChristian Ulmann if (unscheduledOps.contains(parent))
39*b00e0c16SChristian Ulmann return false;
40*b00e0c16SChristian Ulmann } while ((parent = parent->getParentOp()));
41*b00e0c16SChristian Ulmann // No unscheduled op found.
42*b00e0c16SChristian Ulmann return true;
43*b00e0c16SChristian Ulmann };
44*b00e0c16SChristian Ulmann
45*b00e0c16SChristian Ulmann // An operation is recursively ready to be scheduled of it and its nested
46*b00e0c16SChristian Ulmann // operations are ready.
47*b00e0c16SChristian Ulmann WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48*b00e0c16SChristian Ulmann return llvm::all_of(nestedOp->getOperands(),
49*b00e0c16SChristian Ulmann [&](Value operand) { return isReady(operand); })
50*b00e0c16SChristian Ulmann ? WalkResult::advance()
51*b00e0c16SChristian Ulmann : WalkResult::interrupt();
52*b00e0c16SChristian Ulmann });
53*b00e0c16SChristian Ulmann return !readyToSchedule.wasInterrupted();
54*b00e0c16SChristian Ulmann }
55*b00e0c16SChristian Ulmann
sortTopologically(Block * block,llvm::iterator_range<Block::iterator> ops,function_ref<bool (Value,Operation *)> isOperandReady)56*b00e0c16SChristian Ulmann bool mlir::sortTopologically(
57*b00e0c16SChristian Ulmann Block *block, llvm::iterator_range<Block::iterator> ops,
58*b00e0c16SChristian Ulmann function_ref<bool(Value, Operation *)> isOperandReady) {
59*b00e0c16SChristian Ulmann if (ops.empty())
60*b00e0c16SChristian Ulmann return true;
61*b00e0c16SChristian Ulmann
62*b00e0c16SChristian Ulmann // The set of operations that have not yet been scheduled.
63*b00e0c16SChristian Ulmann DenseSet<Operation *> unscheduledOps;
64*b00e0c16SChristian Ulmann // Mark all operations as unscheduled.
65*b00e0c16SChristian Ulmann for (Operation &op : ops)
66*b00e0c16SChristian Ulmann unscheduledOps.insert(&op);
67*b00e0c16SChristian Ulmann
68*b00e0c16SChristian Ulmann Block::iterator nextScheduledOp = ops.begin();
69*b00e0c16SChristian Ulmann Block::iterator end = ops.end();
70*b00e0c16SChristian Ulmann
71*b00e0c16SChristian Ulmann bool allOpsScheduled = true;
72*b00e0c16SChristian Ulmann while (!unscheduledOps.empty()) {
73*b00e0c16SChristian Ulmann bool scheduledAtLeastOnce = false;
74*b00e0c16SChristian Ulmann
75*b00e0c16SChristian Ulmann // Loop over the ops that are not sorted yet, try to find the ones "ready",
76*b00e0c16SChristian Ulmann // i.e. the ones for which there aren't any operand produced by an op in the
77*b00e0c16SChristian Ulmann // set, and "schedule" it (move it before the `nextScheduledOp`).
78*b00e0c16SChristian Ulmann for (Operation &op :
79*b00e0c16SChristian Ulmann llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80*b00e0c16SChristian Ulmann if (!isOpReady(&op, unscheduledOps, isOperandReady))
81*b00e0c16SChristian Ulmann continue;
82*b00e0c16SChristian Ulmann
83*b00e0c16SChristian Ulmann // Schedule the operation by moving it to the start.
84*b00e0c16SChristian Ulmann unscheduledOps.erase(&op);
85*b00e0c16SChristian Ulmann op.moveBefore(block, nextScheduledOp);
86*b00e0c16SChristian Ulmann scheduledAtLeastOnce = true;
87*b00e0c16SChristian Ulmann // Move the iterator forward if we schedule the operation at the front.
88*b00e0c16SChristian Ulmann if (&op == &*nextScheduledOp)
89*b00e0c16SChristian Ulmann ++nextScheduledOp;
90*b00e0c16SChristian Ulmann }
91*b00e0c16SChristian Ulmann // If no operations were scheduled, give up and advance the iterator.
92*b00e0c16SChristian Ulmann if (!scheduledAtLeastOnce) {
93*b00e0c16SChristian Ulmann allOpsScheduled = false;
94*b00e0c16SChristian Ulmann unscheduledOps.erase(&*nextScheduledOp);
95*b00e0c16SChristian Ulmann ++nextScheduledOp;
96*b00e0c16SChristian Ulmann }
97*b00e0c16SChristian Ulmann }
98*b00e0c16SChristian Ulmann
99*b00e0c16SChristian Ulmann return allOpsScheduled;
100*b00e0c16SChristian Ulmann }
101*b00e0c16SChristian Ulmann
sortTopologically(Block * block,function_ref<bool (Value,Operation *)> isOperandReady)102*b00e0c16SChristian Ulmann bool mlir::sortTopologically(
103*b00e0c16SChristian Ulmann Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104*b00e0c16SChristian Ulmann if (block->empty())
105*b00e0c16SChristian Ulmann return true;
106*b00e0c16SChristian Ulmann if (block->back().hasTrait<OpTrait::IsTerminator>())
107*b00e0c16SChristian Ulmann return sortTopologically(block, block->without_terminator(),
108*b00e0c16SChristian Ulmann isOperandReady);
109*b00e0c16SChristian Ulmann return sortTopologically(block, *block, isOperandReady);
110*b00e0c16SChristian Ulmann }
111*b00e0c16SChristian Ulmann
computeTopologicalSorting(MutableArrayRef<Operation * > ops,function_ref<bool (Value,Operation *)> isOperandReady)112*b00e0c16SChristian Ulmann bool mlir::computeTopologicalSorting(
113*b00e0c16SChristian Ulmann MutableArrayRef<Operation *> ops,
114*b00e0c16SChristian Ulmann function_ref<bool(Value, Operation *)> isOperandReady) {
115*b00e0c16SChristian Ulmann if (ops.empty())
116*b00e0c16SChristian Ulmann return true;
117*b00e0c16SChristian Ulmann
118*b00e0c16SChristian Ulmann // The set of operations that have not yet been scheduled.
119*b00e0c16SChristian Ulmann DenseSet<Operation *> unscheduledOps;
120*b00e0c16SChristian Ulmann
121*b00e0c16SChristian Ulmann // Mark all operations as unscheduled.
122*b00e0c16SChristian Ulmann for (Operation *op : ops)
123*b00e0c16SChristian Ulmann unscheduledOps.insert(op);
124*b00e0c16SChristian Ulmann
125*b00e0c16SChristian Ulmann unsigned nextScheduledOp = 0;
126*b00e0c16SChristian Ulmann
127*b00e0c16SChristian Ulmann bool allOpsScheduled = true;
128*b00e0c16SChristian Ulmann while (!unscheduledOps.empty()) {
129*b00e0c16SChristian Ulmann bool scheduledAtLeastOnce = false;
130*b00e0c16SChristian Ulmann
131*b00e0c16SChristian Ulmann // Loop over the ops that are not sorted yet, try to find the ones "ready",
132*b00e0c16SChristian Ulmann // i.e. the ones for which there aren't any operand produced by an op in the
133*b00e0c16SChristian Ulmann // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
134*b00e0c16SChristian Ulmann for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
135*b00e0c16SChristian Ulmann if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
136*b00e0c16SChristian Ulmann continue;
137*b00e0c16SChristian Ulmann
138*b00e0c16SChristian Ulmann // Schedule the operation by moving it to the start.
139*b00e0c16SChristian Ulmann unscheduledOps.erase(ops[i]);
140*b00e0c16SChristian Ulmann std::swap(ops[i], ops[nextScheduledOp]);
141*b00e0c16SChristian Ulmann scheduledAtLeastOnce = true;
142*b00e0c16SChristian Ulmann ++nextScheduledOp;
143*b00e0c16SChristian Ulmann }
144*b00e0c16SChristian Ulmann
145*b00e0c16SChristian Ulmann // If no operations were scheduled, just schedule the first op and continue.
146*b00e0c16SChristian Ulmann if (!scheduledAtLeastOnce) {
147*b00e0c16SChristian Ulmann allOpsScheduled = false;
148*b00e0c16SChristian Ulmann unscheduledOps.erase(ops[nextScheduledOp++]);
149*b00e0c16SChristian Ulmann }
150*b00e0c16SChristian Ulmann }
151*b00e0c16SChristian Ulmann
152*b00e0c16SChristian Ulmann return allOpsScheduled;
153*b00e0c16SChristian Ulmann }
154*b00e0c16SChristian Ulmann
getBlocksSortedByDominance(Region & region)155*b00e0c16SChristian Ulmann SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
156*b00e0c16SChristian Ulmann // For each block that has not been visited yet (i.e. that has no
157*b00e0c16SChristian Ulmann // predecessors), add it to the list as well as its successors.
158*b00e0c16SChristian Ulmann SetVector<Block *> blocks;
159*b00e0c16SChristian Ulmann for (Block &b : region) {
160*b00e0c16SChristian Ulmann if (blocks.count(&b) == 0) {
161*b00e0c16SChristian Ulmann llvm::ReversePostOrderTraversal<Block *> traversal(&b);
162*b00e0c16SChristian Ulmann blocks.insert(traversal.begin(), traversal.end());
163*b00e0c16SChristian Ulmann }
164*b00e0c16SChristian Ulmann }
165*b00e0c16SChristian Ulmann assert(blocks.size() == region.getBlocks().size() &&
166*b00e0c16SChristian Ulmann "some blocks are not sorted");
167*b00e0c16SChristian Ulmann
168*b00e0c16SChristian Ulmann return blocks;
169*b00e0c16SChristian Ulmann }
170*b00e0c16SChristian Ulmann
171*b00e0c16SChristian Ulmann namespace {
172*b00e0c16SChristian Ulmann class TopoSortHelper {
173*b00e0c16SChristian Ulmann public:
TopoSortHelper(const SetVector<Operation * > & toSort)174*b00e0c16SChristian Ulmann explicit TopoSortHelper(const SetVector<Operation *> &toSort)
175*b00e0c16SChristian Ulmann : toSort(toSort) {}
176*b00e0c16SChristian Ulmann
177*b00e0c16SChristian Ulmann /// Executes the topological sort of the operations this instance was
178*b00e0c16SChristian Ulmann /// constructed with. This function will destroy the internal state of the
179*b00e0c16SChristian Ulmann /// instance.
sort()180*b00e0c16SChristian Ulmann SetVector<Operation *> sort() {
181*b00e0c16SChristian Ulmann if (toSort.size() <= 1) {
182*b00e0c16SChristian Ulmann // Note: Creates a copy on purpose.
183*b00e0c16SChristian Ulmann return toSort;
184*b00e0c16SChristian Ulmann }
185*b00e0c16SChristian Ulmann
186*b00e0c16SChristian Ulmann // First, find the root region to start the traversal through the IR. This
187*b00e0c16SChristian Ulmann // additionally enriches the internal caches with all relevant ancestor
188*b00e0c16SChristian Ulmann // regions and blocks.
189*b00e0c16SChristian Ulmann Region *rootRegion = findCommonAncestorRegion();
190*b00e0c16SChristian Ulmann assert(rootRegion && "expected all ops to have a common ancestor");
191*b00e0c16SChristian Ulmann
192*b00e0c16SChristian Ulmann // Sort all elements in `toSort` by traversing the IR in the appropriate
193*b00e0c16SChristian Ulmann // order.
194*b00e0c16SChristian Ulmann SetVector<Operation *> result = topoSortRegion(*rootRegion);
195*b00e0c16SChristian Ulmann assert(result.size() == toSort.size() &&
196*b00e0c16SChristian Ulmann "expected all operations to be present in the result");
197*b00e0c16SChristian Ulmann return result;
198*b00e0c16SChristian Ulmann }
199*b00e0c16SChristian Ulmann
200*b00e0c16SChristian Ulmann private:
201*b00e0c16SChristian Ulmann /// Computes the closest common ancestor region of all operations in `toSort`.
findCommonAncestorRegion()202*b00e0c16SChristian Ulmann Region *findCommonAncestorRegion() {
203*b00e0c16SChristian Ulmann // Map to count the number of times a region was encountered.
204*b00e0c16SChristian Ulmann DenseMap<Region *, size_t> regionCounts;
205*b00e0c16SChristian Ulmann size_t expectedCount = toSort.size();
206*b00e0c16SChristian Ulmann
207*b00e0c16SChristian Ulmann // Walk the region tree for each operation towards the root and add to the
208*b00e0c16SChristian Ulmann // region count.
209*b00e0c16SChristian Ulmann Region *res = nullptr;
210*b00e0c16SChristian Ulmann for (Operation *op : toSort) {
211*b00e0c16SChristian Ulmann Region *current = op->getParentRegion();
212*b00e0c16SChristian Ulmann // Store the block as an ancestor block.
213*b00e0c16SChristian Ulmann ancestorBlocks.insert(op->getBlock());
214*b00e0c16SChristian Ulmann while (current) {
215*b00e0c16SChristian Ulmann // Insert or update the count and compare it.
216*b00e0c16SChristian Ulmann if (++regionCounts[current] == expectedCount) {
217*b00e0c16SChristian Ulmann res = current;
218*b00e0c16SChristian Ulmann break;
219*b00e0c16SChristian Ulmann }
220*b00e0c16SChristian Ulmann ancestorBlocks.insert(current->getParentOp()->getBlock());
221*b00e0c16SChristian Ulmann current = current->getParentRegion();
222*b00e0c16SChristian Ulmann }
223*b00e0c16SChristian Ulmann }
224*b00e0c16SChristian Ulmann auto firstRange = llvm::make_first_range(regionCounts);
225*b00e0c16SChristian Ulmann ancestorRegions.insert(firstRange.begin(), firstRange.end());
226*b00e0c16SChristian Ulmann return res;
227*b00e0c16SChristian Ulmann }
228*b00e0c16SChristian Ulmann
229*b00e0c16SChristian Ulmann /// Performs the dominance respecting IR walk to collect the topological order
230*b00e0c16SChristian Ulmann /// of the operation to sort.
topoSortRegion(Region & rootRegion)231*b00e0c16SChristian Ulmann SetVector<Operation *> topoSortRegion(Region &rootRegion) {
232*b00e0c16SChristian Ulmann using StackT = PointerUnion<Region *, Block *, Operation *>;
233*b00e0c16SChristian Ulmann
234*b00e0c16SChristian Ulmann SetVector<Operation *> result;
235*b00e0c16SChristian Ulmann // Stack that stores the different IR constructs to traverse.
236*b00e0c16SChristian Ulmann SmallVector<StackT> stack;
237*b00e0c16SChristian Ulmann stack.push_back(&rootRegion);
238*b00e0c16SChristian Ulmann
239*b00e0c16SChristian Ulmann // Traverse the IR in a dominance respecting pre-order walk.
240*b00e0c16SChristian Ulmann while (!stack.empty()) {
241*b00e0c16SChristian Ulmann StackT current = stack.pop_back_val();
242*b00e0c16SChristian Ulmann if (auto *region = dyn_cast<Region *>(current)) {
243*b00e0c16SChristian Ulmann // A region's blocks need to be traversed in dominance order.
244*b00e0c16SChristian Ulmann SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
245*b00e0c16SChristian Ulmann for (Block *block : llvm::reverse(sortedBlocks)) {
246*b00e0c16SChristian Ulmann // Only add blocks to the stack that are ancestors of the operations
247*b00e0c16SChristian Ulmann // to sort.
248*b00e0c16SChristian Ulmann if (ancestorBlocks.contains(block))
249*b00e0c16SChristian Ulmann stack.push_back(block);
250*b00e0c16SChristian Ulmann }
251*b00e0c16SChristian Ulmann continue;
252*b00e0c16SChristian Ulmann }
253*b00e0c16SChristian Ulmann
254*b00e0c16SChristian Ulmann if (auto *block = dyn_cast<Block *>(current)) {
255*b00e0c16SChristian Ulmann // Add all of the blocks operations to the stack.
256*b00e0c16SChristian Ulmann for (Operation &op : llvm::reverse(*block))
257*b00e0c16SChristian Ulmann stack.push_back(&op);
258*b00e0c16SChristian Ulmann continue;
259*b00e0c16SChristian Ulmann }
260*b00e0c16SChristian Ulmann
261*b00e0c16SChristian Ulmann auto *op = cast<Operation *>(current);
262*b00e0c16SChristian Ulmann if (toSort.contains(op))
263*b00e0c16SChristian Ulmann result.insert(op);
264*b00e0c16SChristian Ulmann
265*b00e0c16SChristian Ulmann // Add all the subregions that are ancestors of the operations to sort.
266*b00e0c16SChristian Ulmann for (Region &subRegion : op->getRegions())
267*b00e0c16SChristian Ulmann if (ancestorRegions.contains(&subRegion))
268*b00e0c16SChristian Ulmann stack.push_back(&subRegion);
269*b00e0c16SChristian Ulmann }
270*b00e0c16SChristian Ulmann return result;
271*b00e0c16SChristian Ulmann }
272*b00e0c16SChristian Ulmann
273*b00e0c16SChristian Ulmann /// Operations to sort.
274*b00e0c16SChristian Ulmann const SetVector<Operation *> &toSort;
275*b00e0c16SChristian Ulmann /// Set containing all the ancestor regions of the operations to sort.
276*b00e0c16SChristian Ulmann DenseSet<Region *> ancestorRegions;
277*b00e0c16SChristian Ulmann /// Set containing all the ancestor blocks of the operations to sort.
278*b00e0c16SChristian Ulmann DenseSet<Block *> ancestorBlocks;
279*b00e0c16SChristian Ulmann };
280*b00e0c16SChristian Ulmann } // namespace
281*b00e0c16SChristian Ulmann
282*b00e0c16SChristian Ulmann SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)283*b00e0c16SChristian Ulmann mlir::topologicalSort(const SetVector<Operation *> &toSort) {
284*b00e0c16SChristian Ulmann return TopoSortHelper(toSort).sort();
285*b00e0c16SChristian Ulmann }
286