1c484c7ddSChia-hung Duan //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===// 2c484c7ddSChia-hung Duan // 3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information. 5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c484c7ddSChia-hung Duan // 7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 8c484c7ddSChia-hung Duan // 9c484c7ddSChia-hung Duan // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to 10c484c7ddSChia-hung Duan // run any optimization pass within it and only replaces the output module with 11c484c7ddSChia-hung Duan // the transformed version if it is smaller and interesting. 12c484c7ddSChia-hung Duan // 13c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 14c484c7ddSChia-hung Duan 15c484c7ddSChia-hung Duan #include "mlir/Pass/PassManager.h" 16c484c7ddSChia-hung Duan #include "mlir/Pass/PassRegistry.h" 17c484c7ddSChia-hung Duan #include "mlir/Reducer/Passes.h" 18c484c7ddSChia-hung Duan #include "mlir/Reducer/Tester.h" 19c484c7ddSChia-hung Duan #include "llvm/Support/Debug.h" 20c484c7ddSChia-hung Duan 21*67d0d7acSMichele Scuttari namespace mlir { 22*67d0d7acSMichele Scuttari #define GEN_PASS_DEF_OPTREDUCTION 23*67d0d7acSMichele Scuttari #include "mlir/Reducer/Passes.h.inc" 24*67d0d7acSMichele Scuttari } // namespace mlir 25*67d0d7acSMichele Scuttari 26c484c7ddSChia-hung Duan #define DEBUG_TYPE "mlir-reduce" 27c484c7ddSChia-hung Duan 28c484c7ddSChia-hung Duan using namespace mlir; 29c484c7ddSChia-hung Duan 30c484c7ddSChia-hung Duan namespace { 31c484c7ddSChia-hung Duan 32*67d0d7acSMichele Scuttari class OptReductionPass : public impl::OptReductionBase<OptReductionPass> { 33c484c7ddSChia-hung Duan public: 34c484c7ddSChia-hung Duan /// Runs the pass instance in the pass pipeline. 35c484c7ddSChia-hung Duan void runOnOperation() override; 36c484c7ddSChia-hung Duan }; 37c484c7ddSChia-hung Duan 38be0a7e9fSMehdi Amini } // namespace 39c484c7ddSChia-hung Duan 40c484c7ddSChia-hung Duan /// Runs the pass instance in the pass pipeline. runOnOperation()41c484c7ddSChia-hung Duanvoid OptReductionPass::runOnOperation() { 42c484c7ddSChia-hung Duan LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: "); 43c484c7ddSChia-hung Duan 44c484c7ddSChia-hung Duan Tester test(testerName, testerArgs); 45c484c7ddSChia-hung Duan 46c484c7ddSChia-hung Duan ModuleOp module = this->getOperation(); 47c484c7ddSChia-hung Duan ModuleOp moduleVariant = module.clone(); 48c484c7ddSChia-hung Duan 490f304ef0SRiver Riddle OpPassManager passManager("builtin.module"); 50c484c7ddSChia-hung Duan if (failed(parsePassPipeline(optPass, passManager))) { 511a001dedSChia-hung Duan module.emitError() << "\nfailed to parse pass pipeline"; 521a001dedSChia-hung Duan return signalPassFailure(); 53c484c7ddSChia-hung Duan } 54c484c7ddSChia-hung Duan 55c484c7ddSChia-hung Duan std::pair<Tester::Interestingness, int> original = test.isInteresting(module); 56c484c7ddSChia-hung Duan if (original.first != Tester::Interestingness::True) { 571a001dedSChia-hung Duan module.emitError() << "\nthe original input is not interested"; 581a001dedSChia-hung Duan return signalPassFailure(); 59c484c7ddSChia-hung Duan } 60c484c7ddSChia-hung Duan 610f304ef0SRiver Riddle // Temporarily push the variant under the main module and execute the pipeline 620f304ef0SRiver Riddle // on it. 630f304ef0SRiver Riddle module.getBody()->push_back(moduleVariant); 640f304ef0SRiver Riddle LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); 650f304ef0SRiver Riddle moduleVariant->remove(); 660f304ef0SRiver Riddle 670f304ef0SRiver Riddle if (failed(pipelineResult)) { 681a001dedSChia-hung Duan module.emitError() << "\nfailed to run pass pipeline"; 691a001dedSChia-hung Duan return signalPassFailure(); 70c484c7ddSChia-hung Duan } 71c484c7ddSChia-hung Duan 72c484c7ddSChia-hung Duan std::pair<Tester::Interestingness, int> reduced = 73c484c7ddSChia-hung Duan test.isInteresting(moduleVariant); 74c484c7ddSChia-hung Duan 75c484c7ddSChia-hung Duan if (reduced.first == Tester::Interestingness::True && 76c484c7ddSChia-hung Duan reduced.second < original.second) { 77c484c7ddSChia-hung Duan module.getBody()->clear(); 78c484c7ddSChia-hung Duan module.getBody()->getOperations().splice( 79c484c7ddSChia-hung Duan module.getBody()->begin(), moduleVariant.getBody()->getOperations()); 80c484c7ddSChia-hung Duan LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n"); 81c484c7ddSChia-hung Duan } else { 82c484c7ddSChia-hung Duan LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n"); 83c484c7ddSChia-hung Duan } 84c484c7ddSChia-hung Duan 85c484c7ddSChia-hung Duan moduleVariant->destroy(); 86c484c7ddSChia-hung Duan 87c484c7ddSChia-hung Duan LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n"); 88c484c7ddSChia-hung Duan } 89c484c7ddSChia-hung Duan createOptReductionPass()90c484c7ddSChia-hung Duanstd::unique_ptr<Pass> mlir::createOptReductionPass() { 91c484c7ddSChia-hung Duan return std::make_unique<OptReductionPass>(); 92c484c7ddSChia-hung Duan } 93