xref: /llvm-project/mlir/lib/IR/Visitors.cpp (revision df067f13de569979b0d8ad8e9fc91ca06630e58f)
1 //===- Visitors.cpp - MLIR Visitor Utilities ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Visitors.h"
10 #include "mlir/IR/Operation.h"
11 
12 using namespace mlir;
13 
WalkStage(Operation * op)14 WalkStage::WalkStage(Operation *op)
15     : numRegions(op->getNumRegions()), nextRegion(0) {}
16 
makeIterable(Operation & range)17 MutableArrayRef<Region> ForwardIterator::makeIterable(Operation &range) {
18   return range.getRegions();
19 }
20 
walk(Operation * op,function_ref<void (Operation *,const WalkStage &)> callback)21 void detail::walk(Operation *op,
22                   function_ref<void(Operation *, const WalkStage &)> callback) {
23   WalkStage stage(op);
24 
25   for (Region &region : op->getRegions()) {
26     // Invoke callback on the parent op before visiting each child region.
27     callback(op, stage);
28     stage.advance();
29 
30     for (Block &block : region) {
31       for (Operation &nestedOp : block)
32         walk(&nestedOp, callback);
33     }
34   }
35 
36   // Invoke callback after all regions have been visited.
37   callback(op, stage);
38 }
39 
walk(Operation * op,function_ref<WalkResult (Operation *,const WalkStage &)> callback)40 WalkResult detail::walk(
41     Operation *op,
42     function_ref<WalkResult(Operation *, const WalkStage &)> callback) {
43   WalkStage stage(op);
44 
45   for (Region &region : op->getRegions()) {
46     // Invoke callback on the parent op before visiting each child region.
47     WalkResult result = callback(op, stage);
48 
49     if (result.wasSkipped())
50       return WalkResult::advance();
51     if (result.wasInterrupted())
52       return WalkResult::interrupt();
53 
54     stage.advance();
55 
56     for (Block &block : region) {
57       // Early increment here in the case where the operation is erased.
58       for (Operation &nestedOp : llvm::make_early_inc_range(block))
59         if (walk(&nestedOp, callback).wasInterrupted())
60           return WalkResult::interrupt();
61     }
62   }
63   return callback(op, stage);
64 }
65