xref: /llvm-project/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp (revision 0f8a6b7d03550cb58cf49535af2de2230abfe997)
1 //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements mlir::walkAndApplyPatterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
14 
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/OperationSupport.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/Verifier.h"
19 #include "mlir/IR/Visitors.h"
20 #include "mlir/Rewrite/PatternApplicator.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/ErrorHandling.h"
23 
24 #define DEBUG_TYPE "walk-rewriter"
25 
26 namespace mlir {
27 
28 namespace {
29 struct WalkAndApplyPatternsAction final
30     : tracing::ActionImpl<WalkAndApplyPatternsAction> {
31   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
32   using ActionImpl::ActionImpl;
33   static constexpr StringLiteral tag = "walk-and-apply-patterns";
34   void print(raw_ostream &os) const override { os << tag; }
35 };
36 
37 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
38 // Forwarding listener to guard against unsupported erasures of non-descendant
39 // ops/blocks. Because we use walk-based pattern application, erasing the
40 // op/block from the *next* iteration (e.g., a user of the visited op) is not
41 // valid. Note that this is only used with expensive pattern API checks.
42 struct ErasedOpsListener final : RewriterBase::ForwardingListener {
43   using RewriterBase::ForwardingListener::ForwardingListener;
44 
45   void notifyOperationErased(Operation *op) override {
46     checkErasure(op);
47     ForwardingListener::notifyOperationErased(op);
48   }
49 
50   void notifyBlockErased(Block *block) override {
51     checkErasure(block->getParentOp());
52     ForwardingListener::notifyBlockErased(block);
53   }
54 
55   void checkErasure(Operation *op) const {
56     Operation *ancestorOp = op;
57     while (ancestorOp && ancestorOp != visitedOp)
58       ancestorOp = ancestorOp->getParentOp();
59 
60     if (ancestorOp != visitedOp)
61       llvm::report_fatal_error(
62           "unsupported erasure in WalkPatternRewriter; "
63           "erasure is only supported for matched ops and their descendants");
64   }
65 
66   Operation *visitedOp = nullptr;
67 };
68 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
69 } // namespace
70 
71 void walkAndApplyPatterns(Operation *op,
72                           const FrozenRewritePatternSet &patterns,
73                           RewriterBase::Listener *listener) {
74 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
75   if (failed(verify(op)))
76     llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
77 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
78 
79   MLIRContext *ctx = op->getContext();
80   PatternRewriter rewriter(ctx);
81 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
82   ErasedOpsListener erasedListener(listener);
83   rewriter.setListener(&erasedListener);
84 #else
85   rewriter.setListener(listener);
86 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
87 
88   PatternApplicator applicator(patterns);
89   applicator.applyDefaultCostModel();
90 
91   ctx->executeAction<WalkAndApplyPatternsAction>(
92       [&] {
93         for (Region &region : op->getRegions()) {
94           region.walk([&](Operation *visitedOp) {
95             LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
96                 llvm::dbgs(), OpPrintingFlags().skipRegions());
97                        llvm::dbgs() << "\n";);
98 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99             erasedListener.visitedOp = visitedOp;
100 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
101             if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
102               LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
103             }
104           });
105         }
106       },
107       {op});
108 
109 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
110   if (failed(verify(op)))
111     llvm::report_fatal_error(
112         "walk pattern rewriter result IR failed to verify");
113 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
114 }
115 
116 } // namespace mlir
117