1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass converts operations into their canonical forms by 10 // folding constants, applying operation identity transformations etc. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Transforms/Passes.h" 15 16 #include "mlir/Dialect/UB/IR/UBOps.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 namespace mlir { 21 #define GEN_PASS_DEF_CANONICALIZER 22 #include "mlir/Transforms/Passes.h.inc" 23 } // namespace mlir 24 25 using namespace mlir; 26 27 namespace { 28 /// Canonicalize operations in nested regions. 29 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> { 30 Canonicalizer() = default; 31 Canonicalizer(const GreedyRewriteConfig &config, 32 ArrayRef<std::string> disabledPatterns, 33 ArrayRef<std::string> enabledPatterns) 34 : config(config) { 35 this->topDownProcessingEnabled = config.useTopDownTraversal; 36 this->enableRegionSimplification = config.enableRegionSimplification; 37 this->maxIterations = config.maxIterations; 38 this->maxNumRewrites = config.maxNumRewrites; 39 this->disabledPatterns = disabledPatterns; 40 this->enabledPatterns = enabledPatterns; 41 } 42 43 /// Initialize the canonicalizer by building the set of patterns used during 44 /// execution. 45 LogicalResult initialize(MLIRContext *context) override { 46 // Set the config from possible pass options set in the meantime. 47 config.useTopDownTraversal = topDownProcessingEnabled; 48 config.enableRegionSimplification = enableRegionSimplification; 49 config.maxIterations = maxIterations; 50 config.maxNumRewrites = maxNumRewrites; 51 52 RewritePatternSet owningPatterns(context); 53 for (auto *dialect : context->getLoadedDialects()) 54 dialect->getCanonicalizationPatterns(owningPatterns); 55 for (RegisteredOperationName op : context->getRegisteredOperations()) 56 op.getCanonicalizationPatterns(owningPatterns, context); 57 58 patterns = std::make_shared<FrozenRewritePatternSet>( 59 std::move(owningPatterns), disabledPatterns, enabledPatterns); 60 return success(); 61 } 62 void runOnOperation() override { 63 LogicalResult converged = 64 applyPatternsGreedily(getOperation(), *patterns, config); 65 // Canonicalization is best-effort. Non-convergence is not a pass failure. 66 if (testConvergence && failed(converged)) 67 signalPassFailure(); 68 } 69 GreedyRewriteConfig config; 70 std::shared_ptr<const FrozenRewritePatternSet> patterns; 71 }; 72 } // namespace 73 74 /// Create a Canonicalizer pass. 75 std::unique_ptr<Pass> mlir::createCanonicalizerPass() { 76 return std::make_unique<Canonicalizer>(); 77 } 78 79 /// Creates an instance of the Canonicalizer pass with the specified config. 80 std::unique_ptr<Pass> 81 mlir::createCanonicalizerPass(const GreedyRewriteConfig &config, 82 ArrayRef<std::string> disabledPatterns, 83 ArrayRef<std::string> enabledPatterns) { 84 return std::make_unique<Canonicalizer>(config, disabledPatterns, 85 enabledPatterns); 86 } 87