1 //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to 10 // run any optimization pass within it and only replaces the output module with 11 // the transformed version if it is smaller and interesting. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Pass/PassManager.h" 16 #include "mlir/Pass/PassRegistry.h" 17 #include "mlir/Reducer/Passes.h" 18 #include "mlir/Reducer/Tester.h" 19 #include "llvm/Support/Debug.h" 20 21 namespace mlir { 22 #define GEN_PASS_DEF_OPTREDUCTION 23 #include "mlir/Reducer/Passes.h.inc" 24 } // namespace mlir 25 26 #define DEBUG_TYPE "mlir-reduce" 27 28 using namespace mlir; 29 30 namespace { 31 32 class OptReductionPass : public impl::OptReductionBase<OptReductionPass> { 33 public: 34 /// Runs the pass instance in the pass pipeline. 35 void runOnOperation() override; 36 }; 37 38 } // namespace 39 40 /// Runs the pass instance in the pass pipeline. runOnOperation()41void OptReductionPass::runOnOperation() { 42 LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: "); 43 44 Tester test(testerName, testerArgs); 45 46 ModuleOp module = this->getOperation(); 47 ModuleOp moduleVariant = module.clone(); 48 49 OpPassManager passManager("builtin.module"); 50 if (failed(parsePassPipeline(optPass, passManager))) { 51 module.emitError() << "\nfailed to parse pass pipeline"; 52 return signalPassFailure(); 53 } 54 55 std::pair<Tester::Interestingness, int> original = test.isInteresting(module); 56 if (original.first != Tester::Interestingness::True) { 57 module.emitError() << "\nthe original input is not interested"; 58 return signalPassFailure(); 59 } 60 61 // Temporarily push the variant under the main module and execute the pipeline 62 // on it. 63 module.getBody()->push_back(moduleVariant); 64 LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); 65 moduleVariant->remove(); 66 67 if (failed(pipelineResult)) { 68 module.emitError() << "\nfailed to run pass pipeline"; 69 return signalPassFailure(); 70 } 71 72 std::pair<Tester::Interestingness, int> reduced = 73 test.isInteresting(moduleVariant); 74 75 if (reduced.first == Tester::Interestingness::True && 76 reduced.second < original.second) { 77 module.getBody()->clear(); 78 module.getBody()->getOperations().splice( 79 module.getBody()->begin(), moduleVariant.getBody()->getOperations()); 80 LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n"); 81 } else { 82 LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n"); 83 } 84 85 moduleVariant->destroy(); 86 87 LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n"); 88 } 89 createOptReductionPass()90std::unique_ptr<Pass> mlir::createOptReductionPass() { 91 return std::make_unique<OptReductionPass>(); 92 } 93