xref: /llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a set of interfaces that can be used to define information
10// related to type inference.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_INFERTYPEOPINTERFACE
15#define MLIR_INFERTYPEOPINTERFACE
16
17include "mlir/IR/OpBase.td"
18
19// OpInterface to compute the return type of an operation. The arguments match
20// those in Operation::create with the exception that the location is optional
21// (if no location is provided, then the method will not emit an error on
22// mismatch).
23def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
24  let description = [{
25    Interface to infer the return types for an operation that could be used
26    during op construction, verification or type inference.
27  }];
28  let cppNamespace = "::mlir";
29
30  let methods = [
31    StaticInterfaceMethod<
32      /*desc=*/[{Infer the return types that an op would generate.
33
34      The method takes an optional location which, if set, will be used to
35      report errors on. The operands and attributes correspond to those with
36      which an Operation would be created (e.g., as used in Operation::create)
37      and the regions of the op. Be aware that this method is supposed to be
38      called with valid arguments, e.g., operands are verified, or it may result
39      in an undefined behavior.
40      }],
41      /*retTy=*/"::llvm::LogicalResult",
42      /*methodName=*/"inferReturnTypes",
43      /*args=*/(ins "::mlir::MLIRContext *":$context,
44                    "::std::optional<::mlir::Location>":$location,
45                    "::mlir::ValueRange":$operands,
46                    "::mlir::DictionaryAttr":$attributes,
47                    "::mlir::OpaqueProperties":$properties,
48                    "::mlir::RegionRange":$regions,
49                    "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
50    >,
51    StaticInterfaceMethod<
52      /*desc=*/[{Refine the return types that an op would generate.
53
54      This method computes the return types as `inferReturnTypes` does but
55      additionally takes the existing result types as input. The existing
56      result types can be checked as part of inference to provide more
57      op-specific error messages as well as part of inference to merge
58      additional information, attributes, during inference. It is called during
59      verification for ops implementing this trait with default behavior
60      reporting mismatch with current and inferred types printed.
61
62      The operands and attributes correspond to those with which an Operation
63      would be created (e.g., as used in Operation::create) and the regions of
64      the op. The method takes an optional location which, if set, will be used
65      to report errors on.
66
67      The return types may be elided or specific elements be null for elements
68      that should just be returned but not verified.
69
70      Because this method can be called from within different stages of IR
71      verification, implementations should not assume the arguments to
72      represent fully valid IR and are responsible for checking inputs for
73      validity to the degree necessary to perform the return type inference.
74      }],
75      /*retTy=*/"::llvm::LogicalResult",
76      /*methodName=*/"refineReturnTypes",
77      /*args=*/(ins "::mlir::MLIRContext *":$context,
78                    "::std::optional<::mlir::Location>":$location,
79                    "::mlir::ValueRange":$operands,
80                    "::mlir::DictionaryAttr":$attributes,
81                    "::mlir::OpaqueProperties":$properties,
82                    "::mlir::RegionRange":$regions,
83                    "::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes),
84      /*methodBody=*/[{}],
85      /*defaultImplementation=*/[{
86          llvm::SmallVector<Type, 4> inferredReturnTypes;
87          if (failed(ConcreteOp::inferReturnTypes(context, location, operands,
88                                                  attributes, properties, regions,
89                                                  inferredReturnTypes)))
90            return failure();
91          if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes,
92                                                   returnTypes)) {
93            return emitOptionalError(
94                location, "'", ConcreteOp::getOperationName(),
95                "' op inferred type(s) ", inferredReturnTypes,
96                " are incompatible with return type(s) of operation ",
97                returnTypes);
98          }
99          return success();
100      }]
101    >,
102    StaticInterfaceMethod<
103      /*desc=*/"Returns whether two array of types are compatible result types"
104               " for an op.",
105      /*retTy=*/"bool",
106      /*methodName=*/"isCompatibleReturnTypes",
107      /*args=*/(ins "::mlir::TypeRange":$lhs, "::mlir::TypeRange":$rhs),
108      /*methodBody=*/[{
109        return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
110      }],
111      /*defaultImplementation=*/[{
112        /// Returns whether two arrays are equal as strongest check for
113        /// compatibility by default.
114        return lhs == rhs;
115      }]
116    >,
117  ];
118
119  // Inferring result types may need to access the region operations.
120  let verifyWithRegions = 1;
121  let verify = [{
122    return detail::verifyInferredResultTypes($_op);
123  }];
124}
125
126def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
127  let description = [{
128    Interface to infer the components of a ShapedType returned by an operation
129    that could be used during op construction, verification or shape inference.
130
131    The components consists of element type, shape and raw attribute.
132  }];
133  let cppNamespace = "::mlir";
134
135  let methods = [
136    StaticInterfaceMethod<
137      /*desc=*/[{Infer the components of return type of shape containter.
138
139      The method takes an optional location which, if set, will be used to
140      report errors on. The operands and attributes correspond to those with
141      which an Operation would be created (e.g., as used in Operation::create)
142      and the regions of the op.
143
144      Unknown (e.g., unranked) shape and nullptrs for element type and attribute
145      may be returned by this function while returning success. E.g., partial
146      population of components is not error condition.
147
148      Because this method can be called from within different stages of IR
149      verification, implementations should not assume the arguments to
150      represent fully valid IR and are responsible for checking inputs for
151      validity to the degree necessary to perform the return type inference.
152      }],
153      /*retTy=*/"::llvm::LogicalResult",
154      /*methodName=*/"inferReturnTypeComponents",
155      /*args=*/(ins "::mlir::MLIRContext*":$context,
156                    "::std::optional<::mlir::Location>":$location,
157                    "::mlir::ValueShapeRange":$operands,
158                    "::mlir::DictionaryAttr":$attributes,
159                    "::mlir::OpaqueProperties":$properties,
160                    "::mlir::RegionRange":$regions,
161                    "::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
162                      $inferredReturnShapes),
163      /*methodBody=*/[{}],
164      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
165    >,
166    InterfaceMethod<
167      /*desc=*/[{Reify the shape computation for the operation.
168
169      Insert operations using the given OpBuilder that computes the
170      result shape. This interface is supposed to be workable during dialect
171      conversion (e.g. convert from tensor world to buffer world),
172      where `getOperand` may be invalid. For example, some ops (e.g.
173      dynamic_reshape(input, target_shape)) may depend on their operands
174      to calculate the result shape. When the `matchAndRewrite ` method
175      of a conversion pattern is called, the operands of the op to convert
176      may have been converted into other types, which makes it invalid to
177      call the `getOperand` method of such op directly inside the
178      conversion pattern.  To solve this problem, this interface follows
179      the design of the conversion pattern, that is, accepting passed in
180      operands to avoid calling `getOperand` directly inside the interface
181      implementation.
182      }],
183      /*retTy=*/"::llvm::LogicalResult",
184      /*methodName=*/"reifyReturnTypeShapes",
185      /*args=*/(ins "::mlir::OpBuilder&":$builder,
186          "::mlir::ValueRange":$operands,
187          "::llvm::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes),
188      /*methodBody=*/[{}],
189      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
190    >
191  ];
192}
193
194// Convenient trait to define a wrapper to inferReturnTypes that passes in the
195// Op Adaptor directly
196class InferTypeOpAdaptorBase<code additionalDecls = [{}]> : TraitList<
197  [
198    // Op implements infer type op interface.
199    DeclareOpInterfaceMethods<InferTypeOpInterface>,
200    NativeOpTrait<
201      /*name=*/"InferTypeOpAdaptor",
202      /*traits=*/[],
203      /*extraOpDeclaration=*/[{
204        static ::llvm::LogicalResult
205        inferReturnTypes(::mlir::MLIRContext *context,
206                                std::optional<::mlir::Location> location,
207                                Adaptor adaptor,
208                                ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
209      }] # additionalDecls,
210      /*extraOpDefinition=*/[{
211        ::llvm::LogicalResult
212        $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
213                          std::optional<::mlir::Location> location,
214                          ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
215                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
216                          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
217          $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
218          return $cppClass::inferReturnTypes(context,
219            location, adaptor, inferredReturnTypes);
220        }
221      }]
222    >
223  ]>;
224
225def InferTypeOpAdaptor : InferTypeOpAdaptorBase;
226def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
227  [{
228    static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r);
229  }]
230>;
231
232// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
233// in the Op Adaptor directly. Only uses the current types of the operands.
234class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
235  [
236    // Op implements infer type op interface.
237    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
238    NativeOpTrait<
239      /*name=*/"InferShapedTypeOpAdaptor",
240      /*traits=*/[],
241      /*extraOpDeclaration=*/[{
242        static ::llvm::LogicalResult
243        inferReturnTypeComponents(::mlir::MLIRContext *context,
244                                std::optional<::mlir::Location> location,
245                                Adaptor adaptor,
246                                ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes);
247      }],
248      /*extraOpDefinition=*/[{
249        ::llvm::LogicalResult
250        $cppClass::inferReturnTypeComponents(::mlir::MLIRContext *context,
251                          std::optional<::mlir::Location> location,
252                          ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes,
253                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
254                          ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes) {
255          $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
256          return $cppClass::inferReturnTypeComponents(context,
257            location, adaptor, inferredReturnShapes);
258        }
259      }]
260    >
261  ]>;
262
263def InferShapedTypeOpAdaptor : InferShapedTypeOpAdaptorBase<[
264  "inferReturnTypeComponents"]>;
265def InferShapedTypeOpAdaptorWithReify : InferShapedTypeOpAdaptorBase<[
266  "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
267
268// Convenience class grouping together type and shaped type op interfaces for
269// ops that have tensor return types.
270class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
271  [
272    // Op implements infer type op interface.
273    DeclareOpInterfaceMethods<InferTypeOpInterface>,
274    // The op will have methods implementing the ShapedType type inference
275    // interface.
276    DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>,
277    // The op produces tensors and will use the ShapedType type infer interface
278    // along with knowledge that it is producing Tensors to infer the type.
279    NativeOpTrait<
280      /*name=*/"InferTensorType",
281      /*traits=*/[],
282      /*extraOpDeclaration=*/[{}],
283      /*extraOpDefinition=*/[{
284        ::llvm::LogicalResult
285        $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
286                          std::optional<::mlir::Location> location,
287                          ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
288                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
289                          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
290          ::llvm::SmallVector<::mlir::ShapedTypeComponents, 2> retComponents;
291          if (failed($cppClass::inferReturnTypeComponents(context, location,
292                                    operands, attributes, properties, regions,
293                                    retComponents)))
294            return failure();
295          return ::mlir::detail::inferReturnTensorTypes(retComponents,
296                                    inferredReturnTypes);
297        }
298      }]
299    >
300  ]>;
301
302def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>;
303def InferTensorTypeWithReify: InferTensorTypeBase<[
304    "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
305
306// Convenience class grouping together type and shaped type op interfaces for
307// ops that have tensor return types.
308class InferTensorTypeAdaptorBase<list<string> overridenMethods = []> : TraitList<
309  [
310    // Op implements infer type op interface.
311    DeclareOpInterfaceMethods<InferTypeOpInterface>,
312    // The op will have methods implementing the ShapedType type inference
313    // interface.
314    InferShapedTypeOpAdaptorBase<overridenMethods>,
315    // The op produces tensors and will use the ShapedType type infer interface
316    // along with knowledge that it is producing Tensors to infer the type.
317    NativeOpTrait<
318      /*name=*/"InferTensorType",
319      /*traits=*/[],
320      /*extraOpDeclaration=*/[{}],
321      /*extraOpDefinition=*/[{
322        LogicalResult
323        $cppClass::inferReturnTypes(::mlir::MLIRContext *context,
324                          std::optional<::mlir::Location> location,
325                          ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
326                          ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
327                          ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
328          SmallVector<ShapedTypeComponents, 2> retComponents;
329          if (failed($cppClass::inferReturnTypeComponents(context, location,
330                                    operands, attributes, properties, regions,
331                                    retComponents)))
332            return failure();
333          return ::mlir::detail::inferReturnTensorTypes(retComponents,
334                                    inferredReturnTypes);
335        }
336      }]
337    >
338  ]>;
339
340def InferTensorTypeAdaptor : InferTensorTypeAdaptorBase<["inferReturnTypeComponents"]>;
341def InferTensorTypeAdaptorWithReify: InferTensorTypeAdaptorBase<[
342    "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
343
344def ReifyRankedShapedTypeOpInterface :
345    OpInterface<"ReifyRankedShapedTypeOpInterface"> {
346  let description = [{
347    Interface to compute the shape of the result of an operation when
348    the result is a ranked shape type, i.e. `RankedTensorType` or
349    `MemRefType`.
350  }];
351  let cppNamespace = "::mlir";
352
353  let methods = [
354    InterfaceMethod<
355      /*desc=*/[{
356        Reify the shape of the result of an operation (typically in terms of the
357        shape of its operands).
358
359        `reifiedReturnShapes` is populated with one vector per op result. Each
360        of those vectors contains an OpFoldResult for each dimension of the
361        shaped type. In case a dimension in the type is static, the
362        corresponding entry is an IntegerAttr. Otherwise, it is a Value. The
363        given builder may be used to insert ops that compute result shapes.
364
365        If the shape of a particular result cannot be computed it must be empty.
366      }],
367      /*retTy=*/"::llvm::LogicalResult",
368      /*methodName=*/"reifyResultShapes",
369      /*args=*/(ins "::mlir::OpBuilder &":$builder,
370        "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
371    >
372  ];
373}
374
375// Op has the same operand and result type.
376// TODO: Change from hard coded to utilizing type inference trait.
377def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
378
379// Op has the same ranks for all operands and results types, if known.
380def SameOperandsAndResultRank : NativeOpTrait<"SameOperandsAndResultRank">;
381
382#endif // MLIR_INFERTYPEOPINTERFACE
383