xref: /llvm-project/mlir/include/mlir/IR/CommonTypeConstraints.td (revision 990837f91de329b1e045f90fadb86ffe21611d9a)
1//===-- CommonTypeConstraints.td - Common Type Constraints--*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains commonly used type constraints.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef COMMON_TYPE_CONSTRAINTS_TD
14#define COMMON_TYPE_CONSTRAINTS_TD
15
16include "mlir/IR/Constraints.td"
17include "mlir/IR/DialectBase.td"
18
19//===----------------------------------------------------------------------===//
20// Common predicates
21//===----------------------------------------------------------------------===//
22
23// Whether a type is a VectorType.
24// Explicitly disallow 0-D vectors for now until we have good enough coverage.
25def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
26                                         CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
27def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
28                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
29
30// Temporary vector type clone that allows gradual transition to 0-D vectors.
31// TODO: Remove this when all ops support 0-D vectors.
32def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
33
34// Whether a type is a fixed-length VectorType.
35def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
36
37// Whether a type is a scalable VectorType.
38def IsVectorTypeWithAnyDimScalablePred
39        : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
40
41// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
42// Examples:
43// Valid:
44//   - vector<[4]xf32>, vector<2x3x[2]xi64>, vector<32x[8]xi32>
45// Invalid
46//   - vector<[4]x8xi32>, vector<[2]x[2]xf64>, vector<2x[8]x4xi32>
47def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
48  CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
49  CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
50  CPred<"::llvm::cast<::mlir::VectorType>($_self).getScalableDims().back()">,
51  CPred<"!llvm::is_contained(::llvm::cast<::mlir::VectorType>($_self).getScalableDims().drop_back(), true)">
52]>;
53
54// Whether a type is a VectorType and all dimensions are scalable.
55def IsVectorTypeWithAllDimsScalablePred : And<[
56  IsVectorOfNonZeroRankTypePred,
57  CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>
58]>;
59
60// Whether a type is a TensorType.
61def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
62
63// Whether a type is a MemRefType.
64def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
65
66// Whether a type is an UnrankedMemRefType
67def IsUnrankedMemRefTypePred
68        : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
69
70// Whether a type is an UnrankedTensorType
71def IsUnrankedTensorTypePred
72        : CPred<"::llvm::isa<::mlir::UnrankedTensorType>($_self)">;
73
74// Whether a type is a RankedTensorType
75def IsRankedTensorTypePred
76        : CPred<"::llvm::isa<::mlir::RankedTensorType>($_self)">;
77
78// Whether a type is a BaseMemRefType
79def IsBaseMemRefTypePred
80        : CPred<"::llvm::isa<::mlir::BaseMemRefType>($_self)">;
81
82// Whether a type is a ShapedType.
83def IsShapedTypePred : CPred<"::llvm::isa<::mlir::ShapedType>($_self)">;
84
85// For a ShapedType, verify that it has a static shape.
86def HasStaticShapePred :
87        CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasStaticShape()">;
88
89// Whether a type is a TupleType.
90def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">;
91
92// Whether a type has a ValueSemantics trait.
93def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
94
95//===----------------------------------------------------------------------===//
96// Type definitions
97//===----------------------------------------------------------------------===//
98
99// A type, carries type constraints.
100class Type<Pred condition, string descr = "",
101           string cppType = "::mlir::Type"> :
102    TypeConstraint<condition, descr, cppType> {
103  string description = "";
104  string builderCall = "";
105}
106
107// Allows providing an alternative name and summary to an existing type def.
108class TypeAlias<Type t, string summary = t.summary> :
109    Type<t.predicate, summary, t.cppType> {
110  let description = t.description;
111  let builderCall = t.builderCall;
112}
113
114// A type of a specific dialect.
115class DialectType<Dialect d, Pred condition, string descr = "",
116                  string cppType = "::mlir::Type"> :
117    Type<condition, descr, cppType> {
118  Dialect dialect = d;
119}
120
121// A variadic type constraint. It expands to zero or more of the base type. This
122// class is used for supporting variadic operands/results.
123class Variadic<Type type> : TypeConstraint<type.predicate,
124                                           "variadic of " # type.summary,
125                                           type.cppType> {
126  Type baseType = type;
127  int minSize = 0;
128}
129
130// A nested variadic type constraint. It expands to zero or more variadic ranges
131// of the base type. This class is used for supporting variadic operands and
132// results. `variadicSegmentAttrName` should correspond to the name of an
133// DenseI32ArrayAttr argument that provides the sizes of the inner variadic
134// operand groups.
135class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
136    : Variadic<type> {
137  string segmentAttrName = variadicSegmentAttrName;
138}
139
140// An optional type constraint. It expands to either zero or one of the base
141// type. This class is used for supporting optional operands/results.
142class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
143                                           type.cppType> {
144  Type baseType = type;
145}
146
147// A type that can be constructed using MLIR::Builder.
148// Note that this does not "inherit" from Type because it would require
149// duplicating Type subclasses for buildable and non-buildable cases to avoid
150// diamond "inheritance".
151// TODO: we may extend this to a more general 'Buildable' trait, making some
152// Types and some Attrs buildable.
153class BuildableType<code builder> {
154  // The builder call to invoke (if specified) to construct the BuildableType.
155  code builderCall = builder;
156}
157
158// A type that's buildable iff the type passed as an argument is buildable.
159// This is intended for use by types like container types, which are only
160// buildable if the type of their elements is buildable.
161class SameBuildabilityAs<Type type, code builder> {
162  code builderCall = !if(!empty(type.builderCall), "", builder);
163}
164
165// Any type at all.
166def AnyType : Type<CPred<"true">, "any type">;
167
168// None type
169def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
170                    "::mlir::NoneType">,
171      BuildableType<"$_builder.getType<::mlir::NoneType>()">;
172
173// Any type from the given list
174class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
175                string cppType = "::mlir::Type"> : Type<
176    // Satisfy any of the allowed types' conditions.
177    Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
178    !if(!eq(summary, ""),
179        !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
180        summary),
181    cppType> {
182  list<Type> allowedTypes = allowedTypeList;
183}
184
185// A type that satisfies the constraints of all given types.
186class AllOfType<list<Type> allowedTypeList, string summary = "",
187                string cppType = "::mlir::Type"> : Type<
188    // Satisfy all of the allowed types' conditions.
189    And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
190    !if(!eq(summary, ""),
191        !interleave(!foreach(t, allowedTypeList, t.summary), " and "),
192        summary),
193    cppType> {
194  list<Type> allowedTypes = allowedTypeList;
195}
196
197// A type that satisfies additional predicates.
198class ConfinedType<Type type, list<Pred> predicates, string summary = "",
199                   string cppType = type.cppType> : Type<
200    And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
201    summary, cppType> {
202    Type baseType = type;
203    list<Pred> predicateList = predicates;
204}
205
206// Integer types.
207
208// Any integer type irrespective of its width and signedness semantics.
209def AnyInteger : Type<CPred<"::llvm::isa<::mlir::IntegerType>($_self)">, "integer",
210                      "::mlir::IntegerType">;
211
212// Any integer type (regardless of signedness semantics) of a specific width.
213class AnyI<int width>
214    : Type<CPred<"$_self.isInteger(" # width # ")">, width # "-bit integer"> {
215  int bitwidth = width;
216}
217
218class AnyIntOfWidths<list<int> widths> :
219    AnyTypeOf<!foreach(w, widths, AnyI<w>),
220              !interleave(widths, "/") # "-bit integer",
221              "::mlir::IntegerType">;
222
223def AnyI1  : AnyI<1>;
224def AnyI8  : AnyI<8>;
225def AnyI16 : AnyI<16>;
226def AnyI32 : AnyI<32>;
227def AnyI64 : AnyI<64>;
228
229// Any signless integer type irrespective of its width.
230def AnySignlessInteger : Type<
231  CPred<"$_self.isSignlessInteger()">, "signless integer",
232        "::mlir::IntegerType">;
233
234// Signless integer type of a specific width.
235class I<int width>
236    : Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
237                  width # "-bit signless integer", "::mlir::IntegerType">,
238      BuildableType<"$_builder.getIntegerType(" # width # ")"> {
239  int bitwidth = width;
240}
241
242class SignlessIntOfWidths<list<int> widths> :
243    AnyTypeOf<!foreach(w, widths, I<w>),
244              !interleave(widths, "/") # "-bit signless integer">;
245
246def I1  : I<1>;
247def I8  : I<8>;
248def I16 : I<16>;
249def I32 : I<32>;
250def I64 : I<64>;
251def I128 : I<128>;
252
253// Any signed integer type irrespective of its width.
254def AnySignedInteger : Type<
255  CPred<"$_self.isSignedInteger()">, "signed integer">;
256
257// Signed integer type of a specific width.
258class SI<int width>
259    : Type<CPred<"$_self.isSignedInteger(" # width # ")">,
260                  width # "-bit signed integer", "::mlir::IntegerType">,
261      BuildableType<
262        "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> {
263  int bitwidth = width;
264}
265
266class SignedIntOfWidths<list<int> widths> :
267    AnyTypeOf<!foreach(w, widths, SI<w>),
268              !interleave(widths, "/") # "-bit signed integer">;
269
270def SI1  : SI<1>;
271def SI8  : SI<8>;
272def SI16 : SI<16>;
273def SI32 : SI<32>;
274def SI64 : SI<64>;
275
276// Any unsigned integer type irrespective of its width.
277def AnyUnsignedInteger : Type<
278  CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;
279
280// Unsigned integer type of a specific width.
281class UI<int width>
282    : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
283                  width # "-bit unsigned integer", "::mlir::IntegerType">,
284      BuildableType<
285        "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> {
286  int bitwidth = width;
287}
288
289class UnsignedIntOfWidths<list<int> widths> :
290    AnyTypeOf<!foreach(w, widths, UI<w>),
291              !interleave(widths, "/") # "-bit unsigned integer">;
292
293def UI1  : UI<1>;
294def UI8  : UI<8>;
295def UI16 : UI<16>;
296def UI32 : UI<32>;
297def UI64 : UI<64>;
298
299// Index type.
300def Index : Type<CPred<"::llvm::isa<::mlir::IndexType>($_self)">, "index",
301                 "::mlir::IndexType">,
302            BuildableType<"$_builder.getIndexType()">;
303
304// Any signless integer type or index type.
305def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
306                                     "signless integer or index">;
307
308// Floating point types.
309
310// Any float type irrespective of its width.
311def AnyFloat : Type<CPred<"::llvm::isa<::mlir::FloatType>($_self)">, "floating-point",
312                    "::mlir::FloatType">;
313
314// Float type of a specific width.
315class F<int width>
316    : Type<CPred<"$_self.isF" # width # "()">,
317           width # "-bit float", "::mlir::FloatType">,
318      BuildableType<"$_builder.getF" # width # "Type()"> {
319  int bitwidth = width;
320}
321
322class FloatOfWidths<list<int> widths> :
323    AnyTypeOf<!foreach(w, widths, F<w>),
324              !interleave(widths, "/") # "-bit float">;
325
326def F16 : F<16>;
327def F32 : F<32>;
328def F64 : F<64>;
329def F80 : F<80>;
330def F128 : F<128>;
331
332def BF16 : Type<CPred<"::llvm::isa<::mlir::BFloat16Type>($_self)">, "bfloat16 type">,
333           BuildableType<"$_builder.getType<BFloat16Type>()">;
334def TF32 : Type<CPred<"::llvm::isa<::mlir::FloatTF32Type>($_self)">, "tf32 type">,
335           BuildableType<"$_builder.getType<FloatTF32Type>()">;
336def F8E4M3FN : Type<CPred<"::llvm::isa<::mlir::Float8E4M3FNType>($_self)">, "f8E4M3FN type">,
337               BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
338def F8E5M2 : Type<CPred<"::llvm::isa<::mlir::Float8E5M2Type>($_self)">, "f8E5M2 type">,
339             BuildableType<"$_builder.getType<Float8E5M2Type>()">;
340def F8E4M3 : Type<CPred<"::llvm::isa<::mlir::Float8E4M3Type>($_self)">, "f8E4M3 type">,
341             BuildableType<"$_builder.getType<Float8E4M3Type>()">;
342def F8E4M3FNUZ : Type<CPred<"::llvm::isa<::mlir::Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">,
343                 BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
344def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<::mlir::Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">,
345                 BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
346def F8E5M2FNUZ : Type<CPred<"::llvm::isa<::mlir::Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">,
347                 BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
348def F8E3M4 : Type<CPred<"::llvm::isa<::mlir::Float8E3M4Type>($_self)">, "f8E3M4 type">,
349             BuildableType<"$_builder.getType<Float8E3M4Type>()">;
350def F4E2M1FN : Type<CPred<"::llvm::isa<::mlir::Float4E2M1FNType>($_self)">, "f4E2M1FN type">,
351               BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
352def F6E2M3FN : Type<CPred<"::llvm::isa<::mlir::Float6E2M3FNType>($_self)">, "f6E2M3FN type">,
353               BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
354def F6E3M2FN : Type<CPred<"::llvm::isa<::mlir::Float6E3M2FNType>($_self)">, "f6E3M2FN type">,
355               BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
356def F8E8M0FNU : Type<CPred<"::llvm::isa<::mlir::Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">,
357                BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
358
359def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
360                      "complex-type", "::mlir::ComplexType">;
361
362class Complex<Type elType>
363    : ConfinedType<AnyComplex, [
364          SubstLeaves<"$_self",
365                      "::llvm::cast<::mlir::ComplexType>($_self).getElementType()",
366           elType.predicate>],
367           "complex type with " # elType.summary # " elements",
368           "::mlir::ComplexType">,
369      SameBuildabilityAs<elType, "::mlir::ComplexType::get($_builder.get" # elType #
370                               "Type())"> {
371  Type elementType = elType;
372}
373
374class OpaqueType<string dialect, string name, string summary>
375  : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
376         summary, "::mlir::OpaqueType">,
377    BuildableType<"::mlir::OpaqueType::get("
378                  "$_builder.getStringAttr(\"" # dialect # "\"), \""
379                  # name # "\")">;
380
381// Function Type
382
383// Any function type.
384def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
385                              "function type", "::mlir::FunctionType">;
386
387// A container type is a type that has another type embedded within it.
388class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
389                    string descr, string cppType = "::mlir::Type"> :
390    // First, check the container predicate.  Then, substitute the extracted
391    // element into the element type checker.
392    Type<And<[containerPred,
393                SubstLeaves<"$_self", !cast<string>(elementTypeCall),
394                etype.predicate>]>,
395         descr # " of " # etype.summary # " values", cppType>;
396
397class ShapedContainerType<list<Type> allowedTypes,
398                          Pred containerPred, string descr,
399                          string cppType = "::mlir::Type"> :
400    Type<And<[containerPred,
401              Concat<"[](::mlir::Type elementType) { return ",
402                SubstLeaves<"$_self", "elementType",
403                AnyTypeOf<allowedTypes>.predicate>,
404                "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
405         descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;
406
407// Whether a shaped type is ranked.
408def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;
409
410// Whether a shaped type has one of the specified ranks.
411class HasAnyRankOfPred<list<int> ranks> : And<[
412    HasRankPred,
413    Or<!foreach(rank, ranks,
414                CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank()
415                         == }]
416                      # rank>)>]>;
417
418// Whether a shaped type has a rank greater than or equal of the specified rank.
419class HasRankGreaterOrEqualPred<int rank> : And<[
420    HasRankPred,
421    CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank>
422]>;
423
424// Container with value semantics.
425class ValueSemanticsContainerOf<list<Type> allowedTypes> :
426  ShapedContainerType<allowedTypes, HasValueSemanticsPred,
427  "container with value semantics">;
428
429// Vector types.
430
431class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
432  ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
433                      "::mlir::VectorType">;
434
435class FixedVectorOfNonZeroRankOf<list<Type> allowedTypes> :
436  ShapedContainerType<allowedTypes, IsFixedVectorOfNonZeroRankTypePred,
437                      "fixed-length vector", "::mlir::VectorType">;
438
439// Temporary vector type clone that allows gradual transition to 0-D vectors.
440// TODO: Remove this when all ops support 0-D vectors.
441class VectorOfAnyRankOf<list<Type> allowedTypes> :
442  ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
443                      "::mlir::VectorType">;
444
445class FixedVectorOfAnyRank<list<Type> allowedTypes> :
446  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
447          "fixed-length vector", "::mlir::VectorType">;
448
449class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
450  ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
451          "scalable vector", "::mlir::VectorType">;
452
453// Any vector with a single trailing scalable dimension, with an element type in
454// the `allowedTypes` list.
455//
456// Note: This Similar to ScalableVectorOf, with the extra requirement that only
457// the trailing dim is scalable.
458class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
459  ShapedContainerType<allowedTypes, IsVectorTypeWithOnlyTrailingDimScalablePred,
460          "trailing scalable vector", "::mlir::VectorType">;
461
462// Whether the number of elements of a vector is from the given
463// `allowedRanks` list
464class IsVectorOfRankPred<list<int> allowedRanks> :
465  And<[IsVectorOfNonZeroRankTypePred,
466       Or<!foreach(allowedlength, allowedRanks,
467                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
468                           == }]
469                         # allowedlength>)>]>;
470
471// Whether the number of elements of a fixed-length vector is from the given
472// `allowedRanks` list
473class IsFixedVectorOfRankPred<list<int> allowedRanks> :
474  And<[IsFixedVectorOfAnyRankTypePred,
475       Or<!foreach(allowedlength, allowedRanks,
476                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
477                           == }]
478                         # allowedlength>)>]>;
479
480// Whether the number of elements of a scalable vector is from the given
481// `allowedRanks` list
482class IsScalableVectorOfRankPred<list<int> allowedRanks> :
483  And<[IsVectorTypeWithAnyDimScalablePred,
484       Or<!foreach(allowedlength, allowedRanks,
485                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
486                           == }]
487                         # allowedlength>)>]>;
488
489// Any vector where the rank is from the given `allowedRanks` list
490class VectorOfRank<list<int> allowedRanks> : Type<
491  IsVectorOfRankPred<allowedRanks>,
492  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
493
494// Any fixed-length vector where the rank is from the given `allowedRanks` list
495class FixedVectorOfRank<list<int> allowedRanks> : Type<
496  IsFixedVectorOfRankPred<allowedRanks>,
497  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
498
499// Any scalable vector where the rank is from the given `allowedRanks` list
500class ScalableVectorOfRank<list<int> allowedRanks> : Type<
501  IsScalableVectorOfRankPred<allowedRanks>,
502  " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
503
504// Any vector where the rank is from the given `allowedRanks` list and the type
505// is from the given `allowedTypes` list
506class VectorOfRankAndType<list<int> allowedRanks,
507                          list<Type> allowedTypes> : AllOfType<
508  [VectorOfNonZeroRankOf<allowedTypes>, VectorOfRank<allowedRanks>],
509   VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
510  "::mlir::VectorType">;
511
512// Fixed-width vector where the rank is from the given `allowedRanks` list and
513// the type is from the given `allowedTypes` list
514class FixedVectorOfRankAndType<list<int> allowedRanks,
515                          list<Type> allowedTypes> : AllOfType<
516  [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
517  FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
518  "::mlir::VectorType">;
519
520// Whether the number of elements of a vector is from the given
521// `allowedLengths` list
522class IsVectorOfLengthPred<list<int> allowedLengths> :
523  And<[IsVectorOfNonZeroRankTypePred,
524       Or<!foreach(allowedlength, allowedLengths,
525                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
526                           == }]
527                         # allowedlength>)>]>;
528
529// Whether the number of elements of a fixed-length vector is from the given
530// `allowedLengths` list
531class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
532  And<[IsFixedVectorOfAnyRankTypePred,
533       Or<!foreach(allowedlength, allowedLengths,
534                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
535                           == }]
536                         # allowedlength>)>]>;
537
538// Whether the number of elements of a scalable vector is from the given
539// `allowedLengths` list
540class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
541  And<[IsVectorTypeWithAnyDimScalablePred,
542       Or<!foreach(allowedlength, allowedLengths,
543                   CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
544                           == }]
545                         # allowedlength>)>]>;
546
547// Normalizes an index so the indices in both directions have the same value.
548// For example, when indexing forwards index 2 is the third element. When
549// indexing in reverse the third element is -3. This helper would map both of
550// these to the "normalized" index of 3. This makes the bounds checking in
551// IsNthDimSizeIsOneOfPred simpler (see first CPred).
552class NormalizeIndex<int value> {
553  int ret = !if(!lt(value, 0),
554    !sub(0, value)  /* -value if negative */,
555    !add(value, 1)  /* value + 1 if positive*/);
556}
557
558// Whether the n-th dim of the shape is contained within `allowedSizes`.
559// Negative values for `n` index in reverse.
560//
561// Examples:
562// IsNthDimSizeIsOneOfPred<0, {2, 3, 4}>
563//  - Accepts any shape where the first dim is 2, 3, or 4.
564//    * This means shapes like: 2x8x9x5, 4, 3x1, 4x?, etc
565// IsNthDimSizeIsOneOfPred<-1, {16}>
566//  - Accepts any shape where the last dim is 16.
567//    * This means shapes like 2x16, 16, 1x2x3x4x16, etc
568// IsNthDimSizeIsOneOfPred<-2, {10, 5}>
569//  - Accepts any shape where the second to last dim is 10 or 5.
570//    * This means shapes like: 1x10x2, 2x1x4x5x6, 8x10x?, etc
571class IsNthDimSizeIsOneOfPred<int n, list<int> allowedSizes>
572  : And<[
573      CPred<"::llvm::cast<::mlir::ShapedType>($_self).getRank() >= " # NormalizeIndex<n>.ret>,
574      CPred<"::llvm::is_contained(ArrayRef<int64_t>({" # !interleave(allowedSizes, ", ") # "}), "
575        # "::llvm::cast<::mlir::ShapedType>($_self).getDimSize("
576        #   !if(!lt(n, 0),
577              "::llvm::cast<::mlir::ShapedType>($_self).getRank() + " # n,
578              "" # n)
579        # "))">]>;
580
581// Whether the shape of a vector matches the given `shape` list.
582class IsVectorOfShape<list<int> shape>
583  : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
584
585// Any vector where the number of elements is from the given
586// `allowedLengths` list
587class VectorOfLength<list<int> allowedLengths> : Type<
588  IsVectorOfLengthPred<allowedLengths>,
589  " of length " # !interleave(allowedLengths, "/"),
590  "::mlir::VectorType">;
591
592// Any fixed-length vector where the number of elements is from the given
593// `allowedLengths` list
594class FixedVectorOfLength<list<int> allowedLengths> : Type<
595  IsFixedVectorOfLengthPred<allowedLengths>,
596  " of length " # !interleave(allowedLengths, "/"),
597  "::mlir::VectorType">;
598
599// Any scalable vector where the number of elements is from the given
600// `allowedLengths` list
601class ScalableVectorOfLength<list<int> allowedLengths> : Type<
602  IsScalableVectorOfLengthPred<allowedLengths>,
603  " of length " # !interleave(allowedLengths, "/"),
604  "::mlir::VectorType">;
605
606// Any vector where the number of elements is from the given
607// `allowedLengths` list and the type is from the given `allowedTypes`
608// list
609class VectorOfLengthAndType<list<int> allowedLengths,
610                            list<Type> allowedTypes> : AllOfType<
611  [VectorOfNonZeroRankOf<allowedTypes>, VectorOfLength<allowedLengths>],
612   VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
613  "::mlir::VectorType">;
614
615// Any fixed-length vector where the number of elements is from the given
616// `allowedLengths` list and the type is from the given `allowedTypes` list
617class FixedVectorOfLengthAndType<list<int> allowedLengths,
618                                 list<Type> allowedTypes> : AllOfType<
619  [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
620  FixedVectorOfAnyRank<allowedTypes>.summary #
621  FixedVectorOfLength<allowedLengths>.summary,
622  "::mlir::VectorType">;
623
624// Any scalable vector where the number of elements is from the given
625// `allowedLengths` list and the type is from the given `allowedTypes` list
626class ScalableVectorOfLengthAndType<list<int> allowedLengths,
627                                    list<Type> allowedTypes> : AllOfType<
628  [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
629  ScalableVectorOfAnyRank<allowedTypes>.summary #
630  ScalableVectorOfLength<allowedLengths>.summary,
631  "::mlir::VectorType">;
632
633// Any scalable vector where the rank is from the given `allowedRanks` list and
634// the number of elements is from the given `allowedLengths` list and the type
635// is from the given `allowedTypes` list
636class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
637                                           list<int> allowedLengths,
638                                           list<Type> allowedTypes> : AllOfType<
639  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
640   ScalableVectorOfLength<allowedLengths>],
641  ScalableVectorOfRank<allowedRanks>.summary #
642  ScalableVectorOfAnyRank<allowedTypes>.summary #
643  ScalableVectorOfLength<allowedLengths>.summary,
644  "::mlir::VectorType">;
645
646// Any ShapedType where the size of the n-th dim is contained in `allowedSizes`.
647// Negative values for `n` index in reverse.
648class ShapedTypeWithNthDimOfSize<int n, list<int> allowedSizes> : Type<
649  IsNthDimSizeIsOneOfPred<n, allowedSizes>,
650  " with dim " # n # " having a size of {" # !interleave(allowedSizes, ", ") # "}",
651  "::mlir::ShapedType">;
652
653// Any scalable vector with a single trailing scalable dimensions, where the
654// size of the trailing dimension is in `allowedTrailingSizes` list, and the
655// type is in the `allowedTypes` list.
656class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
657                                           list<Type> allowedTypes> : AllOfType<
658  [VectorWithTrailingDimScalableOf<allowedTypes>,
659   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>],
660   VectorWithTrailingDimScalableOf<allowedTypes>.summary #
661   ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
662  "::mlir::VectorType">;
663
664// Unlike the following definitions, this one excludes 0-D vectors
665def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;
666
667def AnyFixedVectorOfNonZeroRank : FixedVectorOfNonZeroRankOf<[AnyType]>;
668
669def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
670
671def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
672
673def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
674
675// Shaped types.
676
677def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
678                                   "::mlir::ShapedType">;
679
680//===----------------------------------------------------------------------===//
681// Tensor types.
682
683// Unranked tensor type whose element type is from the given `allowedTypes`
684// list, and which additionally satisfies an optional list of predicates.
685class UnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
686                       string summary = "unranked tensor">
687  : ShapedContainerType<
688      allowedTypes, And<!listconcat([IsUnrankedTensorTypePred], preds)>,
689      summary, "::mlir::UnrankedTensorType">;
690
691// Ranked tensor type whose element type is from the given `allowedTypes` list,
692// and which additionally satisfies an optional list of predicates.
693class RankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
694                     string summary = "ranked tensor">
695  : ShapedContainerType<
696      allowedTypes, And<!listconcat([IsRankedTensorTypePred], preds)>,
697      summary, "::mlir::RankedTensorType">;
698
699// Any tensor type whose element type is from the given `allowedTypes`
700// list, and which additionally satisfies an optional list of predicates.
701//
702// TODO: use `Constraint` instead of `Pred`, so we can generate a better
703// default summary (a la `ConfinedAttr`).
704class TensorOf<
705    list<Type> allowedTypes,
706    list<Pred> preds = [],
707    string summary = "tensor">
708  : ShapedContainerType<allowedTypes,
709      And<!listconcat([IsTensorTypePred], preds)>,
710      summary, "::mlir::TensorType">;
711
712def AnyTensor  : TensorOf<[AnyType]>;
713
714def I1Tensor   : TensorOf<[I1]>;
715def I8Tensor   : TensorOf<[I8]>;
716def I16Tensor  : TensorOf<[I16]>;
717def I32Tensor  : TensorOf<[I32]>;
718def I64Tensor  : TensorOf<[I64]>;
719def IndexTensor: TensorOf<[Index]>;
720
721def BF16Tensor : TensorOf<[BF16]>;
722def F16Tensor  : TensorOf<[F16]>;
723def F32Tensor  : TensorOf<[F32]>;
724def F64Tensor  : TensorOf<[F64]>;
725
726class Non0RankedTensorOf<list<Type> allowedTypes>
727  : TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>],
728      "non-0-ranked.tensor">;
729
730def AnyRankedTensor : RankedTensorOf<[AnyType]>;
731def AnyNon0RankedTensor  : Non0RankedTensorOf<[AnyType]>;
732def AnyUnrankedTensor  : UnrankedTensorOf<[AnyType]>;
733
734def AnyNon0RankedOrUnrankedTensor
735  : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor],
736              "non-0-ranked or unranked tensor", "::mlir::TensorType">;
737
738// Ranked tensor type with one of the specified types and ranks.
739class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
740  : RankedTensorOf<allowedTypes,
741      [HasAnyRankOfPred<ranks>],
742      !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
743
744class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
745class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
746class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
747class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
748class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
749
750class StaticShapeTensorOf<list<Type> allowedTypes>
751  : RankedTensorOf<allowedTypes, [HasStaticShapePred],
752                   "statically shaped tensor">;
753
754def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
755
756//===----------------------------------------------------------------------===//
757// Memref type.
758
759// Any unranked memref whose element type is from the given `allowedTypes` list.
760class UnrankedMemRefOf<list<Type> allowedTypes> :
761    ShapedContainerType<allowedTypes,
762                        IsUnrankedMemRefTypePred, "unranked.memref",
763                        "::mlir::UnrankedMemRefType">;
764
765def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;
766
767// Any ranked memref whose element type is from the given `allowedTypes` list.
768class MemRefOf<list<Type> allowedTypes> :
769    ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
770                        "::mlir::MemRefType">;
771
772class Non0RankedMemRefOf<list<Type> allowedTypes> :
773    ConfinedType<MemRefOf<allowedTypes>, [HasRankGreaterOrEqualPred<1>],
774         "non-0-ranked." # MemRefOf<allowedTypes>.summary,
775         "::mlir::MemRefType">;
776
777def AnyMemRef : MemRefOf<[AnyType]>;
778def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>;
779
780// Any memref (ranked or unranked) whose element type is from the given
781// `allowedTypes` list, and which additionally satisfies an optional list of
782// predicates.
783class RankedOrUnrankedMemRefOf<
784    list<Type> allowedTypes,
785    list<Pred> preds = [],
786    string summary = "ranked or unranked memref">
787  : ShapedContainerType<allowedTypes,
788      And<!listconcat([IsBaseMemRefTypePred], preds)>,
789      summary, "::mlir::BaseMemRefType">;
790
791def AnyRankedOrUnrankedMemRef  : RankedOrUnrankedMemRefOf<[AnyType]>;
792def AnyNon0RankedOrUnrankedMemRef:
793    AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>;
794
795// Memref declarations handle any memref, independent of rank, size, (static or
796// dynamic), layout, or memory space.
797def I1MemRef  : MemRefOf<[I1]>;
798def I8MemRef  : MemRefOf<[I8]>;
799def I16MemRef : MemRefOf<[I16]>;
800def I32MemRef : MemRefOf<[I32]>;
801def I64MemRef : MemRefOf<[I64]>;
802
803def BF16MemRef : MemRefOf<[BF16]>;
804def F16MemRef  : MemRefOf<[F16]>;
805def F32MemRef  : MemRefOf<[F32]>;
806def F64MemRef  : MemRefOf<[F64]>;
807
808// TODO: Have an easy way to add another constraint to a type.
809class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
810    ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
811         !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
812         MemRefOf<allowedTypes>.summary,
813         "::mlir::MemRefType">;
814
815class StaticShapeMemRefOf<list<Type> allowedTypes> :
816    ConfinedType<MemRefOf<allowedTypes>, [HasStaticShapePred],
817         "statically shaped " # MemRefOf<allowedTypes>.summary,
818         "::mlir::MemRefType">;
819
820def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
821
822// For a MemRefType, verify that it has strides.
823def HasStridesPred : CPred<[{ ::llvm::cast<::mlir::MemRefType>($_self).isStrided() }]>;
824
825class StridedMemRefOf<list<Type> allowedTypes> :
826    ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
827         "strided " # MemRefOf<allowedTypes>.summary>;
828
829def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;
830
831class AnyStridedMemRefOfRank<int rank> :
832  AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>],
833       AnyStridedMemRef.summary # " of rank " # rank>;
834
835class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
836    ConfinedType<MemRefOf<allowedTypes>, [HasAnyRankOfPred<ranks>],
837         !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
838         MemRefOf<allowedTypes>.summary>;
839
840// This represents a generic tuple without any constraints on element type.
841def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">;
842
843// A container type that has other types embedded in it, but (unlike
844// ContainerType) can hold elements with a mix of types. Requires a call that
845// produces a list of all elements' types.
846class MixedContainerType<Type etype, Pred containerPred, code elementTypesCall,
847                         string descr> :
848    Type<
849        And<[
850            containerPred,
851            Concat<
852                "::llvm::all_of(" # elementTypesCall # ", [](::mlir::Type t) { "
853                "return t && (",
854                SubstLeaves<"$_self", "t", etype.predicate>,
855                "); })"
856            >
857        ]>,
858        descr # " with any combination of " # etype.summary # " values"> {
859  // The type of elements in the container.
860  Type elementType = etype;
861
862  // Call to retrieve.
863  code getElementTypesCall = elementTypesCall;
864}
865
866// A Tuple that holds a mix of elements of the allowed types.
867class TupleOf<list<Type> allowedTypes>
868    : MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
869                         "::llvm::cast<::mlir::TupleType>($_self).getTypes()",
870                         "tuple">;
871
872// A Tuple with arbitrary nesting, where all elements are a mix of the allowed
873// types.
874class NestedTupleOf<list<Type> allowedTypes> :
875    MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
876                       "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
877                       "nested tuple">;
878
879//===----------------------------------------------------------------------===//
880// Common type constraints
881//===----------------------------------------------------------------------===//
882// Type constraint for types that are "like" some type or set of types T, that is
883// they're either a T, a vector of Ts, or a tensor of Ts.
884class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
885  allowedType.predicate,
886  ValueSemanticsContainerOf<[allowedType]>.predicate]>,
887  name>;
888
889// Type constraint for types that are "like" some type or set of types T, that is
890// they're either a T or a mapable container of Ts.
891class TypeOrValueSemanticsContainer<Type allowedType, string name>
892    : TypeConstraint<Or<[
893  allowedType.predicate,
894  ValueSemanticsContainerOf<[allowedType]>.predicate]>,
895  name>;
896
897// Temporary constraint to allow gradual transition to supporting 0-D vectors.
898// TODO: Remove this when all ops support 0-D vectors.
899class TypeOrContainerOfAnyRank<Type allowedType, string name> : TypeConstraint<Or<[
900  allowedType.predicate, VectorOfAnyRankOf<[allowedType]>.predicate,
901  TensorOf<[allowedType]>.predicate]>,
902  name>;
903
904
905// Type constraint for bool-like types: bools, vectors of bools, tensors of
906// bools.
907def BoolLike : TypeOrContainer<I1, "bool-like">;
908
909def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
910
911// Type constraint for signless-integer-like types: signless integers,
912// vectors of signless integers or tensors of signless integers.
913def SignlessIntegerLike : TypeOrValueSemanticsContainer<
914    AnySignlessInteger, "signless-integer">;
915
916// Type constraint for signless-integer-like types: signless integers, indices,
917// vectors of signless integers or indices, tensors of signless integers.
918def SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
919    AnySignlessIntegerOrIndex, "signless-integer-like">;
920
921def SignlessIntegerOrIndexLikeOfAnyRank : TypeOrContainerOfAnyRank<
922    AnySignlessIntegerOrIndex,
923    "signless-integer-like">;
924
925// Type constraint for float-like types: floats, vectors or tensors thereof.
926def FloatLike : TypeOrContainer<AnyFloat, "floating-point-like">;
927
928// Type constraint for signless-integer-or-index-like or float-like types.
929def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
930    SignlessIntegerLike.predicate, FloatLike.predicate]>,
931    "signless-integer-like or floating-point-like">;
932
933// Type constraint for signless-integer-or-index-like or float-like types.
934def SignlessIntegerOrIndexOrFloatLike : TypeConstraint<Or<[
935    SignlessIntegerOrIndexLike.predicate, FloatLike.predicate]>,
936    "signless-integer-or-index-like or floating-point-like">;
937
938#endif // COMMON_TYPE_CONSTRAINTS_TD
939