xref: /llvm-project/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp (revision 0f8a6b7d03550cb58cf49535af2de2230abfe997)
1*0f8a6b7dSJakub Kuderski //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2*0f8a6b7dSJakub Kuderski //
3*0f8a6b7dSJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0f8a6b7dSJakub Kuderski // See https://llvm.org/LICENSE.txt for license information.
5*0f8a6b7dSJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0f8a6b7dSJakub Kuderski //
7*0f8a6b7dSJakub Kuderski //===----------------------------------------------------------------------===//
8*0f8a6b7dSJakub Kuderski //
9*0f8a6b7dSJakub Kuderski // Implements mlir::walkAndApplyPatterns.
10*0f8a6b7dSJakub Kuderski //
11*0f8a6b7dSJakub Kuderski //===----------------------------------------------------------------------===//
12*0f8a6b7dSJakub Kuderski 
13*0f8a6b7dSJakub Kuderski #include "mlir/Transforms/WalkPatternRewriteDriver.h"
14*0f8a6b7dSJakub Kuderski 
15*0f8a6b7dSJakub Kuderski #include "mlir/IR/MLIRContext.h"
16*0f8a6b7dSJakub Kuderski #include "mlir/IR/OperationSupport.h"
17*0f8a6b7dSJakub Kuderski #include "mlir/IR/PatternMatch.h"
18*0f8a6b7dSJakub Kuderski #include "mlir/IR/Verifier.h"
19*0f8a6b7dSJakub Kuderski #include "mlir/IR/Visitors.h"
20*0f8a6b7dSJakub Kuderski #include "mlir/Rewrite/PatternApplicator.h"
21*0f8a6b7dSJakub Kuderski #include "llvm/Support/Debug.h"
22*0f8a6b7dSJakub Kuderski #include "llvm/Support/ErrorHandling.h"
23*0f8a6b7dSJakub Kuderski 
24*0f8a6b7dSJakub Kuderski #define DEBUG_TYPE "walk-rewriter"
25*0f8a6b7dSJakub Kuderski 
26*0f8a6b7dSJakub Kuderski namespace mlir {
27*0f8a6b7dSJakub Kuderski 
28*0f8a6b7dSJakub Kuderski namespace {
29*0f8a6b7dSJakub Kuderski struct WalkAndApplyPatternsAction final
30*0f8a6b7dSJakub Kuderski     : tracing::ActionImpl<WalkAndApplyPatternsAction> {
31*0f8a6b7dSJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
32*0f8a6b7dSJakub Kuderski   using ActionImpl::ActionImpl;
33*0f8a6b7dSJakub Kuderski   static constexpr StringLiteral tag = "walk-and-apply-patterns";
34*0f8a6b7dSJakub Kuderski   void print(raw_ostream &os) const override { os << tag; }
35*0f8a6b7dSJakub Kuderski };
36*0f8a6b7dSJakub Kuderski 
37*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
38*0f8a6b7dSJakub Kuderski // Forwarding listener to guard against unsupported erasures of non-descendant
39*0f8a6b7dSJakub Kuderski // ops/blocks. Because we use walk-based pattern application, erasing the
40*0f8a6b7dSJakub Kuderski // op/block from the *next* iteration (e.g., a user of the visited op) is not
41*0f8a6b7dSJakub Kuderski // valid. Note that this is only used with expensive pattern API checks.
42*0f8a6b7dSJakub Kuderski struct ErasedOpsListener final : RewriterBase::ForwardingListener {
43*0f8a6b7dSJakub Kuderski   using RewriterBase::ForwardingListener::ForwardingListener;
44*0f8a6b7dSJakub Kuderski 
45*0f8a6b7dSJakub Kuderski   void notifyOperationErased(Operation *op) override {
46*0f8a6b7dSJakub Kuderski     checkErasure(op);
47*0f8a6b7dSJakub Kuderski     ForwardingListener::notifyOperationErased(op);
48*0f8a6b7dSJakub Kuderski   }
49*0f8a6b7dSJakub Kuderski 
50*0f8a6b7dSJakub Kuderski   void notifyBlockErased(Block *block) override {
51*0f8a6b7dSJakub Kuderski     checkErasure(block->getParentOp());
52*0f8a6b7dSJakub Kuderski     ForwardingListener::notifyBlockErased(block);
53*0f8a6b7dSJakub Kuderski   }
54*0f8a6b7dSJakub Kuderski 
55*0f8a6b7dSJakub Kuderski   void checkErasure(Operation *op) const {
56*0f8a6b7dSJakub Kuderski     Operation *ancestorOp = op;
57*0f8a6b7dSJakub Kuderski     while (ancestorOp && ancestorOp != visitedOp)
58*0f8a6b7dSJakub Kuderski       ancestorOp = ancestorOp->getParentOp();
59*0f8a6b7dSJakub Kuderski 
60*0f8a6b7dSJakub Kuderski     if (ancestorOp != visitedOp)
61*0f8a6b7dSJakub Kuderski       llvm::report_fatal_error(
62*0f8a6b7dSJakub Kuderski           "unsupported erasure in WalkPatternRewriter; "
63*0f8a6b7dSJakub Kuderski           "erasure is only supported for matched ops and their descendants");
64*0f8a6b7dSJakub Kuderski   }
65*0f8a6b7dSJakub Kuderski 
66*0f8a6b7dSJakub Kuderski   Operation *visitedOp = nullptr;
67*0f8a6b7dSJakub Kuderski };
68*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
69*0f8a6b7dSJakub Kuderski } // namespace
70*0f8a6b7dSJakub Kuderski 
71*0f8a6b7dSJakub Kuderski void walkAndApplyPatterns(Operation *op,
72*0f8a6b7dSJakub Kuderski                           const FrozenRewritePatternSet &patterns,
73*0f8a6b7dSJakub Kuderski                           RewriterBase::Listener *listener) {
74*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
75*0f8a6b7dSJakub Kuderski   if (failed(verify(op)))
76*0f8a6b7dSJakub Kuderski     llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
77*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
78*0f8a6b7dSJakub Kuderski 
79*0f8a6b7dSJakub Kuderski   MLIRContext *ctx = op->getContext();
80*0f8a6b7dSJakub Kuderski   PatternRewriter rewriter(ctx);
81*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
82*0f8a6b7dSJakub Kuderski   ErasedOpsListener erasedListener(listener);
83*0f8a6b7dSJakub Kuderski   rewriter.setListener(&erasedListener);
84*0f8a6b7dSJakub Kuderski #else
85*0f8a6b7dSJakub Kuderski   rewriter.setListener(listener);
86*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
87*0f8a6b7dSJakub Kuderski 
88*0f8a6b7dSJakub Kuderski   PatternApplicator applicator(patterns);
89*0f8a6b7dSJakub Kuderski   applicator.applyDefaultCostModel();
90*0f8a6b7dSJakub Kuderski 
91*0f8a6b7dSJakub Kuderski   ctx->executeAction<WalkAndApplyPatternsAction>(
92*0f8a6b7dSJakub Kuderski       [&] {
93*0f8a6b7dSJakub Kuderski         for (Region &region : op->getRegions()) {
94*0f8a6b7dSJakub Kuderski           region.walk([&](Operation *visitedOp) {
95*0f8a6b7dSJakub Kuderski             LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
96*0f8a6b7dSJakub Kuderski                 llvm::dbgs(), OpPrintingFlags().skipRegions());
97*0f8a6b7dSJakub Kuderski                        llvm::dbgs() << "\n";);
98*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99*0f8a6b7dSJakub Kuderski             erasedListener.visitedOp = visitedOp;
100*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
101*0f8a6b7dSJakub Kuderski             if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
102*0f8a6b7dSJakub Kuderski               LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
103*0f8a6b7dSJakub Kuderski             }
104*0f8a6b7dSJakub Kuderski           });
105*0f8a6b7dSJakub Kuderski         }
106*0f8a6b7dSJakub Kuderski       },
107*0f8a6b7dSJakub Kuderski       {op});
108*0f8a6b7dSJakub Kuderski 
109*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
110*0f8a6b7dSJakub Kuderski   if (failed(verify(op)))
111*0f8a6b7dSJakub Kuderski     llvm::report_fatal_error(
112*0f8a6b7dSJakub Kuderski         "walk pattern rewriter result IR failed to verify");
113*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
114*0f8a6b7dSJakub Kuderski }
115*0f8a6b7dSJakub Kuderski 
116*0f8a6b7dSJakub Kuderski } // namespace mlir
117