xref: /llvm-project/mlir/include/mlir/IR/OpBase.td (revision d0b7633d7ad566579bfb794f95cce9aef294c92b)
1//===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the base operation definition file.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef OP_BASE
14#define OP_BASE
15
16include "mlir/IR/Constraints.td"
17include "mlir/IR/DialectBase.td"
18include "mlir/IR/Interfaces.td"
19include "mlir/IR/Properties.td"
20include "mlir/IR/Traits.td"
21include "mlir/IR/Utils.td"
22include "mlir/IR/AttrTypeBase.td"
23
24//===----------------------------------------------------------------------===//
25// OpTrait definitions
26//===----------------------------------------------------------------------===//
27
28// A trait that describes the structure of operation will be marked with
29// `StructuralOpTrait` and they will be verified first.
30class StructuralOpTrait;
31
32// These classes are used to define operation specific traits.
33
34// Specify op specific declarations and definitions in `extraOpDeclaration`
35// and `extraOpDefinition` template arguments.
36class NativeOpTrait<string name, list<Trait> traits = [],
37                    code extraOpDeclaration = [{}],
38                    code extraOpDefinition = [{}]>
39    : NativeTrait<name, "Op", extraOpDeclaration, extraOpDefinition> {
40  // Specify the list of traits that need to be verified before the verification
41  // of this NativeOpTrait.
42  list<Trait> dependentTraits = traits;
43}
44class ParamNativeOpTrait<string prop, string params,
45                         list<Trait> traits = []>
46    : ParamNativeTrait<prop, params, "Op"> {
47  // Specify the list of traits that need to be verified before the verification
48  // of this ParamNativeOpTrait.
49  list<Trait> dependentTraits = traits;
50}
51class GenInternalOpTrait<string prop, list<Trait> traits = []>
52    : GenInternalTrait<prop, "Op"> {
53  // Specify the list of traits that need to be verified before the verification
54  // of this GenInternalOpTrait.
55  list<Trait> dependentTraits = traits;
56}
57class PredOpTrait<string descr, Pred pred, list<Trait> traits = []>
58    : PredTrait<descr, pred> {
59  // Specify the list of traits that need to be verified before the verification
60  // of this PredOpTrait.
61  list<Trait> dependentTraits = traits;
62}
63
64// Op defines an affine scope.
65def AffineScope : NativeOpTrait<"AffineScope">;
66// Op defines an automatic allocation scope.
67def AutomaticAllocationScope :
68  NativeOpTrait<"AutomaticAllocationScope">;
69// Op supports operand broadcast behavior.
70def ResultsBroadcastableShape :
71  NativeOpTrait<"ResultsBroadcastableShape">;
72// X op Y == Y op X
73def Commutative  : NativeOpTrait<"IsCommutative">;
74// op op X == op X (unary) / X op X == X (binary)
75// FIXME: Idempotent should depend on SameOperandsAndResultType
76def Idempotent : NativeOpTrait<"IsIdempotent">;
77// op op X == X
78// FIXME: Involution should depend on SameOperandsAndResultType
79def Involution : NativeOpTrait<"IsInvolution">;
80// Op behaves like a constant.
81def ConstantLike : NativeOpTrait<"ConstantLike">;
82// Op is isolated from above.
83def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">;
84// Op results are float or vectors/tensors thereof.
85def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">;
86// Op has the same operand type.
87def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
88// Op has same shape for all operands.
89def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
90// Op has same operand and result shape.
91def SameOperandsAndResultShape :
92  NativeOpTrait<"SameOperandsAndResultShape">;
93// Op has the same element type (or type itself, if scalar) for all operands.
94def SameOperandsElementType :
95  NativeOpTrait<"SameOperandsElementType">;
96// Op has the same operand and result element type (or type itself, if scalar).
97def SameOperandsAndResultElementType :
98  NativeOpTrait<"SameOperandsAndResultElementType">;
99// Op is a terminator.
100def Terminator : NativeOpTrait<"IsTerminator">;
101// Op can be safely normalized in the presence of MemRefs with
102// non-identity maps.
103def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
104// Op is elementwise on tensor/vector operands and results.
105def Elementwise : NativeOpTrait<"Elementwise">;
106// Elementwise op can be applied to scalars instead tensor/vector operands.
107def Scalarizable : NativeOpTrait<"Scalarizable", [Elementwise]>;
108// Elementwise op can be applied to all-vector operands.
109def Vectorizable : NativeOpTrait<"Vectorizable", [Elementwise]>;
110// Elementwise op can be applied to all-tensor operands.
111def Tensorizable : NativeOpTrait<"Tensorizable", [Elementwise]>;
112
113// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and
114// `Tensorizable` for convenience.
115def ElementwiseMappable : TraitList<[
116    Elementwise,
117    Scalarizable,
118    Vectorizable,
119    Tensorizable,
120]>;
121
122// Op's regions have a single block.
123def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait;
124
125class SingleBlockImplicitTerminatorImpl<string op>
126    : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op, [SingleBlock]>,
127      StructuralOpTrait;
128
129// Op's regions have a single block with the specified terminator.
130class SingleBlockImplicitTerminator<string op>
131    : TraitList<[SingleBlock, SingleBlockImplicitTerminatorImpl<op>]>;
132
133// Op's regions don't have terminator.
134def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
135
136// Op's parent operation is the provided one.
137class HasParent<string op>
138    : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
139
140class ParentOneOf<list<string> ops>
141    : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
142      StructuralOpTrait;
143
144// Op result type is derived from the first attribute. If the attribute is an
145// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
146// attribute content is used.
147def FirstAttrDerivedResultType :
148  GenInternalOpTrait<"FirstAttrDerivedResultType">;
149
150// TODO: Turn the following into normal traits and generate verification for
151// them.
152
153// All variadic operands of the op have the same number of values.
154// A variadic operand contains an array of values whose array size is only
155// known at runtime. This trait requires all variadic operands of an op
156// to have the same array size.
157def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
158// All variadic results of the op have the same number of values.
159// A variadic result contains an array of values whose array size is only
160// known at runtime. This trait requires all variadic results of an op
161// to have the same array size.
162def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
163
164// Uses an attribute named `operandSegmentSizes` to specify how many actual
165// operand each ODS-declared operand (variadic or not) corresponds to.
166// This trait is used for ops that have multiple variadic operands but do
167// not know statically their size relationship. The attribute must be a 1D
168// vector that has the same number of elements as the number of ODS declared
169// operands. That means even if some operands are non-variadic, the attribute
170// still need to have an element for its size, which is always 1.
171def AttrSizedOperandSegments :
172  NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait;
173// Similar to AttrSizedOperandSegments, but used for results. The attribute
174// should be named as `resultSegmentSizes`.
175def AttrSizedResultSegments  :
176  NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait;
177
178// Op attached regions have no arguments
179def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait;
180
181//===----------------------------------------------------------------------===//
182// Successor definitions
183//===----------------------------------------------------------------------===//
184
185class Successor<Pred condition, string descr = ""> :
186    SuccessorConstraint<condition, descr>;
187
188// Any successor.
189def AnySuccessor : Successor<?, "any successor">;
190
191// A variadic successor constraint. It expands to zero or more of the base
192// successor.
193class VariadicSuccessor<Successor successor>
194  : Successor<successor.predicate, successor.summary>;
195
196//===----------------------------------------------------------------------===//
197// Region definitions
198//===----------------------------------------------------------------------===//
199
200class Region<Pred condition, string descr = ""> :
201    RegionConstraint<condition, descr>;
202
203// Any region.
204def AnyRegion : Region<CPred<"true">, "any region">;
205
206// A region with the given number of blocks.
207class SizedRegion<int numBlocks> : Region<
208  CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">,
209  "region with " # numBlocks # " blocks"> {
210  int blocks = numBlocks;
211}
212
213// A region with at least the given number of blocks.
214class MinSizedRegion<int numBlocks> : Region<
215  CPred<"::llvm::hasNItemsOrMore($_self, " # numBlocks # ")">,
216  "region with at least " # numBlocks # " blocks">;
217
218// A region with at most the given number of blocks.
219class MaxSizedRegion<int numBlocks> : Region<
220  CPred<"::llvm::hasNItemsOrLess($_self, " # numBlocks # ")">,
221  "region with at most " # numBlocks # " blocks">;
222
223// A variadic region constraint. It expands to zero or more of the base region.
224class VariadicRegion<Region region>
225  : Region<region.predicate, region.summary>;
226
227//===----------------------------------------------------------------------===//
228// Markers
229//===----------------------------------------------------------------------===//
230
231// Marker used to identify the region list.
232def region;
233
234// Marker used to identify the successor list.
235def successor;
236
237//===----------------------------------------------------------------------===//
238// Op definitions
239//===----------------------------------------------------------------------===//
240
241// Class for defining a custom builder.
242//
243// TableGen generates several generic builders for each op by default (see
244// comment in the `Op` class). If the default generated ones cannot cover
245// some use case, custom builders can be defined using instances of this class.
246//
247// The signature of the builder is always
248//
249// ```c++
250// static void build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
251//                   <other-parameters>...) {
252//   <body>...
253// }
254// ```
255//
256// To define a custom builder, the parameter list (*excluding* the
257// `OpBuilder &builder, OperationState &state` part) and body should be passed
258// in as separate template arguments to this class. The parameter list is a
259// TableGen DAG with `ins` operation with named arguments, which has either:
260//   - string initializers ("Type":$name) to represent a typed parameter, or
261//   - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a
262//     typed parameter that may have a default value.
263// The type string is used verbatim to produce code and, therefore, must be a
264// valid C++ type. It is used inside the C++ namespace of the parent Op's
265// dialect; explicit namespace qualification like `::mlir` may be necessary if
266// Ops are not placed inside the `mlir` namespace. The default value string is
267// used verbatim to produce code and must be a valid C++ initializer the given
268// type. For example, the following signature specification
269//
270// ```
271// OpBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)>
272// ```
273//
274// has an integer parameter and a float parameter with a default value.
275//
276// If an empty string is passed in for `body`, then *only* the builder
277// declaration will be generated; this provides a way to define complicated
278// builders entirely in C++.
279class OpBuilder<dag p, code b = ""> {
280  dag dagParams = p;
281  code body = b;
282}
283
284// OpBuilder like the above, but the emitted 'build' method is marked as
285// deprecated in C++. Use of it will emit a warning by the C++ compiler
286// with the given reason.
287class DeprecatedOpBuilder<string reason, dag p, code b = "">
288  : OpBuilder<p, b>, CppDeprecated<reason>;
289
290// A base decorator class that may optionally be added to OpVariables.
291class OpVariableDecorator;
292
293// Class for providing additional information on the variables, i.e. arguments
294// and results, of an operation.
295class OpVariable<Constraint varConstraint, string desc = "",
296                 list<OpVariableDecorator> varDecorators = []> {
297  // The constraint, either attribute or type, of the argument.
298  Constraint constraint = varConstraint;
299
300  // One-line human-readable description of the argument.
301  string summary = desc;
302
303  // The list of decorators for this variable, e.g. side effects.
304  list<OpVariableDecorator> decorators = varDecorators;
305}
306class Arg<Constraint constraint, string desc = "",
307          list<OpVariableDecorator> decorators = []>
308  : OpVariable<constraint, desc, decorators>;
309class Res<Constraint constraint, string desc = "",
310          list<OpVariableDecorator> decorators = []>
311  : OpVariable<constraint, desc, decorators>;
312
313// Marker to group ops together for documentation purposes.
314class OpDocGroup {
315  // Single line summary of the group of ops.
316  string summary;
317
318  // Longer description of documentation group.
319  string description;
320}
321
322// Base class for all ops.
323class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
324  // The dialect of the op.
325  Dialect opDialect = dialect;
326
327  // The mnemonic of the op.
328  string opName = mnemonic;
329
330  // The C++ namespace to use for this op.
331  string cppNamespace = dialect.cppNamespace;
332
333  // One-line human-readable description of what the op does.
334  string summary = "";
335
336  // Additional, longer human-readable description of what the op does.
337  string description = "";
338
339  // Optional. The group of ops this op is part of.
340  OpDocGroup opDocGroup = ?;
341
342  // Dag containing the arguments of the op. Default to 0 arguments.
343  dag arguments = (ins);
344
345  // The list of results of the op. Default to 0 results.
346  dag results = (outs);
347
348  // The list of regions of the op. Default to 0 regions.
349  dag regions = (region);
350
351  // The list of successors of the op. Default to 0 successors.
352  dag successors = (successor);
353
354  // Attribute getters can be added to the op by adding an Attr member
355  // with the name and type of the attribute. E.g., adding int attribute
356  // with name "value" and type "i32":
357  //   I32Attr value;
358
359  // Define the hooks used for building, parsing, printing, verification.
360
361  // Custom builder.
362  // In addition to the custom builder provided here, and unless
363  // skipDefaultBuilders is set, two default builders are generated, with the
364  // following signatures:
365  //
366  // ```c++
367  // static void build(OpBuilder &, OperationState &odsState,
368  //                   Type <result0-name>, Type <result1-name>, ...,
369  //                   Value <arg0-name>, Value <arg1-name>, ...,
370  //                   Attribute <attr0-name>, Attribute <attr1-name>, ...);
371  // ```
372  // * where the attributes follow the same declaration order as in the op.
373  //
374  // ```c++
375  // static void build(OpBuilder &, OperationState &odsState,
376  //                   TypeRange resultTypes,
377  //                   ValueRange operands,
378  //                   ArrayRef<NamedAttribute> attributes);
379  // ```
380  list<OpBuilder> builders = ?;
381
382  // Avoid generating default build functions.  Custom builders must be
383  // provided.
384  bit skipDefaultBuilders = 0;
385
386  // Custom assembly format.
387  /// This field corresponds to a declarative description of the assembly format
388  /// for this operation. If populated, the `hasCustomAssemblyFormat` field is
389  /// ignored.
390  string assemblyFormat = ?;
391  /// This field indicates that the operation has a custom assembly format
392  /// implemented in C++. When set to `1` a `parse` and `print` method are generated
393  /// on the operation class. The operation should implement these methods to
394  /// support the custom format of the operation. The methods have the form:
395  ///   * ParseResult parse(OpAsmParser &parser, OperationState &result)
396  ///   * void print(OpAsmPrinter &p)
397  bit hasCustomAssemblyFormat = 0;
398
399  // A bit indicating if the operation has additional invariants that need to
400  // verified (aside from those verified by other ODS constructs). If set to `1`,
401  // an additional `LogicalResult verify()` declaration will be generated on the
402  // operation class. The operation should implement this method and verify the
403  // additional necessary invariants. This verifier shouldn't access any nested
404  // operations because those operations may be ill-formed. Use the
405  // `hasRegionVerifier` below instead.
406  bit hasVerifier = 0;
407
408  // A bit indicating if the operation has additional invariants that need to
409  // verified and which associate with regions (aside from those verified by the
410  // traits). If set to `1`, an additional `LogicalResult verifyRegions()`
411  // declaration will be generated on the operation class. The operation should
412  // implement this method and verify the additional necessary invariants
413  // associated with regions. Note that this method is invoked after all the
414  // region ops are verified.
415  bit hasRegionVerifier = 0;
416
417  // Whether this op has associated canonicalization patterns.
418  bit hasCanonicalizer = 0;
419
420  // Whether this op has a static "canonicalize" method to perform "match and
421  // rewrite patterns".
422  bit hasCanonicalizeMethod = 0;
423
424  // Whether this op has a folder.
425  bit hasFolder = 0;
426
427  // Whether to let ops implement their custom `readProperties` and
428  // `writeProperties` methods to emit bytecode.
429  bit useCustomPropertiesEncoding = 0;
430
431  // Op traits.
432  // Note: The list of traits will be uniqued by ODS.
433  list<Trait> traits = props;
434
435  // Additional code that will be added to the public part of the generated
436  // C++ code of the op declaration.
437  code extraClassDeclaration = ?;
438
439  // Additional code that will be added to the generated source file. The
440  // generated code is placed inside the op's C++ namespace. `$cppClass` is
441  // replaced by the op's C++ class name.
442  code extraClassDefinition = ?;
443}
444
445// The arguments of an op.
446class Arguments<dag args> {
447  dag arguments = args;
448}
449
450// The results of an op.
451class Results<dag rets> {
452  dag results = rets;
453}
454
455//===----------------------------------------------------------------------===//
456// Common promised interface constraints
457//===----------------------------------------------------------------------===//
458
459// This constrait represents a promise or an implementation of an attr interface.
460class PromisedAttrInterface<AttrInterface interface> : AttrConstraint<
461  CPred<"$_self.hasPromiseOrImplementsInterface<" #
462    !if(!empty(interface.cppNamespace),
463        "",
464        interface.cppNamespace # "::") # interface.cppInterfaceName #">()">,
465  "promising or implementing the `" # interface.cppInterfaceName # "` attr interface">;
466
467// This predicate checks if the type promises or implementats a type interface.
468class HasPromiseOrImplementsTypeInterface<TypeInterface interface> :
469  CPred<"$_self.hasPromiseOrImplementsInterface<" #
470    !if(!empty(interface.cppNamespace),
471        "",
472        interface.cppNamespace # "::") # interface.cppInterfaceName #">()">;
473
474// This constrait represents a promise or an implementation of a type interface.
475class PromisedTypeInterface<TypeInterface interface> : TypeConstraint<
476  HasPromiseOrImplementsTypeInterface<interface>,
477  "promising or implementing the `" # interface.cppInterfaceName # "` type interface">;
478
479//===----------------------------------------------------------------------===//
480// Common op type constraints
481//===----------------------------------------------------------------------===//
482
483// These traits are for verifying properties of an op that require knowledge of
484// multiple arguments or results. For verifying properties of a single argument
485// or result, prefer operand type constraints.
486
487// These traits often require including "mlir/IR/TypeUtilities.h".
488
489// TODO: Improve the autogenerated error messages.
490
491class Rank<string name> :
492    StrFunc<"::llvm::cast<::mlir::ShapedType>($" # name # ".getType()).getRank()">;
493
494class Shape<string name> :
495    StrFunc<"::llvm::cast<::mlir::ShapedType>($" # name # ".getType()).getShape()">;
496
497class ElementCount<string name> :
498  StrFunc<"llvm::cast<::mlir::ShapedType>($" # name # ".getType())"
499                                 ".getNumElements()">;
500
501class ElementType<string name> : StrFunc<"getElementTypeOrSelf($" # name # ")">;
502
503class AnyPred<list<string> values> :
504  CPred<!if(!lt(!size(values), 1),
505            "false",
506            !foldl("(" # !head(values) # ")", !tail(values), acc, v,
507                   acc # " || (" # v # ")"))>;
508
509class AllMatchPred<list<string> values> :
510  CPred<!if(!lt(!size(values), 2),
511            "true",
512            !foldl("(" # !head(values) # ")", !tail(values), acc, v,
513                   acc # " == (" # v # ") && (" # v # ")")
514              # " == (" # !head(values) # ")")>;
515
516class AllMatch<list<string> values, string summary> :
517    PredOpTrait<summary, AllMatchPred<values>>;
518
519// TODO: Only works for non-variadic.
520class AllMatchSameOperatorPred<list<string> names, string operator> :
521    AllMatchPred<!foreach(n, names, !subst("$_self", "$" # n, operator))>;
522
523class AllMatchSameOperatorTrait<list<string> names, string operator,
524                                string summary> :
525    PredOpTrait<
526        "all of {" # !interleave(names, ", ") # "} have same " # summary,
527        AllMatchSameOperatorPred<names, operator>> {
528  list<string> values = names;
529}
530
531class AnyMatchOperatorPred<list<string> names, string operator> :
532    AnyPred<!foreach(n, names, !subst("$_self", "$" # n, operator))>;
533
534class AnyMatchOperatorTrait<list<string> names, string operator,
535                            string summary> :
536    PredOpTrait<
537        "any of {" # !interleave(names, ", ") # "} has " # summary,
538        AnyMatchOperatorPred<names, operator>> {
539  list<string> values = names;
540}
541
542class AllElementCountsMatch<list<string> names> :
543    AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
544                              "element count">;
545
546class AllElementTypesMatch<list<string> names> :
547    AllMatchSameOperatorTrait<names, ElementType<"_self">.result,
548                              "element type">;
549
550class AllRanksMatch<list<string> names> :
551    AllMatchSameOperatorTrait<names, Rank<"_self">.result, "rank">;
552
553class AllShapesMatch<list<string> names> :
554    AllMatchSameOperatorTrait<names, Shape<"_self">.result, "shape">;
555
556class AllTypesMatch<list<string> names> :
557    AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
558
559// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
560// An optional comparator function may be provided that changes the above form
561// into: `comparator(transform(lhs.getType()), rhs.getType())`.
562class TypesMatchWith<string summary, string lhsArg, string rhsArg,
563                     string transform, string comparator = "std::equal_to<>()">
564  : PredOpTrait<summary, CPred<
565      comparator # "(" #
566      !subst("$_self", "$" # lhsArg # ".getType()", transform) #
567      ", $" # rhsArg # ".getType())">> {
568  string lhs = lhsArg;
569  string rhs = rhsArg;
570  string transformer = transform;
571}
572
573// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
574// and not present returns success.
575class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
576                     string transform, string comparator = "std::equal_to<>()">
577  : TypesMatchWith<summary, lhsArg, rhsArg, transform,
578     "!get" # snakeCaseToCamelCase<lhsArg>.ret # "()"
579     # " || !get" # snakeCaseToCamelCase<rhsArg>.ret # "() || " # comparator>;
580
581// Special variant of `TypesMatchWith` that provides a comparator suitable for
582// ranged arguments.
583class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
584                           string transform>
585  : TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
586
587// Type Constraint operand `idx`'s Element type is `type`.
588class TCopVTEtIs<int idx, Type type> : And<[
589   CPred<"$_op.getNumOperands() > " # idx>,
590   SubstLeaves<"$_self", "$_op.getOperand(" # idx # ").getType()",
591     IsShapedTypePred>,
592   SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # idx # "))",
593     type.predicate>]>;
594
595// Predicate to verify that a named argument or result's element type matches a
596// given type.
597class TypeIsPred<string name, Type type> :
598   SubstLeaves<"$_self", "$" # name # ".getType()", type.predicate>;
599class TypeIs<string name, Type type> : PredOpTrait<
600  "'" # name # "' is " # type.summary, TypeIsPred<name, type>>;
601
602// Predicate to verify that a named argument or result's element type matches a
603// given type.
604class ElementTypeIsPred<string name, Type type> : And<[
605   SubstLeaves<"$_self", "$" # name # ".getType()", IsShapedTypePred>,
606   SubstLeaves<"$_self", "getElementTypeOrSelf($" # name # ")",
607     type.predicate>]>;
608class ElementTypeIs<string name, Type type> : PredOpTrait<
609  "'" # name # "' is " # type.summary, ElementTypeIsPred<name, type>>;
610
611// Predicate to verify that the i'th operand and the j'th operand have the same
612// elemental type.
613// Type Constraint operand `i`'s Element type is Same As operand `j`'s Element
614// type.
615class TCopVTEtIsSameAs<int i, int j> : And<[
616    CPred<"$_op.getNumOperands() > " # !if(!gt(i,j),i,j)>,
617    SubstLeaves<"$_self", "$_op.getOperand(" # i # ").getType()",
618      IsShapedTypePred>,
619    SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()",
620      IsShapedTypePred>,
621    CPred<"::mlir::getElementTypeOrSelf($_op.getOperand(" # i # ")) == "
622          "::mlir::getElementTypeOrSelf($_op.getOperand(" # j # "))">]>;
623
624// Predicate to verify that the i'th result and the j'th operand exist and has
625// shaped types.
626class TCOpResIsShapedTypePred<int i, int j> : And<[
627    CPred<"$_op.getNumResults() > " # i>,
628    CPred<"$_op.getNumOperands() > " # j>,
629    SubstLeaves<"$_self", "$_op.getResult(" # i # ").getType()",
630      IsShapedTypePred>,
631    SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()",
632      IsShapedTypePred>]>;
633
634// Predicate to verify that the i'th result and the j'th operand have the same
635// type.
636class TCresIsSameAsOpBase<int i, int j> :
637    CPred<"$_op.getResult(" # i # ").getType() == "
638          "$_op.getOperand(" # j # ").getType()">;
639
640// Basic Predicate to verify that the i'th result and the j'th operand have the
641// same elemental type.
642class TCresVTEtIsSameAsOpBase<int i, int j> :
643    CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")) == "
644          "getElementTypeOrSelf($_op.getOperand(" # j # "))">;
645
646// Predicate to verify that the i'th result and the j'th operand have the same
647// elemental type.
648// Type Constraint result`i`'s Element type is Same As Operand `j`'s Element
649// type.
650class TCresVTEtIsSameAsOp<int i, int j> : And<[
651    TCOpResIsShapedTypePred<i, j>,
652    TCresVTEtIsSameAsOpBase<i, j>]>;
653
654// Predicate to verify that the opId'th operand can be broadcasted to the type
655// of the resId'th result.
656class TCOpIsBroadcastableToRes<int opId, int resId> : And<[
657    TCOpResIsShapedTypePred<opId, resId>,
658    CPred<"::mlir::OpTrait::util::getBroadcastedType("
659                  "$_op.getOperand(" # opId # ").getType(), "
660                  "$_op.getResult(" # resId # ").getType())">]>;
661
662// Predicate to verify that all the operands at the given `indices`
663// have the same element type.
664// Type Constraint operands' Element type are all Same At the given `indices`.
665// We query the operands' types into a list and check they are all the same.
666// Precondition:
667// 1) all operands involved are of shaped type and
668// 2) the indices are not out of range.
669class TCopVTEtAreSameAt<list<int> indices> : CPred<
670  "::llvm::all_equal(::llvm::map_range("
671      "::mlir::ArrayRef<unsigned>({" # !interleave(indices, ", ") # "}), "
672      "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); "
673      "}))">;
674
675#endif // OP_BASE
676