1 //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/TopologicalSortUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/OpDefinition.h"
12 #include "mlir/IR/RegionGraphTraits.h"
13
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SetVector.h"
16
17 using namespace mlir;
18
19 /// Return `true` if the given operation is ready to be scheduled.
isOpReady(Operation * op,DenseSet<Operation * > & unscheduledOps,function_ref<bool (Value,Operation *)> isOperandReady)20 static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21 function_ref<bool(Value, Operation *)> isOperandReady) {
22 // An operation is ready to be scheduled if all its operands are ready. An
23 // operation is ready if:
24 const auto isReady = [&](Value value) {
25 // - the user-provided callback marks it as ready,
26 if (isOperandReady && isOperandReady(value, op))
27 return true;
28 Operation *parent = value.getDefiningOp();
29 // - it is a block argument,
30 if (!parent)
31 return true;
32 // - or it is not defined by an unscheduled op (and also not nested within
33 // an unscheduled op).
34 do {
35 // Stop traversal when op under examination is reached.
36 if (parent == op)
37 return true;
38 if (unscheduledOps.contains(parent))
39 return false;
40 } while ((parent = parent->getParentOp()));
41 // No unscheduled op found.
42 return true;
43 };
44
45 // An operation is recursively ready to be scheduled of it and its nested
46 // operations are ready.
47 WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48 return llvm::all_of(nestedOp->getOperands(),
49 [&](Value operand) { return isReady(operand); })
50 ? WalkResult::advance()
51 : WalkResult::interrupt();
52 });
53 return !readyToSchedule.wasInterrupted();
54 }
55
sortTopologically(Block * block,llvm::iterator_range<Block::iterator> ops,function_ref<bool (Value,Operation *)> isOperandReady)56 bool mlir::sortTopologically(
57 Block *block, llvm::iterator_range<Block::iterator> ops,
58 function_ref<bool(Value, Operation *)> isOperandReady) {
59 if (ops.empty())
60 return true;
61
62 // The set of operations that have not yet been scheduled.
63 DenseSet<Operation *> unscheduledOps;
64 // Mark all operations as unscheduled.
65 for (Operation &op : ops)
66 unscheduledOps.insert(&op);
67
68 Block::iterator nextScheduledOp = ops.begin();
69 Block::iterator end = ops.end();
70
71 bool allOpsScheduled = true;
72 while (!unscheduledOps.empty()) {
73 bool scheduledAtLeastOnce = false;
74
75 // Loop over the ops that are not sorted yet, try to find the ones "ready",
76 // i.e. the ones for which there aren't any operand produced by an op in the
77 // set, and "schedule" it (move it before the `nextScheduledOp`).
78 for (Operation &op :
79 llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80 if (!isOpReady(&op, unscheduledOps, isOperandReady))
81 continue;
82
83 // Schedule the operation by moving it to the start.
84 unscheduledOps.erase(&op);
85 op.moveBefore(block, nextScheduledOp);
86 scheduledAtLeastOnce = true;
87 // Move the iterator forward if we schedule the operation at the front.
88 if (&op == &*nextScheduledOp)
89 ++nextScheduledOp;
90 }
91 // If no operations were scheduled, give up and advance the iterator.
92 if (!scheduledAtLeastOnce) {
93 allOpsScheduled = false;
94 unscheduledOps.erase(&*nextScheduledOp);
95 ++nextScheduledOp;
96 }
97 }
98
99 return allOpsScheduled;
100 }
101
sortTopologically(Block * block,function_ref<bool (Value,Operation *)> isOperandReady)102 bool mlir::sortTopologically(
103 Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104 if (block->empty())
105 return true;
106 if (block->back().hasTrait<OpTrait::IsTerminator>())
107 return sortTopologically(block, block->without_terminator(),
108 isOperandReady);
109 return sortTopologically(block, *block, isOperandReady);
110 }
111
computeTopologicalSorting(MutableArrayRef<Operation * > ops,function_ref<bool (Value,Operation *)> isOperandReady)112 bool mlir::computeTopologicalSorting(
113 MutableArrayRef<Operation *> ops,
114 function_ref<bool(Value, Operation *)> isOperandReady) {
115 if (ops.empty())
116 return true;
117
118 // The set of operations that have not yet been scheduled.
119 DenseSet<Operation *> unscheduledOps;
120
121 // Mark all operations as unscheduled.
122 for (Operation *op : ops)
123 unscheduledOps.insert(op);
124
125 unsigned nextScheduledOp = 0;
126
127 bool allOpsScheduled = true;
128 while (!unscheduledOps.empty()) {
129 bool scheduledAtLeastOnce = false;
130
131 // Loop over the ops that are not sorted yet, try to find the ones "ready",
132 // i.e. the ones for which there aren't any operand produced by an op in the
133 // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
134 for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
135 if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
136 continue;
137
138 // Schedule the operation by moving it to the start.
139 unscheduledOps.erase(ops[i]);
140 std::swap(ops[i], ops[nextScheduledOp]);
141 scheduledAtLeastOnce = true;
142 ++nextScheduledOp;
143 }
144
145 // If no operations were scheduled, just schedule the first op and continue.
146 if (!scheduledAtLeastOnce) {
147 allOpsScheduled = false;
148 unscheduledOps.erase(ops[nextScheduledOp++]);
149 }
150 }
151
152 return allOpsScheduled;
153 }
154
getBlocksSortedByDominance(Region & region)155 SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
156 // For each block that has not been visited yet (i.e. that has no
157 // predecessors), add it to the list as well as its successors.
158 SetVector<Block *> blocks;
159 for (Block &b : region) {
160 if (blocks.count(&b) == 0) {
161 llvm::ReversePostOrderTraversal<Block *> traversal(&b);
162 blocks.insert(traversal.begin(), traversal.end());
163 }
164 }
165 assert(blocks.size() == region.getBlocks().size() &&
166 "some blocks are not sorted");
167
168 return blocks;
169 }
170
171 namespace {
172 class TopoSortHelper {
173 public:
TopoSortHelper(const SetVector<Operation * > & toSort)174 explicit TopoSortHelper(const SetVector<Operation *> &toSort)
175 : toSort(toSort) {}
176
177 /// Executes the topological sort of the operations this instance was
178 /// constructed with. This function will destroy the internal state of the
179 /// instance.
sort()180 SetVector<Operation *> sort() {
181 if (toSort.size() <= 1) {
182 // Note: Creates a copy on purpose.
183 return toSort;
184 }
185
186 // First, find the root region to start the traversal through the IR. This
187 // additionally enriches the internal caches with all relevant ancestor
188 // regions and blocks.
189 Region *rootRegion = findCommonAncestorRegion();
190 assert(rootRegion && "expected all ops to have a common ancestor");
191
192 // Sort all elements in `toSort` by traversing the IR in the appropriate
193 // order.
194 SetVector<Operation *> result = topoSortRegion(*rootRegion);
195 assert(result.size() == toSort.size() &&
196 "expected all operations to be present in the result");
197 return result;
198 }
199
200 private:
201 /// Computes the closest common ancestor region of all operations in `toSort`.
findCommonAncestorRegion()202 Region *findCommonAncestorRegion() {
203 // Map to count the number of times a region was encountered.
204 DenseMap<Region *, size_t> regionCounts;
205 size_t expectedCount = toSort.size();
206
207 // Walk the region tree for each operation towards the root and add to the
208 // region count.
209 Region *res = nullptr;
210 for (Operation *op : toSort) {
211 Region *current = op->getParentRegion();
212 // Store the block as an ancestor block.
213 ancestorBlocks.insert(op->getBlock());
214 while (current) {
215 // Insert or update the count and compare it.
216 if (++regionCounts[current] == expectedCount) {
217 res = current;
218 break;
219 }
220 ancestorBlocks.insert(current->getParentOp()->getBlock());
221 current = current->getParentRegion();
222 }
223 }
224 auto firstRange = llvm::make_first_range(regionCounts);
225 ancestorRegions.insert(firstRange.begin(), firstRange.end());
226 return res;
227 }
228
229 /// Performs the dominance respecting IR walk to collect the topological order
230 /// of the operation to sort.
topoSortRegion(Region & rootRegion)231 SetVector<Operation *> topoSortRegion(Region &rootRegion) {
232 using StackT = PointerUnion<Region *, Block *, Operation *>;
233
234 SetVector<Operation *> result;
235 // Stack that stores the different IR constructs to traverse.
236 SmallVector<StackT> stack;
237 stack.push_back(&rootRegion);
238
239 // Traverse the IR in a dominance respecting pre-order walk.
240 while (!stack.empty()) {
241 StackT current = stack.pop_back_val();
242 if (auto *region = dyn_cast<Region *>(current)) {
243 // A region's blocks need to be traversed in dominance order.
244 SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
245 for (Block *block : llvm::reverse(sortedBlocks)) {
246 // Only add blocks to the stack that are ancestors of the operations
247 // to sort.
248 if (ancestorBlocks.contains(block))
249 stack.push_back(block);
250 }
251 continue;
252 }
253
254 if (auto *block = dyn_cast<Block *>(current)) {
255 // Add all of the blocks operations to the stack.
256 for (Operation &op : llvm::reverse(*block))
257 stack.push_back(&op);
258 continue;
259 }
260
261 auto *op = cast<Operation *>(current);
262 if (toSort.contains(op))
263 result.insert(op);
264
265 // Add all the subregions that are ancestors of the operations to sort.
266 for (Region &subRegion : op->getRegions())
267 if (ancestorRegions.contains(&subRegion))
268 stack.push_back(&subRegion);
269 }
270 return result;
271 }
272
273 /// Operations to sort.
274 const SetVector<Operation *> &toSort;
275 /// Set containing all the ancestor regions of the operations to sort.
276 DenseSet<Region *> ancestorRegions;
277 /// Set containing all the ancestor blocks of the operations to sort.
278 DenseSet<Block *> ancestorBlocks;
279 };
280 } // namespace
281
282 SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)283 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
284 return TopoSortHelper(toSort).sort();
285 }
286