xref: /llvm-project/mlir/lib/Analysis/TopologicalSortUtils.cpp (revision b00e0c167186d69e1e6bceda57c09b272bd6acfc)
1 //===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/TopologicalSortUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/OpDefinition.h"
12 #include "mlir/IR/RegionGraphTraits.h"
13 
14 #include "llvm/ADT/PostOrderIterator.h"
15 #include "llvm/ADT/SetVector.h"
16 
17 using namespace mlir;
18 
19 /// Return `true` if the given operation is ready to be scheduled.
isOpReady(Operation * op,DenseSet<Operation * > & unscheduledOps,function_ref<bool (Value,Operation *)> isOperandReady)20 static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
21                       function_ref<bool(Value, Operation *)> isOperandReady) {
22   // An operation is ready to be scheduled if all its operands are ready. An
23   // operation is ready if:
24   const auto isReady = [&](Value value) {
25     // - the user-provided callback marks it as ready,
26     if (isOperandReady && isOperandReady(value, op))
27       return true;
28     Operation *parent = value.getDefiningOp();
29     // - it is a block argument,
30     if (!parent)
31       return true;
32     // - or it is not defined by an unscheduled op (and also not nested within
33     //   an unscheduled op).
34     do {
35       // Stop traversal when op under examination is reached.
36       if (parent == op)
37         return true;
38       if (unscheduledOps.contains(parent))
39         return false;
40     } while ((parent = parent->getParentOp()));
41     // No unscheduled op found.
42     return true;
43   };
44 
45   // An operation is recursively ready to be scheduled of it and its nested
46   // operations are ready.
47   WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
48     return llvm::all_of(nestedOp->getOperands(),
49                         [&](Value operand) { return isReady(operand); })
50                ? WalkResult::advance()
51                : WalkResult::interrupt();
52   });
53   return !readyToSchedule.wasInterrupted();
54 }
55 
sortTopologically(Block * block,llvm::iterator_range<Block::iterator> ops,function_ref<bool (Value,Operation *)> isOperandReady)56 bool mlir::sortTopologically(
57     Block *block, llvm::iterator_range<Block::iterator> ops,
58     function_ref<bool(Value, Operation *)> isOperandReady) {
59   if (ops.empty())
60     return true;
61 
62   // The set of operations that have not yet been scheduled.
63   DenseSet<Operation *> unscheduledOps;
64   // Mark all operations as unscheduled.
65   for (Operation &op : ops)
66     unscheduledOps.insert(&op);
67 
68   Block::iterator nextScheduledOp = ops.begin();
69   Block::iterator end = ops.end();
70 
71   bool allOpsScheduled = true;
72   while (!unscheduledOps.empty()) {
73     bool scheduledAtLeastOnce = false;
74 
75     // Loop over the ops that are not sorted yet, try to find the ones "ready",
76     // i.e. the ones for which there aren't any operand produced by an op in the
77     // set, and "schedule" it (move it before the `nextScheduledOp`).
78     for (Operation &op :
79          llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
80       if (!isOpReady(&op, unscheduledOps, isOperandReady))
81         continue;
82 
83       // Schedule the operation by moving it to the start.
84       unscheduledOps.erase(&op);
85       op.moveBefore(block, nextScheduledOp);
86       scheduledAtLeastOnce = true;
87       // Move the iterator forward if we schedule the operation at the front.
88       if (&op == &*nextScheduledOp)
89         ++nextScheduledOp;
90     }
91     // If no operations were scheduled, give up and advance the iterator.
92     if (!scheduledAtLeastOnce) {
93       allOpsScheduled = false;
94       unscheduledOps.erase(&*nextScheduledOp);
95       ++nextScheduledOp;
96     }
97   }
98 
99   return allOpsScheduled;
100 }
101 
sortTopologically(Block * block,function_ref<bool (Value,Operation *)> isOperandReady)102 bool mlir::sortTopologically(
103     Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
104   if (block->empty())
105     return true;
106   if (block->back().hasTrait<OpTrait::IsTerminator>())
107     return sortTopologically(block, block->without_terminator(),
108                              isOperandReady);
109   return sortTopologically(block, *block, isOperandReady);
110 }
111 
computeTopologicalSorting(MutableArrayRef<Operation * > ops,function_ref<bool (Value,Operation *)> isOperandReady)112 bool mlir::computeTopologicalSorting(
113     MutableArrayRef<Operation *> ops,
114     function_ref<bool(Value, Operation *)> isOperandReady) {
115   if (ops.empty())
116     return true;
117 
118   // The set of operations that have not yet been scheduled.
119   DenseSet<Operation *> unscheduledOps;
120 
121   // Mark all operations as unscheduled.
122   for (Operation *op : ops)
123     unscheduledOps.insert(op);
124 
125   unsigned nextScheduledOp = 0;
126 
127   bool allOpsScheduled = true;
128   while (!unscheduledOps.empty()) {
129     bool scheduledAtLeastOnce = false;
130 
131     // Loop over the ops that are not sorted yet, try to find the ones "ready",
132     // i.e. the ones for which there aren't any operand produced by an op in the
133     // set, and "schedule" it (swap it with the op at `nextScheduledOp`).
134     for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
135       if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
136         continue;
137 
138       // Schedule the operation by moving it to the start.
139       unscheduledOps.erase(ops[i]);
140       std::swap(ops[i], ops[nextScheduledOp]);
141       scheduledAtLeastOnce = true;
142       ++nextScheduledOp;
143     }
144 
145     // If no operations were scheduled, just schedule the first op and continue.
146     if (!scheduledAtLeastOnce) {
147       allOpsScheduled = false;
148       unscheduledOps.erase(ops[nextScheduledOp++]);
149     }
150   }
151 
152   return allOpsScheduled;
153 }
154 
getBlocksSortedByDominance(Region & region)155 SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
156   // For each block that has not been visited yet (i.e. that has no
157   // predecessors), add it to the list as well as its successors.
158   SetVector<Block *> blocks;
159   for (Block &b : region) {
160     if (blocks.count(&b) == 0) {
161       llvm::ReversePostOrderTraversal<Block *> traversal(&b);
162       blocks.insert(traversal.begin(), traversal.end());
163     }
164   }
165   assert(blocks.size() == region.getBlocks().size() &&
166          "some blocks are not sorted");
167 
168   return blocks;
169 }
170 
171 namespace {
172 class TopoSortHelper {
173 public:
TopoSortHelper(const SetVector<Operation * > & toSort)174   explicit TopoSortHelper(const SetVector<Operation *> &toSort)
175       : toSort(toSort) {}
176 
177   /// Executes the topological sort of the operations this instance was
178   /// constructed with. This function will destroy the internal state of the
179   /// instance.
sort()180   SetVector<Operation *> sort() {
181     if (toSort.size() <= 1) {
182       // Note: Creates a copy on purpose.
183       return toSort;
184     }
185 
186     // First, find the root region to start the traversal through the IR. This
187     // additionally enriches the internal caches with all relevant ancestor
188     // regions and blocks.
189     Region *rootRegion = findCommonAncestorRegion();
190     assert(rootRegion && "expected all ops to have a common ancestor");
191 
192     // Sort all elements in `toSort` by traversing the IR in the appropriate
193     // order.
194     SetVector<Operation *> result = topoSortRegion(*rootRegion);
195     assert(result.size() == toSort.size() &&
196            "expected all operations to be present in the result");
197     return result;
198   }
199 
200 private:
201   /// Computes the closest common ancestor region of all operations in `toSort`.
findCommonAncestorRegion()202   Region *findCommonAncestorRegion() {
203     // Map to count the number of times a region was encountered.
204     DenseMap<Region *, size_t> regionCounts;
205     size_t expectedCount = toSort.size();
206 
207     // Walk the region tree for each operation towards the root and add to the
208     // region count.
209     Region *res = nullptr;
210     for (Operation *op : toSort) {
211       Region *current = op->getParentRegion();
212       // Store the block as an ancestor block.
213       ancestorBlocks.insert(op->getBlock());
214       while (current) {
215         // Insert or update the count and compare it.
216         if (++regionCounts[current] == expectedCount) {
217           res = current;
218           break;
219         }
220         ancestorBlocks.insert(current->getParentOp()->getBlock());
221         current = current->getParentRegion();
222       }
223     }
224     auto firstRange = llvm::make_first_range(regionCounts);
225     ancestorRegions.insert(firstRange.begin(), firstRange.end());
226     return res;
227   }
228 
229   /// Performs the dominance respecting IR walk to collect the topological order
230   /// of the operation to sort.
topoSortRegion(Region & rootRegion)231   SetVector<Operation *> topoSortRegion(Region &rootRegion) {
232     using StackT = PointerUnion<Region *, Block *, Operation *>;
233 
234     SetVector<Operation *> result;
235     // Stack that stores the different IR constructs to traverse.
236     SmallVector<StackT> stack;
237     stack.push_back(&rootRegion);
238 
239     // Traverse the IR in a dominance respecting pre-order walk.
240     while (!stack.empty()) {
241       StackT current = stack.pop_back_val();
242       if (auto *region = dyn_cast<Region *>(current)) {
243         // A region's blocks need to be traversed in dominance order.
244         SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
245         for (Block *block : llvm::reverse(sortedBlocks)) {
246           // Only add blocks to the stack that are ancestors of the operations
247           // to sort.
248           if (ancestorBlocks.contains(block))
249             stack.push_back(block);
250         }
251         continue;
252       }
253 
254       if (auto *block = dyn_cast<Block *>(current)) {
255         // Add all of the blocks operations to the stack.
256         for (Operation &op : llvm::reverse(*block))
257           stack.push_back(&op);
258         continue;
259       }
260 
261       auto *op = cast<Operation *>(current);
262       if (toSort.contains(op))
263         result.insert(op);
264 
265       // Add all the subregions that are ancestors of the operations to sort.
266       for (Region &subRegion : op->getRegions())
267         if (ancestorRegions.contains(&subRegion))
268           stack.push_back(&subRegion);
269     }
270     return result;
271   }
272 
273   /// Operations to sort.
274   const SetVector<Operation *> &toSort;
275   /// Set containing all the ancestor regions of the operations to sort.
276   DenseSet<Region *> ancestorRegions;
277   /// Set containing all the ancestor blocks of the operations to sort.
278   DenseSet<Block *> ancestorBlocks;
279 };
280 } // namespace
281 
282 SetVector<Operation *>
topologicalSort(const SetVector<Operation * > & toSort)283 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
284   return TopoSortHelper(toSort).sort();
285 }
286