1 //===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements mlir::walkAndApplyPatterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/WalkPatternRewriteDriver.h" 14 15 #include "mlir/IR/MLIRContext.h" 16 #include "mlir/IR/OperationSupport.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/IR/Verifier.h" 19 #include "mlir/IR/Visitors.h" 20 #include "mlir/Rewrite/PatternApplicator.h" 21 #include "llvm/Support/Debug.h" 22 #include "llvm/Support/ErrorHandling.h" 23 24 #define DEBUG_TYPE "walk-rewriter" 25 26 namespace mlir { 27 28 namespace { 29 struct WalkAndApplyPatternsAction final 30 : tracing::ActionImpl<WalkAndApplyPatternsAction> { 31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction) 32 using ActionImpl::ActionImpl; 33 static constexpr StringLiteral tag = "walk-and-apply-patterns"; 34 void print(raw_ostream &os) const override { os << tag; } 35 }; 36 37 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 38 // Forwarding listener to guard against unsupported erasures of non-descendant 39 // ops/blocks. Because we use walk-based pattern application, erasing the 40 // op/block from the *next* iteration (e.g., a user of the visited op) is not 41 // valid. Note that this is only used with expensive pattern API checks. 42 struct ErasedOpsListener final : RewriterBase::ForwardingListener { 43 using RewriterBase::ForwardingListener::ForwardingListener; 44 45 void notifyOperationErased(Operation *op) override { 46 checkErasure(op); 47 ForwardingListener::notifyOperationErased(op); 48 } 49 50 void notifyBlockErased(Block *block) override { 51 checkErasure(block->getParentOp()); 52 ForwardingListener::notifyBlockErased(block); 53 } 54 55 void checkErasure(Operation *op) const { 56 Operation *ancestorOp = op; 57 while (ancestorOp && ancestorOp != visitedOp) 58 ancestorOp = ancestorOp->getParentOp(); 59 60 if (ancestorOp != visitedOp) 61 llvm::report_fatal_error( 62 "unsupported erasure in WalkPatternRewriter; " 63 "erasure is only supported for matched ops and their descendants"); 64 } 65 66 Operation *visitedOp = nullptr; 67 }; 68 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 69 } // namespace 70 71 void walkAndApplyPatterns(Operation *op, 72 const FrozenRewritePatternSet &patterns, 73 RewriterBase::Listener *listener) { 74 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 75 if (failed(verify(op))) 76 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify"); 77 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 78 79 MLIRContext *ctx = op->getContext(); 80 PatternRewriter rewriter(ctx); 81 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 82 ErasedOpsListener erasedListener(listener); 83 rewriter.setListener(&erasedListener); 84 #else 85 rewriter.setListener(listener); 86 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 87 88 PatternApplicator applicator(patterns); 89 applicator.applyDefaultCostModel(); 90 91 ctx->executeAction<WalkAndApplyPatternsAction>( 92 [&] { 93 for (Region ®ion : op->getRegions()) { 94 region.walk([&](Operation *visitedOp) { 95 LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( 96 llvm::dbgs(), OpPrintingFlags().skipRegions()); 97 llvm::dbgs() << "\n";); 98 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 99 erasedListener.visitedOp = visitedOp; 100 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 101 if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) { 102 LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); 103 } 104 }); 105 } 106 }, 107 {op}); 108 109 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 110 if (failed(verify(op))) 111 llvm::report_fatal_error( 112 "walk pattern rewriter result IR failed to verify"); 113 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 114 } 115 116 } // namespace mlir 117