1*0f8a6b7dSJakub Kuderski //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===// 2*0f8a6b7dSJakub Kuderski // 3*0f8a6b7dSJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*0f8a6b7dSJakub Kuderski // See https://llvm.org/LICENSE.txt for license information. 5*0f8a6b7dSJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*0f8a6b7dSJakub Kuderski // 7*0f8a6b7dSJakub Kuderski //===----------------------------------------------------------------------===// 8*0f8a6b7dSJakub Kuderski // 9*0f8a6b7dSJakub Kuderski // Implements mlir::walkAndApplyPatterns. 10*0f8a6b7dSJakub Kuderski // 11*0f8a6b7dSJakub Kuderski //===----------------------------------------------------------------------===// 12*0f8a6b7dSJakub Kuderski 13*0f8a6b7dSJakub Kuderski #include "mlir/Transforms/WalkPatternRewriteDriver.h" 14*0f8a6b7dSJakub Kuderski 15*0f8a6b7dSJakub Kuderski #include "mlir/IR/MLIRContext.h" 16*0f8a6b7dSJakub Kuderski #include "mlir/IR/OperationSupport.h" 17*0f8a6b7dSJakub Kuderski #include "mlir/IR/PatternMatch.h" 18*0f8a6b7dSJakub Kuderski #include "mlir/IR/Verifier.h" 19*0f8a6b7dSJakub Kuderski #include "mlir/IR/Visitors.h" 20*0f8a6b7dSJakub Kuderski #include "mlir/Rewrite/PatternApplicator.h" 21*0f8a6b7dSJakub Kuderski #include "llvm/Support/Debug.h" 22*0f8a6b7dSJakub Kuderski #include "llvm/Support/ErrorHandling.h" 23*0f8a6b7dSJakub Kuderski 24*0f8a6b7dSJakub Kuderski #define DEBUG_TYPE "walk-rewriter" 25*0f8a6b7dSJakub Kuderski 26*0f8a6b7dSJakub Kuderski namespace mlir { 27*0f8a6b7dSJakub Kuderski 28*0f8a6b7dSJakub Kuderski namespace { 29*0f8a6b7dSJakub Kuderski struct WalkAndApplyPatternsAction final 30*0f8a6b7dSJakub Kuderski : tracing::ActionImpl<WalkAndApplyPatternsAction> { 31*0f8a6b7dSJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction) 32*0f8a6b7dSJakub Kuderski using ActionImpl::ActionImpl; 33*0f8a6b7dSJakub Kuderski static constexpr StringLiteral tag = "walk-and-apply-patterns"; 34*0f8a6b7dSJakub Kuderski void print(raw_ostream &os) const override { os << tag; } 35*0f8a6b7dSJakub Kuderski }; 36*0f8a6b7dSJakub Kuderski 37*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 38*0f8a6b7dSJakub Kuderski // Forwarding listener to guard against unsupported erasures of non-descendant 39*0f8a6b7dSJakub Kuderski // ops/blocks. Because we use walk-based pattern application, erasing the 40*0f8a6b7dSJakub Kuderski // op/block from the *next* iteration (e.g., a user of the visited op) is not 41*0f8a6b7dSJakub Kuderski // valid. Note that this is only used with expensive pattern API checks. 42*0f8a6b7dSJakub Kuderski struct ErasedOpsListener final : RewriterBase::ForwardingListener { 43*0f8a6b7dSJakub Kuderski using RewriterBase::ForwardingListener::ForwardingListener; 44*0f8a6b7dSJakub Kuderski 45*0f8a6b7dSJakub Kuderski void notifyOperationErased(Operation *op) override { 46*0f8a6b7dSJakub Kuderski checkErasure(op); 47*0f8a6b7dSJakub Kuderski ForwardingListener::notifyOperationErased(op); 48*0f8a6b7dSJakub Kuderski } 49*0f8a6b7dSJakub Kuderski 50*0f8a6b7dSJakub Kuderski void notifyBlockErased(Block *block) override { 51*0f8a6b7dSJakub Kuderski checkErasure(block->getParentOp()); 52*0f8a6b7dSJakub Kuderski ForwardingListener::notifyBlockErased(block); 53*0f8a6b7dSJakub Kuderski } 54*0f8a6b7dSJakub Kuderski 55*0f8a6b7dSJakub Kuderski void checkErasure(Operation *op) const { 56*0f8a6b7dSJakub Kuderski Operation *ancestorOp = op; 57*0f8a6b7dSJakub Kuderski while (ancestorOp && ancestorOp != visitedOp) 58*0f8a6b7dSJakub Kuderski ancestorOp = ancestorOp->getParentOp(); 59*0f8a6b7dSJakub Kuderski 60*0f8a6b7dSJakub Kuderski if (ancestorOp != visitedOp) 61*0f8a6b7dSJakub Kuderski llvm::report_fatal_error( 62*0f8a6b7dSJakub Kuderski "unsupported erasure in WalkPatternRewriter; " 63*0f8a6b7dSJakub Kuderski "erasure is only supported for matched ops and their descendants"); 64*0f8a6b7dSJakub Kuderski } 65*0f8a6b7dSJakub Kuderski 66*0f8a6b7dSJakub Kuderski Operation *visitedOp = nullptr; 67*0f8a6b7dSJakub Kuderski }; 68*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 69*0f8a6b7dSJakub Kuderski } // namespace 70*0f8a6b7dSJakub Kuderski 71*0f8a6b7dSJakub Kuderski void walkAndApplyPatterns(Operation *op, 72*0f8a6b7dSJakub Kuderski const FrozenRewritePatternSet &patterns, 73*0f8a6b7dSJakub Kuderski RewriterBase::Listener *listener) { 74*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 75*0f8a6b7dSJakub Kuderski if (failed(verify(op))) 76*0f8a6b7dSJakub Kuderski llvm::report_fatal_error("walk pattern rewriter input IR failed to verify"); 77*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 78*0f8a6b7dSJakub Kuderski 79*0f8a6b7dSJakub Kuderski MLIRContext *ctx = op->getContext(); 80*0f8a6b7dSJakub Kuderski PatternRewriter rewriter(ctx); 81*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 82*0f8a6b7dSJakub Kuderski ErasedOpsListener erasedListener(listener); 83*0f8a6b7dSJakub Kuderski rewriter.setListener(&erasedListener); 84*0f8a6b7dSJakub Kuderski #else 85*0f8a6b7dSJakub Kuderski rewriter.setListener(listener); 86*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 87*0f8a6b7dSJakub Kuderski 88*0f8a6b7dSJakub Kuderski PatternApplicator applicator(patterns); 89*0f8a6b7dSJakub Kuderski applicator.applyDefaultCostModel(); 90*0f8a6b7dSJakub Kuderski 91*0f8a6b7dSJakub Kuderski ctx->executeAction<WalkAndApplyPatternsAction>( 92*0f8a6b7dSJakub Kuderski [&] { 93*0f8a6b7dSJakub Kuderski for (Region ®ion : op->getRegions()) { 94*0f8a6b7dSJakub Kuderski region.walk([&](Operation *visitedOp) { 95*0f8a6b7dSJakub Kuderski LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( 96*0f8a6b7dSJakub Kuderski llvm::dbgs(), OpPrintingFlags().skipRegions()); 97*0f8a6b7dSJakub Kuderski llvm::dbgs() << "\n";); 98*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 99*0f8a6b7dSJakub Kuderski erasedListener.visitedOp = visitedOp; 100*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 101*0f8a6b7dSJakub Kuderski if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) { 102*0f8a6b7dSJakub Kuderski LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); 103*0f8a6b7dSJakub Kuderski } 104*0f8a6b7dSJakub Kuderski }); 105*0f8a6b7dSJakub Kuderski } 106*0f8a6b7dSJakub Kuderski }, 107*0f8a6b7dSJakub Kuderski {op}); 108*0f8a6b7dSJakub Kuderski 109*0f8a6b7dSJakub Kuderski #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 110*0f8a6b7dSJakub Kuderski if (failed(verify(op))) 111*0f8a6b7dSJakub Kuderski llvm::report_fatal_error( 112*0f8a6b7dSJakub Kuderski "walk pattern rewriter result IR failed to verify"); 113*0f8a6b7dSJakub Kuderski #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 114*0f8a6b7dSJakub Kuderski } 115*0f8a6b7dSJakub Kuderski 116*0f8a6b7dSJakub Kuderski } // namespace mlir 117