1 //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to 10 // run any optimization pass within it and only replaces the output module with 11 // the transformed version if it is smaller and interesting. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Pass/PassManager.h" 16 #include "mlir/Pass/PassRegistry.h" 17 #include "mlir/Reducer/PassDetail.h" 18 #include "mlir/Reducer/Passes.h" 19 #include "mlir/Reducer/Tester.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "mlir-reduce" 23 24 using namespace mlir; 25 26 namespace { 27 28 class OptReductionPass : public OptReductionBase<OptReductionPass> { 29 public: 30 /// Runs the pass instance in the pass pipeline. 31 void runOnOperation() override; 32 }; 33 34 } // namespace 35 36 /// Runs the pass instance in the pass pipeline. 37 void OptReductionPass::runOnOperation() { 38 LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: "); 39 40 Tester test(testerName, testerArgs); 41 42 ModuleOp module = this->getOperation(); 43 ModuleOp moduleVariant = module.clone(); 44 45 OpPassManager passManager("builtin.module"); 46 if (failed(parsePassPipeline(optPass, passManager))) { 47 module.emitError() << "\nfailed to parse pass pipeline"; 48 return signalPassFailure(); 49 } 50 51 std::pair<Tester::Interestingness, int> original = test.isInteresting(module); 52 if (original.first != Tester::Interestingness::True) { 53 module.emitError() << "\nthe original input is not interested"; 54 return signalPassFailure(); 55 } 56 57 // Temporarily push the variant under the main module and execute the pipeline 58 // on it. 59 module.getBody()->push_back(moduleVariant); 60 LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); 61 moduleVariant->remove(); 62 63 if (failed(pipelineResult)) { 64 module.emitError() << "\nfailed to run pass pipeline"; 65 return signalPassFailure(); 66 } 67 68 std::pair<Tester::Interestingness, int> reduced = 69 test.isInteresting(moduleVariant); 70 71 if (reduced.first == Tester::Interestingness::True && 72 reduced.second < original.second) { 73 module.getBody()->clear(); 74 module.getBody()->getOperations().splice( 75 module.getBody()->begin(), moduleVariant.getBody()->getOperations()); 76 LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n"); 77 } else { 78 LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n"); 79 } 80 81 moduleVariant->destroy(); 82 83 LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n"); 84 } 85 86 std::unique_ptr<Pass> mlir::createOptReductionPass() { 87 return std::make_unique<OptReductionPass>(); 88 } 89