19e3b928eSChris Lattner //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===// 29e3b928eSChris Lattner // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 69e3b928eSChris Lattner // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 89e3b928eSChris Lattner // 99e3b928eSChris Lattner // This transformation pass converts operations into their canonical forms by 109e3b928eSChris Lattner // folding constants, applying operation identity transformations etc. 119e3b928eSChris Lattner // 129e3b928eSChris Lattner //===----------------------------------------------------------------------===// 139e3b928eSChris Lattner 1467d0d7acSMichele Scuttari #include "mlir/Transforms/Passes.h" 1567d0d7acSMichele Scuttari 16*35df525fSDiego Caballero #include "mlir/Dialect/UB/IR/UBOps.h" 1748ccae24SRiver Riddle #include "mlir/Pass/Pass.h" 18b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 1967d0d7acSMichele Scuttari 2067d0d7acSMichele Scuttari namespace mlir { 2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_CANONICALIZER 2267d0d7acSMichele Scuttari #include "mlir/Transforms/Passes.h.inc" 2367d0d7acSMichele Scuttari } // namespace mlir 24ad4b4acbSUday Bondhugula 259e3b928eSChris Lattner using namespace mlir; 269e3b928eSChris Lattner 279e3b928eSChris Lattner namespace { 282b61b797SRiver Riddle /// Canonicalize operations in nested regions. 2967d0d7acSMichele Scuttari struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> { 30039b969bSMichele Scuttari Canonicalizer() = default; 31039b969bSMichele Scuttari Canonicalizer(const GreedyRewriteConfig &config, 3242ac4f3dSMogball ArrayRef<std::string> disabledPatterns, 33cd7af14cSMehdi Amini ArrayRef<std::string> enabledPatterns) 34cd7af14cSMehdi Amini : config(config) { 35ebad5fb3Srkayaith this->topDownProcessingEnabled = config.useTopDownTraversal; 36ebad5fb3Srkayaith this->enableRegionSimplification = config.enableRegionSimplification; 37ebad5fb3Srkayaith this->maxIterations = config.maxIterations; 38391cb541SMatthias Springer this->maxNumRewrites = config.maxNumRewrites; 3942ac4f3dSMogball this->disabledPatterns = disabledPatterns; 4042ac4f3dSMogball this->enabledPatterns = enabledPatterns; 4142ac4f3dSMogball } 422f23f9e6SChris Lattner 431ba5ea67SRiver Riddle /// Initialize the canonicalizer by building the set of patterns used during 441ba5ea67SRiver Riddle /// execution. 45b1aaed02SMehdi Amini LogicalResult initialize(MLIRContext *context) override { 46cd7af14cSMehdi Amini // Set the config from possible pass options set in the meantime. 47cd7af14cSMehdi Amini config.useTopDownTraversal = topDownProcessingEnabled; 48cd7af14cSMehdi Amini config.enableRegionSimplification = enableRegionSimplification; 49cd7af14cSMehdi Amini config.maxIterations = maxIterations; 50cd7af14cSMehdi Amini config.maxNumRewrites = maxNumRewrites; 51cd7af14cSMehdi Amini 52dc4e913bSChris Lattner RewritePatternSet owningPatterns(context); 53108ca7a7SMatthias Springer for (auto *dialect : context->getLoadedDialects()) 54108ca7a7SMatthias Springer dialect->getCanonicalizationPatterns(owningPatterns); 55edc6c0ecSRiver Riddle for (RegisteredOperationName op : context->getRegisteredOperations()) 56edc6c0ecSRiver Riddle op.getCanonicalizationPatterns(owningPatterns, context); 570289a269SRiver Riddle 58cd7af14cSMehdi Amini patterns = std::make_shared<FrozenRewritePatternSet>( 59cd7af14cSMehdi Amini std::move(owningPatterns), disabledPatterns, enabledPatterns); 60b1aaed02SMehdi Amini return success(); 619e3b928eSChris Lattner } 621ba5ea67SRiver Riddle void runOnOperation() override { 63e7790fbeSMatthias Springer LogicalResult converged = 6409dfc571SJacques Pienaar applyPatternsGreedily(getOperation(), *patterns, config); 65afc800b1SMatthias Springer // Canonicalization is best-effort. Non-convergence is not a pass failure. 66e7790fbeSMatthias Springer if (testConvergence && failed(converged)) 67e7790fbeSMatthias Springer signalPassFailure(); 681ba5ea67SRiver Riddle } 69cd7af14cSMehdi Amini GreedyRewriteConfig config; 70cd7af14cSMehdi Amini std::shared_ptr<const FrozenRewritePatternSet> patterns; 712b61b797SRiver Riddle }; 72be0a7e9fSMehdi Amini } // namespace 739e3b928eSChris Lattner 749e3b928eSChris Lattner /// Create a Canonicalizer pass. 752b61b797SRiver Riddle std::unique_ptr<Pass> mlir::createCanonicalizerPass() { 76039b969bSMichele Scuttari return std::make_unique<Canonicalizer>(); 77c6c53449SRiver Riddle } 782f23f9e6SChris Lattner 792f23f9e6SChris Lattner /// Creates an instance of the Canonicalizer pass with the specified config. 802f23f9e6SChris Lattner std::unique_ptr<Pass> 81db68e6abSMogball mlir::createCanonicalizerPass(const GreedyRewriteConfig &config, 82db68e6abSMogball ArrayRef<std::string> disabledPatterns, 83db68e6abSMogball ArrayRef<std::string> enabledPatterns) { 84039b969bSMichele Scuttari return std::make_unique<Canonicalizer>(config, disabledPatterns, 8542ac4f3dSMogball enabledPatterns); 862f23f9e6SChris Lattner } 87