xref: /llvm-project/mlir/lib/Reducer/OptReductionPass.cpp (revision 67d0d7ac0acb0665d6a09f61278fbcf51f0114c2)
1c484c7ddSChia-hung Duan //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===//
2c484c7ddSChia-hung Duan //
3c484c7ddSChia-hung Duan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c484c7ddSChia-hung Duan // See https://llvm.org/LICENSE.txt for license information.
5c484c7ddSChia-hung Duan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c484c7ddSChia-hung Duan //
7c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
8c484c7ddSChia-hung Duan //
9c484c7ddSChia-hung Duan // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
10c484c7ddSChia-hung Duan // run any optimization pass within it and only replaces the output module with
11c484c7ddSChia-hung Duan // the transformed version if it is smaller and interesting.
12c484c7ddSChia-hung Duan //
13c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===//
14c484c7ddSChia-hung Duan 
15c484c7ddSChia-hung Duan #include "mlir/Pass/PassManager.h"
16c484c7ddSChia-hung Duan #include "mlir/Pass/PassRegistry.h"
17c484c7ddSChia-hung Duan #include "mlir/Reducer/Passes.h"
18c484c7ddSChia-hung Duan #include "mlir/Reducer/Tester.h"
19c484c7ddSChia-hung Duan #include "llvm/Support/Debug.h"
20c484c7ddSChia-hung Duan 
21*67d0d7acSMichele Scuttari namespace mlir {
22*67d0d7acSMichele Scuttari #define GEN_PASS_DEF_OPTREDUCTION
23*67d0d7acSMichele Scuttari #include "mlir/Reducer/Passes.h.inc"
24*67d0d7acSMichele Scuttari } // namespace mlir
25*67d0d7acSMichele Scuttari 
26c484c7ddSChia-hung Duan #define DEBUG_TYPE "mlir-reduce"
27c484c7ddSChia-hung Duan 
28c484c7ddSChia-hung Duan using namespace mlir;
29c484c7ddSChia-hung Duan 
30c484c7ddSChia-hung Duan namespace {
31c484c7ddSChia-hung Duan 
32*67d0d7acSMichele Scuttari class OptReductionPass : public impl::OptReductionBase<OptReductionPass> {
33c484c7ddSChia-hung Duan public:
34c484c7ddSChia-hung Duan   /// Runs the pass instance in the pass pipeline.
35c484c7ddSChia-hung Duan   void runOnOperation() override;
36c484c7ddSChia-hung Duan };
37c484c7ddSChia-hung Duan 
38be0a7e9fSMehdi Amini } // namespace
39c484c7ddSChia-hung Duan 
40c484c7ddSChia-hung Duan /// Runs the pass instance in the pass pipeline.
runOnOperation()41c484c7ddSChia-hung Duan void OptReductionPass::runOnOperation() {
42c484c7ddSChia-hung Duan   LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
43c484c7ddSChia-hung Duan 
44c484c7ddSChia-hung Duan   Tester test(testerName, testerArgs);
45c484c7ddSChia-hung Duan 
46c484c7ddSChia-hung Duan   ModuleOp module = this->getOperation();
47c484c7ddSChia-hung Duan   ModuleOp moduleVariant = module.clone();
48c484c7ddSChia-hung Duan 
490f304ef0SRiver Riddle   OpPassManager passManager("builtin.module");
50c484c7ddSChia-hung Duan   if (failed(parsePassPipeline(optPass, passManager))) {
511a001dedSChia-hung Duan     module.emitError() << "\nfailed to parse pass pipeline";
521a001dedSChia-hung Duan     return signalPassFailure();
53c484c7ddSChia-hung Duan   }
54c484c7ddSChia-hung Duan 
55c484c7ddSChia-hung Duan   std::pair<Tester::Interestingness, int> original = test.isInteresting(module);
56c484c7ddSChia-hung Duan   if (original.first != Tester::Interestingness::True) {
571a001dedSChia-hung Duan     module.emitError() << "\nthe original input is not interested";
581a001dedSChia-hung Duan     return signalPassFailure();
59c484c7ddSChia-hung Duan   }
60c484c7ddSChia-hung Duan 
610f304ef0SRiver Riddle   // Temporarily push the variant under the main module and execute the pipeline
620f304ef0SRiver Riddle   // on it.
630f304ef0SRiver Riddle   module.getBody()->push_back(moduleVariant);
640f304ef0SRiver Riddle   LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
650f304ef0SRiver Riddle   moduleVariant->remove();
660f304ef0SRiver Riddle 
670f304ef0SRiver Riddle   if (failed(pipelineResult)) {
681a001dedSChia-hung Duan     module.emitError() << "\nfailed to run pass pipeline";
691a001dedSChia-hung Duan     return signalPassFailure();
70c484c7ddSChia-hung Duan   }
71c484c7ddSChia-hung Duan 
72c484c7ddSChia-hung Duan   std::pair<Tester::Interestingness, int> reduced =
73c484c7ddSChia-hung Duan       test.isInteresting(moduleVariant);
74c484c7ddSChia-hung Duan 
75c484c7ddSChia-hung Duan   if (reduced.first == Tester::Interestingness::True &&
76c484c7ddSChia-hung Duan       reduced.second < original.second) {
77c484c7ddSChia-hung Duan     module.getBody()->clear();
78c484c7ddSChia-hung Duan     module.getBody()->getOperations().splice(
79c484c7ddSChia-hung Duan         module.getBody()->begin(), moduleVariant.getBody()->getOperations());
80c484c7ddSChia-hung Duan     LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n");
81c484c7ddSChia-hung Duan   } else {
82c484c7ddSChia-hung Duan     LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n");
83c484c7ddSChia-hung Duan   }
84c484c7ddSChia-hung Duan 
85c484c7ddSChia-hung Duan   moduleVariant->destroy();
86c484c7ddSChia-hung Duan 
87c484c7ddSChia-hung Duan   LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n");
88c484c7ddSChia-hung Duan }
89c484c7ddSChia-hung Duan 
createOptReductionPass()90c484c7ddSChia-hung Duan std::unique_ptr<Pass> mlir::createOptReductionPass() {
91c484c7ddSChia-hung Duan   return std::make_unique<OptReductionPass>();
92c484c7ddSChia-hung Duan }
93