1//===-- CommonTypeConstraints.td - Common Type Constraints--*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file contains commonly used type constraints. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef COMMON_TYPE_CONSTRAINTS_TD 14#define COMMON_TYPE_CONSTRAINTS_TD 15 16include "mlir/IR/Constraints.td" 17include "mlir/IR/DialectBase.td" 18 19//===----------------------------------------------------------------------===// 20// Common predicates 21//===----------------------------------------------------------------------===// 22 23// Whether a type is a VectorType. 24// Explicitly disallow 0-D vectors for now until we have good enough coverage. 25def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, 26 CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; 27def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">, 28 CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; 29 30// Temporary vector type clone that allows gradual transition to 0-D vectors. 31// TODO: Remove this when all ops support 0-D vectors. 32def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">; 33 34// Whether a type is a fixed-length VectorType. 35def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>; 36 37// Whether a type is a scalable VectorType. 38def IsVectorTypeWithAnyDimScalablePred 39 : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>; 40 41// Whether a type is a scalable VectorType, with a single trailing scalable dimension. 42// Examples: 43// Valid: 44// - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32> 45// Invalid 46// - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32> 47def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[ 48 CPred<"::llvm::isa<::mlir::VectorType>($_self)">, 49 CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">, 50 CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">, 51 CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)"> 52]>; 53 54// Whether a type is a VectorType and all dimensions are scalable. 55def IsVectorTypeWithAllDimsScalablePred : And<[ 56 IsVectorOfNonZeroRankTypePred, 57 CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]> 58]>; 59 60// Whether a type is a TensorType. 61def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">; 62 63// Whether a type is a MemRefType. 64def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">; 65 66// Whether a type is an UnrankedMemRefType 67def IsUnrankedMemRefTypePred 68 : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">; 69 70// Whether a type is an UnrankedTensorType 71def IsUnrankedTensorTypePred 72 : CPred<"::llvm::isa<::mlir::UnrankedTensorType>($_self)">; 73 74// Whether a type is a RankedTensorType 75def IsRankedTensorTypePred 76 : CPred<"::llvm::isa<::mlir::RankedTensorType>($_self)">; 77 78// Whether a type is a BaseMemRefType 79def IsBaseMemRefTypePred 80 : CPred<"::llvm::isa<::mlir::BaseMemRefType>($_self)">; 81 82// Whether a type is a ShapedType. 83def IsShapedTypePred : CPred<"::llvm::isa<::mlir::ShapedType>($_self)">; 84 85// For a ShapedType, verify that it has a static shape. 86def HasStaticShapePred : 87 CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasStaticShape()">; 88 89// Whether a type is a TupleType. 90def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">; 91 92// Whether a type has a ValueSemantics trait. 93def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">; 94 95//===----------------------------------------------------------------------===// 96// Type definitions 97//===----------------------------------------------------------------------===// 98 99// A type, carries type constraints. 100class Type<Pred condition, string descr = "", 101 string cppType = "::mlir::Type"> : 102 TypeConstraint<condition, descr, cppType> { 103 string description = ""; 104 string builderCall = ""; 105} 106 107// Allows providing an alternative name and summary to an existing type def. 108class TypeAlias<Type t, string summary = t.summary> : 109 Type<t.predicate, summary, t.cppType> { 110 let description = t.description; 111 let builderCall = t.builderCall; 112} 113 114// A type of a specific dialect. 115class DialectType<Dialect d, Pred condition, string descr = "", 116 string cppType = "::mlir::Type"> : 117 Type<condition, descr, cppType> { 118 Dialect dialect = d; 119} 120 121// A variadic type constraint. It expands to zero or more of the base type. This 122// class is used for supporting variadic operands/results. 123class Variadic<Type type> : TypeConstraint<type.predicate, 124 "variadic of " # type.summary, 125 type.cppType> { 126 Type baseType = type; 127 int minSize = 0; 128} 129 130// A nested variadic type constraint. It expands to zero or more variadic ranges 131// of the base type. This class is used for supporting variadic operands and 132// results. `variadicSegmentAttrName` should correspond to the name of an 133// DenseI32ArrayAttr argument that provides the sizes of the inner variadic 134// operand groups. 135class VariadicOfVariadic<Type type, string variadicSegmentAttrName> 136 : Variadic<type> { 137 string segmentAttrName = variadicSegmentAttrName; 138} 139 140// An optional type constraint. It expands to either zero or one of the base 141// type. This class is used for supporting optional operands/results. 142class Optional<Type type> : TypeConstraint<type.predicate, type.summary, 143 type.cppType> { 144 Type baseType = type; 145} 146 147// A type that can be constructed using MLIR::Builder. 148// Note that this does not "inherit" from Type because it would require 149// duplicating Type subclasses for buildable and non-buildable cases to avoid 150// diamond "inheritance". 151// TODO: we may extend this to a more general 'Buildable' trait, making some 152// Types and some Attrs buildable. 153class BuildableType<code builder> { 154 // The builder call to invoke (if specified) to construct the BuildableType. 155 code builderCall = builder; 156} 157 158// A type that's buildable iff the type passed as an argument is buildable. 159// This is intended for use by types like container types, which are only 160// buildable if the type of their elements is buildable. 161class SameBuildabilityAs<Type type, code builder> { 162 code builderCall = !if(!empty(type.builderCall), "", builder); 163} 164 165// Any type at all. 166def AnyType : Type<CPred<"true">, "any type">; 167 168// None type 169def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type", 170 "::mlir::NoneType">, 171 BuildableType<"$_builder.getType<::mlir::NoneType>()">; 172 173// Any type from the given list 174class AnyTypeOf<list<Type> allowedTypeList, string summary = "", 175 string cppType = "::mlir::Type"> : Type< 176 // Satisfy any of the allowed types' conditions. 177 Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>, 178 !if(!eq(summary, ""), 179 !interleave(!foreach(t, allowedTypeList, t.summary), " or "), 180 summary), 181 cppType> { 182 list<Type> allowedTypes = allowedTypeList; 183} 184 185// A type that satisfies the constraints of all given types. 186class AllOfType<list<Type> allowedTypeList, string summary = "", 187 string cppType = "::mlir::Type"> : Type< 188 // Satisfy all of the allowed types' conditions. 189 And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>, 190 !if(!eq(summary, ""), 191 !interleave(!foreach(t, allowedTypeList, t.summary), " and "), 192 summary), 193 cppType> { 194 list<Type> allowedTypes = allowedTypeList; 195} 196 197// A type that satisfies additional predicates. 198class ConfinedType<Type type, list<Pred> predicates, string summary = "", 199 string cppType = type.cppType> : Type< 200 And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>, 201 summary, cppType> { 202 Type baseType = type; 203 list<Pred> predicateList = predicates; 204} 205 206// Integer types. 207 208// Any integer type irrespective of its width and signedness semantics. 209def AnyInteger : Type<CPred<"::llvm::isa<::mlir::IntegerType>($_self)">, "integer", 210 "::mlir::IntegerType">; 211 212// Any integer type (regardless of signedness semantics) of a specific width. 213class AnyI<int width> 214 : Type<CPred<"$_self.isInteger(" # width # ")">, width # "-bit integer"> { 215 int bitwidth = width; 216} 217 218class AnyIntOfWidths<list<int> widths> : 219 AnyTypeOf<!foreach(w, widths, AnyI<w>), 220 !interleave(widths, "/") # "-bit integer", 221 "::mlir::IntegerType">; 222 223def AnyI1 : AnyI<1>; 224def AnyI8 : AnyI<8>; 225def AnyI16 : AnyI<16>; 226def AnyI32 : AnyI<32>; 227def AnyI64 : AnyI<64>; 228 229// Any signless integer type irrespective of its width. 230def AnySignlessInteger : Type< 231 CPred<"$_self.isSignlessInteger()">, "signless integer", 232 "::mlir::IntegerType">; 233 234// Signless integer type of a specific width. 235class I<int width> 236 : Type<CPred<"$_self.isSignlessInteger(" # width # ")">, 237 width # "-bit signless integer", "::mlir::IntegerType">, 238 BuildableType<"$_builder.getIntegerType(" # width # ")"> { 239 int bitwidth = width; 240} 241 242class SignlessIntOfWidths<list<int> widths> : 243 AnyTypeOf<!foreach(w, widths, I<w>), 244 !interleave(widths, "/") # "-bit signless integer">; 245 246def I1 : I<1>; 247def I8 : I<8>; 248def I16 : I<16>; 249def I32 : I<32>; 250def I64 : I<64>; 251def I128 : I<128>; 252 253// Any signed integer type irrespective of its width. 254def AnySignedInteger : Type< 255 CPred<"$_self.isSignedInteger()">, "signed integer">; 256 257// Signed integer type of a specific width. 258class SI<int width> 259 : Type<CPred<"$_self.isSignedInteger(" # width # ")">, 260 width # "-bit signed integer", "::mlir::IntegerType">, 261 BuildableType< 262 "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> { 263 int bitwidth = width; 264} 265 266class SignedIntOfWidths<list<int> widths> : 267 AnyTypeOf<!foreach(w, widths, SI<w>), 268 !interleave(widths, "/") # "-bit signed integer">; 269 270def SI1 : SI<1>; 271def SI8 : SI<8>; 272def SI16 : SI<16>; 273def SI32 : SI<32>; 274def SI64 : SI<64>; 275 276// Any unsigned integer type irrespective of its width. 277def AnyUnsignedInteger : Type< 278 CPred<"$_self.isUnsignedInteger()">, "unsigned integer">; 279 280// Unsigned integer type of a specific width. 281class UI<int width> 282 : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">, 283 width # "-bit unsigned integer", "::mlir::IntegerType">, 284 BuildableType< 285 "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> { 286 int bitwidth = width; 287} 288 289class UnsignedIntOfWidths<list<int> widths> : 290 AnyTypeOf<!foreach(w, widths, UI<w>), 291 !interleave(widths, "/") # "-bit unsigned integer">; 292 293def UI1 : UI<1>; 294def UI8 : UI<8>; 295def UI16 : UI<16>; 296def UI32 : UI<32>; 297def UI64 : UI<64>; 298 299// Index type. 300def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index", 301 "::mlir::IndexType">, 302 BuildableType<"$_builder.getIndexType()">; 303 304// Any signless integer type or index type. 305def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">, 306 "signless integer or index">; 307 308// Floating point types. 309 310// Any float type irrespective of its width. 311def AnyFloat : Type<CPred<"::llvm::isa<::mlir::FloatType>($_self)">, "floating-point", 312 "::mlir::FloatType">; 313 314// Float type of a specific width. 315class F<int width> 316 : Type<CPred<"$_self.isF" # width # "()">, 317 width # "-bit float", "::mlir::FloatType">, 318 BuildableType<"$_builder.getF" # width # "Type()"> { 319 int bitwidth = width; 320} 321 322class FloatOfWidths<list<int> widths> : 323 AnyTypeOf<!foreach(w, widths, F<w>), 324 !interleave(widths, "/") # "-bit float">; 325 326def F16 : F<16>; 327def F32 : F<32>; 328def F64 : F<64>; 329def F80 : F<80>; 330def F128 : F<128>; 331 332def BF16 : Type<CPred<"::llvm::isa<::mlir::BFloat16Type>($_self)">, "bfloat16 type">, 333 BuildableType<"$_builder.getType<BFloat16Type>()">; 334def TF32 : Type<CPred<"::llvm::isa<::mlir::FloatTF32Type>($_self)">, "tf32 type">, 335 BuildableType<"$_builder.getType<FloatTF32Type>()">; 336def F8E4M3FN : Type<CPred<"::llvm::isa<::mlir::Float8E4M3FNType>($_self)">, "f8E4M3FN type">, 337 BuildableType<"$_builder.getType<Float8E4M3FNType>()">; 338def F8E5M2 : Type<CPred<"::llvm::isa<::mlir::Float8E5M2Type>($_self)">, "f8E5M2 type">, 339 BuildableType<"$_builder.getType<Float8E5M2Type>()">; 340def F8E4M3 : Type<CPred<"::llvm::isa<::mlir::Float8E4M3Type>($_self)">, "f8E4M3 type">, 341 BuildableType<"$_builder.getType<Float8E4M3Type>()">; 342def F8E4M3FNUZ : Type<CPred<"::llvm::isa<::mlir::Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">, 343 BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">; 344def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<::mlir::Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">, 345 BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">; 346def F8E5M2FNUZ : Type<CPred<"::llvm::isa<::mlir::Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">, 347 BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">; 348def F8E3M4 : Type<CPred<"::llvm::isa<::mlir::Float8E3M4Type>($_self)">, "f8E3M4 type">, 349 BuildableType<"$_builder.getType<Float8E3M4Type>()">; 350def F4E2M1FN : Type<CPred<"::llvm::isa<::mlir::Float4E2M1FNType>($_self)">, "f4E2M1FN type">, 351 BuildableType<"$_builder.getType<Float4E2M1FNType>()">; 352def F6E2M3FN : Type<CPred<"::llvm::isa<::mlir::Float6E2M3FNType>($_self)">, "f6E2M3FN type">, 353 BuildableType<"$_builder.getType<Float6E2M3FNType>()">; 354def F6E3M2FN : Type<CPred<"::llvm::isa<::mlir::Float6E3M2FNType>($_self)">, "f6E3M2FN type">, 355 BuildableType<"$_builder.getType<Float6E3M2FNType>()">; 356def F8E8M0FNU : Type<CPred<"::llvm::isa<::mlir::Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">, 357 BuildableType<"$_builder.getType<Float8E8M0FNUType>()">; 358 359def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">, 360 "complex-type", "::mlir::ComplexType">; 361 362class Complex<Type elType> 363 : ConfinedType<AnyComplex, [ 364 SubstLeaves<"$_self", 365 "::llvm::cast<::mlir::ComplexType>($_self).getElementType()", 366 elType.predicate>], 367 "complex type with " # elType.summary # " elements", 368 "::mlir::ComplexType">, 369 SameBuildabilityAs<elType, "::mlir::ComplexType::get($_builder.get" # elType # 370 "Type())"> { 371 Type elementType = elType; 372} 373 374class OpaqueType<string dialect, string name, string summary> 375 : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">, 376 summary, "::mlir::OpaqueType">, 377 BuildableType<"::mlir::OpaqueType::get(" 378 "$_builder.getStringAttr(\"" # dialect # "\"), \"" 379 # name # "\")">; 380 381// Function Type 382 383// Any function type. 384def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">, 385 "function type", "::mlir::FunctionType">; 386 387// A container type is a type that has another type embedded within it. 388class ContainerType<Type etype, Pred containerPred, code elementTypeCall, 389 string descr, string cppType = "::mlir::Type"> : 390 // First, check the container predicate. Then, substitute the extracted 391 // element into the element type checker. 392 Type<And<[containerPred, 393 SubstLeaves<"$_self", !cast<string>(elementTypeCall), 394 etype.predicate>]>, 395 descr # " of " # etype.summary # " values", cppType>; 396 397class ShapedContainerType<list<Type> allowedTypes, 398 Pred containerPred, string descr, 399 string cppType = "::mlir::Type"> : 400 Type<And<[containerPred, 401 Concat<"[](::mlir::Type elementType) { return ", 402 SubstLeaves<"$_self", "elementType", 403 AnyTypeOf<allowedTypes>.predicate>, 404 "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>, 405 descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>; 406 407// Whether a shaped type is ranked. 408def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">; 409 410// Whether a shaped type has one of the specified ranks. 411class HasAnyRankOfPred<list<int> ranks> : And<[ 412 HasRankPred, 413 Or<!foreach(rank, ranks, 414 CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() 415 == }] 416 # rank>)>]>; 417 418// Whether a shaped type has a rank greater than or equal of the specified rank. 419class HasRankGreaterOrEqualPred<int rank> : And<[ 420 HasRankPred, 421 CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank> 422]>; 423 424// Container with value semantics. 425class ValueSemanticsContainerOf<list<Type> allowedTypes> : 426 ShapedContainerType<allowedTypes, HasValueSemanticsPred, 427 "container with value semantics">; 428 429// Vector types. 430 431class VectorOfNonZeroRankOf<list<Type> allowedTypes> : 432 ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector", 433 "::mlir::VectorType">; 434 435class FixedVectorOfNonZeroRankOf<list<Type> allowedTypes> : 436 ShapedContainerType<allowedTypes, IsFixedVectorOfNonZeroRankTypePred, 437 "fixed-length vector", "::mlir::VectorType">; 438 439// Temporary vector type clone that allows gradual transition to 0-D vectors. 440// TODO: Remove this when all ops support 0-D vectors. 441class VectorOfAnyRankOf<list<Type> allowedTypes> : 442 ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector", 443 "::mlir::VectorType">; 444 445class FixedVectorOfAnyRank<list<Type> allowedTypes> : 446 ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred, 447 "fixed-length vector", "::mlir::VectorType">; 448 449class ScalableVectorOfAnyRank<list<Type> allowedTypes> : 450 ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred, 451 "scalable vector", "::mlir::VectorType">; 452 453// Any vector with a single trailing scalable dimension, with an element type in 454// the `allowedTypes` list. 455// 456// Note: This Similar to ScalableVectorOf, with the extra requirement that only 457// the trailing dim is scalable. 458class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> : 459 ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred, 460 "trailing scalable vector", "::mlir::VectorType">; 461 462// Whether the number of elements of a vector is from the given 463// `allowedRanks` list 464class IsVectorOfRankPred<list<int> allowedRanks> : 465 And<[IsVectorOfNonZeroRankTypePred, 466 Or<!foreach(allowedlength, allowedRanks, 467 CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank() 468 == }] 469 # allowedlength>)>]>; 470 471// Whether the number of elements of a fixed-length vector is from the given 472// `allowedRanks` list 473class IsFixedVectorOfRankPred<list<int> allowedRanks> : 474 And<[IsFixedVectorOfAnyRankTypePred, 475 Or<!foreach(allowedlength, allowedRanks, 476 CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank() 477 == }] 478 # allowedlength>)>]>; 479 480// Whether the number of elements of a scalable vector is from the given 481// `allowedRanks` list 482class IsScalableVectorOfRankPred<list<int> allowedRanks> : 483 And<[IsVectorTypeWithAnyDimScalablePred, 484 Or<!foreach(allowedlength, allowedRanks, 485 CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank() 486 == }] 487 # allowedlength>)>]>; 488 489// Any vector where the rank is from the given `allowedRanks` list 490class VectorOfRank<list<int> allowedRanks> : Type< 491 IsVectorOfRankPred<allowedRanks>, 492 " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; 493 494// Any fixed-length vector where the rank is from the given `allowedRanks` list 495class FixedVectorOfRank<list<int> allowedRanks> : Type< 496 IsFixedVectorOfRankPred<allowedRanks>, 497 " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; 498 499// Any scalable vector where the rank is from the given `allowedRanks` list 500class ScalableVectorOfRank<list<int> allowedRanks> : Type< 501 IsScalableVectorOfRankPred<allowedRanks>, 502 " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; 503 504// Any vector where the rank is from the given `allowedRanks` list and the type 505// is from the given `allowedTypes` list 506class VectorOfRankAndType<list<int> allowedRanks, 507 list<Type> allowedTypes> : AllOfType< 508 [VectorOfNonZeroRankOf<allowedTypes>, VectorOfRank<allowedRanks>], 509 VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary, 510 "::mlir::VectorType">; 511 512// Fixed-width vector where the rank is from the given `allowedRanks` list and 513// the type is from the given `allowedTypes` list 514class FixedVectorOfRankAndType<list<int> allowedRanks, 515 list<Type> allowedTypes> : AllOfType< 516 [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>], 517 FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary, 518 "::mlir::VectorType">; 519 520// Whether the number of elements of a vector is from the given 521// `allowedLengths` list 522class IsVectorOfLengthPred<list<int> allowedLengths> : 523 And<[IsVectorOfNonZeroRankTypePred, 524 Or<!foreach(allowedlength, allowedLengths, 525 CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements() 526 == }] 527 # allowedlength>)>]>; 528 529// Whether the number of elements of a fixed-length vector is from the given 530// `allowedLengths` list 531class IsFixedVectorOfLengthPred<list<int> allowedLengths> : 532 And<[IsFixedVectorOfAnyRankTypePred, 533 Or<!foreach(allowedlength, allowedLengths, 534 CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements() 535 == }] 536 # allowedlength>)>]>; 537 538// Whether the number of elements of a scalable vector is from the given 539// `allowedLengths` list 540class IsScalableVectorOfLengthPred<list<int> allowedLengths> : 541 And<[IsVectorTypeWithAnyDimScalablePred, 542 Or<!foreach(allowedlength, allowedLengths, 543 CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements() 544 == }] 545 # allowedlength>)>]>; 546 547// Normalizes an index so the indices in both directions have the same value. 548// For example, when indexing forwards index 2 is the third element. When 549// indexing in reverse the third element is -3. This helper would map both of 550// these to the "normalized" index of 3. This makes the bounds checking in 551// IsNthDimSizeIsOneOfPred simpler (see first CPred). 552class NormalizeIndex<int value> { 553 int ret = !if(!lt(value, 0), 554 !sub(0, value) /* -value if negative */, 555 !add(value, 1) /* value + 1 if positive*/); 556} 557 558// Whether the n-th dim of the shape is contained within `allowedSizes`. 559// Negative values for `n` index in reverse. 560// 561// Examples: 562// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}> 563// - Accepts any shape where the first dim is 2, 3, or 4. 564// * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc 565// IsNthDimSizeIsOneOfPred<-1, {16}> 566// - Accepts any shape where the last dim is 16. 567// * This means shapes like 2x16, 16, 1x2x3x4x16, etc 568// IsNthDimSizeIsOneOfPred<-2, {10, 5}> 569// - Accepts any shape where the second to last dim is 10 or 5. 570// * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc 571class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes> 572 : And<[ 573 CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>, 574 CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), " 575 # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize(" 576 # !if(!lt(n, 0), 577 "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n, 578 "" # n) 579 # "))">]>; 580 581// Whether the shape of a vector matches the given `shape` list. 582class IsVectorOfShape<list<int> shape> 583 : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">; 584 585// Any vector where the number of elements is from the given 586// `allowedLengths` list 587class VectorOfLength<list<int> allowedLengths> : Type< 588 IsVectorOfLengthPred<allowedLengths>, 589 " of length " # !interleave(allowedLengths, "/"), 590 "::mlir::VectorType">; 591 592// Any fixed-length vector where the number of elements is from the given 593// `allowedLengths` list 594class FixedVectorOfLength<list<int> allowedLengths> : Type< 595 IsFixedVectorOfLengthPred<allowedLengths>, 596 " of length " # !interleave(allowedLengths, "/"), 597 "::mlir::VectorType">; 598 599// Any scalable vector where the number of elements is from the given 600// `allowedLengths` list 601class ScalableVectorOfLength<list<int> allowedLengths> : Type< 602 IsScalableVectorOfLengthPred<allowedLengths>, 603 " of length " # !interleave(allowedLengths, "/"), 604 "::mlir::VectorType">; 605 606// Any vector where the number of elements is from the given 607// `allowedLengths` list and the type is from the given `allowedTypes` 608// list 609class VectorOfLengthAndType<list<int> allowedLengths, 610 list<Type> allowedTypes> : AllOfType< 611 [VectorOfNonZeroRankOf<allowedTypes>, VectorOfLength<allowedLengths>], 612 VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary, 613 "::mlir::VectorType">; 614 615// Any fixed-length vector where the number of elements is from the given 616// `allowedLengths` list and the type is from the given `allowedTypes` list 617class FixedVectorOfLengthAndType<list<int> allowedLengths, 618 list<Type> allowedTypes> : AllOfType< 619 [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>], 620 FixedVectorOfAnyRank<allowedTypes>.summary # 621 FixedVectorOfLength<allowedLengths>.summary, 622 "::mlir::VectorType">; 623 624// Any scalable vector where the number of elements is from the given 625// `allowedLengths` list and the type is from the given `allowedTypes` list 626class ScalableVectorOfLengthAndType<list<int> allowedLengths, 627 list<Type> allowedTypes> : AllOfType< 628 [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>], 629 ScalableVectorOfAnyRank<allowedTypes>.summary # 630 ScalableVectorOfLength<allowedLengths>.summary, 631 "::mlir::VectorType">; 632 633// Any scalable vector where the rank is from the given `allowedRanks` list and 634// the number of elements is from the given `allowedLengths` list and the type 635// is from the given `allowedTypes` list 636class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks, 637 list<int> allowedLengths, 638 list<Type> allowedTypes> : AllOfType< 639 [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>, 640 ScalableVectorOfLength<allowedLengths>], 641 ScalableVectorOfRank<allowedRanks>.summary # 642 ScalableVectorOfAnyRank<allowedTypes>.summary # 643 ScalableVectorOfLength<allowedLengths>.summary, 644 "::mlir::VectorType">; 645 646// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`. 647// Negative values for `n` index in reverse. 648class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type< 649 IsNthDimSizeIsOneOfPred<n, allowedSizes>, 650 " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}", 651 "::mlir::ShapedType">; 652 653// Any scalable vector with a single trailing scalable dimensions, where the 654// size of the trailing dimension is in `allowedTrailingSizes` list, and the 655// type is in the `allowedTypes` list. 656class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes, 657 list<Type> allowedTypes> : AllOfType< 658 [VectorWithTrailingDimScalableOf<allowedTypes>, 659 ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>], 660 VectorWithTrailingDimScalableOf<allowedTypes>.summary # 661 ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary, 662 "::mlir::VectorType">; 663 664// Unlike the following definitions, this one excludes 0-D vectors 665def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>; 666 667def AnyFixedVectorOfNonZeroRank : FixedVectorOfNonZeroRankOf<[AnyType]>; 668 669def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; 670 671def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>; 672 673def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>; 674 675// Shaped types. 676 677def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", 678 "::mlir::ShapedType">; 679 680//===----------------------------------------------------------------------===// 681// Tensor types. 682 683// Unranked tensor type whose element type is from the given `allowedTypes` 684// list, and which additionally satisfies an optional list of predicates. 685class UnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], 686 string summary = "unranked tensor"> 687 : ShapedContainerType< 688 allowedTypes, And<!listconcat([IsUnrankedTensorTypePred], preds)>, 689 summary, "::mlir::UnrankedTensorType">; 690 691// Ranked tensor type whose element type is from the given `allowedTypes` list, 692// and which additionally satisfies an optional list of predicates. 693class RankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], 694 string summary = "ranked tensor"> 695 : ShapedContainerType< 696 allowedTypes, And<!listconcat([IsRankedTensorTypePred], preds)>, 697 summary, "::mlir::RankedTensorType">; 698 699// Any tensor type whose element type is from the given `allowedTypes` 700// list, and which additionally satisfies an optional list of predicates. 701// 702// TODO: use `Constraint` instead of `Pred`, so we can generate a better 703// default summary (a la `ConfinedAttr`). 704class TensorOf< 705 list<Type> allowedTypes, 706 list<Pred> preds = [], 707 string summary = "tensor"> 708 : ShapedContainerType<allowedTypes, 709 And<!listconcat([IsTensorTypePred], preds)>, 710 summary, "::mlir::TensorType">; 711 712def AnyTensor : TensorOf<[AnyType]>; 713 714def I1Tensor : TensorOf<[I1]>; 715def I8Tensor : TensorOf<[I8]>; 716def I16Tensor : TensorOf<[I16]>; 717def I32Tensor : TensorOf<[I32]>; 718def I64Tensor : TensorOf<[I64]>; 719def IndexTensor: TensorOf<[Index]>; 720 721def BF16Tensor : TensorOf<[BF16]>; 722def F16Tensor : TensorOf<[F16]>; 723def F32Tensor : TensorOf<[F32]>; 724def F64Tensor : TensorOf<[F64]>; 725 726class Non0RankedTensorOf<list<Type> allowedTypes> 727 : TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>], 728 "non-0-ranked.tensor">; 729 730def AnyRankedTensor : RankedTensorOf<[AnyType]>; 731def AnyNon0RankedTensor : Non0RankedTensorOf<[AnyType]>; 732def AnyUnrankedTensor : UnrankedTensorOf<[AnyType]>; 733 734def AnyNon0RankedOrUnrankedTensor 735 : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor], 736 "non-0-ranked or unranked tensor", "::mlir::TensorType">; 737 738// Ranked tensor type with one of the specified types and ranks. 739class TensorRankOf<list<Type> allowedTypes, list<int> ranks> 740 : RankedTensorOf<allowedTypes, 741 [HasAnyRankOfPred<ranks>], 742 !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; 743 744class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>; 745class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>; 746class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>; 747class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>; 748class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>; 749 750class StaticShapeTensorOf<list<Type> allowedTypes> 751 : RankedTensorOf<allowedTypes, [HasStaticShapePred], 752 "statically shaped tensor">; 753 754def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; 755 756//===----------------------------------------------------------------------===// 757// Memref type. 758 759// Any unranked memref whose element type is from the given `allowedTypes` list. 760class UnrankedMemRefOf<list<Type> allowedTypes> : 761 ShapedContainerType<allowedTypes, 762 IsUnrankedMemRefTypePred, "unranked.memref", 763 "::mlir::UnrankedMemRefType">; 764 765def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>; 766 767// Any ranked memref whose element type is from the given `allowedTypes` list. 768class MemRefOf<list<Type> allowedTypes> : 769 ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref", 770 "::mlir::MemRefType">; 771 772class Non0RankedMemRefOf<list<Type> allowedTypes> : 773 ConfinedType<MemRefOf<allowedTypes>, [HasRankGreaterOrEqualPred<1>], 774 "non-0-ranked." # MemRefOf<allowedTypes>.summary, 775 "::mlir::MemRefType">; 776 777def AnyMemRef : MemRefOf<[AnyType]>; 778def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>; 779 780// Any memref (ranked or unranked) whose element type is from the given 781// `allowedTypes` list, and which additionally satisfies an optional list of 782// predicates. 783class RankedOrUnrankedMemRefOf< 784 list<Type> allowedTypes, 785 list<Pred> preds = [], 786 string summary = "ranked or unranked memref"> 787 : ShapedContainerType<allowedTypes, 788 And<!listconcat([IsBaseMemRefTypePred], preds)>, 789 summary, "::mlir::BaseMemRefType">; 790 791def AnyRankedOrUnrankedMemRef : RankedOrUnrankedMemRefOf<[AnyType]>; 792def AnyNon0RankedOrUnrankedMemRef: 793 AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>; 794 795// Memref declarations handle any memref, independent of rank, size, (static or 796// dynamic), layout, or memory space. 797def I1MemRef : MemRefOf<[I1]>; 798def I8MemRef : MemRefOf<[I8]>; 799def I16MemRef : MemRefOf<[I16]>; 800def I32MemRef : MemRefOf<[I32]>; 801def I64MemRef : MemRefOf<[I64]>; 802 803def BF16MemRef : MemRefOf<[BF16]>; 804def F16MemRef : MemRefOf<[F16]>; 805def F32MemRef : MemRefOf<[F32]>; 806def F64MemRef : MemRefOf<[F64]>; 807 808// TODO: Have an easy way to add another constraint to a type. 809class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> : 810 ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>], 811 !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # 812 MemRefOf<allowedTypes>.summary, 813 "::mlir::MemRefType">; 814 815class StaticShapeMemRefOf<list<Type> allowedTypes> : 816 ConfinedType<MemRefOf<allowedTypes>, [HasStaticShapePred], 817 "statically shaped " # MemRefOf<allowedTypes>.summary, 818 "::mlir::MemRefType">; 819 820def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; 821 822// For a MemRefType, verify that it has strides. 823def HasStridesPred : CPred<[{ ::llvm::cast<::mlir::MemRefType>($_self).isStrided() }]>; 824 825class StridedMemRefOf<list<Type> allowedTypes> : 826 ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred], 827 "strided " # MemRefOf<allowedTypes>.summary>; 828 829def AnyStridedMemRef : StridedMemRefOf<[AnyType]>; 830 831class AnyStridedMemRefOfRank<int rank> : 832 AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>], 833 AnyStridedMemRef.summary # " of rank " # rank>; 834 835class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> : 836 ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>], 837 !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # 838 MemRefOf<allowedTypes>.summary>; 839 840// This represents a generic tuple without any constraints on element type. 841def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">; 842 843// A container type that has other types embedded in it, but (unlike 844// ContainerType) can hold elements with a mix of types. Requires a call that 845// produces a list of all elements' types. 846class MixedContainerType<Type etype, Pred containerPred, code elementTypesCall, 847 string descr> : 848 Type< 849 And<[ 850 containerPred, 851 Concat< 852 "::llvm::all_of(" # elementTypesCall # ", [](::mlir::Type t) { " 853 "return t && (", 854 SubstLeaves<"$_self", "t", etype.predicate>, 855 "); })" 856 > 857 ]>, 858 descr # " with any combination of " # etype.summary # " values"> { 859 // The type of elements in the container. 860 Type elementType = etype; 861 862 // Call to retrieve. 863 code getElementTypesCall = elementTypesCall; 864} 865 866// A Tuple that holds a mix of elements of the allowed types. 867class TupleOf<list<Type> allowedTypes> 868 : MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred, 869 "::llvm::cast<::mlir::TupleType>($_self).getTypes()", 870 "tuple">; 871 872// A Tuple with arbitrary nesting, where all elements are a mix of the allowed 873// types. 874class NestedTupleOf<list<Type> allowedTypes> : 875 MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred, 876 "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))", 877 "nested tuple">; 878 879//===----------------------------------------------------------------------===// 880// Common type constraints 881//===----------------------------------------------------------------------===// 882// Type constraint for types that are "like" some type or set of types T, that is 883// they're either a T, a vector of Ts, or a tensor of Ts. 884class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[ 885 allowedType.predicate, 886 ValueSemanticsContainerOf<[allowedType]>.predicate]>, 887 name>; 888 889// Type constraint for types that are "like" some type or set of types T, that is 890// they're either a T or a mapable container of Ts. 891class TypeOrValueSemanticsContainer<Type allowedType, string name> 892 : TypeConstraint<Or<[ 893 allowedType.predicate, 894 ValueSemanticsContainerOf<[allowedType]>.predicate]>, 895 name>; 896 897// Temporary constraint to allow gradual transition to supporting 0-D vectors. 898// TODO: Remove this when all ops support 0-D vectors. 899class TypeOrContainerOfAnyRank<Type allowedType, string name> : TypeConstraint<Or<[ 900 allowedType.predicate, VectorOfAnyRankOf<[allowedType]>.predicate, 901 TensorOf<[allowedType]>.predicate]>, 902 name>; 903 904 905// Type constraint for bool-like types: bools, vectors of bools, tensors of 906// bools. 907def BoolLike : TypeOrContainer<I1, "bool-like">; 908 909def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">; 910 911// Type constraint for signless-integer-like types: signless integers, 912// vectors of signless integers or tensors of signless integers. 913def SignlessIntegerLike : TypeOrValueSemanticsContainer< 914 AnySignlessInteger, "signless-integer">; 915 916// Type constraint for signless-integer-like types: signless integers, indices, 917// vectors of signless integers or indices, tensors of signless integers. 918def SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer< 919 AnySignlessIntegerOrIndex, "signless-integer-like">; 920 921def SignlessIntegerOrIndexLikeOfAnyRank : TypeOrContainerOfAnyRank< 922 AnySignlessIntegerOrIndex, 923 "signless-integer-like">; 924 925// Type constraint for float-like types: floats, vectors or tensors thereof. 926def FloatLike : TypeOrContainer<AnyFloat, "floating-point-like">; 927 928// Type constraint for signless-integer-or-index-like or float-like types. 929def SignlessIntegerOrFloatLike : TypeConstraint<Or<[ 930 SignlessIntegerLike.predicate, FloatLike.predicate]>, 931 "signless-integer-like or floating-point-like">; 932 933// Type constraint for signless-integer-or-index-like or float-like types. 934def SignlessIntegerOrIndexOrFloatLike : TypeConstraint<Or<[ 935 SignlessIntegerOrIndexLike.predicate, FloatLike.predicate]>, 936 "signless-integer-or-index-like or floating-point-like">; 937 938#endif // COMMON_TYPE_CONSTRAINTS_TD 939