1//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file contains a set of interfaces that can be used to define information 10// related to type inference. 11// 12//===----------------------------------------------------------------------===// 13 14#ifndef MLIR_INFERTYPEOPINTERFACE 15#define MLIR_INFERTYPEOPINTERFACE 16 17include "mlir/IR/OpBase.td" 18 19// OpInterface to compute the return type of an operation. The arguments match 20// those in Operation::create with the exception that the location is optional 21// (if no location is provided, then the method will not emit an error on 22// mismatch). 23def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { 24 let description = [{ 25 Interface to infer the return types for an operation that could be used 26 during op construction, verification or type inference. 27 }]; 28 let cppNamespace = "::mlir"; 29 30 let methods = [ 31 StaticInterfaceMethod< 32 /*desc=*/[{Infer the return types that an op would generate. 33 34 The method takes an optional location which, if set, will be used to 35 report errors on. The operands and attributes correspond to those with 36 which an Operation would be created (e.g., as used in Operation::create) 37 and the regions of the op. Be aware that this method is supposed to be 38 called with valid arguments, e.g., operands are verified, or it may result 39 in an undefined behavior. 40 }], 41 /*retTy=*/"::llvm::LogicalResult", 42 /*methodName=*/"inferReturnTypes", 43 /*args=*/(ins "::mlir::MLIRContext *":$context, 44 "::std::optional<::mlir::Location>":$location, 45 "::mlir::ValueRange":$operands, 46 "::mlir::DictionaryAttr":$attributes, 47 "::mlir::OpaqueProperties":$properties, 48 "::mlir::RegionRange":$regions, 49 "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes) 50 >, 51 StaticInterfaceMethod< 52 /*desc=*/[{Refine the return types that an op would generate. 53 54 This method computes the return types as `inferReturnTypes` does but 55 additionally takes the existing result types as input. The existing 56 result types can be checked as part of inference to provide more 57 op-specific error messages as well as part of inference to merge 58 additional information, attributes, during inference. It is called during 59 verification for ops implementing this trait with default behavior 60 reporting mismatch with current and inferred types printed. 61 62 The operands and attributes correspond to those with which an Operation 63 would be created (e.g., as used in Operation::create) and the regions of 64 the op. The method takes an optional location which, if set, will be used 65 to report errors on. 66 67 The return types may be elided or specific elements be null for elements 68 that should just be returned but not verified. 69 70 Because this method can be called from within different stages of IR 71 verification, implementations should not assume the arguments to 72 represent fully valid IR and are responsible for checking inputs for 73 validity to the degree necessary to perform the return type inference. 74 }], 75 /*retTy=*/"::llvm::LogicalResult", 76 /*methodName=*/"refineReturnTypes", 77 /*args=*/(ins "::mlir::MLIRContext *":$context, 78 "::std::optional<::mlir::Location>":$location, 79 "::mlir::ValueRange":$operands, 80 "::mlir::DictionaryAttr":$attributes, 81 "::mlir::OpaqueProperties":$properties, 82 "::mlir::RegionRange":$regions, 83 "::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes), 84 /*methodBody=*/[{}], 85 /*defaultImplementation=*/[{ 86 llvm::SmallVector<Type, 4> inferredReturnTypes; 87 if (failed(ConcreteOp::inferReturnTypes(context, location, operands, 88 attributes, properties, regions, 89 inferredReturnTypes))) 90 return failure(); 91 if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes, 92 returnTypes)) { 93 return emitOptionalError( 94 location, "'", ConcreteOp::getOperationName(), 95 "' op inferred type(s) ", inferredReturnTypes, 96 " are incompatible with return type(s) of operation ", 97 returnTypes); 98 } 99 return success(); 100 }] 101 >, 102 StaticInterfaceMethod< 103 /*desc=*/"Returns whether two array of types are compatible result types" 104 " for an op.", 105 /*retTy=*/"bool", 106 /*methodName=*/"isCompatibleReturnTypes", 107 /*args=*/(ins "::mlir::TypeRange":$lhs, "::mlir::TypeRange":$rhs), 108 /*methodBody=*/[{ 109 return ConcreteOp::isCompatibleReturnTypes(lhs, rhs); 110 }], 111 /*defaultImplementation=*/[{ 112 /// Returns whether two arrays are equal as strongest check for 113 /// compatibility by default. 114 return lhs == rhs; 115 }] 116 >, 117 ]; 118 119 // Inferring result types may need to access the region operations. 120 let verifyWithRegions = 1; 121 let verify = [{ 122 return detail::verifyInferredResultTypes($_op); 123 }]; 124} 125 126def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> { 127 let description = [{ 128 Interface to infer the components of a ShapedType returned by an operation 129 that could be used during op construction, verification or shape inference. 130 131 The components consists of element type, shape and raw attribute. 132 }]; 133 let cppNamespace = "::mlir"; 134 135 let methods = [ 136 StaticInterfaceMethod< 137 /*desc=*/[{Infer the components of return type of shape containter. 138 139 The method takes an optional location which, if set, will be used to 140 report errors on. The operands and attributes correspond to those with 141 which an Operation would be created (e.g., as used in Operation::create) 142 and the regions of the op. 143 144 Unknown (e.g., unranked) shape and nullptrs for element type and attribute 145 may be returned by this function while returning success. E.g., partial 146 population of components is not error condition. 147 148 Because this method can be called from within different stages of IR 149 verification, implementations should not assume the arguments to 150 represent fully valid IR and are responsible for checking inputs for 151 validity to the degree necessary to perform the return type inference. 152 }], 153 /*retTy=*/"::llvm::LogicalResult", 154 /*methodName=*/"inferReturnTypeComponents", 155 /*args=*/(ins "::mlir::MLIRContext*":$context, 156 "::std::optional<::mlir::Location>":$location, 157 "::mlir::ValueShapeRange":$operands, 158 "::mlir::DictionaryAttr":$attributes, 159 "::mlir::OpaqueProperties":$properties, 160 "::mlir::RegionRange":$regions, 161 "::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents>&": 162 $inferredReturnShapes), 163 /*methodBody=*/[{}], 164 /*defaultImplementation=*/[{ return ::mlir::failure(); }] 165 >, 166 InterfaceMethod< 167 /*desc=*/[{Reify the shape computation for the operation. 168 169 Insert operations using the given OpBuilder that computes the 170 result shape. This interface is supposed to be workable during dialect 171 conversion (e.g. convert from tensor world to buffer world), 172 where `getOperand` may be invalid. For example, some ops (e.g. 173 dynamic_reshape(input, target_shape)) may depend on their operands 174 to calculate the result shape. When the `matchAndRewrite ` method 175 of a conversion pattern is called, the operands of the op to convert 176 may have been converted into other types, which makes it invalid to 177 call the `getOperand` method of such op directly inside the 178 conversion pattern. To solve this problem, this interface follows 179 the design of the conversion pattern, that is, accepting passed in 180 operands to avoid calling `getOperand` directly inside the interface 181 implementation. 182 }], 183 /*retTy=*/"::llvm::LogicalResult", 184 /*methodName=*/"reifyReturnTypeShapes", 185 /*args=*/(ins "::mlir::OpBuilder&":$builder, 186 "::mlir::ValueRange":$operands, 187 "::llvm::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes), 188 /*methodBody=*/[{}], 189 /*defaultImplementation=*/[{ return ::mlir::failure(); }] 190 > 191 ]; 192} 193 194// Convenient trait to define a wrapper to inferReturnTypes that passes in the 195// Op Adaptor directly 196class InferTypeOpAdaptorBase<code additionalDecls = [{}]> : TraitList< 197 [ 198 // Op implements infer type op interface. 199 DeclareOpInterfaceMethods<InferTypeOpInterface>, 200 NativeOpTrait< 201 /*name=*/"InferTypeOpAdaptor", 202 /*traits=*/[], 203 /*extraOpDeclaration=*/[{ 204 static ::llvm::LogicalResult 205 inferReturnTypes(::mlir::MLIRContext *context, 206 std::optional<::mlir::Location> location, 207 Adaptor adaptor, 208 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes); 209 }] # additionalDecls, 210 /*extraOpDefinition=*/[{ 211 ::llvm::LogicalResult 212 $cppClass::inferReturnTypes(::mlir::MLIRContext *context, 213 std::optional<::mlir::Location> location, 214 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 215 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, 216 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 217 $cppClass::Adaptor adaptor(operands, attributes, properties, regions); 218 return $cppClass::inferReturnTypes(context, 219 location, adaptor, inferredReturnTypes); 220 } 221 }] 222 > 223 ]>; 224 225def InferTypeOpAdaptor : InferTypeOpAdaptorBase; 226def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase< 227 [{ 228 static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r); 229 }] 230>; 231 232// Convenient trait to define a wrapper to inferReturnTypeComponents that passes 233// in the Op Adaptor directly. Only uses the current types of the operands. 234class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList< 235 [ 236 // Op implements infer type op interface. 237 DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>, 238 NativeOpTrait< 239 /*name=*/"InferShapedTypeOpAdaptor", 240 /*traits=*/[], 241 /*extraOpDeclaration=*/[{ 242 static ::llvm::LogicalResult 243 inferReturnTypeComponents(::mlir::MLIRContext *context, 244 std::optional<::mlir::Location> location, 245 Adaptor adaptor, 246 ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes); 247 }], 248 /*extraOpDefinition=*/[{ 249 ::llvm::LogicalResult 250 $cppClass::inferReturnTypeComponents(::mlir::MLIRContext *context, 251 std::optional<::mlir::Location> location, 252 ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes, 253 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, 254 ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes) { 255 $cppClass::Adaptor adaptor(operands, attributes, properties, regions); 256 return $cppClass::inferReturnTypeComponents(context, 257 location, adaptor, inferredReturnShapes); 258 } 259 }] 260 > 261 ]>; 262 263def InferShapedTypeOpAdaptor : InferShapedTypeOpAdaptorBase<[ 264 "inferReturnTypeComponents"]>; 265def InferShapedTypeOpAdaptorWithReify : InferShapedTypeOpAdaptorBase<[ 266 "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; 267 268// Convenience class grouping together type and shaped type op interfaces for 269// ops that have tensor return types. 270class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList< 271 [ 272 // Op implements infer type op interface. 273 DeclareOpInterfaceMethods<InferTypeOpInterface>, 274 // The op will have methods implementing the ShapedType type inference 275 // interface. 276 DeclareOpInterfaceMethods<InferShapedTypeOpInterface, overridenMethods>, 277 // The op produces tensors and will use the ShapedType type infer interface 278 // along with knowledge that it is producing Tensors to infer the type. 279 NativeOpTrait< 280 /*name=*/"InferTensorType", 281 /*traits=*/[], 282 /*extraOpDeclaration=*/[{}], 283 /*extraOpDefinition=*/[{ 284 ::llvm::LogicalResult 285 $cppClass::inferReturnTypes(::mlir::MLIRContext *context, 286 std::optional<::mlir::Location> location, 287 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 288 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, 289 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 290 ::llvm::SmallVector<::mlir::ShapedTypeComponents, 2> retComponents; 291 if (failed($cppClass::inferReturnTypeComponents(context, location, 292 operands, attributes, properties, regions, 293 retComponents))) 294 return failure(); 295 return ::mlir::detail::inferReturnTensorTypes(retComponents, 296 inferredReturnTypes); 297 } 298 }] 299 > 300 ]>; 301 302def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>; 303def InferTensorTypeWithReify: InferTensorTypeBase<[ 304 "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; 305 306// Convenience class grouping together type and shaped type op interfaces for 307// ops that have tensor return types. 308class InferTensorTypeAdaptorBase<list<string> overridenMethods = []> : TraitList< 309 [ 310 // Op implements infer type op interface. 311 DeclareOpInterfaceMethods<InferTypeOpInterface>, 312 // The op will have methods implementing the ShapedType type inference 313 // interface. 314 InferShapedTypeOpAdaptorBase<overridenMethods>, 315 // The op produces tensors and will use the ShapedType type infer interface 316 // along with knowledge that it is producing Tensors to infer the type. 317 NativeOpTrait< 318 /*name=*/"InferTensorType", 319 /*traits=*/[], 320 /*extraOpDeclaration=*/[{}], 321 /*extraOpDefinition=*/[{ 322 LogicalResult 323 $cppClass::inferReturnTypes(::mlir::MLIRContext *context, 324 std::optional<::mlir::Location> location, 325 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 326 ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, 327 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 328 SmallVector<ShapedTypeComponents, 2> retComponents; 329 if (failed($cppClass::inferReturnTypeComponents(context, location, 330 operands, attributes, properties, regions, 331 retComponents))) 332 return failure(); 333 return ::mlir::detail::inferReturnTensorTypes(retComponents, 334 inferredReturnTypes); 335 } 336 }] 337 > 338 ]>; 339 340def InferTensorTypeAdaptor : InferTensorTypeAdaptorBase<["inferReturnTypeComponents"]>; 341def InferTensorTypeAdaptorWithReify: InferTensorTypeAdaptorBase<[ 342 "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; 343 344def ReifyRankedShapedTypeOpInterface : 345 OpInterface<"ReifyRankedShapedTypeOpInterface"> { 346 let description = [{ 347 Interface to compute the shape of the result of an operation when 348 the result is a ranked shape type, i.e. `RankedTensorType` or 349 `MemRefType`. 350 }]; 351 let cppNamespace = "::mlir"; 352 353 let methods = [ 354 InterfaceMethod< 355 /*desc=*/[{ 356 Reify the shape of the result of an operation (typically in terms of the 357 shape of its operands). 358 359 `reifiedReturnShapes` is populated with one vector per op result. Each 360 of those vectors contains an OpFoldResult for each dimension of the 361 shaped type. In case a dimension in the type is static, the 362 corresponding entry is an IntegerAttr. Otherwise, it is a Value. The 363 given builder may be used to insert ops that compute result shapes. 364 365 If the shape of a particular result cannot be computed it must be empty. 366 }], 367 /*retTy=*/"::llvm::LogicalResult", 368 /*methodName=*/"reifyResultShapes", 369 /*args=*/(ins "::mlir::OpBuilder &":$builder, 370 "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes) 371 > 372 ]; 373} 374 375// Op has the same operand and result type. 376// TODO: Change from hard coded to utilizing type inference trait. 377def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; 378 379// Op has the same ranks for all operands and results types, if known. 380def SameOperandsAndResultRank : NativeOpTrait<"SameOperandsAndResultRank">; 381 382#endif // MLIR_INFERTYPEOPINTERFACE 383