xref: /llvm-project/mlir/test/lib/IR/TestVisitorsGeneric.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
18067ced1SRahul Joshi //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
28067ced1SRahul Joshi //
38067ced1SRahul Joshi // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48067ced1SRahul Joshi // See https://llvm.org/LICENSE.txt for license information.
58067ced1SRahul Joshi // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68067ced1SRahul Joshi //
78067ced1SRahul Joshi //===----------------------------------------------------------------------===//
88067ced1SRahul Joshi 
9*e95e94adSJeff Niu #include "TestOps.h"
108067ced1SRahul Joshi #include "mlir/Pass/Pass.h"
118067ced1SRahul Joshi 
128067ced1SRahul Joshi using namespace mlir;
138067ced1SRahul Joshi 
getStageDescription(const WalkStage & stage)148067ced1SRahul Joshi static std::string getStageDescription(const WalkStage &stage) {
158067ced1SRahul Joshi   if (stage.isBeforeAllRegions())
168067ced1SRahul Joshi     return "before all regions";
178067ced1SRahul Joshi   if (stage.isAfterAllRegions())
188067ced1SRahul Joshi     return "after all regions";
198067ced1SRahul Joshi   return "before region #" + std::to_string(stage.getNextRegion());
208067ced1SRahul Joshi }
218067ced1SRahul Joshi 
228067ced1SRahul Joshi namespace {
238067ced1SRahul Joshi /// This pass exercises generic visitor with void callbacks and prints the order
248067ced1SRahul Joshi /// and stage in which operations are visited.
255e50dd04SRiver Riddle struct TestGenericIRVisitorPass
268067ced1SRahul Joshi     : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorPass275e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass)
285e50dd04SRiver Riddle 
298067ced1SRahul Joshi   StringRef getArgument() const final { return "test-generic-ir-visitors"; }
getDescription__anon059fcff50111::TestGenericIRVisitorPass308067ced1SRahul Joshi   StringRef getDescription() const final { return "Test generic IR visitors."; }
runOnOperation__anon059fcff50111::TestGenericIRVisitorPass318067ced1SRahul Joshi   void runOnOperation() override {
328067ced1SRahul Joshi     Operation *outerOp = getOperation();
338067ced1SRahul Joshi     int stepNo = 0;
348067ced1SRahul Joshi     outerOp->walk([&](Operation *op, const WalkStage &stage) {
358067ced1SRahul Joshi       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
368067ced1SRahul Joshi                    << getStageDescription(stage) << "\n";
378067ced1SRahul Joshi     });
388067ced1SRahul Joshi 
398067ced1SRahul Joshi     // Exercise static inference of operation type.
408067ced1SRahul Joshi     outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
418067ced1SRahul Joshi       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
428067ced1SRahul Joshi                    << getStageDescription(stage) << "\n";
438067ced1SRahul Joshi     });
448067ced1SRahul Joshi   }
458067ced1SRahul Joshi };
468067ced1SRahul Joshi 
478067ced1SRahul Joshi /// This pass exercises the generic visitor with non-void callbacks and prints
488067ced1SRahul Joshi /// the order and stage in which operations are visited. It will interrupt the
498067ced1SRahul Joshi /// walk based on attributes peesent in the IR.
505e50dd04SRiver Riddle struct TestGenericIRVisitorInterruptPass
518067ced1SRahul Joshi     : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorInterruptPass525e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
535e50dd04SRiver Riddle       TestGenericIRVisitorInterruptPass)
545e50dd04SRiver Riddle 
558067ced1SRahul Joshi   StringRef getArgument() const final {
568067ced1SRahul Joshi     return "test-generic-ir-visitors-interrupt";
578067ced1SRahul Joshi   }
getDescription__anon059fcff50111::TestGenericIRVisitorInterruptPass588067ced1SRahul Joshi   StringRef getDescription() const final {
598067ced1SRahul Joshi     return "Test generic IR visitors with interrupts.";
608067ced1SRahul Joshi   }
runOnOperation__anon059fcff50111::TestGenericIRVisitorInterruptPass618067ced1SRahul Joshi   void runOnOperation() override {
628067ced1SRahul Joshi     Operation *outerOp = getOperation();
638067ced1SRahul Joshi     int stepNo = 0;
648067ced1SRahul Joshi 
658067ced1SRahul Joshi     auto walker = [&](Operation *op, const WalkStage &stage) {
668067ced1SRahul Joshi       if (auto interruptBeforeAall =
678067ced1SRahul Joshi               op->getAttrOfType<BoolAttr>("interrupt_before_all"))
688067ced1SRahul Joshi         if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
698067ced1SRahul Joshi           return WalkResult::interrupt();
708067ced1SRahul Joshi 
718067ced1SRahul Joshi       if (auto interruptAfterAll =
728067ced1SRahul Joshi               op->getAttrOfType<BoolAttr>("interrupt_after_all"))
738067ced1SRahul Joshi         if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
748067ced1SRahul Joshi           return WalkResult::interrupt();
758067ced1SRahul Joshi 
768067ced1SRahul Joshi       if (auto interruptAfterRegion =
778067ced1SRahul Joshi               op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
788067ced1SRahul Joshi         if (stage.isAfterRegion(
798067ced1SRahul Joshi                 static_cast<int>(interruptAfterRegion.getInt())))
808067ced1SRahul Joshi           return WalkResult::interrupt();
818067ced1SRahul Joshi 
828067ced1SRahul Joshi       if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
838067ced1SRahul Joshi         if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
848067ced1SRahul Joshi           return WalkResult::skip();
858067ced1SRahul Joshi 
868067ced1SRahul Joshi       if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
878067ced1SRahul Joshi         if (skipAfterAll.getValue() && stage.isAfterAllRegions())
888067ced1SRahul Joshi           return WalkResult::skip();
898067ced1SRahul Joshi 
908067ced1SRahul Joshi       if (auto skipAfterRegion =
918067ced1SRahul Joshi               op->getAttrOfType<IntegerAttr>("skip_after_region"))
928067ced1SRahul Joshi         if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
938067ced1SRahul Joshi           return WalkResult::skip();
948067ced1SRahul Joshi 
958067ced1SRahul Joshi       llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
968067ced1SRahul Joshi                    << getStageDescription(stage) << "\n";
978067ced1SRahul Joshi       return WalkResult::advance();
988067ced1SRahul Joshi     };
998067ced1SRahul Joshi 
1008067ced1SRahul Joshi     // Interrupt the walk based on attributes on the operation.
1018067ced1SRahul Joshi     auto result = outerOp->walk(walker);
1028067ced1SRahul Joshi 
1038067ced1SRahul Joshi     if (result.wasInterrupted())
1048067ced1SRahul Joshi       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
1058067ced1SRahul Joshi 
1068067ced1SRahul Joshi     // Exercise static inference of operation type.
1078067ced1SRahul Joshi     result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
1088067ced1SRahul Joshi       return walker(op, stage);
1098067ced1SRahul Joshi     });
1108067ced1SRahul Joshi 
1118067ced1SRahul Joshi     if (result.wasInterrupted())
1128067ced1SRahul Joshi       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
1138067ced1SRahul Joshi   }
1148067ced1SRahul Joshi };
1158067ced1SRahul Joshi 
116f2b94bd7SAshay Rane struct TestGenericIRBlockVisitorInterruptPass
117f2b94bd7SAshay Rane     : public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
118f2b94bd7SAshay Rane                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass119f2b94bd7SAshay Rane   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
120f2b94bd7SAshay Rane       TestGenericIRBlockVisitorInterruptPass)
121f2b94bd7SAshay Rane 
122f2b94bd7SAshay Rane   StringRef getArgument() const final {
123f2b94bd7SAshay Rane     return "test-generic-ir-block-visitors-interrupt";
124f2b94bd7SAshay Rane   }
getDescription__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass125f2b94bd7SAshay Rane   StringRef getDescription() const final {
126f2b94bd7SAshay Rane     return "Test generic IR visitors with interrupts, starting with Blocks.";
127f2b94bd7SAshay Rane   }
128f2b94bd7SAshay Rane 
runOnOperation__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass129f2b94bd7SAshay Rane   void runOnOperation() override {
130f2b94bd7SAshay Rane     int stepNo = 0;
131f2b94bd7SAshay Rane 
132f2b94bd7SAshay Rane     auto walker = [&](Block *block) {
133f2b94bd7SAshay Rane       for (Operation &op : *block)
134179588eaSAshay Rane         if (op.getAttrOfType<BoolAttr>("interrupt"))
135f2b94bd7SAshay Rane           return WalkResult::interrupt();
136f2b94bd7SAshay Rane 
137f2b94bd7SAshay Rane       llvm::outs() << "step " << stepNo++ << "\n";
138f2b94bd7SAshay Rane       return WalkResult::advance();
139f2b94bd7SAshay Rane     };
140f2b94bd7SAshay Rane 
141f2b94bd7SAshay Rane     auto result = getOperation()->walk(walker);
142f2b94bd7SAshay Rane     if (result.wasInterrupted())
143f2b94bd7SAshay Rane       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
144f2b94bd7SAshay Rane   }
145f2b94bd7SAshay Rane };
146f2b94bd7SAshay Rane 
147f2b94bd7SAshay Rane struct TestGenericIRRegionVisitorInterruptPass
148f2b94bd7SAshay Rane     : public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
149f2b94bd7SAshay Rane                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass150f2b94bd7SAshay Rane   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
151f2b94bd7SAshay Rane       TestGenericIRRegionVisitorInterruptPass)
152f2b94bd7SAshay Rane 
153f2b94bd7SAshay Rane   StringRef getArgument() const final {
154f2b94bd7SAshay Rane     return "test-generic-ir-region-visitors-interrupt";
155f2b94bd7SAshay Rane   }
getDescription__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass156f2b94bd7SAshay Rane   StringRef getDescription() const final {
157f2b94bd7SAshay Rane     return "Test generic IR visitors with interrupts, starting with Regions.";
158f2b94bd7SAshay Rane   }
159f2b94bd7SAshay Rane 
runOnOperation__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass160f2b94bd7SAshay Rane   void runOnOperation() override {
161f2b94bd7SAshay Rane     int stepNo = 0;
162f2b94bd7SAshay Rane 
163f2b94bd7SAshay Rane     auto walker = [&](Region *region) {
164179588eaSAshay Rane       for (Operation &op : region->getOps())
165179588eaSAshay Rane         if (op.getAttrOfType<BoolAttr>("interrupt"))
166f2b94bd7SAshay Rane           return WalkResult::interrupt();
167f2b94bd7SAshay Rane 
168f2b94bd7SAshay Rane       llvm::outs() << "step " << stepNo++ << "\n";
169f2b94bd7SAshay Rane       return WalkResult::advance();
170f2b94bd7SAshay Rane     };
171f2b94bd7SAshay Rane 
172f2b94bd7SAshay Rane     auto result = getOperation()->walk(walker);
173f2b94bd7SAshay Rane     if (result.wasInterrupted())
174f2b94bd7SAshay Rane       llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
175f2b94bd7SAshay Rane   }
176f2b94bd7SAshay Rane };
177f2b94bd7SAshay Rane 
1788067ced1SRahul Joshi } // namespace
1798067ced1SRahul Joshi 
1808067ced1SRahul Joshi namespace mlir {
1818067ced1SRahul Joshi namespace test {
registerTestGenericIRVisitorsPass()1828067ced1SRahul Joshi void registerTestGenericIRVisitorsPass() {
1838067ced1SRahul Joshi   PassRegistration<TestGenericIRVisitorPass>();
1848067ced1SRahul Joshi   PassRegistration<TestGenericIRVisitorInterruptPass>();
185f2b94bd7SAshay Rane   PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
186f2b94bd7SAshay Rane   PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
1878067ced1SRahul Joshi }
1888067ced1SRahul Joshi 
1898067ced1SRahul Joshi } // namespace test
1908067ced1SRahul Joshi } // namespace mlir
191