xref: /llvm-project/mlir/lib/Reducer/OptReductionPass.cpp (revision 67d0d7ac0acb0665d6a09f61278fbcf51f0114c2)
1 //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
10 // run any optimization pass within it and only replaces the output module with
11 // the transformed version if it is smaller and interesting.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Pass/PassRegistry.h"
17 #include "mlir/Reducer/Passes.h"
18 #include "mlir/Reducer/Tester.h"
19 #include "llvm/Support/Debug.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_OPTREDUCTION
23 #include "mlir/Reducer/Passes.h.inc"
24 } // namespace mlir
25 
26 #define DEBUG_TYPE "mlir-reduce"
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 class OptReductionPass : public impl::OptReductionBase<OptReductionPass> {
33 public:
34   /// Runs the pass instance in the pass pipeline.
35   void runOnOperation() override;
36 };
37 
38 } // namespace
39 
40 /// Runs the pass instance in the pass pipeline.
runOnOperation()41 void OptReductionPass::runOnOperation() {
42   LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
43 
44   Tester test(testerName, testerArgs);
45 
46   ModuleOp module = this->getOperation();
47   ModuleOp moduleVariant = module.clone();
48 
49   OpPassManager passManager("builtin.module");
50   if (failed(parsePassPipeline(optPass, passManager))) {
51     module.emitError() << "\nfailed to parse pass pipeline";
52     return signalPassFailure();
53   }
54 
55   std::pair<Tester::Interestingness, int> original = test.isInteresting(module);
56   if (original.first != Tester::Interestingness::True) {
57     module.emitError() << "\nthe original input is not interested";
58     return signalPassFailure();
59   }
60 
61   // Temporarily push the variant under the main module and execute the pipeline
62   // on it.
63   module.getBody()->push_back(moduleVariant);
64   LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
65   moduleVariant->remove();
66 
67   if (failed(pipelineResult)) {
68     module.emitError() << "\nfailed to run pass pipeline";
69     return signalPassFailure();
70   }
71 
72   std::pair<Tester::Interestingness, int> reduced =
73       test.isInteresting(moduleVariant);
74 
75   if (reduced.first == Tester::Interestingness::True &&
76       reduced.second < original.second) {
77     module.getBody()->clear();
78     module.getBody()->getOperations().splice(
79         module.getBody()->begin(), moduleVariant.getBody()->getOperations());
80     LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n");
81   } else {
82     LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n");
83   }
84 
85   moduleVariant->destroy();
86 
87   LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n");
88 }
89 
createOptReductionPass()90 std::unique_ptr<Pass> mlir::createOptReductionPass() {
91   return std::make_unique<OptReductionPass>();
92 }
93