xref: /llvm-project/mlir/lib/Transforms/Canonicalizer.cpp (revision 35df525fd00c2037ef144189ee818b7d612241ff)
1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass converts operations into their canonical forms by
10 // folding constants, applying operation identity transformations etc.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/UB/IR/UBOps.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CANONICALIZER
22 #include "mlir/Transforms/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 
27 namespace {
28 /// Canonicalize operations in nested regions.
29 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
30   Canonicalizer() = default;
31   Canonicalizer(const GreedyRewriteConfig &config,
32                 ArrayRef<std::string> disabledPatterns,
33                 ArrayRef<std::string> enabledPatterns)
34       : config(config) {
35     this->topDownProcessingEnabled = config.useTopDownTraversal;
36     this->enableRegionSimplification = config.enableRegionSimplification;
37     this->maxIterations = config.maxIterations;
38     this->maxNumRewrites = config.maxNumRewrites;
39     this->disabledPatterns = disabledPatterns;
40     this->enabledPatterns = enabledPatterns;
41   }
42 
43   /// Initialize the canonicalizer by building the set of patterns used during
44   /// execution.
45   LogicalResult initialize(MLIRContext *context) override {
46     // Set the config from possible pass options set in the meantime.
47     config.useTopDownTraversal = topDownProcessingEnabled;
48     config.enableRegionSimplification = enableRegionSimplification;
49     config.maxIterations = maxIterations;
50     config.maxNumRewrites = maxNumRewrites;
51 
52     RewritePatternSet owningPatterns(context);
53     for (auto *dialect : context->getLoadedDialects())
54       dialect->getCanonicalizationPatterns(owningPatterns);
55     for (RegisteredOperationName op : context->getRegisteredOperations())
56       op.getCanonicalizationPatterns(owningPatterns, context);
57 
58     patterns = std::make_shared<FrozenRewritePatternSet>(
59         std::move(owningPatterns), disabledPatterns, enabledPatterns);
60     return success();
61   }
62   void runOnOperation() override {
63     LogicalResult converged =
64         applyPatternsGreedily(getOperation(), *patterns, config);
65     // Canonicalization is best-effort. Non-convergence is not a pass failure.
66     if (testConvergence && failed(converged))
67       signalPassFailure();
68   }
69   GreedyRewriteConfig config;
70   std::shared_ptr<const FrozenRewritePatternSet> patterns;
71 };
72 } // namespace
73 
74 /// Create a Canonicalizer pass.
75 std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
76   return std::make_unique<Canonicalizer>();
77 }
78 
79 /// Creates an instance of the Canonicalizer pass with the specified config.
80 std::unique_ptr<Pass>
81 mlir::createCanonicalizerPass(const GreedyRewriteConfig &config,
82                               ArrayRef<std::string> disabledPatterns,
83                               ArrayRef<std::string> enabledPatterns) {
84   return std::make_unique<Canonicalizer>(config, disabledPatterns,
85                                          enabledPatterns);
86 }
87