xref: /llvm-project/mlir/include/mlir/IR/Visitors.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
17 #include "llvm/ADT/STLExtras.h"
18 
19 namespace mlir {
20 class Diagnostic;
21 class InFlightDiagnostic;
22 class Operation;
23 class Block;
24 class Region;
25 
26 /// A utility result that is used to signal how to proceed with an ongoing walk:
27 ///   * Interrupt: the walk will be interrupted and no more operations, regions
28 ///   or blocks will be visited.
29 ///   * Advance: the walk will continue.
30 ///   * Skip: the walk of the current operation, region or block and their
31 ///   nested elements that haven't been visited already will be skipped and will
32 ///   continue with the next operation, region or block.
33 class WalkResult {
34   enum ResultEnum { Interrupt, Advance, Skip } result;
35 
36 public:
result(result)37   WalkResult(ResultEnum result = Advance) : result(result) {}
38 
39   /// Allow LogicalResult to interrupt the walk on failure.
WalkResult(LogicalResult result)40   WalkResult(LogicalResult result)
41       : result(failed(result) ? Interrupt : Advance) {}
42 
43   /// Allow diagnostics to interrupt the walk.
WalkResult(Diagnostic &&)44   WalkResult(Diagnostic &&) : result(Interrupt) {}
WalkResult(InFlightDiagnostic &&)45   WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
46 
47   bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
48   bool operator!=(const WalkResult &rhs) const { return result != rhs.result; }
49 
interrupt()50   static WalkResult interrupt() { return {Interrupt}; }
advance()51   static WalkResult advance() { return {Advance}; }
skip()52   static WalkResult skip() { return {Skip}; }
53 
54   /// Returns true if the walk was interrupted.
wasInterrupted()55   bool wasInterrupted() const { return result == Interrupt; }
56 
57   /// Returns true if the walk was skipped.
wasSkipped()58   bool wasSkipped() const { return result == Skip; }
59 };
60 
61 /// Traversal order for region, block and operation walk utilities.
62 enum class WalkOrder { PreOrder, PostOrder };
63 
64 /// This iterator enumerates the elements in "forward" order.
65 struct ForwardIterator {
66   /// Make operations iterable: return the list of regions.
67   static MutableArrayRef<Region> makeIterable(Operation &range);
68 
69   /// Regions and block are already iterable.
70   template <typename T>
makeIterableForwardIterator71   static constexpr T &makeIterable(T &range) {
72     return range;
73   }
74 };
75 
76 /// A utility class to encode the current walk stage for "generic" walkers.
77 /// When walking an operation, we can either choose a Pre/Post order walker
78 /// which invokes the callback on an operation before/after all its attached
79 /// regions have been visited, or choose a "generic" walker where the callback
80 /// is invoked on the operation N+1 times where N is the number of regions
81 /// attached to that operation. The `WalkStage` class below encodes the current
82 /// stage of the walk, i.e., which regions have already been visited, and the
83 /// callback accepts an additional argument for the current stage. Such
84 /// generic walkers that accept stage-aware callbacks are only applicable when
85 /// the callback operates on an operation (i.e., not applicable for callbacks
86 /// on Blocks or Regions).
87 class WalkStage {
88 public:
89   explicit WalkStage(Operation *op);
90 
91   /// Return true if parent operation is being visited before all regions.
isBeforeAllRegions()92   bool isBeforeAllRegions() const { return nextRegion == 0; }
93   /// Returns true if parent operation is being visited just before visiting
94   /// region number `region`.
isBeforeRegion(int region)95   bool isBeforeRegion(int region) const { return nextRegion == region; }
96   /// Returns true if parent operation is being visited just after visiting
97   /// region number `region`.
isAfterRegion(int region)98   bool isAfterRegion(int region) const { return nextRegion == region + 1; }
99   /// Return true if parent operation is being visited after all regions.
isAfterAllRegions()100   bool isAfterAllRegions() const { return nextRegion == numRegions; }
101   /// Advance the walk stage.
advance()102   void advance() { nextRegion++; }
103   /// Returns the next region that will be visited.
getNextRegion()104   int getNextRegion() const { return nextRegion; }
105 
106 private:
107   const int numRegions;
108   int nextRegion;
109 };
110 
111 namespace detail {
112 /// Helper templates to deduce the first argument of a callback parameter.
113 template <typename Ret, typename Arg, typename... Rest>
114 Arg first_argument_type(Ret (*)(Arg, Rest...));
115 template <typename Ret, typename F, typename Arg, typename... Rest>
116 Arg first_argument_type(Ret (F::*)(Arg, Rest...));
117 template <typename Ret, typename F, typename Arg, typename... Rest>
118 Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
119 template <typename F>
120 decltype(first_argument_type(&F::operator())) first_argument_type(F);
121 
122 /// Type definition of the first argument to the given callable 'T'.
123 template <typename T>
124 using first_argument = decltype(first_argument_type(std::declval<T>()));
125 
126 /// Walk all of the regions, blocks, or operations nested under (and including)
127 /// the given operation. The order in which regions, blocks and operations at
128 /// the same nesting level are visited (e.g., lexicographical or reverse
129 /// lexicographical order) is determined by 'Iterator'. The walk order for
130 /// enclosing regions, blocks and operations with respect to their nested ones
131 /// is specified by 'order'. These methods are invoked for void-returning
132 /// callbacks. A callback on a block or operation is allowed to erase that block
133 /// or operation only if the walk is in post-order. See non-void method for
134 /// pre-order erasure.
135 template <typename Iterator>
walk(Operation * op,function_ref<void (Region *)> callback,WalkOrder order)136 void walk(Operation *op, function_ref<void(Region *)> callback,
137           WalkOrder order) {
138   // We don't use early increment for regions because they can't be erased from
139   // a callback.
140   for (auto &region : Iterator::makeIterable(*op)) {
141     if (order == WalkOrder::PreOrder)
142       callback(&region);
143     for (auto &block : Iterator::makeIterable(region)) {
144       for (auto &nestedOp : Iterator::makeIterable(block))
145         walk<Iterator>(&nestedOp, callback, order);
146     }
147     if (order == WalkOrder::PostOrder)
148       callback(&region);
149   }
150 }
151 
152 template <typename Iterator>
walk(Operation * op,function_ref<void (Block *)> callback,WalkOrder order)153 void walk(Operation *op, function_ref<void(Block *)> callback,
154           WalkOrder order) {
155   for (auto &region : Iterator::makeIterable(*op)) {
156     // Early increment here in the case where the block is erased.
157     for (auto &block :
158          llvm::make_early_inc_range(Iterator::makeIterable(region))) {
159       if (order == WalkOrder::PreOrder)
160         callback(&block);
161       for (auto &nestedOp : Iterator::makeIterable(block))
162         walk<Iterator>(&nestedOp, callback, order);
163       if (order == WalkOrder::PostOrder)
164         callback(&block);
165     }
166   }
167 }
168 
169 template <typename Iterator>
walk(Operation * op,function_ref<void (Operation *)> callback,WalkOrder order)170 void walk(Operation *op, function_ref<void(Operation *)> callback,
171           WalkOrder order) {
172   if (order == WalkOrder::PreOrder)
173     callback(op);
174 
175   // TODO: This walk should be iterative over the operations.
176   for (auto &region : Iterator::makeIterable(*op)) {
177     for (auto &block : Iterator::makeIterable(region)) {
178       // Early increment here in the case where the operation is erased.
179       for (auto &nestedOp :
180            llvm::make_early_inc_range(Iterator::makeIterable(block)))
181         walk<Iterator>(&nestedOp, callback, order);
182     }
183   }
184 
185   if (order == WalkOrder::PostOrder)
186     callback(op);
187 }
188 
189 /// Walk all of the regions, blocks, or operations nested under (and including)
190 /// the given operation. The order in which regions, blocks and operations at
191 /// the same nesting level are visited (e.g., lexicographical or reverse
192 /// lexicographical order) is determined by 'Iterator'. The walk order for
193 /// enclosing regions, blocks and operations with respect to their nested ones
194 /// is specified by 'order'. This method is invoked for skippable or
195 /// interruptible callbacks. A callback on a block or operation is allowed to
196 /// erase that block or operation if either:
197 ///   * the walk is in post-order, or
198 ///   * the walk is in pre-order and the walk is skipped after the erasure.
199 template <typename Iterator>
walk(Operation * op,function_ref<WalkResult (Region *)> callback,WalkOrder order)200 WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
201                 WalkOrder order) {
202   // We don't use early increment for regions because they can't be erased from
203   // a callback.
204   for (auto &region : Iterator::makeIterable(*op)) {
205     if (order == WalkOrder::PreOrder) {
206       WalkResult result = callback(&region);
207       if (result.wasSkipped())
208         continue;
209       if (result.wasInterrupted())
210         return WalkResult::interrupt();
211     }
212     for (auto &block : Iterator::makeIterable(region)) {
213       for (auto &nestedOp : Iterator::makeIterable(block))
214         if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
215           return WalkResult::interrupt();
216     }
217     if (order == WalkOrder::PostOrder) {
218       if (callback(&region).wasInterrupted())
219         return WalkResult::interrupt();
220       // We don't check if this region was skipped because its walk already
221       // finished and the walk will continue with the next region.
222     }
223   }
224   return WalkResult::advance();
225 }
226 
227 template <typename Iterator>
walk(Operation * op,function_ref<WalkResult (Block *)> callback,WalkOrder order)228 WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
229                 WalkOrder order) {
230   for (auto &region : Iterator::makeIterable(*op)) {
231     // Early increment here in the case where the block is erased.
232     for (auto &block :
233          llvm::make_early_inc_range(Iterator::makeIterable(region))) {
234       if (order == WalkOrder::PreOrder) {
235         WalkResult result = callback(&block);
236         if (result.wasSkipped())
237           continue;
238         if (result.wasInterrupted())
239           return WalkResult::interrupt();
240       }
241       for (auto &nestedOp : Iterator::makeIterable(block))
242         if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
243           return WalkResult::interrupt();
244       if (order == WalkOrder::PostOrder) {
245         if (callback(&block).wasInterrupted())
246           return WalkResult::interrupt();
247         // We don't check if this block was skipped because its walk already
248         // finished and the walk will continue with the next block.
249       }
250     }
251   }
252   return WalkResult::advance();
253 }
254 
255 template <typename Iterator>
walk(Operation * op,function_ref<WalkResult (Operation *)> callback,WalkOrder order)256 WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
257                 WalkOrder order) {
258   if (order == WalkOrder::PreOrder) {
259     WalkResult result = callback(op);
260     // If skipped, caller will continue the walk on the next operation.
261     if (result.wasSkipped())
262       return WalkResult::advance();
263     if (result.wasInterrupted())
264       return WalkResult::interrupt();
265   }
266 
267   // TODO: This walk should be iterative over the operations.
268   for (auto &region : Iterator::makeIterable(*op)) {
269     for (auto &block : Iterator::makeIterable(region)) {
270       // Early increment here in the case where the operation is erased.
271       for (auto &nestedOp :
272            llvm::make_early_inc_range(Iterator::makeIterable(block))) {
273         if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
274           return WalkResult::interrupt();
275       }
276     }
277   }
278 
279   if (order == WalkOrder::PostOrder)
280     return callback(op);
281   return WalkResult::advance();
282 }
283 
284 // Below are a set of functions to walk nested operations. Users should favor
285 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
286 // methods. They are also templated to allow for statically dispatching based
287 // upon the type of the callback function.
288 
289 /// Walk all of the regions, blocks, or operations nested under (and including)
290 /// the given operation. The order in which regions, blocks and operations at
291 /// the same nesting level are visited (e.g., lexicographical or reverse
292 /// lexicographical order) is determined by 'Iterator'. The walk order for
293 /// enclosing regions, blocks and operations with respect to their nested ones
294 /// is specified by 'Order' (post-order by default). A callback on a block or
295 /// operation is allowed to erase that block or operation if either:
296 ///   * the walk is in post-order, or
297 ///   * the walk is in pre-order and the walk is skipped after the erasure.
298 /// This method is selected for callbacks that operate on Region*, Block*, and
299 /// Operation*.
300 ///
301 /// Example:
302 ///   op->walk([](Region *r) { ... });
303 ///   op->walk([](Block *b) { ... });
304 ///   op->walk([](Operation *op) { ... });
305 template <
306     WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
307     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
308     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
309 std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
310                  RetT>
walk(Operation * op,FuncTy && callback)311 walk(Operation *op, FuncTy &&callback) {
312   return detail::walk<Iterator>(op, function_ref<RetT(ArgT)>(callback), Order);
313 }
314 
315 /// Walk all of the operations of type 'ArgT' nested under and including the
316 /// given operation. The order in which regions, blocks and operations at
317 /// the same nesting are visited (e.g., lexicographical or reverse
318 /// lexicographical order) is determined by 'Iterator'. The walk order for
319 /// enclosing regions, blocks and operations with respect to their nested ones
320 /// is specified by 'order' (post-order by default). This method is selected for
321 /// void-returning callbacks that operate on a specific derived operation type.
322 /// A callback on an operation is allowed to erase that operation only if the
323 /// walk is in post-order. See non-void method for pre-order erasure.
324 ///
325 /// Example:
326 ///   op->walk([](ReturnOp op) { ... });
327 template <
328     WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
329     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
330     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
331 std::enable_if_t<
332     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
333         std::is_same<RetT, void>::value,
334     RetT>
walk(Operation * op,FuncTy && callback)335 walk(Operation *op, FuncTy &&callback) {
336   auto wrapperFn = [&](Operation *op) {
337     if (auto derivedOp = dyn_cast<ArgT>(op))
338       callback(derivedOp);
339   };
340   return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
341                                 Order);
342 }
343 
344 /// Walk all of the operations of type 'ArgT' nested under and including the
345 /// given operation. The order in which regions, blocks and operations at
346 /// the same nesting are visited (e.g., lexicographical or reverse
347 /// lexicographical order) is determined by 'Iterator'. The walk order for
348 /// enclosing regions, blocks and operations with respect to their nested ones
349 /// is specified by 'Order' (post-order by default). This method is selected for
350 /// WalkReturn returning skippable or interruptible callbacks that operate on a
351 /// specific derived operation type. A callback on an operation is allowed to
352 /// erase that operation if either:
353 ///   * the walk is in post-order, or
354 ///   * the walk is in pre-order and the walk is skipped after the erasure.
355 ///
356 /// Example:
357 ///   op->walk([](ReturnOp op) {
358 ///     if (some_invariant)
359 ///       return WalkResult::skip();
360 ///     if (another_invariant)
361 ///       return WalkResult::interrupt();
362 ///     return WalkResult::advance();
363 ///   });
364 template <
365     WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
366     typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
367     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
368 std::enable_if_t<
369     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
370         std::is_same<RetT, WalkResult>::value,
371     RetT>
walk(Operation * op,FuncTy && callback)372 walk(Operation *op, FuncTy &&callback) {
373   auto wrapperFn = [&](Operation *op) {
374     if (auto derivedOp = dyn_cast<ArgT>(op))
375       return callback(derivedOp);
376     return WalkResult::advance();
377   };
378   return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
379                                 Order);
380 }
381 
382 /// Generic walkers with stage aware callbacks.
383 
384 /// Walk all the operations nested under (and including) the given operation,
385 /// with the callback being invoked on each operation N+1 times, where N is the
386 /// number of regions attached to the operation. The `stage` input to the
387 /// callback indicates the current walk stage. This method is invoked for void
388 /// returning callbacks.
389 void walk(Operation *op,
390           function_ref<void(Operation *, const WalkStage &stage)> callback);
391 
392 /// Walk all the operations nested under (and including) the given operation,
393 /// with the callback being invoked on each operation N+1 times, where N is the
394 /// number of regions attached to the operation. The `stage` input to the
395 /// callback indicates the current walk stage. This method is invoked for
396 /// skippable or interruptible callbacks.
397 WalkResult
398 walk(Operation *op,
399      function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
400 
401 /// Walk all of the operations nested under and including the given operation.
402 /// This method is selected for stage-aware callbacks that operate on
403 /// Operation*.
404 ///
405 /// Example:
406 ///   op->walk([](Operation *op, const WalkStage &stage) { ... });
407 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
408           typename RetT = decltype(std::declval<FuncTy>()(
409               std::declval<ArgT>(), std::declval<const WalkStage &>()))>
410 std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>
walk(Operation * op,FuncTy && callback)411 walk(Operation *op, FuncTy &&callback) {
412   return detail::walk(op,
413                       function_ref<RetT(ArgT, const WalkStage &)>(callback));
414 }
415 
416 /// Walk all of the operations of type 'ArgT' nested under and including the
417 /// given operation. This method is selected for void returning callbacks that
418 /// operate on a specific derived operation type.
419 ///
420 /// Example:
421 ///   op->walk([](ReturnOp op, const WalkStage &stage) { ... });
422 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
423           typename RetT = decltype(std::declval<FuncTy>()(
424               std::declval<ArgT>(), std::declval<const WalkStage &>()))>
425 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
426                      std::is_same<RetT, void>::value,
427                  RetT>
walk(Operation * op,FuncTy && callback)428 walk(Operation *op, FuncTy &&callback) {
429   auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
430     if (auto derivedOp = dyn_cast<ArgT>(op))
431       callback(derivedOp, stage);
432   };
433   return detail::walk(
434       op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
435 }
436 
437 /// Walk all of the operations of type 'ArgT' nested under and including the
438 /// given operation. This method is selected for WalkReturn returning
439 /// interruptible callbacks that operate on a specific derived operation type.
440 ///
441 /// Example:
442 ///   op->walk(op, [](ReturnOp op, const WalkStage &stage) {
443 ///     if (some_invariant)
444 ///       return WalkResult::interrupt();
445 ///     return WalkResult::advance();
446 ///   });
447 template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
448           typename RetT = decltype(std::declval<FuncTy>()(
449               std::declval<ArgT>(), std::declval<const WalkStage &>()))>
450 std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
451                      std::is_same<RetT, WalkResult>::value,
452                  RetT>
walk(Operation * op,FuncTy && callback)453 walk(Operation *op, FuncTy &&callback) {
454   auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
455     if (auto derivedOp = dyn_cast<ArgT>(op))
456       return callback(derivedOp, stage);
457     return WalkResult::advance();
458   };
459   return detail::walk(
460       op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
461 }
462 
463 /// Utility to provide the return type of a templated walk method.
464 template <typename FnT>
465 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
466 } // namespace detail
467 
468 } // namespace mlir
469 
470 #endif
471