xref: /llvm-project/mlir/lib/Transforms/InlinerPass.cpp (revision cd9ca423b7400000b4e0199450283439fcc1bbd9)
1 //===- InlinerPass.cpp - Pass to inline function calls --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Transforms/Passes.h"
17 
18 #include "mlir/Analysis/CallGraph.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/Inliner.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_INLINER
24 #include "mlir/Transforms/Passes.h.inc"
25 } // namespace mlir
26 
27 #define DEBUG_TYPE "inliner-pass"
28 
29 using namespace mlir;
30 
31 /// This function implements the inliner optimization pipeline.
32 static void defaultInlinerOptPipeline(OpPassManager &pm) {
33   pm.addPass(createCanonicalizerPass());
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // InlinerPass
38 //===----------------------------------------------------------------------===//
39 
40 namespace {
41 class InlinerPass : public impl::InlinerBase<InlinerPass> {
42 public:
43   InlinerPass();
44   InlinerPass(const InlinerPass &) = default;
45   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
46   InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
47               llvm::StringMap<OpPassManager> opPipelines);
48   void runOnOperation() override;
49 
50   /// A callback provided to the inliner driver to execute
51   /// the specified pass pipeline on the given operation
52   /// within the context of the current inliner pass,
53   /// which is passed as the first argument.
54   /// runPipeline API is protected within the Pass class,
55   /// so this helper is required to call it from the foreign
56   /// inliner driver.
57   static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline,
58                                          Operation *op) {
59     return mlir::cast<InlinerPass>(pass).runPipeline(pipeline, op);
60   }
61 
62 private:
63   /// Attempt to initialize the options of this pass from the given string.
64   /// Derived classes may override this method to hook into the point at which
65   /// options are initialized, but should generally always invoke this base
66   /// class variant.
67   LogicalResult initializeOptions(
68       StringRef options,
69       function_ref<LogicalResult(const Twine &)> errorHandler) override;
70 
71   /// Inliner configuration parameters created from the pass options.
72   InlinerConfig config;
73 };
74 } // namespace
75 
76 InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
77 
78 InlinerPass::InlinerPass(
79     std::function<void(OpPassManager &)> defaultPipelineArg)
80     : InlinerPass(std::move(defaultPipelineArg),
81                   llvm::StringMap<OpPassManager>{}) {}
82 
83 InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
84                          llvm::StringMap<OpPassManager> opPipelines)
85     : config(std::move(defaultPipeline), maxInliningIterations) {
86   if (opPipelines.empty())
87     return;
88 
89   // Update the option for the op specific optimization pipelines.
90   for (auto &it : opPipelines)
91     opPipelineList.addValue(it.second);
92   config.setOpPipelines(std::move(opPipelines));
93 }
94 
95 // Return true if the inlining ratio does not exceed the threshold.
96 static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
97                                  unsigned inliningThreshold) {
98   // Return early, ratio <= 0U will always be false.
99   if (inliningThreshold == 0U)
100     return false;
101   // Return early, ratio <= -1U will always be true.
102   if (inliningThreshold == -1U)
103     return true;
104 
105   Region *callerRegion = resolvedCall.sourceNode->getCallableRegion();
106   Region *calleeRegion = resolvedCall.targetNode->getCallableRegion();
107 
108   assert(calleeRegion && callerRegion && "unexpected external node");
109 
110   auto countOps = [](Region *region) {
111     unsigned count = 0;
112     region->walk([&](Operation *) { ++count; });
113     return count;
114   };
115 
116   unsigned callerOps = countOps(callerRegion);
117 
118   // Always inline empty callees (if it is possible at all).
119   if (callerOps == 0)
120     return true;
121 
122   unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
123   LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
124                           << inliningThreshold << "%): " << ratio << "%\n");
125   return ratio <= inliningThreshold;
126 }
127 
128 void InlinerPass::runOnOperation() {
129   CallGraph &cg = getAnalysis<CallGraph>();
130 
131   // The inliner should only be run on operations that define a symbol table,
132   // as the callgraph will need to resolve references.
133   Operation *op = getOperation();
134   if (!op->hasTrait<OpTrait::SymbolTable>()) {
135     op->emitOpError() << " was scheduled to run under the inliner, but does "
136                          "not define a symbol table";
137     return signalPassFailure();
138   }
139 
140   // By default, assume that any inlining is profitable.
141   auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
142     return isProfitableToInline(call, inliningThreshold);
143   };
144 
145   // Get an instance of the inliner.
146   Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper,
147                   config, profitabilityCb);
148 
149   // Run the inlining.
150   if (failed(inliner.doInlining()))
151     signalPassFailure();
152 }
153 
154 LogicalResult InlinerPass::initializeOptions(
155     StringRef options,
156     function_ref<LogicalResult(const Twine &)> errorHandler) {
157   if (failed(Pass::initializeOptions(options, errorHandler)))
158     return failure();
159 
160   // Initialize the pipeline builder for operations without the dedicated
161   // optimization pipeline in opPipelineList to use the option string.
162   // TODO: Use a generic pass manager for the pre-inline pipeline, and remove
163   // this.
164   if (!defaultPipelineStr.empty()) {
165     std::string defaultPipelineCopy = defaultPipelineStr;
166     config.setDefaultPipeline([=](OpPassManager &pm) {
167       (void)parsePassPipeline(defaultPipelineCopy, pm);
168     });
169   } else if (defaultPipelineStr.getNumOccurrences()) {
170     config.setDefaultPipeline(nullptr);
171   }
172 
173   // Initialize the op specific pass pipelines.
174   llvm::StringMap<OpPassManager> pipelines;
175   for (OpPassManager pipeline : opPipelineList)
176     if (!pipeline.empty())
177       pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
178   config.setOpPipelines(std::move(pipelines));
179 
180   config.setMaxInliningIterations(maxInliningIterations);
181 
182   return success();
183 }
184 
185 std::unique_ptr<Pass> mlir::createInlinerPass() {
186   return std::make_unique<InlinerPass>();
187 }
188 std::unique_ptr<Pass>
189 mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
190   return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
191                                        std::move(opPipelines));
192 }
193 std::unique_ptr<Pass> mlir::createInlinerPass(
194     llvm::StringMap<OpPassManager> opPipelines,
195     std::function<void(OpPassManager &)> defaultPipelineBuilder) {
196   return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
197                                        std::move(opPipelines));
198 }
199