xref: /llvm-project/mlir/lib/Transforms/Canonicalizer.cpp (revision 35df525fd00c2037ef144189ee818b7d612241ff)
19e3b928eSChris Lattner //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
29e3b928eSChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69e3b928eSChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
89e3b928eSChris Lattner //
99e3b928eSChris Lattner // This transformation pass converts operations into their canonical forms by
109e3b928eSChris Lattner // folding constants, applying operation identity transformations etc.
119e3b928eSChris Lattner //
129e3b928eSChris Lattner //===----------------------------------------------------------------------===//
139e3b928eSChris Lattner 
1467d0d7acSMichele Scuttari #include "mlir/Transforms/Passes.h"
1567d0d7acSMichele Scuttari 
16*35df525fSDiego Caballero #include "mlir/Dialect/UB/IR/UBOps.h"
1748ccae24SRiver Riddle #include "mlir/Pass/Pass.h"
18b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1967d0d7acSMichele Scuttari 
2067d0d7acSMichele Scuttari namespace mlir {
2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_CANONICALIZER
2267d0d7acSMichele Scuttari #include "mlir/Transforms/Passes.h.inc"
2367d0d7acSMichele Scuttari } // namespace mlir
24ad4b4acbSUday Bondhugula 
259e3b928eSChris Lattner using namespace mlir;
269e3b928eSChris Lattner 
279e3b928eSChris Lattner namespace {
282b61b797SRiver Riddle /// Canonicalize operations in nested regions.
2967d0d7acSMichele Scuttari struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
30039b969bSMichele Scuttari   Canonicalizer() = default;
31039b969bSMichele Scuttari   Canonicalizer(const GreedyRewriteConfig &config,
3242ac4f3dSMogball                 ArrayRef<std::string> disabledPatterns,
33cd7af14cSMehdi Amini                 ArrayRef<std::string> enabledPatterns)
34cd7af14cSMehdi Amini       : config(config) {
35ebad5fb3Srkayaith     this->topDownProcessingEnabled = config.useTopDownTraversal;
36ebad5fb3Srkayaith     this->enableRegionSimplification = config.enableRegionSimplification;
37ebad5fb3Srkayaith     this->maxIterations = config.maxIterations;
38391cb541SMatthias Springer     this->maxNumRewrites = config.maxNumRewrites;
3942ac4f3dSMogball     this->disabledPatterns = disabledPatterns;
4042ac4f3dSMogball     this->enabledPatterns = enabledPatterns;
4142ac4f3dSMogball   }
422f23f9e6SChris Lattner 
431ba5ea67SRiver Riddle   /// Initialize the canonicalizer by building the set of patterns used during
441ba5ea67SRiver Riddle   /// execution.
45b1aaed02SMehdi Amini   LogicalResult initialize(MLIRContext *context) override {
46cd7af14cSMehdi Amini     // Set the config from possible pass options set in the meantime.
47cd7af14cSMehdi Amini     config.useTopDownTraversal = topDownProcessingEnabled;
48cd7af14cSMehdi Amini     config.enableRegionSimplification = enableRegionSimplification;
49cd7af14cSMehdi Amini     config.maxIterations = maxIterations;
50cd7af14cSMehdi Amini     config.maxNumRewrites = maxNumRewrites;
51cd7af14cSMehdi Amini 
52dc4e913bSChris Lattner     RewritePatternSet owningPatterns(context);
53108ca7a7SMatthias Springer     for (auto *dialect : context->getLoadedDialects())
54108ca7a7SMatthias Springer       dialect->getCanonicalizationPatterns(owningPatterns);
55edc6c0ecSRiver Riddle     for (RegisteredOperationName op : context->getRegisteredOperations())
56edc6c0ecSRiver Riddle       op.getCanonicalizationPatterns(owningPatterns, context);
570289a269SRiver Riddle 
58cd7af14cSMehdi Amini     patterns = std::make_shared<FrozenRewritePatternSet>(
59cd7af14cSMehdi Amini         std::move(owningPatterns), disabledPatterns, enabledPatterns);
60b1aaed02SMehdi Amini     return success();
619e3b928eSChris Lattner   }
621ba5ea67SRiver Riddle   void runOnOperation() override {
63e7790fbeSMatthias Springer     LogicalResult converged =
6409dfc571SJacques Pienaar         applyPatternsGreedily(getOperation(), *patterns, config);
65afc800b1SMatthias Springer     // Canonicalization is best-effort. Non-convergence is not a pass failure.
66e7790fbeSMatthias Springer     if (testConvergence && failed(converged))
67e7790fbeSMatthias Springer       signalPassFailure();
681ba5ea67SRiver Riddle   }
69cd7af14cSMehdi Amini   GreedyRewriteConfig config;
70cd7af14cSMehdi Amini   std::shared_ptr<const FrozenRewritePatternSet> patterns;
712b61b797SRiver Riddle };
72be0a7e9fSMehdi Amini } // namespace
739e3b928eSChris Lattner 
749e3b928eSChris Lattner /// Create a Canonicalizer pass.
752b61b797SRiver Riddle std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
76039b969bSMichele Scuttari   return std::make_unique<Canonicalizer>();
77c6c53449SRiver Riddle }
782f23f9e6SChris Lattner 
792f23f9e6SChris Lattner /// Creates an instance of the Canonicalizer pass with the specified config.
802f23f9e6SChris Lattner std::unique_ptr<Pass>
81db68e6abSMogball mlir::createCanonicalizerPass(const GreedyRewriteConfig &config,
82db68e6abSMogball                               ArrayRef<std::string> disabledPatterns,
83db68e6abSMogball                               ArrayRef<std::string> enabledPatterns) {
84039b969bSMichele Scuttari   return std::make_unique<Canonicalizer>(config, disabledPatterns,
8542ac4f3dSMogball                                          enabledPatterns);
862f23f9e6SChris Lattner }
87