18067ced1SRahul Joshi //===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
28067ced1SRahul Joshi //
38067ced1SRahul Joshi // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48067ced1SRahul Joshi // See https://llvm.org/LICENSE.txt for license information.
58067ced1SRahul Joshi // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68067ced1SRahul Joshi //
78067ced1SRahul Joshi //===----------------------------------------------------------------------===//
88067ced1SRahul Joshi
9*e95e94adSJeff Niu #include "TestOps.h"
108067ced1SRahul Joshi #include "mlir/Pass/Pass.h"
118067ced1SRahul Joshi
128067ced1SRahul Joshi using namespace mlir;
138067ced1SRahul Joshi
getStageDescription(const WalkStage & stage)148067ced1SRahul Joshi static std::string getStageDescription(const WalkStage &stage) {
158067ced1SRahul Joshi if (stage.isBeforeAllRegions())
168067ced1SRahul Joshi return "before all regions";
178067ced1SRahul Joshi if (stage.isAfterAllRegions())
188067ced1SRahul Joshi return "after all regions";
198067ced1SRahul Joshi return "before region #" + std::to_string(stage.getNextRegion());
208067ced1SRahul Joshi }
218067ced1SRahul Joshi
228067ced1SRahul Joshi namespace {
238067ced1SRahul Joshi /// This pass exercises generic visitor with void callbacks and prints the order
248067ced1SRahul Joshi /// and stage in which operations are visited.
255e50dd04SRiver Riddle struct TestGenericIRVisitorPass
268067ced1SRahul Joshi : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorPass275e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass)
285e50dd04SRiver Riddle
298067ced1SRahul Joshi StringRef getArgument() const final { return "test-generic-ir-visitors"; }
getDescription__anon059fcff50111::TestGenericIRVisitorPass308067ced1SRahul Joshi StringRef getDescription() const final { return "Test generic IR visitors."; }
runOnOperation__anon059fcff50111::TestGenericIRVisitorPass318067ced1SRahul Joshi void runOnOperation() override {
328067ced1SRahul Joshi Operation *outerOp = getOperation();
338067ced1SRahul Joshi int stepNo = 0;
348067ced1SRahul Joshi outerOp->walk([&](Operation *op, const WalkStage &stage) {
358067ced1SRahul Joshi llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
368067ced1SRahul Joshi << getStageDescription(stage) << "\n";
378067ced1SRahul Joshi });
388067ced1SRahul Joshi
398067ced1SRahul Joshi // Exercise static inference of operation type.
408067ced1SRahul Joshi outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
418067ced1SRahul Joshi llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
428067ced1SRahul Joshi << getStageDescription(stage) << "\n";
438067ced1SRahul Joshi });
448067ced1SRahul Joshi }
458067ced1SRahul Joshi };
468067ced1SRahul Joshi
478067ced1SRahul Joshi /// This pass exercises the generic visitor with non-void callbacks and prints
488067ced1SRahul Joshi /// the order and stage in which operations are visited. It will interrupt the
498067ced1SRahul Joshi /// walk based on attributes peesent in the IR.
505e50dd04SRiver Riddle struct TestGenericIRVisitorInterruptPass
518067ced1SRahul Joshi : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRVisitorInterruptPass525e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
535e50dd04SRiver Riddle TestGenericIRVisitorInterruptPass)
545e50dd04SRiver Riddle
558067ced1SRahul Joshi StringRef getArgument() const final {
568067ced1SRahul Joshi return "test-generic-ir-visitors-interrupt";
578067ced1SRahul Joshi }
getDescription__anon059fcff50111::TestGenericIRVisitorInterruptPass588067ced1SRahul Joshi StringRef getDescription() const final {
598067ced1SRahul Joshi return "Test generic IR visitors with interrupts.";
608067ced1SRahul Joshi }
runOnOperation__anon059fcff50111::TestGenericIRVisitorInterruptPass618067ced1SRahul Joshi void runOnOperation() override {
628067ced1SRahul Joshi Operation *outerOp = getOperation();
638067ced1SRahul Joshi int stepNo = 0;
648067ced1SRahul Joshi
658067ced1SRahul Joshi auto walker = [&](Operation *op, const WalkStage &stage) {
668067ced1SRahul Joshi if (auto interruptBeforeAall =
678067ced1SRahul Joshi op->getAttrOfType<BoolAttr>("interrupt_before_all"))
688067ced1SRahul Joshi if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
698067ced1SRahul Joshi return WalkResult::interrupt();
708067ced1SRahul Joshi
718067ced1SRahul Joshi if (auto interruptAfterAll =
728067ced1SRahul Joshi op->getAttrOfType<BoolAttr>("interrupt_after_all"))
738067ced1SRahul Joshi if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
748067ced1SRahul Joshi return WalkResult::interrupt();
758067ced1SRahul Joshi
768067ced1SRahul Joshi if (auto interruptAfterRegion =
778067ced1SRahul Joshi op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
788067ced1SRahul Joshi if (stage.isAfterRegion(
798067ced1SRahul Joshi static_cast<int>(interruptAfterRegion.getInt())))
808067ced1SRahul Joshi return WalkResult::interrupt();
818067ced1SRahul Joshi
828067ced1SRahul Joshi if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
838067ced1SRahul Joshi if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
848067ced1SRahul Joshi return WalkResult::skip();
858067ced1SRahul Joshi
868067ced1SRahul Joshi if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
878067ced1SRahul Joshi if (skipAfterAll.getValue() && stage.isAfterAllRegions())
888067ced1SRahul Joshi return WalkResult::skip();
898067ced1SRahul Joshi
908067ced1SRahul Joshi if (auto skipAfterRegion =
918067ced1SRahul Joshi op->getAttrOfType<IntegerAttr>("skip_after_region"))
928067ced1SRahul Joshi if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
938067ced1SRahul Joshi return WalkResult::skip();
948067ced1SRahul Joshi
958067ced1SRahul Joshi llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
968067ced1SRahul Joshi << getStageDescription(stage) << "\n";
978067ced1SRahul Joshi return WalkResult::advance();
988067ced1SRahul Joshi };
998067ced1SRahul Joshi
1008067ced1SRahul Joshi // Interrupt the walk based on attributes on the operation.
1018067ced1SRahul Joshi auto result = outerOp->walk(walker);
1028067ced1SRahul Joshi
1038067ced1SRahul Joshi if (result.wasInterrupted())
1048067ced1SRahul Joshi llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
1058067ced1SRahul Joshi
1068067ced1SRahul Joshi // Exercise static inference of operation type.
1078067ced1SRahul Joshi result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
1088067ced1SRahul Joshi return walker(op, stage);
1098067ced1SRahul Joshi });
1108067ced1SRahul Joshi
1118067ced1SRahul Joshi if (result.wasInterrupted())
1128067ced1SRahul Joshi llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
1138067ced1SRahul Joshi }
1148067ced1SRahul Joshi };
1158067ced1SRahul Joshi
116f2b94bd7SAshay Rane struct TestGenericIRBlockVisitorInterruptPass
117f2b94bd7SAshay Rane : public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
118f2b94bd7SAshay Rane OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass119f2b94bd7SAshay Rane MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
120f2b94bd7SAshay Rane TestGenericIRBlockVisitorInterruptPass)
121f2b94bd7SAshay Rane
122f2b94bd7SAshay Rane StringRef getArgument() const final {
123f2b94bd7SAshay Rane return "test-generic-ir-block-visitors-interrupt";
124f2b94bd7SAshay Rane }
getDescription__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass125f2b94bd7SAshay Rane StringRef getDescription() const final {
126f2b94bd7SAshay Rane return "Test generic IR visitors with interrupts, starting with Blocks.";
127f2b94bd7SAshay Rane }
128f2b94bd7SAshay Rane
runOnOperation__anon059fcff50111::TestGenericIRBlockVisitorInterruptPass129f2b94bd7SAshay Rane void runOnOperation() override {
130f2b94bd7SAshay Rane int stepNo = 0;
131f2b94bd7SAshay Rane
132f2b94bd7SAshay Rane auto walker = [&](Block *block) {
133f2b94bd7SAshay Rane for (Operation &op : *block)
134179588eaSAshay Rane if (op.getAttrOfType<BoolAttr>("interrupt"))
135f2b94bd7SAshay Rane return WalkResult::interrupt();
136f2b94bd7SAshay Rane
137f2b94bd7SAshay Rane llvm::outs() << "step " << stepNo++ << "\n";
138f2b94bd7SAshay Rane return WalkResult::advance();
139f2b94bd7SAshay Rane };
140f2b94bd7SAshay Rane
141f2b94bd7SAshay Rane auto result = getOperation()->walk(walker);
142f2b94bd7SAshay Rane if (result.wasInterrupted())
143f2b94bd7SAshay Rane llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
144f2b94bd7SAshay Rane }
145f2b94bd7SAshay Rane };
146f2b94bd7SAshay Rane
147f2b94bd7SAshay Rane struct TestGenericIRRegionVisitorInterruptPass
148f2b94bd7SAshay Rane : public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
149f2b94bd7SAshay Rane OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass150f2b94bd7SAshay Rane MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
151f2b94bd7SAshay Rane TestGenericIRRegionVisitorInterruptPass)
152f2b94bd7SAshay Rane
153f2b94bd7SAshay Rane StringRef getArgument() const final {
154f2b94bd7SAshay Rane return "test-generic-ir-region-visitors-interrupt";
155f2b94bd7SAshay Rane }
getDescription__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass156f2b94bd7SAshay Rane StringRef getDescription() const final {
157f2b94bd7SAshay Rane return "Test generic IR visitors with interrupts, starting with Regions.";
158f2b94bd7SAshay Rane }
159f2b94bd7SAshay Rane
runOnOperation__anon059fcff50111::TestGenericIRRegionVisitorInterruptPass160f2b94bd7SAshay Rane void runOnOperation() override {
161f2b94bd7SAshay Rane int stepNo = 0;
162f2b94bd7SAshay Rane
163f2b94bd7SAshay Rane auto walker = [&](Region *region) {
164179588eaSAshay Rane for (Operation &op : region->getOps())
165179588eaSAshay Rane if (op.getAttrOfType<BoolAttr>("interrupt"))
166f2b94bd7SAshay Rane return WalkResult::interrupt();
167f2b94bd7SAshay Rane
168f2b94bd7SAshay Rane llvm::outs() << "step " << stepNo++ << "\n";
169f2b94bd7SAshay Rane return WalkResult::advance();
170f2b94bd7SAshay Rane };
171f2b94bd7SAshay Rane
172f2b94bd7SAshay Rane auto result = getOperation()->walk(walker);
173f2b94bd7SAshay Rane if (result.wasInterrupted())
174f2b94bd7SAshay Rane llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
175f2b94bd7SAshay Rane }
176f2b94bd7SAshay Rane };
177f2b94bd7SAshay Rane
1788067ced1SRahul Joshi } // namespace
1798067ced1SRahul Joshi
1808067ced1SRahul Joshi namespace mlir {
1818067ced1SRahul Joshi namespace test {
registerTestGenericIRVisitorsPass()1828067ced1SRahul Joshi void registerTestGenericIRVisitorsPass() {
1838067ced1SRahul Joshi PassRegistration<TestGenericIRVisitorPass>();
1848067ced1SRahul Joshi PassRegistration<TestGenericIRVisitorInterruptPass>();
185f2b94bd7SAshay Rane PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
186f2b94bd7SAshay Rane PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
1878067ced1SRahul Joshi }
1888067ced1SRahul Joshi
1898067ced1SRahul Joshi } // namespace test
1908067ced1SRahul Joshi } // namespace mlir
191