xref: /llvm-project/mlir/include/mlir/Transforms/InliningUtils.h (revision b39c5cb6977f35ad727d86b2dd6232099734ffd3)
1 //===- InliningUtils.h - Inliner utilities ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines interfaces for various inlining utility methods.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_TRANSFORMS_INLININGUTILS_H
14 #define MLIR_TRANSFORMS_INLININGUTILS_H
15 
16 #include "mlir/IR/BuiltinAttributes.h"
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/Region.h"
20 #include "mlir/IR/ValueRange.h"
21 #include <optional>
22 
23 namespace mlir {
24 
25 class Block;
26 class IRMapping;
27 class CallableOpInterface;
28 class CallOpInterface;
29 class OpBuilder;
30 class Operation;
31 class Region;
32 class TypeRange;
33 class Value;
34 class ValueRange;
35 
36 //===----------------------------------------------------------------------===//
37 // InlinerInterface
38 //===----------------------------------------------------------------------===//
39 
40 /// This is the interface that must be implemented by the dialects of operations
41 /// to be inlined. This interface should only handle the operations of the
42 /// given dialect.
43 class DialectInlinerInterface
44     : public DialectInterface::Base<DialectInlinerInterface> {
45 public:
46   DialectInlinerInterface(Dialect *dialect) : Base(dialect) {}
47 
48   //===--------------------------------------------------------------------===//
49   // Analysis Hooks
50   //===--------------------------------------------------------------------===//
51 
52   /// Returns true if the given operation 'callable', that implements the
53   /// 'CallableOpInterface', can be inlined into the position given call
54   /// operation 'call', that is registered to the current dialect and implements
55   /// the `CallOpInterface`. 'wouldBeCloned' is set to true if the region of the
56   /// given 'callable' is set to be cloned during the inlining process, or false
57   /// if the region is set to be moved in-place(i.e. no duplicates would be
58   /// created).
59   virtual bool isLegalToInline(Operation *call, Operation *callable,
60                                bool wouldBeCloned) const {
61     return false;
62   }
63 
64   /// Returns true if the given region 'src' can be inlined into the region
65   /// 'dest' that is attached to an operation registered to the current dialect.
66   /// 'wouldBeCloned' is set to true if the given 'src' region is set to be
67   /// cloned during the inlining process, or false if the region is set to be
68   /// moved in-place(i.e. no duplicates would be created). 'valueMapping'
69   /// contains any remapped values from within the 'src' region. This can be
70   /// used to examine what values will replace entry arguments into the 'src'
71   /// region for example.
72   virtual bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
73                                IRMapping &valueMapping) const {
74     return false;
75   }
76 
77   /// Returns true if the given operation 'op', that is registered to this
78   /// dialect, can be inlined into the given region, false otherwise.
79   /// 'wouldBeCloned' is set to true if the given 'op' is set to be cloned
80   /// during the inlining process, or false if the operation is set to be moved
81   /// in-place(i.e. no duplicates would be created). 'valueMapping' contains any
82   /// remapped values from within the 'src' region. This can be used to examine
83   /// what values may potentially replace the operands to 'op'.
84   virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
85                                IRMapping &valueMapping) const {
86     return false;
87   }
88 
89   /// This hook is invoked on an operation that contains regions. It should
90   /// return true if the analyzer should recurse within the regions of this
91   /// operation when computing legality and cost, false otherwise. The default
92   /// implementation returns true.
93   virtual bool shouldAnalyzeRecursively(Operation *op) const { return true; }
94 
95   //===--------------------------------------------------------------------===//
96   // Transformation Hooks
97   //===--------------------------------------------------------------------===//
98 
99   /// Handle the given inlined terminator by replacing it with a new operation
100   /// as necessary. This overload is called when the inlined region has more
101   /// than one block. The 'newDest' block represents the new final branching
102   /// destination of blocks within this region, i.e. operations that release
103   /// control to the parent operation will likely now branch to this block.
104   /// Its block arguments correspond to any values that need to be replaced by
105   /// terminators within the inlined region.
106   virtual void handleTerminator(Operation *op, Block *newDest) const {
107     llvm_unreachable("must implement handleTerminator in the case of multiple "
108                      "inlined blocks");
109   }
110 
111   /// Handle the given inlined terminator by replacing it with a new operation
112   /// as necessary. This overload is called when the inlined region only
113   /// contains one block. 'valuesToReplace' contains the previously returned
114   /// values of the call site before inlining. These values must be replaced by
115   /// this callback if they had any users (for example for traditional function
116   /// calls, these are directly replaced with the operands of the `return`
117   /// operation). The given 'op' will be removed by the caller, after this
118   /// function has been called.
119   virtual void handleTerminator(Operation *op,
120                                 ValueRange valuesToReplace) const {
121     llvm_unreachable(
122         "must implement handleTerminator in the case of one inlined block");
123   }
124 
125   /// Attempt to materialize a conversion for a type mismatch between a call
126   /// from this dialect, and a callable region. This method should generate an
127   /// operation that takes 'input' as the only operand, and produces a single
128   /// result of 'resultType'. If a conversion can not be generated, nullptr
129   /// should be returned. For example, this hook may be invoked in the following
130   /// scenarios:
131   ///   func @foo(i32) -> i32 { ... }
132   ///
133   ///   // Mismatched input operand
134   ///   ... = foo.call @foo(%input : i16) -> i32
135   ///
136   ///   // Mismatched result type.
137   ///   ... = foo.call @foo(%input : i32) -> i16
138   ///
139   /// NOTE: This hook may be invoked before the 'isLegal' checks above.
140   virtual Operation *materializeCallConversion(OpBuilder &builder, Value input,
141                                                Type resultType,
142                                                Location conversionLoc) const {
143     return nullptr;
144   }
145 
146   /// Hook to transform the call arguments before using them to replace the
147   /// callee arguments. Returns a value of the same type or the `argument`
148   /// itself if nothing changed. The `argumentAttrs` dictionary is non-null even
149   /// if no attribute is present. The hook is called after converting the
150   /// callsite argument types using the materializeCallConversion callback, and
151   /// right before inlining the callee region. Any operations created using the
152   /// provided `builder` are inserted right before the inlined callee region. An
153   /// example use case is the insertion of copies for by value arguments.
154   virtual Value handleArgument(OpBuilder &builder, Operation *call,
155                                Operation *callable, Value argument,
156                                DictionaryAttr argumentAttrs) const {
157     return argument;
158   }
159 
160   /// Hook to transform the callee results before using them to replace the call
161   /// results. Returns a value of the same type or the `result` itself if
162   /// nothing changed. The `resultAttrs` dictionary is non-null even if no
163   /// attribute is present. The hook is called right before handling
164   /// terminators, and obtains the callee result before converting its type
165   /// using the `materializeCallConversion` callback. Any operations created
166   /// using the provided `builder` are inserted right after the inlined callee
167   /// region. An example use case is the insertion of copies for by value
168   /// results. NOTE: This hook is invoked after inlining the `callable` region.
169   virtual Value handleResult(OpBuilder &builder, Operation *call,
170                              Operation *callable, Value result,
171                              DictionaryAttr resultAttrs) const {
172     return result;
173   }
174 
175   /// Process a set of blocks that have been inlined for a call. This callback
176   /// is invoked before inlined terminator operations have been processed.
177   virtual void processInlinedCallBlocks(
178       Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {}
179 
180   /// Returns true if the inliner can assume a fast path of not creating a new
181   /// block, if there is only one block.
182   virtual bool allowSingleBlockOptimization(
183       iterator_range<Region::iterator> inlinedBlocks) const {
184     return true;
185   }
186 };
187 
188 /// This interface provides the hooks into the inlining interface.
189 /// Note: this class automatically collects 'DialectInlinerInterface' objects
190 /// registered to each dialect within the given context.
191 class InlinerInterface
192     : public DialectInterfaceCollection<DialectInlinerInterface> {
193 public:
194   using Base::Base;
195 
196   /// Process a set of blocks that have been inlined. This callback is invoked
197   /// *before* inlined terminator operations have been processed.
198   virtual void
199   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) {}
200 
201   /// These hooks mirror the hooks for the DialectInlinerInterface, with default
202   /// implementations that call the hook on the handler for the dialect 'op' is
203   /// registered to.
204 
205   //===--------------------------------------------------------------------===//
206   // Analysis Hooks
207   //===--------------------------------------------------------------------===//
208 
209   virtual bool isLegalToInline(Operation *call, Operation *callable,
210                                bool wouldBeCloned) const;
211   virtual bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
212                                IRMapping &valueMapping) const;
213   virtual bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
214                                IRMapping &valueMapping) const;
215   virtual bool shouldAnalyzeRecursively(Operation *op) const;
216 
217   //===--------------------------------------------------------------------===//
218   // Transformation Hooks
219   //===--------------------------------------------------------------------===//
220 
221   virtual void handleTerminator(Operation *op, Block *newDest) const;
222   virtual void handleTerminator(Operation *op, ValueRange valuesToRepl) const;
223 
224   virtual Value handleArgument(OpBuilder &builder, Operation *call,
225                                Operation *callable, Value argument,
226                                DictionaryAttr argumentAttrs) const;
227   virtual Value handleResult(OpBuilder &builder, Operation *call,
228                              Operation *callable, Value result,
229                              DictionaryAttr resultAttrs) const;
230 
231   virtual void processInlinedCallBlocks(
232       Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
233 
234   virtual bool allowSingleBlockOptimization(
235       iterator_range<Region::iterator> inlinedBlocks) const;
236 };
237 
238 //===----------------------------------------------------------------------===//
239 // Inline Methods.
240 //===----------------------------------------------------------------------===//
241 
242 /// This function inlines a region, 'src', into another. This function returns
243 /// failure if it is not possible to inline this function. If the function
244 /// returned failure, then no changes to the module have been made.
245 ///
246 /// The provided 'inlinePoint' must be within a region, and corresponds to the
247 /// location where the 'src' region should be inlined. 'mapping' contains any
248 /// remapped operands that are used within the region, and *must* include
249 /// remappings for the entry arguments to the region. 'resultsToReplace'
250 /// corresponds to any results that should be replaced by terminators within the
251 /// inlined region. 'regionResultTypes' specifies the expected return types of
252 /// the terminators in the region. 'inlineLoc' is an optional Location that, if
253 /// provided, will be used to update the inlined operations' location
254 /// information. 'shouldCloneInlinedRegion' corresponds to whether the source
255 /// region should be cloned into the 'inlinePoint' or spliced directly.
256 LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
257                            Operation *inlinePoint, IRMapping &mapper,
258                            ValueRange resultsToReplace,
259                            TypeRange regionResultTypes,
260                            std::optional<Location> inlineLoc = std::nullopt,
261                            bool shouldCloneInlinedRegion = true);
262 LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
263                            Block *inlineBlock, Block::iterator inlinePoint,
264                            IRMapping &mapper, ValueRange resultsToReplace,
265                            TypeRange regionResultTypes,
266                            std::optional<Location> inlineLoc = std::nullopt,
267                            bool shouldCloneInlinedRegion = true);
268 
269 /// This function is an overload of the above 'inlineRegion' that allows for
270 /// providing the set of operands ('inlinedOperands') that should be used
271 /// in-favor of the region arguments when inlining.
272 LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
273                            Operation *inlinePoint, ValueRange inlinedOperands,
274                            ValueRange resultsToReplace,
275                            std::optional<Location> inlineLoc = std::nullopt,
276                            bool shouldCloneInlinedRegion = true);
277 LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
278                            Block *inlineBlock, Block::iterator inlinePoint,
279                            ValueRange inlinedOperands,
280                            ValueRange resultsToReplace,
281                            std::optional<Location> inlineLoc = std::nullopt,
282                            bool shouldCloneInlinedRegion = true);
283 
284 /// This function inlines a given region, 'src', of a callable operation,
285 /// 'callable', into the location defined by the given call operation. This
286 /// function returns failure if inlining is not possible, success otherwise. On
287 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
288 /// corresponds to whether the source region should be cloned into the 'call' or
289 /// spliced directly.
290 LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
291                          CallableOpInterface callable, Region *src,
292                          bool shouldCloneInlinedRegion = true);
293 
294 } // namespace mlir
295 
296 #endif // MLIR_TRANSFORMS_INLININGUTILS_H
297