1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/STLExtras.h"
18
19 namespace mlir {
20 class Diagnostic;
21 class InFlightDiagnostic;
22 class Operation;
23 class Block;
24 class Region;
25
26 /// A utility result that is used to signal how to proceed with an ongoing walk:
27 /// * Interrupt: the walk will be interrupted and no more operations, regions
28 /// or blocks will be visited.
29 /// * Advance: the walk will continue.
30 /// * Skip: the walk of the current operation, region or block and their
31 /// nested elements that haven't been visited already will be skipped and will
32 /// continue with the next operation, region or block.
33 class WalkResult {
34 enum ResultEnum { Interrupt, Advance, Skip } result;
35
36 public:
result(result)37 WalkResult(ResultEnum result = Advance) : result(result) {}
38
39 /// Allow LogicalResult to interrupt the walk on failure.
WalkResult(LogicalResult result)40 WalkResult(LogicalResult result)
41 : result(failed(result) ? Interrupt : Advance) {}
42
43 /// Allow diagnostics to interrupt the walk.
WalkResult(Diagnostic &&)44 WalkResult(Diagnostic &&) : result(Interrupt) {}
WalkResult(InFlightDiagnostic &&)45 WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
46
47 bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
48 bool operator!=(const WalkResult &rhs) const { return result != rhs.result; }
49
interrupt()50 static WalkResult interrupt() { return {Interrupt}; }
advance()51 static WalkResult advance() { return {Advance}; }
skip()52 static WalkResult skip() { return {Skip}; }
53
54 /// Returns true if the walk was interrupted.
wasInterrupted()55 bool wasInterrupted() const { return result == Interrupt; }
56
57 /// Returns true if the walk was skipped.
wasSkipped()58 bool wasSkipped() const { return result == Skip; }
59 };
60
61 /// Traversal order for region, block and operation walk utilities.
62 enum class WalkOrder { PreOrder, PostOrder };
63
64 /// This iterator enumerates the elements in "forward" order.
65 struct ForwardIterator {
66 /// Make operations iterable: return the list of regions.
67 static MutableArrayRef<Region> makeIterable(Operation &range);
68
69 /// Regions and block are already iterable.
70 template <typename T>
makeIterableForwardIterator71 static constexpr T &makeIterable(T &range) {
72 return range;
73 }
74 };
75
76 /// A utility class to encode the current walk stage for "generic" walkers.
77 /// When walking an operation, we can either choose a Pre/Post order walker
78 /// which invokes the callback on an operation before/after all its attached
79 /// regions have been visited, or choose a "generic" walker where the callback
80 /// is invoked on the operation N+1 times where N is the number of regions
81 /// attached to that operation. The `WalkStage` class below encodes the current
82 /// stage of the walk, i.e., which regions have already been visited, and the
83 /// callback accepts an additional argument for the current stage. Such
84 /// generic walkers that accept stage-aware callbacks are only applicable when
85 /// the callback operates on an operation (i.e., not applicable for callbacks
86 /// on Blocks or Regions).
87 class WalkStage {
88 public:
89 explicit WalkStage(Operation *op);
90
91 /// Return true if parent operation is being visited before all regions.
isBeforeAllRegions()92 bool isBeforeAllRegions() const { return nextRegion == 0; }
93 /// Returns true if parent operation is being visited just before visiting
94 /// region number `region`.
isBeforeRegion(int region)95 bool isBeforeRegion(int region) const { return nextRegion == region; }
96 /// Returns true if parent operation is being visited just after visiting
97 /// region number `region`.
isAfterRegion(int region)98 bool isAfterRegion(int region) const { return nextRegion == region + 1; }
99 /// Return true if parent operation is being visited after all regions.
isAfterAllRegions()100 bool isAfterAllRegions() const { return nextRegion == numRegions; }
101 /// Advance the walk stage.
advance()102 void advance() { nextRegion++; }
103 /// Returns the next region that will be visited.
getNextRegion()104 int getNextRegion() const { return nextRegion; }
105
106 private:
107 const int numRegions;
108 int nextRegion;
109 };
110
111 namespace detail {
112 /// Helper templates to deduce the first argument of a callback parameter.
113 template <typename Ret, typename Arg, typename... Rest>
114 Arg first_argument_type(Ret (*)(Arg, Rest...));
115 template <typename Ret, typename F, typename Arg, typename... Rest>
116 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
117 template <typename Ret, typename F, typename Arg, typename... Rest>
118 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
119 template <typename F>
120 decltype(first_argument_type(&F::operator())) first_argument_type(F);
121
122 /// Type definition of the first argument to the given callable 'T'.
123 template <typename T>
124 using first_argument = decltype(first_argument_type(std::declval<T>()));
125
126 /// Walk all of the regions, blocks, or operations nested under (and including)
127 /// the given operation. The order in which regions, blocks and operations at
128 /// the same nesting level are visited (e.g., lexicographical or reverse
129 /// lexicographical order) is determined by 'Iterator'. The walk order for
130 /// enclosing regions, blocks and operations with respect to their nested ones
131 /// is specified by 'order'. These methods are invoked for void-returning
132 /// callbacks. A callback on a block or operation is allowed to erase that block
133 /// or operation only if the walk is in post-order. See non-void method for
134 /// pre-order erasure.
135 template <typename Iterator>
walk(Operation * op,function_ref<void (Region *)> callback,WalkOrder order)136 void walk(Operation *op, function_ref<void(Region *)> callback,
137 WalkOrder order) {
138 // We don't use early increment for regions because they can't be erased from
139 // a callback.
140 for (auto ®ion : Iterator::makeIterable(*op)) {
141 if (order == WalkOrder::PreOrder)
142 callback(®ion);
143 for (auto &block : Iterator::makeIterable(region)) {
144 for (auto &nestedOp : Iterator::makeIterable(block))
145 walk<Iterator>(&nestedOp, callback, order);
146 }
147 if (order == WalkOrder::PostOrder)
148 callback(®ion);
149 }
150 }
151
152 template <typename Iterator>
walk(Operation * op,function_ref<void (Block *)> callback,WalkOrder order)153 void walk(Operation *op, function_ref<void(Block *)> callback,
154 WalkOrder order) {
155 for (auto ®ion : Iterator::makeIterable(*op)) {
156 // Early increment here in the case where the block is erased.
157 for (auto &block :
158 llvm::make_early_inc_range(Iterator::makeIterable(region))) {
159 if (order == WalkOrder::PreOrder)
160 callback(&block);
161 for (auto &nestedOp : Iterator::makeIterable(block))
162 walk<Iterator>(&nestedOp, callback, order);
163 if (order == WalkOrder::PostOrder)
164 callback(&block);
165 }
166 }
167 }
168
169 template <typename Iterator>
walk(Operation * op,function_ref<void (Operation *)> callback,WalkOrder order)170 void walk(Operation *op, function_ref<void(Operation *)> callback,
171 WalkOrder order) {
172 if (order == WalkOrder::PreOrder)
173 callback(op);
174
175 // TODO: This walk should be iterative over the operations.
176 for (auto ®ion : Iterator::makeIterable(*op)) {
177 for (auto &block : Iterator::makeIterable(region)) {
178 // Early increment here in the case where the operation is erased.
179 for (auto &nestedOp :
180 llvm::make_early_inc_range(Iterator::makeIterable(block)))
181 walk<Iterator>(&nestedOp, callback, order);
182 }
183 }
184
185 if (order == WalkOrder::PostOrder)
186 callback(op);
187 }
188
189 /// Walk all of the regions, blocks, or operations nested under (and including)
190 /// the given operation. The order in which regions, blocks and operations at
191 /// the same nesting level are visited (e.g., lexicographical or reverse
192 /// lexicographical order) is determined by 'Iterator'. The walk order for
193 /// enclosing regions, blocks and operations with respect to their nested ones
194 /// is specified by 'order'. This method is invoked for skippable or
195 /// interruptible callbacks. A callback on a block or operation is allowed to
196 /// erase that block or operation if either:
197 /// * the walk is in post-order, or
198 /// * the walk is in pre-order and the walk is skipped after the erasure.
199 template <typename Iterator>
walk(Operation * op,function_ref<WalkResult (Region *)> callback,WalkOrder order)200 WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
201 WalkOrder order) {
202 // We don't use early increment for regions because they can't be erased from
203 // a callback.
204 for (auto ®ion : Iterator::makeIterable(*op)) {
205 if (order == WalkOrder::PreOrder) {
206 WalkResult result = callback(®ion);
207 if (result.wasSkipped())
208 continue;
209 if (result.wasInterrupted())
210 return WalkResult::interrupt();
211 }
212 for (auto &block : Iterator::makeIterable(region)) {
213 for (auto &nestedOp : Iterator::makeIterable(block))
214 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
215 return WalkResult::interrupt();
216 }
217 if (order == WalkOrder::PostOrder) {
218 if (callback(®ion).wasInterrupted())
219 return WalkResult::interrupt();
220 // We don't check if this region was skipped because its walk already
221 // finished and the walk will continue with the next region.
222 }
223 }
224 return WalkResult::advance();
225 }
226
227 template <typename Iterator>
walk(Operation * op,function_ref<WalkResult (Block *)> callback,WalkOrder order)228 WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
229 WalkOrder order) {
230 for (auto ®ion : Iterator::makeIterable(*op)) {
231 // Early increment here in the case where the block is erased.
232 for (auto &block :
233 llvm::make_early_inc_range(Iterator::makeIterable(region))) {
234 if (order == WalkOrder::PreOrder) {
235 WalkResult result = callback(&block);
236 if (result.wasSkipped())
237 continue;
238 if (result.wasInterrupted())
239 return WalkResult::interrupt();
240 }
241 for (auto &nestedOp : Iterator::makeIterable(block))
242 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
243 return WalkResult::interrupt();
244 if (order == WalkOrder::PostOrder) {
245 if (callback(&block).wasInterrupted())
246 return WalkResult::interrupt();
247 // We don't check if this block was skipped because its walk already
248 // finished and the walk will continue with the next block.
249 }
250 }
251 }
252 return WalkResult::advance();
253 }
254
255 template <typename Iterator>
walk(Operation * op,function_ref<WalkResult (Operation *)> callback,WalkOrder order)256 WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
257 WalkOrder order) {
258 if (order == WalkOrder::PreOrder) {
259 WalkResult result = callback(op);
260 // If skipped, caller will continue the walk on the next operation.
261 if (result.wasSkipped())
262 return WalkResult::advance();
263 if (result.wasInterrupted())
264 return WalkResult::interrupt();
265 }
266
267 // TODO: This walk should be iterative over the operations.
268 for (auto ®ion : Iterator::makeIterable(*op)) {
269 for (auto &block : Iterator::makeIterable(region)) {
270 // Early increment here in the case where the operation is erased.
271 for (auto &nestedOp :
272 llvm::make_early_inc_range(Iterator::makeIterable(block))) {
273 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
274 return WalkResult::interrupt();
275 }
276 }
277 }
278
279 if (order == WalkOrder::PostOrder)
280 return callback(op);
281 return WalkResult::advance();
282 }
283
284 // Below are a set of functions to walk nested operations. Users should favor
285 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
286 // methods. They are also templated to allow for statically dispatching based
287 // upon the type of the callback function.
288
289 /// Walk all of the regions, blocks, or operations nested under (and including)
290 /// the given operation. The order in which regions, blocks and operations at
291 /// the same nesting level are visited (e.g., lexicographical or reverse
292 /// lexicographical order) is determined by 'Iterator'. The walk order for
293 /// enclosing regions, blocks and operations with respect to their nested ones
294 /// is specified by 'Order' (post-order by default). A callback on a block or
295 /// operation is allowed to erase that block or operation if either:
296 /// * the walk is in post-order, or
297 /// * the walk is in pre-order and the walk is skipped after the erasure.
298 /// This method is selected for callbacks that operate on Region*, Block*, and
299 /// Operation*.
300 ///
301 /// Example:
302 /// op->walk([](Region *r) { ... });
303 /// op->walk([](Block *b) { ... });
304 /// op->walk([](Operation *op) { ... });
305 template <
306 WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
307 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
308 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
309 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
310 RetT>
walk(Operation * op,FuncTy && callback)311 walk(Operation *op, FuncTy &&callback) {
312 return detail::walk<Iterator>(op, function_ref<RetT(ArgT)>(callback), Order);
313 }
314
315 /// Walk all of the operations of type 'ArgT' nested under and including the
316 /// given operation. The order in which regions, blocks and operations at
317 /// the same nesting are visited (e.g., lexicographical or reverse
318 /// lexicographical order) is determined by 'Iterator'. The walk order for
319 /// enclosing regions, blocks and operations with respect to their nested ones
320 /// is specified by 'order' (post-order by default). This method is selected for
321 /// void-returning callbacks that operate on a specific derived operation type.
322 /// A callback on an operation is allowed to erase that operation only if the
323 /// walk is in post-order. See non-void method for pre-order erasure.
324 ///
325 /// Example:
326 /// op->walk([](ReturnOp op) { ... });
327 template <
328 WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
329 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
330 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
331 std::enable_if_t<
332 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
333 std::is_same<RetT, void>::value,
334 RetT>
walk(Operation * op,FuncTy && callback)335 walk(Operation *op, FuncTy &&callback) {
336 auto wrapperFn = [&](Operation *op) {
337 if (auto derivedOp = dyn_cast<ArgT>(op))
338 callback(derivedOp);
339 };
340 return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
341 Order);
342 }
343
344 /// Walk all of the operations of type 'ArgT' nested under and including the
345 /// given operation. The order in which regions, blocks and operations at
346 /// the same nesting are visited (e.g., lexicographical or reverse
347 /// lexicographical order) is determined by 'Iterator'. The walk order for
348 /// enclosing regions, blocks and operations with respect to their nested ones
349 /// is specified by 'Order' (post-order by default). This method is selected for
350 /// WalkReturn returning skippable or interruptible callbacks that operate on a
351 /// specific derived operation type. A callback on an operation is allowed to
352 /// erase that operation if either:
353 /// * the walk is in post-order, or
354 /// * the walk is in pre-order and the walk is skipped after the erasure.
355 ///
356 /// Example:
357 /// op->walk([](ReturnOp op) {
358 /// if (some_invariant)
359 /// return WalkResult::skip();
360 /// if (another_invariant)
361 /// return WalkResult::interrupt();
362 /// return WalkResult::advance();
363 /// });
364 template <
365 WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
366 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
367 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
368 std::enable_if_t<
369 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
370 std::is_same<RetT, WalkResult>::value,
371 RetT>
walk(Operation * op,FuncTy && callback)372 walk(Operation *op, FuncTy &&callback) {
373 auto wrapperFn = [&](Operation *op) {
374 if (auto derivedOp = dyn_cast<ArgT>(op))
375 return callback(derivedOp);
376 return WalkResult::advance();
377 };
378 return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
379 Order);
380 }
381
382 /// Generic walkers with stage aware callbacks.
383
384 /// Walk all the operations nested under (and including) the given operation,
385 /// with the callback being invoked on each operation N+1 times, where N is the
386 /// number of regions attached to the operation. The `stage` input to the
387 /// callback indicates the current walk stage. This method is invoked for void
388 /// returning callbacks.
389 void walk(Operation *op,
390 function_ref<void(Operation *, const WalkStage &stage)> callback);
391
392 /// Walk all the operations nested under (and including) the given operation,
393 /// with the callback being invoked on each operation N+1 times, where N is the
394 /// number of regions attached to the operation. The `stage` input to the
395 /// callback indicates the current walk stage. This method is invoked for
396 /// skippable or interruptible callbacks.
397 WalkResult
398 walk(Operation *op,
399 function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
400
401 /// Walk all of the operations nested under and including the given operation.
402 /// This method is selected for stage-aware callbacks that operate on
403 /// Operation*.
404 ///
405 /// Example:
406 /// op->walk([](Operation *op, const WalkStage &stage) { ... });
407 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
408 typename RetT = decltype(std::declval<FuncTy>()(
409 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
410 std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>
walk(Operation * op,FuncTy && callback)411 walk(Operation *op, FuncTy &&callback) {
412 return detail::walk(op,
413 function_ref<RetT(ArgT, const WalkStage &)>(callback));
414 }
415
416 /// Walk all of the operations of type 'ArgT' nested under and including the
417 /// given operation. This method is selected for void returning callbacks that
418 /// operate on a specific derived operation type.
419 ///
420 /// Example:
421 /// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
422 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
423 typename RetT = decltype(std::declval<FuncTy>()(
424 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
425 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
426 std::is_same<RetT, void>::value,
427 RetT>
walk(Operation * op,FuncTy && callback)428 walk(Operation *op, FuncTy &&callback) {
429 auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
430 if (auto derivedOp = dyn_cast<ArgT>(op))
431 callback(derivedOp, stage);
432 };
433 return detail::walk(
434 op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
435 }
436
437 /// Walk all of the operations of type 'ArgT' nested under and including the
438 /// given operation. This method is selected for WalkReturn returning
439 /// interruptible callbacks that operate on a specific derived operation type.
440 ///
441 /// Example:
442 /// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
443 /// if (some_invariant)
444 /// return WalkResult::interrupt();
445 /// return WalkResult::advance();
446 /// });
447 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
448 typename RetT = decltype(std::declval<FuncTy>()(
449 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
450 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
451 std::is_same<RetT, WalkResult>::value,
452 RetT>
walk(Operation * op,FuncTy && callback)453 walk(Operation *op, FuncTy &&callback) {
454 auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
455 if (auto derivedOp = dyn_cast<ArgT>(op))
456 return callback(derivedOp, stage);
457 return WalkResult::advance();
458 };
459 return detail::walk(
460 op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
461 }
462
463 /// Utility to provide the return type of a templated walk method.
464 template <typename FnT>
465 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
466 } // namespace detail
467
468 } // namespace mlir
469
470 #endif
471