xref: /llvm-project/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
1//===- TensorOps.td - Tensor op definitions ----------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef TENSOR_OPS
10#define TENSOR_OPS
11
12include "mlir/Dialect/Tensor/IR/TensorBase.td"
13include "mlir/Interfaces/CastInterfaces.td"
14include "mlir/Interfaces/ControlFlowInterfaces.td"
15include "mlir/Interfaces/DestinationStyleOpInterface.td"
16include "mlir/Interfaces/InferTypeOpInterface.td"
17include "mlir/Interfaces/ParallelCombiningOpInterface.td"
18include "mlir/Interfaces/ShapedOpInterfaces.td"
19include "mlir/Interfaces/SideEffectInterfaces.td"
20include "mlir/Interfaces/TilingInterface.td"
21include "mlir/Interfaces/ViewLikeInterface.td"
22include "mlir/IR/OpAsmInterface.td"
23
24class Tensor_Op<string mnemonic, list<Trait> traits = []>
25    : Op<Tensor_Dialect, mnemonic, traits>;
26
27// Base class for ops with static/dynamic offset, sizes and strides
28// attributes/arguments.
29class Tensor_OpWithOffsetSizesAndStrides<string mnemonic,
30                                         list<Trait> traits = []>
31    : Tensor_Op<mnemonic, traits> {
32  code extraBaseClassDeclaration = [{
33    /// Return the type of the base tensor operand.
34    ::mlir::RankedTensorType getSourceType() {
35      return ::llvm::cast<RankedTensorType>(getSource().getType());
36    }
37
38    /// Return the type of the result tensor.
39    ::mlir::RankedTensorType getResultType() {
40      return ::llvm::cast<RankedTensorType>(getResult().getType());
41    }
42
43    /// Return the dynamic sizes for this subview operation if specified.
44    ::mlir::Operation::operand_range getDynamicSizes() { return getSizes(); }
45
46    /// Return the list of Range (i.e. offset, size, stride). Each
47    /// Range entry contains either the dynamic value or a ConstantIndexOp
48    /// constructed with `b` at location `loc`.
49    ::mlir::SmallVector<::mlir::Range, 8> getOrCreateRanges(
50        ::mlir::OpBuilder &b, ::mlir::Location loc) {
51      return ::mlir::getOrCreateRanges(*this, b, loc);
52    }
53  }];
54}
55
56//===----------------------------------------------------------------------===//
57// BitcastOp
58//===----------------------------------------------------------------------===//
59
60def Tensor_BitcastOp : Tensor_Op<"bitcast", [
61    DeclareOpInterfaceMethods<CastOpInterface>,
62    Pure
63  ]> {
64  let summary = "tensor bitcast operation";
65  let description = [{
66    Bitcast a tensor from one type to another type of equivalent element width.
67    If both are ranked, then the rank should be the same and static dimensions
68    should match.
69
70    Example:
71
72    ```mlir
73    // Bitcast from unsigned to signed or signless integer.
74    %2 = tensor.bitcast %1 : tensor<4xui32> to tensor<4xi32>
75    ```
76  }];
77
78  let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
79                                 AnySignedInteger, AnyFloat]>:$source);
80  let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
81                                AnySignedInteger, AnyFloat]>:$dest);
82  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
83
84  let hasCanonicalizer = 1;
85}
86
87//===----------------------------------------------------------------------===//
88// CastOp
89//===----------------------------------------------------------------------===//
90
91def Tensor_CastOp : Tensor_Op<"cast", [
92    DeclareOpInterfaceMethods<CastOpInterface>,
93    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
94    Pure
95  ]> {
96  let summary = "tensor cast operation";
97  let description = [{
98    Convert a tensor from one type to an equivalent type without changing any
99    data elements. The source and destination types must both be tensor types
100    with the same element type. If both are ranked, then the rank should be the
101    same and static dimensions should match. The operation is invalid if
102    converting to a mismatching constant dimension.
103
104    Example:
105
106    ```mlir
107    // Convert from unknown rank to rank 2 with unknown dimension sizes.
108    %2 = tensor.cast %1 : tensor<*xf32> to tensor<?x?xf32>
109
110    // Convert to a type with more known dimensions.
111    %3 = tensor.cast %2 : tensor<?x?xf32> to tensor<4x?xf32>
112
113    // Discard static dimension and rank information.
114    %4 = tensor.cast %3 : tensor<4x?xf32> to tensor<?x?xf32>
115    %5 = tensor.cast %4 : tensor<?x?xf32> to tensor<*xf32>
116    ```
117  }];
118
119  let arguments = (ins AnyTensor:$source);
120  let results = (outs AnyTensor:$dest);
121  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
122
123  let hasCanonicalizer = 1;
124}
125
126//===----------------------------------------------------------------------===//
127// ConcatOp
128//===----------------------------------------------------------------------===//
129
130def Tensor_ConcatOp : Tensor_Op<"concat",
131    [Pure,
132     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
133     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
134  let summary = "tensor concatenation operation";
135  let description = [{
136    The "concat" operation constructs a tensor out of a variadic list of input
137    tensors, concatenated along a static dimension number. All inputs and the
138    result type must share the same rank.
139
140    `dim` specifies the dimension along which to concatenate. The size of the
141    concatenated dimension in the result must be equal to the sum of the sizes
142    of the inputs along that dimension. All other dimensions in both the inputs
143    and result must be the same size.
144
145    Example:
146
147    ```mlir
148    %0 = tensor.concat dim(0) %0, %1, %2 :
149        (tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32>
150
151    // Dynamic + dynamic -> static
152    %0 = tensor.concat dim(1) %0, %1, %2 :
153        (tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>
154    ```
155  }];
156  let arguments = (ins I64Attr:$dim,
157                       Variadic<AnyRankedTensor>:$inputs);
158  let results = (outs AnyRankedTensor:$result);
159  let assemblyFormat = [{
160    `dim` `(` $dim `)` $inputs attr-dict
161    `:` functional-type(operands, results)
162  }];
163
164  let builders = [
165    // Builder with an inferred result type.
166    OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>,
167  ];
168
169  let extraClassDeclaration = [{
170    // Helper to infer the concatenated result type for the given list of input
171    // types, being concatenated along `dim`. Because concatenation can specify
172    // more static information than can automatically be inferred,
173    // InferTypeOpInterface is not used.
174    static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes);
175
176    RankedTensorType getResultType() {
177      return ::llvm::cast<RankedTensorType>(getResult().getType());
178    }
179
180    int64_t getRank() {
181      return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
182    }
183
184    // Method to decompose the operation into a sequence of insert_slices.
185    FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
186  }];
187
188  let hasCanonicalizer = 1;
189  let hasFolder = 1;
190  let hasVerifier = 1;
191}
192
193//===----------------------------------------------------------------------===//
194// DimOp
195//===----------------------------------------------------------------------===//
196
197def Tensor_DimOp : Tensor_Op<"dim", [
198    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
199    ConditionallySpeculatable, NoMemoryEffect,
200    ShapedDimOpInterface]> {
201  let summary = "dimension index operation";
202  let description = [{
203    The `tensor.dim` operation takes a tensor and a dimension operand of type
204    `index`. It returns the size of the requested dimension of the given
205    tensor. If the dimension index is out of bounds, the behavior is undefined.
206
207    The specified tensor type is that of the first operand.
208
209    Example:
210
211    ```mlir
212    // Always returns 4, can be constant folded:
213    %c0 = arith.constant 0 : index
214    %x = tensor.dim %A, %c0 : tensor<4x?xf32>
215
216    // Return the dynamic dimension of %A.
217    %c1 = arith.constant 1 : index
218    %y = tensor.dim %A, %c1 : tensor<4x?xf32>
219
220    // Equivalent generic form:
221    %x = "tensor.dim"(%A, %c0) : (tensor<4x?xf32>, index) -> index
222    %y = "tensor.dim"(%A, %c1) : (tensor<4x?xf32>, index) -> index
223    ```
224  }];
225
226  let arguments = (ins AnyNon0RankedOrUnrankedTensor:$source,
227                       Index:$index);
228  let results = (outs Index:$result);
229
230  let assemblyFormat = [{
231    attr-dict $source `,` $index `:` type($source)
232  }];
233
234  let builders = [
235    OpBuilder<(ins "Value":$source, "int64_t":$index)>
236  ];
237
238  let extraClassDeclaration = [{
239    /// Helper function to get the index as a simple integer if it is constant.
240    std::optional<int64_t> getConstantIndex();
241
242    /// Interface method of ShapedDimOpInterface: Return the source tensor.
243    Value getShapedValue() { return getSource(); }
244
245    /// Interface method of ShapedDimOpInterface: Return the dimension.
246    OpFoldResult getDimension() { return getIndex(); }
247
248    /// Interface method for ConditionallySpeculatable.
249    Speculation::Speculatability getSpeculatability();
250  }];
251
252  let hasCanonicalizer = 1;
253  let hasFolder = 1;
254}
255
256//===----------------------------------------------------------------------===//
257// EmptyOp
258//===----------------------------------------------------------------------===//
259
260def Tensor_EmptyOp : Tensor_Op<"empty",
261    [Pure,
262     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
263  let summary = "empty tensor operation";
264
265  let description = [{
266    `tensor.empty` is an operation that defines a tensor of a particular shape.
267    The shape could be dynamic or static. The contents of the tensor are
268    unspecified and the only purpose of the op result is to materialize the
269    specified shape in IR and make it available to other transformations.
270
271    `tensor.empty` is useful in transformations that expect destination style
272    ops. I.e., ops that implement `DestinationStyleOpInterface`. Ops that are
273    not in destination style can be made compatible with such transformations
274    with a `tensor.empty` destination.
275
276    Note: This op can be lowered to a `bufferization.alloc_tensor`, at which
277    point it turns into an explicit buffer allocation.
278  }];
279
280  let arguments = (ins Variadic<Index>:$dynamicSizes);
281
282  let results = (outs AnyRankedTensor:$result);
283
284  let assemblyFormat = "`(`$dynamicSizes`)` attr-dict `:` type($result)";
285
286  let extraClassDeclaration = [{
287    RankedTensorType getType() {
288      return ::llvm::cast<RankedTensorType>(getResult().getType());
289    }
290
291    // Return both static and dynamic sizes as a list of `OpFoldResult`.
292    SmallVector<OpFoldResult> getMixedSizes();
293
294    // Return the Value of the dynamic size of the tensor at dimension `idx`.
295    // Asserts that the shape is dynamic at that `idx`.
296    Value getDynamicSize(unsigned idx);
297  }];
298
299  let builders = [
300    // Build with fully static sizes.
301    OpBuilder<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType,
302                   CArg<"Attribute", "{}">:$encoding)>,
303
304    // Build with mixed static/dynamic sizes.
305    OpBuilder<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType,
306                   "ValueRange":$dynamicSizes,
307                   CArg<"Attribute", "{}">:$encoding)>,
308
309    // Build with mixed static/dynamic sizes.
310    OpBuilder<(ins "ArrayRef<OpFoldResult>":$sizes, "Type":$elementType,
311                   CArg<"Attribute", "{}">:$encoding)>
312  ];
313
314  let hasCanonicalizer = 1;
315  let hasVerifier = 1;
316}
317
318//===----------------------------------------------------------------------===//
319// ExtractOp
320//===----------------------------------------------------------------------===//
321
322def Tensor_ExtractOp : Tensor_Op<"extract", [
323    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
324    Pure,
325    TypesMatchWith<"result type matches element type of tensor",
326                   "tensor", "result",
327                   "::llvm::cast<TensorType>($_self).getElementType()">]> {
328  let summary = "element extraction operation";
329  let description = [{
330    The `tensor.extract` op reads a ranked tensor and returns one element as
331    specified by the given indices. The result of the op is a value with the
332    same type as the elements of the tensor. The arity of indices must match
333    the rank of the accessed value. All indices should all be of `index` type.
334
335    Example:
336
337    ```mlir
338    %4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
339    %5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
340    ```
341  }];
342
343  let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
344  let results = (outs AnyType:$result);
345  let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
346
347  let hasCanonicalizer = 1;
348  let hasFolder = 1;
349  let hasVerifier = 1;
350}
351
352
353//===----------------------------------------------------------------------===//
354// ExtractSliceOp
355//===----------------------------------------------------------------------===//
356
357def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
358    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
359    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
360    AttrSizedOperandSegments,
361    Pure,
362    OffsetSizeAndStrideOpInterface
363  ]> {
364  let summary = "extract slice operation";
365  let description = [{
366    The "extract_slice" operation extract a tensor from another tensor as
367    specified by the operation's offsets, sizes and strides arguments.
368
369    The extract_slice operation supports the following arguments:
370
371    * source: the "base" tensor from which to extract a slice.
372    * offsets: tensor-rank number of offsets into the "base" tensor from which
373               to extract the slice.
374    * sizes: tensor-rank number of sizes which specify the sizes of the result
375             tensor type.
376    * strides: tensor-rank number of strides specifying subsampling in each
377               dimension.
378
379    The representation based on offsets, sizes and strides support a
380    partially-static specification via attributes specified through the
381    `static_offsets`, `static_sizes` and `static_strides` arguments. A special
382    sentinel value ShapedType::kDynamic encodes that the corresponding entry has
383    a dynamic value.
384
385    After buffer allocation, the "extract_slice" op is expected to lower into a
386    memref.subview op.
387
388    An extract_slice operation may additionally reduce the rank of the resulting
389    tensor by removing dimensions that are statically known to be of size 1.
390    This rank-reduction behavior is not required by the op semantics: this
391    flexibility allows to progressively drop unit dimensions while lowering
392    between different flavors of ops on that operate on tensors.
393
394    #### Verification vs Inference in the rank-reduced case
395
396    Note that there may be multiple ways to infer a resulting rank-reduced type.
397      e.g. 1x6x1 could potentially rank-reduce to either 1x6 or 6x1 2-D shapes.
398
399    To disambiguate, the inference helpers `inferCanonicalRankReducedResultType`
400    only drop the first unit dimensions, in order:
401      e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6.
402
403    Verification however has access to result type and does not need to infer.
404    The verifier calls `isRankReducedType(getSource(), getResult())` to
405    determine whether the result type is rank-reduced from the source type.
406    This computes a so-called rank-reduction mask, consisting of dropped unit
407    dims, to map the rank-reduced type to the source type by dropping ones:
408      e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2}
409           6x1 is a rank-reduced version of 1x6x1 by mask {0}
410           1x2x1x4 is a rank-reduced version of 1x1x2x1x1x4x1 by mask {1, 4, 6}
411             (remaining common 1 dimensions are matched eagerly)
412
413    Example:
414
415    ```mlir
416    // Rank-reducing extract_slice.
417    %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] :
418      tensor<8x16x4xf32> to tensor<16x4xf32>
419    %3 = tensor.extract_slice %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] :
420      tensor<8x16x4xf32> to tensor<1x?xf32>
421    ```
422  }];
423
424  let arguments = (ins
425    AnyRankedTensor:$source,
426    Variadic<Index>:$offsets,
427    Variadic<Index>:$sizes,
428    Variadic<Index>:$strides,
429    DenseI64ArrayAttr:$static_offsets,
430    DenseI64ArrayAttr:$static_sizes,
431    DenseI64ArrayAttr:$static_strides
432  );
433  let results = (outs AnyRankedTensor:$result);
434
435  let assemblyFormat = [{
436    $source ``
437    custom<DynamicIndexList>($offsets, $static_offsets)
438    custom<DynamicIndexList>($sizes, $static_sizes)
439    custom<DynamicIndexList>($strides, $static_strides)
440    attr-dict `:` type($source) `to` type($result)
441  }];
442
443  let builders = [
444    // Build an ExtractSliceOp with mixed static and dynamic entries and
445    // inferred result type.
446    OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$offsets,
447      "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
448      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
449    // Build an ExtractSliceOp with mixed static and dynamic entries and custom
450    // result type. If the type passed is nullptr, it is inferred.
451    OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
452      "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
453      "ArrayRef<OpFoldResult>":$strides,
454      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
455    // Build an ExtractSliceOp with dynamic entries and custom result type. If
456    // the type passed is nullptr, it is inferred.
457    OpBuilder<(ins "Value":$source, "ValueRange":$offsets,
458      "ValueRange":$sizes, "ValueRange":$strides,
459      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
460    // Build an ExtractSliceOp with dynamic entries and inferred result type.
461    OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
462      "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
463      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
464    // Build an ExtractSliceOp with mixed static and dynamic entries packed in
465    // a Range vector.
466    OpBuilder<(ins "Value":$source, "ArrayRef<Range>":$ranges,
467      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
468  ];
469
470  let extraClassDeclaration = extraBaseClassDeclaration # [{
471    /// The result of an extract_slice is always a tensor.
472    // TODO: deprecate
473    RankedTensorType getType() {
474      return getResultType();
475    }
476
477    /// Compute the rank-reduction mask that can be applied to map the source
478    /// tensor type to the result tensor type by dropping unit dims.
479    std::optional<llvm::SmallDenseSet<unsigned>>
480    computeRankReductionMask() {
481      return ::mlir::computeRankReductionMask(getSourceType().getShape(),
482                                              getType().getShape());
483    };
484
485    /// An extract_slice result type can be inferred, when it is not
486    /// rank-reduced, from the source type and the static representation of
487    /// offsets, sizes and strides. Special sentinels encode the dynamic case.
488    static RankedTensorType inferResultType(
489      RankedTensorType sourceTensorType,
490      ArrayRef<int64_t> staticOffsets,
491      ArrayRef<int64_t> staticSizes,
492      ArrayRef<int64_t> staticStrides);
493    static RankedTensorType inferResultType(
494      RankedTensorType sourceTensorType,
495      ArrayRef<OpFoldResult> staticOffsets,
496      ArrayRef<OpFoldResult> staticSizes,
497      ArrayRef<OpFoldResult> staticStrides);
498
499    /// If the rank is reduced (i.e. the desiredResultRank is smaller than the
500    /// number of sizes), drop as many size 1 as needed to produce an inferred type
501    /// with the desired rank.
502    ///
503    /// Note that there may be multiple ways to compute this rank-reduced type:
504    ///   e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
505    ///
506    /// To disambiguate, this function always drops the first 1 sizes occurrences.
507    static RankedTensorType inferCanonicalRankReducedResultType(
508      unsigned resultRank,
509      RankedTensorType sourceRankedTensorType,
510      ArrayRef<int64_t> staticOffsets,
511      ArrayRef<int64_t> staticSizes,
512      ArrayRef<int64_t> staticStrides);
513    static RankedTensorType inferCanonicalRankReducedResultType(
514      unsigned resultRank,
515      RankedTensorType sourceRankedTensorType,
516      ArrayRef<OpFoldResult> staticOffsets,
517      ArrayRef<OpFoldResult> staticSizes,
518      ArrayRef<OpFoldResult> staticStrides);
519
520    /// Return the expected rank of each of the`static_offsets`, `static_sizes`
521    /// and `static_strides` attributes.
522    std::array<unsigned, 3> getArrayAttrMaxRanks() {
523      unsigned rank = getSourceType().getRank();
524      return {rank, rank, rank};
525    }
526
527    /// Return the number of leading operands before the `offsets`, `sizes` and
528    /// and `strides` operands.
529    static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
530
531    /// Return the dimensions of the source that are dropped in the
532    /// result when the result is rank-reduced.
533    llvm::SmallBitVector getDroppedDims();
534
535    /// Given a `value`, asserted to be of RankedTensorType, build an
536    /// ExtractSliceOp that results in a rank-reducing extract to the desired
537    /// tensor shape and return the new value created.
538    /// If the shape of `value` is already the `desiredShape`, just return
539    /// `value`.
540    /// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
541    static FailureOr<Value> rankReduceIfNeeded(
542      OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
543  }];
544
545  let hasCanonicalizer = 1;
546  let hasFolder = 1;
547  let hasVerifier = 1;
548}
549
550//===----------------------------------------------------------------------===//
551// FromElementsOp
552//===----------------------------------------------------------------------===//
553
554def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
555    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
556    Pure,
557    TypesMatchWith<"operand types match result element type",
558                   "result", "elements", "SmallVector<Type, 2>("
559                   "::llvm::cast<RankedTensorType>($_self).getNumElements(), "
560                   "::llvm::cast<RankedTensorType>($_self).getElementType())">
561  ]> {
562  let summary = "tensor from elements operation.";
563  let description = [{
564    Create a N-D tensor from a range of same-type arguments. The number of
565    provided `elements` should equal to the number of the elements in the
566    result type. The `elements` correspond to a flattened tensor.
567
568    Example:
569
570    ```mlir
571    tensor.from_elements %a, %b, %c, %d, %e, %f :  tensor<2x3xindex>
572    ```
573
574    will result in a tensor
575
576    [[%a, %b, %c]
577     [%d, %e, %f]]
578  }];
579
580  let arguments = (ins Variadic<AnyType>:$elements);
581  let results = (outs AnyStaticShapeTensor:$result);
582
583  let assemblyFormat = "$elements attr-dict `:` type($result)";
584
585  let builders = [
586    // Special case builder for when `elements` has size >=1.
587    OpBuilder<(ins "ValueRange":$elements)>
588  ];
589
590  let hasCanonicalizer = 1;
591  let hasFolder = 1;
592}
593
594//===----------------------------------------------------------------------===//
595// GatherOp
596//===----------------------------------------------------------------------===//
597
598def Tensor_GatherOp : Tensor_Op<"gather", [
599    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
600    Pure
601  ]> {
602  let summary = "gather a subset of a tensor at specified indices";
603  let description = [{
604    The `gather` operation extracts a subset of the elements from a `source`
605    tensor at the given indices.
606
607    In its most general form, the tensor of indices specifies all the coordinates
608    of every element to extract (i.e. COO format, without the payload).
609    The indices are expected to be confined to coordinate values that fit the
610    range of the `source` tensor, otherwise the behavior is undefined.
611
612    The leading dimensions of the index tensor give the result tensor its leading
613    dimensions. The trailing dimensions of the result tensor are obtained from
614    the source tensor by omitting the dimensions specified in `gather_dims`
615    (rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
616    (see examples).
617    The trailing dimension of the index tensor contains the coordinates and is
618    expected to have its size equal to the number of dimensions being gathered.
619    This convention allows an idiomatic specification and lowering of "gathering
620    multiple N-D slices from the source tensor".
621
622    Note: in the examples below, we separate out the indexing part of the tensor
623    type by a whitespace for readability purposes.
624
625    Example:
626
627    ```mlir
628        // For each 1x2 triple of coordinates in %indices, extract the
629        // element (i.e. 0-D subset) at the coordinates triple in %source.
630        //
631        %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) :
632          (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
633
634        // Note: result type may be further rank-reduced to tensor<1x2x f32>.
635    ```
636
637    A slice variant is provided to allow specifying whole slices of the source
638    tensor.
639
640    Example:
641
642    ```mlir
643        // For each 5x6 singleton of coordinates in %indices, extract the 2-D
644        // slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices
645        // corresponding to the `gather_dims` attribute specified by %indices.
646        //
647        %out = tensor.gather %source[%indices] gather_dims([1]) :
648          (tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32>
649
650        // Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>.
651    ```
652
653    The dimensions specified in the gather_dims attribute are ones for which the
654    result tensor has size `1`.
655    I.e. if the source type is `axbxcxd` and the coordinates are [1, 3], then
656    the shape suffix is `ax1xcx1`.
657    Gather also allows rank-reducing semantics where the shape `ax1xcx1` can be
658    further simplified to `axc`.
659
660    The elemental type of the indices tensor can be any integer type.
661    In the absence of target-specific or problem specific information the default
662    type one should use is `index`.
663
664    This operation does not support unranked tensors.
665
666    An optional `unique` unit attribute may be specified to indicate that the
667    coordinates in `indices` are statically guaranteed to be unique at runtime.
668    Incorrectly setting the `unique` attribute when the coordinates are not truly
669    unique is undefined behavior.
670
671    Only full slices are meant to be supported by this op, if one desires
672    partial slices (e.g. strided windows) one should compose this op with other
673    tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of
674    complexity that would make the op unusable in practice.
675
676    At the tensor-level, the index tensor is specified in an AoS form (i.e.
677    coordinate tuple is the most minor). It is the responsibility of further
678    lowerings and bufferization to implement various concrete layouts.
679
680    Note: As currently specified, the operation must lower to an abstraction that
681    performs copies to the output tensor. This is because the buffer type system
682    is currently not rich enough to allow multiple non-contiguous views in the
683    same type. This is visible more clearly in a notional buffer version of the
684    op:
685
686    ```mlir
687        // memref<?x4x1xf32> is a contiguous buffer of ?x4x1 elements.
688        // gather from random source slices must copy to the contiguous output.
689        %out = memref.gather %source[%indices] gather_dims([1]) :
690          (memref<4x4xf32>, memref<?x 1xindex>) -> memref<?x 4x1xf32>
691
692        // Nested buffer support would allow gather to directly index into the
693        // source buffer (i.e. represent a jagged view into the source).
694        %out = memref.gather %source[%indices] gather_dims([1]) :
695          (memref<4x4xf32>, memref<?x 1xindex>) -> memref<? x memref<4x1xf32>>
696    ```
697  }];
698
699  let arguments = (ins AnyRankedTensor:$source,
700                       RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
701                       DenseI64ArrayAttr:$gather_dims,
702                       UnitAttr:$unique);
703  let results = (outs AnyRankedTensor:$result);
704
705  let assemblyFormat = [{
706    $source `[` $indices `]`
707      `gather_dims` `(` $gather_dims `)`
708      (`unique` $unique^)?
709      attr-dict
710    `:` functional-type(operands, results)
711  }];
712
713  let extraClassDeclaration = [{
714    // TODO: InferTypeOpInterface once enough confidence is built with
715    // tensor<tensor> and its lowering to memref<memref>.
716    static RankedTensorType inferResultType(RankedTensorType sourceType,
717                                            RankedTensorType indicesType,
718                                            ArrayRef<int64_t> gatherDims,
719                                            bool rankReduced);
720    RankedTensorType getIndicesType() {
721      return ::llvm::cast<RankedTensorType>(getIndices().getType());
722    }
723    RankedTensorType getSourceType() {
724      return ::llvm::cast<RankedTensorType>(getSource().getType());
725    }
726    RankedTensorType getResultType() {
727      return ::llvm::cast<RankedTensorType>(getResult().getType());
728    }
729  }];
730  let hasVerifier = 1;
731  let hasFolder = 1;
732}
733
734//===----------------------------------------------------------------------===//
735// GenerateOp
736//===----------------------------------------------------------------------===//
737
738def Tensor_GenerateOp : Tensor_Op<"generate", [
739    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
740    RecursiveMemoryEffects,
741    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
742    SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
743  let summary = "Creates a dynamically sized tensor from elements";
744  let description = [{
745    This operation creates a dynamically sized tensor with elements of any type.
746    It expects one index operand per dynamic extent of the result tensor.
747
748    The body region defines the tensor's elements. It takes index operands as
749    its region arguments that span the index space. The element at the given
750    position is yielded with the `yield` operation (see `YieldOp`). There is
751    no defined ordering to the invocations of the body. It is conceptually
752    a "parallel map" operation.
753
754    Example:
755
756    ```mlir
757      %tnsr = tensor.generate %m, %n {
758      ^bb0(%i : index, %j : index, %k : index):
759        ...
760        yield %elem : f32
761      } : tensor<?x3x?f32>
762    ```
763  }];
764
765  let arguments = (ins Variadic<Index>:$dynamicExtents);
766  let results = (outs AnyRankedTensor:$result);
767  let regions = (region SizedRegion<1>:$body);
768  let assemblyFormat = "$dynamicExtents $body attr-dict `:` type($result)";
769
770  let builders = [
771    // Build op and populate its body per callback function.
772    OpBuilder<(ins "Type":$resultTy, "ValueRange":$dynamicExtents,
773      "function_ref<void(OpBuilder &, Location, ValueRange)>")>,
774  ];
775
776  let hasCanonicalizer = 1;
777  let hasVerifier = 1;
778  let hasRegionVerifier = 1;
779}
780
781//===----------------------------------------------------------------------===//
782// InsertOp
783//===----------------------------------------------------------------------===//
784
785def Tensor_InsertOp : Tensor_Op<"insert", [
786    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
787    DestinationStyleOpInterface,
788    Pure,
789    TypesMatchWith<"result type matches type of dest",
790                   "dest", "result",
791                   "$_self">,
792    TypesMatchWith<"scalar type matches element type of dest",
793                   "dest", "scalar",
794                   "::llvm::cast<TensorType>($_self).getElementType()">]> {
795  let summary = "element insertion operation";
796  let description = [{
797    The `tensor.insert` op inserts a scalar into a ranked tensor `dest` as
798    specified by the operation's indices.
799
800    It returns a copy of `dest` with the indexed position updated to the value
801    of `scalar`.
802
803    The arity of `indices `must match the rank of the tensor `dest`. All
804    indices should be of `index` type.
805
806    Example:
807
808    ```mlir
809    %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32>
810    %5 = tensor.insert %rt into %dest[%1, %2] : tensor<?x?xi32>
811    ```
812  }];
813
814  let arguments = (ins AnyType:$scalar,
815                       AnyRankedTensor:$dest,
816                       Variadic<Index>:$indices);
817  let results = (outs AnyRankedTensor:$result);
818  let assemblyFormat = [{
819    $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
820  }];
821
822  let extraClassDeclaration = [{
823    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
824  }];
825
826  let hasFolder = 1;
827  let hasVerifier = 1;
828}
829
830//===----------------------------------------------------------------------===//
831// InsertSliceOp
832//===----------------------------------------------------------------------===//
833
834def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
835    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
836    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
837    AttrSizedOperandSegments,
838    DestinationStyleOpInterface,
839    Pure,
840    OffsetSizeAndStrideOpInterface,
841    TypesMatchWith<"expected result type to match dest type",
842                   "dest", "result", "$_self">
843  ]> {
844  let summary = "insert_slice operation";
845  let description = [{
846    The "insert_slice" operation insert a tensor `source` into another
847    tensor `dest` as specified by the operation's offsets, sizes and strides
848    arguments.
849
850    It returns a copy of `dest` with the proper slice updated with the value
851    of `source`.
852
853    The insert_slice operation supports the following arguments:
854
855    * source: the tensor that is inserted.
856    * dest: the tensor into which the source tensor is inserted.
857    * offsets: tensor-rank number of offsets into the `dest` tensor into which
858               the slice is inserted.
859    * sizes: tensor-rank number of sizes which specify the sizes of the source
860             tensor type.
861    * strides: tensor-rank number of strides that specify subsampling in each
862               dimension.
863
864    The representation based on offsets, sizes and strides support a
865    partially-static specification via attributes specified through the
866    `static_offsets`, `static_sizes` and `static_strides` arguments. A special
867    sentinel value ShapedType::kDynamic encodes that the corresponding entry has
868    a dynamic value.
869
870    After buffer allocation, the "insert_slice" op is expected to lower into a
871    memref.subview op.
872
873    An insert_slice operation may additionally specify insertion into a tensor
874    of higher rank than the source tensor, along dimensions that are statically
875    known to be of size 1.
876    This rank-altering behavior is not required by the op semantics: this
877    flexibility allows to progressively drop unit dimensions while lowering
878    between different flavors of ops on that operate on tensors.
879    The rank-altering behavior of tensor.insert_slice matches the rank-reducing
880    behavior of tensor.extract_slice.
881
882    #### Verification in the rank-reduced case
883
884    The same verification discussion and mechanisms apply as for ExtractSliceOp.
885    Unlike ExtractSliceOp however, there is no need for a specific inference.
886
887    Example:
888
889    ```mlir
890    // Rank-altering insert_slice.
891    %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] :
892      tensor<16x4xf32> into tensor<8x16x4xf32>
893    %3 = tensor.insert_slice %tt into %2[%o0, 4, %o2][1, %sz1, 1][1, %st1, 1] :
894      tensor<1x?xf32> into tensor<8x16x4xf32>
895    ```
896  }];
897
898  let arguments = (ins
899    AnyRankedTensor:$source,
900    AnyRankedTensor:$dest,
901    Variadic<Index>:$offsets,
902    Variadic<Index>:$sizes,
903    Variadic<Index>:$strides,
904    DenseI64ArrayAttr:$static_offsets,
905    DenseI64ArrayAttr:$static_sizes,
906    DenseI64ArrayAttr:$static_strides
907  );
908  let results = (outs AnyRankedTensor:$result);
909
910  let assemblyFormat = [{
911    $source `into` $dest ``
912    custom<DynamicIndexList>($offsets, $static_offsets)
913    custom<DynamicIndexList>($sizes, $static_sizes)
914    custom<DynamicIndexList>($strides, $static_strides)
915    attr-dict `:` type($source) `into` type($dest)
916  }];
917
918  let builders = [
919    // Build a InsertSliceOp with mixed static and dynamic entries and inferred
920    // result type.
921    OpBuilder<(ins "Value":$source, "Value":$dest,
922      "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
923      "ArrayRef<OpFoldResult>":$strides,
924      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
925    // Build a InsertSliceOp with dynamic entries and inferred result type.
926    OpBuilder<(ins "Value":$source, "Value":$dest,
927      "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
928      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
929    // Build an InsertSliceOp with mixed static and dynamic entries packed in
930    // a Range vector and inferred result type.
931    OpBuilder<(ins "Value":$source, "Value":$dest,
932      "ArrayRef<Range>":$ranges,
933      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
934  ];
935
936  let extraClassDeclaration = extraBaseClassDeclaration # [{
937    /// The result of a insert_slice is always a tensor.
938    // TODO: Deprecate this method.
939    RankedTensorType getType() {
940      return getResultType();
941    }
942
943    /// The `dest` type is the same as the result type.
944    RankedTensorType getDestType() {
945      return getResultType();
946    }
947
948    /// Return the expected rank of each of the`static_offsets`, `static_sizes`
949    /// and `static_strides` attributes.
950    std::array<unsigned, 3> getArrayAttrMaxRanks() {
951      unsigned rank = getResultType().getRank();
952      return {rank, rank, rank};
953    }
954
955    /// Return the dimensions of the dest that are omitted to insert a source
956    /// when the result is rank-extended.
957    llvm::SmallBitVector getDroppedDims();
958
959    /// Return the number of leading operands before the `offsets`, `sizes` and
960    /// and `strides` operands.
961    static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
962
963    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
964  }];
965
966  let hasCanonicalizer = 1;
967  let hasFolder = 1;
968  let hasVerifier = 1;
969}
970
971//===----------------------------------------------------------------------===//
972// RankOp
973//===----------------------------------------------------------------------===//
974
975def Tensor_RankOp : Tensor_Op<"rank", [
976    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
977    Pure]> {
978  let summary = "rank operation";
979  let description = [{
980    The `tensor.rank` operation takes a tensor operand and returns its rank.
981
982    Example:
983
984    ```mlir
985    %0 = tensor.rank %arg0 : tensor<*xf32>
986    %1 = tensor.rank %arg1 : tensor<?x?xf32>
987    ```
988  }];
989
990  let arguments = (ins AnyTensor:$tensor);
991  let results = (outs Index);
992
993  let hasFolder = 1;
994  let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
995}
996
997//===----------------------------------------------------------------------===//
998// ReshapeOp
999//===----------------------------------------------------------------------===//
1000
1001def Tensor_ReshapeOp: Tensor_Op<"reshape", [
1002    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1003    Pure]>  {
1004  let summary = "tensor reshape operation";
1005  let description = [{
1006    The `reshape` operation converts a tensor from one type to an equivalent
1007    type with a provided shape. The source and destination types are compatible
1008    if both have the same element type, same number of elements. The following
1009    combinations are possible:
1010
1011    a. Source type is ranked or unranked. Shape argument has static size.
1012    Result type is ranked.
1013
1014    ```mlir
1015    // Reshape statically-shaped tensor.
1016    %dst = tensor.reshape %src(%shape)
1017             : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
1018    %dst0 = tensor.reshape %src(%shape0)
1019             : (tensor<4x1xf32>, tensor<2xi32>) -> tensor<2x2xf32>
1020    // Flatten unranked tensor.
1021    %dst = tensor.reshape %src(%shape)
1022             : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
1023    ```
1024
1025    b. Source type is ranked or unranked. Shape argument has dynamic size.
1026    Result type is unranked.
1027
1028    ```mlir
1029    // Reshape dynamically-shaped 1D tensor.
1030    %dst = tensor.reshape %src(%shape)
1031             : (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
1032    // Reshape unranked tensor.
1033    %dst = tensor.reshape %src(%shape)
1034             : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
1035    ```
1036  }];
1037
1038  let arguments = (ins
1039    AnyTensor:$source,
1040    TensorRankOf<[AnySignlessInteger, Index], [1]>:$shape
1041  );
1042  let results = (outs AnyTensor:$result);
1043
1044  let builders = [OpBuilder<
1045     (ins "TensorType":$resultType, "Value":$operand, "Value":$shape), [{
1046       $_state.addOperands(operand);
1047       $_state.addOperands(shape);
1048       $_state.addTypes(resultType);
1049     }]>];
1050
1051  let extraClassDeclaration = [{
1052    TensorType getResultType() { return ::llvm::cast<TensorType>(getResult().getType()); }
1053  }];
1054
1055  let assemblyFormat = [{
1056    $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
1057  }];
1058  let hasVerifier = 1;
1059  let hasFolder = 1;
1060}
1061
1062//===----------------------------------------------------------------------===//
1063// ExpandShapeOp / CollapseShapeOp
1064//===----------------------------------------------------------------------===//
1065
1066class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
1067    Tensor_Op<mnemonic, !listconcat(traits, [
1068      DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1069      Pure])>,
1070    Results<(outs AnyTensor:$result)> {
1071
1072  code commonExtraClassDeclaration = [{
1073    static StringRef getReassociationAttrStrName() { return "reassociation"; }
1074    SmallVector<AffineMap, 4> getReassociationMaps();
1075    SmallVector<ReassociationExprs, 4> getReassociationExprs();
1076    SmallVector<ReassociationIndices, 4> getReassociationIndices() {
1077      SmallVector<ReassociationIndices, 4> reassociationIndices;
1078      for (auto attr : getReassociation())
1079        reassociationIndices.push_back(llvm::to_vector<2>(
1080            llvm::map_range(::llvm::cast<ArrayAttr>(attr), [&](Attribute indexAttr) {
1081              return ::llvm::cast<IntegerAttr>(indexAttr).getInt();
1082            })));
1083      return reassociationIndices;
1084    }
1085    RankedTensorType getSrcType() {
1086      return ::llvm::cast<RankedTensorType>(getSrc().getType());
1087    }
1088    RankedTensorType getResultType() {
1089      return ::llvm::cast<RankedTensorType>(getResult().getType());
1090    }
1091  }];
1092
1093  let hasFolder = 1;
1094  let hasCanonicalizer = 1;
1095  let hasVerifier = 1;
1096}
1097
1098def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
1099  let summary = "operation to produce a tensor with a higher rank";
1100  let description = [{
1101    The `tensor.expand_shape` op produces a tensor of higher (or equal)
1102    rank than the operand `src` whose dimension sizes are a reassociation of
1103    `src`.
1104
1105    A reassociation is defined as a continuous grouping of dimensions and is
1106    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
1107    maps applied to the result tensor with the higher rank must result in the
1108    operand tensor with the smaller rank.
1109
1110    The representation for the output shape supports a partially-static
1111    specification via attributes specified through the `static_output_shape`
1112    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
1113    corresponding entry has a dynamic value.  There must be exactly as many SSA
1114    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1115    `static_output_shape`.
1116
1117    Example:
1118
1119    ```mlir
1120    // Dimension expansion i -> (i', j') and (k) -> (k')
1121    %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1122        : tensor<?x32xf32> into tensor<?x?x32xf32>
1123    ```
1124  }];
1125
1126  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
1127                       Variadic<Index>:$output_shape,
1128                       DenseI64ArrayAttr:$static_output_shape);
1129
1130  let assemblyFormat = [{
1131    $src $reassociation `output_shape`
1132    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1133    type($src) `into` type($result)
1134  }];
1135
1136  let builders = [
1137    // Builders using ReassociationIndices.
1138    OpBuilder<(ins "Type":$resultType, "Value":$src,
1139      "ArrayRef<ReassociationIndices>":$reassociation,
1140      "ArrayRef<OpFoldResult>":$outputShape)>,
1141
1142    // It will infer output shape using inferOutputShape() method.
1143    OpBuilder<(ins "Type":$resultType, "Value":$src,
1144      "ArrayRef<ReassociationIndices>":$reassociation)>,
1145
1146    // Builder using ReassociationExprs.
1147    OpBuilder<(ins "Type":$resultType, "Value":$src,
1148      "ArrayRef<ReassociationExprs>":$reassociation),
1149    [{
1150      auto reassociationIndices =
1151          convertReassociationMapsToIndices(reassociation);
1152      build($_builder, $_state, resultType, src, reassociationIndices);
1153    }]>,
1154    OpBuilder<(ins "Type":$resultType, "Value":$src,
1155      "ArrayRef<ReassociationExprs>":$reassociation,
1156      "ArrayRef<OpFoldResult>":$outputShape),
1157    [{
1158      auto reassociationIndices =
1159          convertReassociationMapsToIndices(reassociation);
1160      build($_builder, $_state, resultType, src, reassociationIndices,
1161            outputShape);
1162    }]>
1163  ];
1164
1165  let extraClassDeclaration = commonExtraClassDeclaration # [{
1166    int64_t getCorrespondingSourceDim(int64_t resultDim);
1167
1168    // Return output shape as mixes static/dynamic shapes.
1169    SmallVector<OpFoldResult> getMixedOutputShape();
1170
1171    // Infer the output shape for a tensor.expand_shape when it is possible
1172    // to do so.
1173    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1174        OpBuilder &b, Location loc, RankedTensorType expandedType,
1175        ArrayRef<ReassociationIndices> reassociation,
1176        ArrayRef<OpFoldResult> inputShape);
1177  }];
1178
1179  let hasVerifier = 1;
1180}
1181
1182def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
1183  let summary = "operation to produce a tensor with a smaller rank";
1184  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
1185  let description = [{
1186    The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
1187    rank whose dimension sizes are a reassociation of the original `src` dimensions.
1188
1189    A reassociation is defined as a continuous grouping of dimensions and is
1190    represented by an array of DenseI64ArrayAttr attribute. The reassociation
1191    maps are applied to the operand shape to obtain the result shape.
1192
1193
1194    Example:
1195
1196    ```mlir
1197    // Dimension collapse (i, j) -> i' and k -> k'
1198    %b = tensor.collapse_shape %a [[0, 1], [2]]
1199        : tensor<?x?x?xf32> into tensor<?x?xf32>
1200    ```
1201  }];
1202
1203  let assemblyFormat = [{
1204    $src $reassociation attr-dict `:` type($src) `into` type($result)
1205  }];
1206
1207  let builders = [
1208    // Builders for a contracting reshape whose result type is computed from
1209    // `src` and `reassociation`.
1210    OpBuilder<(ins "Value":$src,
1211      "ArrayRef<ReassociationIndices>":$reassociation,
1212      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1213    OpBuilder<(ins "Value":$src,
1214      "ArrayRef<ReassociationExprs>":$reassociation,
1215      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1216    [{
1217      auto reassociationMaps =
1218          convertReassociationMapsToIndices(reassociation);
1219      build($_builder, $_state, src, reassociationMaps, attrs);
1220    }]>,
1221
1222    // Builders for a reshape whose result type is passed explicitly.
1223    OpBuilder<(ins "Type":$resultType, "Value":$src,
1224      "ArrayRef<ReassociationIndices>":$reassociation,
1225      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1226    [{
1227      $_state.addAttribute("reassociation",
1228          getReassociationIndicesAttribute($_builder, reassociation));
1229      build($_builder, $_state, resultType, src, attrs);
1230    }]>,
1231    OpBuilder<(ins "Type":$resultType, "Value":$src,
1232      "ArrayRef<ReassociationExprs>":$reassociation,
1233      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1234    [{
1235      auto reassociationMaps =
1236          convertReassociationMapsToIndices(reassociation);
1237      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1238    }]>
1239  ];
1240
1241  let extraClassDeclaration = commonExtraClassDeclaration # [{
1242    static RankedTensorType
1243    inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
1244    static RankedTensorType
1245    inferCollapsedType(RankedTensorType type,
1246                       SmallVector<ReassociationIndices> reassociation);
1247  }];
1248  let hasVerifier = 1;
1249}
1250
1251//===----------------------------------------------------------------------===//
1252// PadOp
1253//===----------------------------------------------------------------------===//
1254
1255def Tensor_PadOp : Tensor_Op<"pad", [
1256    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1257    AttrSizedOperandSegments,
1258    Pure,
1259    SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
1260  let summary = "tensor pad operation";
1261  let description = [{
1262    `tensor.pad` is an operation that pads the `source` tensor
1263    with given `low` and `high` padding config.
1264
1265    The PadOp operation supports the following arguments:
1266
1267    * source: the "base" tensor on which to pad.
1268    * low: A list contains the padding along the start of each
1269           dimension, i.e., how many padded values are prepended
1270           to the beginning of the tensor in each dimension.
1271    * high: A list contains the padding along the end of each
1272            dimension, i.e., how many padded values are appended
1273            to the end of the tensor in each dimension.
1274    * nofold: indicates that the operation should not be folded when source and
1275              result types are equal.
1276
1277    The result tensor dimensions are `low[i]` + `dim[i]` + `high[i]` for each
1278    dimension `i`. The number of elements of `low` and `high` must match the
1279    rank of the input tensor. They can be either a constant or a dynamic value.
1280
1281    The region of the `tensor.pad` operation returns the value to use
1282    for the padding. The arguments of the region represent the index
1283    of the source being accessed. There should be as many arguments as
1284    the rank of the `source` tensor. The value `yield`-ed by the
1285    region is used as the value of the view at the given position.
1286
1287    If `nofold` is set, the padding operation will not be folded away even
1288    if the source type and the padded type have the same static shape. This can
1289    be used, e.g., for packing or promotion to faster memory.
1290
1291    Example 1: add 3 zeros to the beginning and 5 zeros to the end of a 1D
1292    tensor.
1293
1294    ```mlir
1295      %arg0 = ... : tensor<10xi32>
1296      %c0_i32 = arith.constant 0 : i32
1297      %padded = tensor.pad %arg0 low[3] high[5] {
1298      ^bb0(%arg1: index):
1299        tensor.yield %c0_i32 : i32
1300      } : tensor<10xi32> to tensor<18xi32>
1301    ```
1302
1303    Example 2: add 1 value to the beginning of dimension 0, 2 values to the end
1304    of dimension 0, 2 values to the start of dimension 1, and 3 values to the
1305    end of dimension 1.
1306
1307    ```mlir
1308      %pad_value = ... : f32
1309      %0 = tensor.pad %0 low[1, 2] high[2, 3] {
1310      ^bb0(%arg0 : index, %arg1 : index):
1311        tensor.yield %pad_value : f32
1312      } : tensor<?x?xf32> to tensor<?x?xf32>
1313    ```
1314
1315    Example 3:
1316
1317    ```mlir
1318      %pad_value = ... : f32
1319      %0 = tensor.pad %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] {
1320      ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index):
1321          tensor.yield %pad_value : f32
1322      } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
1323    ```
1324
1325    Example 4:
1326
1327    ```mlir
1328      %pad_value = ... : f32
1329      %0 = tensor.pad %arg0 low[0, 0] high[%ub0, %ub1] {
1330      ^bb0(%arg1: index, %arg2: index):
1331        tensor.yield %pad_value : f32
1332      } : tensor<2x3xf32> to tensor<?x?xf32>
1333    ```
1334
1335    Example 5: Force a padded value to be always exist with `nofold`, even
1336    though the padding config specifies that no new elements will be added to
1337    the tensor.
1338
1339    ```mlir
1340      %pad_value = ... : f32
1341      %0 = tensor.pad %arg0 nofold low[0, 0] high[0, 0] {
1342      ^bb0(%arg1: index, %arg2: index):
1343        tensor.yield %pad_value : f32
1344      } : tensor<2x3xf32> to tensor<2x3xf32>
1345    ```
1346  }];
1347
1348  let arguments = (ins
1349    AnyRankedTensor:$source,
1350    Variadic<Index>:$low,
1351    Variadic<Index>:$high,
1352    DenseI64ArrayAttr:$static_low,
1353    DenseI64ArrayAttr:$static_high,
1354    UnitAttr:$nofold);
1355
1356  let regions = (region SizedRegion<1>:$region);
1357
1358  let results = (outs AnyRankedTensor:$result);
1359
1360  // TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
1361  let assemblyFormat = [{
1362    $source
1363    (`nofold` $nofold^)?
1364    `low` `` custom<DynamicIndexList>($low, $static_low)
1365    `high` `` custom<DynamicIndexList>($high, $static_high)
1366    $region attr-dict `:` type($source) `to` type($result)
1367  }];
1368
1369  let extraClassDeclaration = [{
1370    static StringRef getStaticLowAttrStrName() {
1371      return "static_low";
1372    }
1373
1374    static StringRef getStaticHighAttrStrName() {
1375      return "static_high";
1376    }
1377
1378    RankedTensorType getSourceType() {
1379      return ::llvm::cast<RankedTensorType>(getSource().getType());
1380    }
1381    RankedTensorType getResultType() {
1382      return ::llvm::cast<RankedTensorType>(getResult().getType());
1383    }
1384
1385    // Infer the shape of the result tensor given the type of the source tensor
1386    // and paddings. Known result dimensions that cannot necessarily be inferred
1387    // from low/high padding sizes can be optionally specified. Those will be
1388    // considered when computing the result type.
1389    static RankedTensorType inferResultType(
1390                                RankedTensorType sourceType,
1391                                ArrayRef<int64_t> staticLow,
1392                                ArrayRef<int64_t> staticHigh,
1393                                ArrayRef<int64_t> resultShape = {});
1394
1395    // Return the pad value if it is a constant. Return null value otherwise.
1396    Value getConstantPaddingValue();
1397
1398    // Return a vector of all the static or dynamic values (low/high padding) of
1399    // the op.
1400    inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayRef<int64_t> staticAttrs,
1401                                                     ValueRange values) {
1402      Builder builder(*this);
1403      SmallVector<OpFoldResult> res;
1404      unsigned numDynamic = 0;
1405      unsigned count = staticAttrs.size();
1406      for (unsigned idx = 0; idx < count; ++idx) {
1407        if (ShapedType::isDynamic(staticAttrs[idx]))
1408          res.push_back(getAsOpFoldResult(values[numDynamic++]));
1409        else
1410          res.push_back(builder.getI64IntegerAttr(staticAttrs[idx]));
1411      }
1412      return res;
1413    }
1414    SmallVector<OpFoldResult> getMixedLowPad() {
1415      return getMixedPadImpl(getStaticLow(), getLow());
1416    }
1417    SmallVector<OpFoldResult> getMixedHighPad() {
1418      return getMixedPadImpl(getStaticHigh(), getHigh());
1419    }
1420    // Return true if low padding is guaranteed to be 0.
1421    bool hasZeroLowPad() {
1422      return llvm::all_of(getMixedLowPad(), [](OpFoldResult ofr) {
1423        return getConstantIntValue(ofr) == static_cast<int64_t>(0);
1424      });
1425    }
1426    // Return true if high padding is guaranteed to be 0.
1427    bool hasZeroHighPad() {
1428      return llvm::all_of(getMixedHighPad(), [](OpFoldResult ofr) {
1429        return getConstantIntValue(ofr) == static_cast<int64_t>(0);
1430      });
1431    }
1432    /// Return the dimensions with a non-zero low or high padding.
1433    llvm::SmallBitVector getPaddedDims();
1434  }];
1435
1436  let builders = [
1437    // Build a PadOp with mixed static and dynamic entries.
1438    OpBuilder<(ins "Type":$resultType, "Value":$source,
1439      "ArrayRef<int64_t>":$staticLow, "ArrayRef<int64_t>":$staticHigh,
1440      "ValueRange":$low, "ValueRange":$high, CArg<"bool", "false">:$nofold,
1441      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1442    // Build a PadOp with all dynamic entries.
1443    OpBuilder<(ins "Type":$resultType, "Value":$source, "ValueRange":$low,
1444      "ValueRange":$high, CArg<"bool", "false">:$nofold,
1445      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1446    // Build a PadOp with mixed static and dynamic entries and custom
1447    // result type. If the type passed is nullptr, it is inferred.
1448    OpBuilder<(ins "Type":$resultType, "Value":$source,
1449      "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
1450      CArg<"bool", "false">:$nofold,
1451      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1452    // Build a PadOp with constant padding,  mixed static and dynamic entries
1453    // and custom result type. If the type passed is nullptr, it is inferred.
1454    OpBuilder<(ins "Type":$resultType, "Value":$source,
1455      "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
1456      "Value":$constantPadValue, CArg<"bool", "false">:$nofold,
1457      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
1458  ];
1459
1460  let hasCanonicalizer = 1;
1461  let hasFolder = 1;
1462  let hasVerifier = 1;
1463  let hasRegionVerifier = 1;
1464}
1465
1466//===----------------------------------------------------------------------===//
1467// ParallelInsertSliceOp
1468//===----------------------------------------------------------------------===//
1469
1470// TODO: Implement InParallelOpInterface.
1471def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
1472       AttrSizedOperandSegments,
1473       OffsetSizeAndStrideOpInterface,
1474       // TODO: Cannot use an interface here atm, verify this manually for now.
1475       // HasParent<"ParallelCombiningOpInterface">
1476  ]> {
1477  let summary = [{
1478    Specify the tensor slice update of a single thread of a parent
1479    ParallelCombiningOpInterface op.
1480  }];
1481  let description = [{
1482    The `parallel_insert_slice` yields a subset tensor value to its parent
1483    ParallelCombiningOpInterface. These subset tensor values are aggregated to
1484    in some unspecified order into a full tensor value returned by the parent
1485    parallel iterating op.
1486    The `parallel_insert_slice` is one such op allowed in the
1487    ParallelCombiningOpInterface op.
1488
1489    Conflicting writes result in undefined semantics, in that the indices written
1490    to by multiple parallel updates might contain data from any of the updates,
1491    or even a malformed bit pattern.
1492
1493    If an index is updated exactly once, the value contained at that index
1494    in the resulting tensor will be equal to the value at a corresponding index
1495    of a slice that was used for the updated. If an index is not updated at all,
1496    its value will be equal to the one in the original tensor.
1497
1498    This op does not create a new value, which allows maintaining a clean
1499    separation between the subset and full tensor.
1500
1501    Note that we cannot mark this operation as pure (Pures), even
1502    though it has no side effects, because it will get DCEd during
1503    canonicalization.
1504
1505    The parallel_insert_slice operation supports the following arguments:
1506
1507    * source: the tensor that is inserted.
1508    * dest: the tensor into which the source tensor is inserted.
1509    * offsets: tensor-rank number of offsets into the `dest` tensor into which
1510               the slice is inserted.
1511    * sizes: tensor-rank number of sizes which specify the sizes of the source
1512             tensor type.
1513    * strides: tensor-rank number of strides that specify subsampling in each
1514               dimension.
1515
1516    The representation based on offsets, sizes and strides support a
1517    partially-static specification via attributes specified through the
1518    `static_offsets`, `static_sizes` and `static_strides` arguments. A special
1519    sentinel value ShapedType::kDynamic encodes that the corresponding entry has
1520    a dynamic value.
1521
1522    After buffer allocation, the "parallel_insert_slice" op is expected to lower
1523    into a memref.subview op.
1524
1525    A parallel_insert_slice operation may additionally specify insertion into a
1526    tensor of higher rank than the source tensor, along dimensions that are
1527    statically known to be of size 1.
1528    This rank-altering behavior is not required by the op semantics: this
1529    flexibility allows to progressively drop unit dimensions while lowering
1530    between different flavors of ops on that operate on tensors.
1531    The rank-altering behavior of tensor.parallel_insert_slice matches the
1532    rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
1533
1534    #### Verification in the rank-reduced case
1535
1536    The same verification discussion and mechanisms apply as for ExtractSliceOp.
1537    Unlike ExtractSliceOp however, there is no need for a specific inference.
1538  }];
1539
1540  let arguments = (ins
1541    AnyRankedTensor:$source,
1542    AnyRankedTensor:$dest,
1543    Variadic<Index>:$offsets,
1544    Variadic<Index>:$sizes,
1545    Variadic<Index>:$strides,
1546    DenseI64ArrayAttr:$static_offsets,
1547    DenseI64ArrayAttr:$static_sizes,
1548    DenseI64ArrayAttr:$static_strides
1549  );
1550  let assemblyFormat = [{
1551    $source `into` $dest ``
1552    custom<DynamicIndexList>($offsets, $static_offsets)
1553    custom<DynamicIndexList>($sizes, $static_sizes)
1554    custom<DynamicIndexList>($strides, $static_strides)
1555    attr-dict `:` type($source) `into` type($dest)
1556  }];
1557
1558  let extraClassDeclaration = [{
1559    Type yieldedType() { return getDest().getType(); }
1560
1561    RankedTensorType getSourceType() {
1562      return ::llvm::cast<RankedTensorType>(getSource().getType());
1563    }
1564
1565    RankedTensorType getDestType() {
1566      return ::llvm::cast<RankedTensorType>(getDest().getType());
1567    }
1568
1569    ParallelCombiningOpInterface getParallelCombiningParent() {
1570      return dyn_cast<ParallelCombiningOpInterface>(
1571        getOperation()->getParentOp());
1572    }
1573
1574    /// Return the expected rank of each of the `static_offsets`, `static_sizes`
1575    /// and `static_strides` attributes.
1576    std::array<unsigned, 3> getArrayAttrMaxRanks() {
1577      unsigned rank = getDestType().getRank();
1578      return {rank, rank, rank};
1579    }
1580
1581    /// Return the number of leading operands before `offsets`, `sizes` and
1582    /// `strides` operands.
1583    static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
1584
1585    /// Return the OpResult of the enclosing ForallOp that is
1586    /// corresponding to this ParallelInsertSliceOp.
1587    OpResult getTiedOpResult();
1588
1589    /// Return the dimensions of the dest that are omitted to insert a source
1590    /// when the result is rank-extended.
1591    llvm::SmallBitVector getDroppedDims();
1592  }];
1593
1594  let builders = [
1595    // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
1596    OpBuilder<(ins "Value":$source, "Value":$dest,
1597      "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
1598      "ArrayRef<OpFoldResult>":$strides,
1599      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1600    // Build a ParallelInsertSliceOp with mixed static and dynamic entries
1601    // packed into a Range vector.
1602    OpBuilder<(ins "Value":$source, "Value":$dest,
1603      "ArrayRef<Range>":$ranges,
1604      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
1605    // Build a ParallelInsertSliceOp with dynamic entries.
1606    OpBuilder<(ins "Value":$source, "Value":$dest,
1607      "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
1608      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
1609  ];
1610
1611  let hasCanonicalizer = 1;
1612  let hasVerifier = 1;
1613}
1614
1615
1616//===----------------------------------------------------------------------===//
1617// ScatterOp
1618//===----------------------------------------------------------------------===//
1619
1620def Tensor_ScatterOp : Tensor_Op<"scatter", [
1621    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1622    Pure
1623  ]> {
1624  let summary =
1625    "scatter a tensor into a destination tensor at specified indices";
1626  let description = [{
1627    The `scatter` operation inserts a `source` tensor into a `dest` tensor at
1628    the given indices.
1629
1630    In its most general form, the tensor of indices specifies all the coordinates
1631    of every element to insert (i.e. COO format, without the payload).
1632    The indices are expected to be confined to coordinate values that fit the
1633    range of the `dest` tensor, otherwise the behavior is undefined.
1634
1635    The leading dimensions of the index tensor must match that of the dest
1636    tensor. The trailing dimensions of the dest tensor must match those of the
1637    source tensor by omitting the dimensions specified in scatter_dims
1638    (rank-reducing semantics) or setting them to `1` (rank-preserving semantics)
1639    (see examples).
1640    This convention allows an idiomatic specification and lowering of
1641    "scattering multiple N-D slices into the dest tensor".
1642    The result type must match the type of the dest tensor.
1643
1644    Note: in the examples below, we separate out the indexing part of the tensor
1645    type by a whitespace for readability purposes.
1646
1647    Example:
1648
1649    ```mlir
1650        // For each 1x2 triple of coordinates in %indices, insert the
1651        // element (i.e. 0-D subset) at the coordinates triple in %dest.
1652        //
1653        %out = tensor.scatter %source into %dest[%indices]
1654            scatter_dims([0, 1, 2]) unique :
1655          (tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>)
1656            -> tensor<4x4x4xf32>
1657
1658        // Note: source type may be further rank-reduced to tensor<1x2x f32>.
1659    ```
1660
1661    A slice variant is provided to allow specifying insertion of whole tensor
1662    slices into the `dest` tensor.
1663
1664    Example:
1665
1666    ```mlir
1667        // For each 3 singleton of coordinates in %indices, insert the 2-D
1668        // slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the
1669        // indices corresponding to the scatter_dims attribute specified by
1670        // %indices.
1671        //
1672        %out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique :
1673          (tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>)
1674            -> tensor<4x5x6xf32>
1675    ```
1676
1677    The dimensions specified in the scatter_dims attribute are ones for which the
1678    source tensor has size `1`.
1679    I.e. if the dest type is `axbxcxd` and the coordinates are [1, 3], then
1680    the source type suffix is `ax1xcx1`.
1681    Scatter also allows rank-reducing semantics where the shape `ax1xcx1` can be
1682    further simplified to `axc`.
1683
1684    The elemental type of the indices tensor can be any integer type.
1685    In the absence of target-specific or problem specific information the default
1686    type one should use is `index`.
1687
1688    This operation does not support unranked tensors.
1689
1690    A `unique` unit attribute must be be specified to indicate that the
1691    coordinates are statically guaranteed to be unique at runtime. If coordinates
1692    are not truly unique at runtime, the behavior is undefined.
1693
1694    Only full slices are meant to be supported by this op, if one desires
1695    partial slices (e.g. strided windows) one should compose this op with other
1696    tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of
1697    complexity that would make the op unusable in practice.
1698
1699    At the tensor-level, the index tensor is specified in an AoS form (i.e.
1700    coordinate tuple is the most minor). It is the responsibility of further
1701    lowerings and bufferization to implement various concrete layouts.
1702
1703    Note: As currently specified, the operation must lower to an abstraction that
1704    performs copies to the output tensor. This is because the buffer type system
1705    is currently not rich enough to allow multiple non-contiguous views in the
1706    same type. This is visible more clearly in a notional buffer version of the
1707    op:
1708
1709    ```mlir
1710        // memref<?x 4xf32> is a contiguous buffer of ?x4 elements, scatter into
1711        // random dest slices must copy to the contiguous dest.
1712        //
1713        some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32>
1714        memref.scatter %source into %dest[%indices] scatter_dims([1]) unique :
1715          (memref<3x 4xf32>, memref<?x 4xf32>, memref<?x 1xindex>)
1716
1717        // Nested buffer support in the producing op would allow writing directly
1718        // into the dest buffer.
1719        %v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique :
1720          memref<? x memref<4xf32>>
1721        some_side_effecting_op_writing_into %v, ...: memref<? x memref<4xf32>>
1722    ```
1723  }];
1724
1725  let arguments = (ins AnyRankedTensor:$source,
1726                       AnyRankedTensor:$dest,
1727                       RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices,
1728                       DenseI64ArrayAttr:$scatter_dims,
1729                       UnitAttr:$unique);
1730  let results = (outs AnyRankedTensor:$result);
1731
1732  let assemblyFormat = [{
1733    $source `into` $dest `[` $indices `]`
1734      `scatter_dims` `(` $scatter_dims `)`
1735      (`unique` $unique^)?
1736      attr-dict
1737    `:` functional-type(operands, results)
1738  }];
1739
1740  let extraClassDeclaration = [{
1741    RankedTensorType getDestType() {
1742      return ::llvm::cast<RankedTensorType>(getDest().getType());
1743    }
1744    RankedTensorType getIndicesType() {
1745      return ::llvm::cast<RankedTensorType>(getIndices().getType());
1746    }
1747    RankedTensorType getSourceType() {
1748      return ::llvm::cast<RankedTensorType>(getSource().getType());
1749    }
1750    RankedTensorType getResultType() {
1751      return ::llvm::cast<RankedTensorType>(getResult().getType());
1752    }
1753  }];
1754  let hasVerifier = 1;
1755}
1756
1757//===----------------------------------------------------------------------===//
1758// SplatOp
1759//===----------------------------------------------------------------------===//
1760
1761def Tensor_SplatOp : Tensor_Op<"splat", [
1762    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1763    DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1764    Pure,
1765    TypesMatchWith<"operand type matches element type of result",
1766                   "aggregate", "input",
1767                   "::llvm::cast<TensorType>($_self).getElementType()">
1768  ]> {
1769  let summary = "tensor splat or broadcast operation";
1770  let description = [{
1771    Broadcast the operand to all elements of the result tensor. The operand is
1772    required to be of integer/index/float type.
1773
1774    An additional argument of type `index` must be provided for each dynamic
1775    dimension present in the result type.
1776
1777    Example for a statically shaped tensor:
1778
1779    ```mlir
1780    %s = arith.constant 1.0 : f32
1781    %t = tensor.splat %s : tensor<8x16xf32>
1782    ```
1783
1784    Example for a tensor containing dynamic dimensions:
1785
1786    ```mlir
1787    // Broadcasts %s to a 3D dynamically shaped tensor, with %m and %n binding
1788    // to dimensions 0 and 2 of the resulting tensor, respectively.
1789    %m = arith.constant 10 : index
1790    %n = arith.constant 30 : index
1791    %t = tensor.splat %s[%m, %n] : tensor<?x20x?xf32>
1792    ```
1793  }];
1794
1795  let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
1796                                 "integer/index/float type">:$input,
1797                       Variadic<Index>:$dynamicSizes);
1798  let results = (outs AnyRankedTensor:$aggregate);
1799
1800  let builders = [
1801    // Build with an explicit result type and a list of values corresponding
1802    // to the dynamic sizes present in the result type.
1803    OpBuilder<(ins "Value":$element,
1804                   "Type":$aggregateType,
1805                   CArg<"ValueRange", "{}">:$dynamicSizes)>,
1806
1807    // Build with a result tensor shape and a list of values corresponding to
1808    // the elements in the result tensor shape set to ShapedType::kDynamic.
1809    OpBuilder<(ins "Value":$element,
1810                   "ArrayRef<int64_t>":$staticShape,
1811                   CArg<"ValueRange", "{}">:$dynamicSizes)>,
1812
1813    // Build with mixed static/dynamic sizes, where an attribute represents
1814    // a static dimension and a value represents a dynamic dimension.
1815    OpBuilder<(ins "Value":$element, "ArrayRef<OpFoldResult>":$sizes)>
1816  ];
1817
1818  let assemblyFormat = "$input (`[` $dynamicSizes^ `]`)? attr-dict `:` type($aggregate)";
1819
1820  let hasFolder = 1;
1821  let hasVerifier = 1;
1822}
1823
1824//===----------------------------------------------------------------------===//
1825// RelayoutOp
1826//===----------------------------------------------------------------------===//
1827
1828class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
1829      Tensor_Op<mnemonic, !listconcat(traits, [
1830        DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1831        DestinationStyleOpInterface,
1832        ConditionallySpeculatable, NoMemoryEffect,
1833        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1834        TypesMatchWith<"result type matches type of dest",
1835                   "dest", "result",
1836                   "$_self">])> {
1837
1838  code commonExtraClassDeclaration = [{
1839    size_t getSourceRank() { return getSourceType().getRank(); };
1840    size_t getDestRank() { return getDestType().getRank(); };
1841    RankedTensorType getSourceType() {
1842      return ::llvm::cast<RankedTensorType>(getSource().getType()); };
1843    RankedTensorType getDestType() {
1844      return ::llvm::cast<RankedTensorType>(getDest().getType()); };
1845
1846    MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
1847
1848    /// Interface method for ConditionallySpeculatable.
1849    Speculation::Speculatability getSpeculatability();
1850
1851    /// Return a mapping from positions `inner_dims_pos` to their
1852    /// tile factors.
1853    DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
1854
1855    /// Return the tile sizes as OpFoldResult.
1856    SmallVector<OpFoldResult> getMixedTiles();
1857
1858    /// Return the tile sizes as `int64_t`. If a tile size is dynamic
1859    /// a sentinel `kDynamic` is introduced at that position in
1860    /// the returned vector.
1861    SmallVector<int64_t> getStaticTiles();
1862
1863    /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
1864    /// dims excluding the trailing dims corresponding to `innerTiles`. Note
1865    /// that this will include both tiled and non-tiled dimensions. The order
1866    /// of the output dimensions is consistent with the shape of the packed
1867    /// tensor.
1868    ArrayRef<int64_t> getAllOuterDims();
1869
1870    /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
1871    /// have been tiled. Also, the order of the output dimensions is consistent
1872    /// with `inner_dims_pos` rather than the packed tensor.
1873    SmallVector<int64_t> getTiledOuterDims();
1874  }];
1875
1876  let hasVerifier = 1;
1877}
1878
1879//===----------------------------------------------------------------------===//
1880// PackOp
1881//===----------------------------------------------------------------------===//
1882
1883def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
1884    AttrSizedOperandSegments]> {
1885  let summary = "tensor pack operation";
1886  let description = [{
1887    The "pack" operation converts a source tensor of rank `n` into a result
1888    tensor of rank `n + k` with a tiled and packed layout (maybe with padding)
1889    and optionally transposes the tiled source tensor dimensions.
1890
1891    `inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
1892    being tiled, where `0 < k <= n`. The order of the dimensions matters:
1893     - The tiled dimensions (of size `inner_tiles`) are added to the end of the result
1894    tensor in the order in which they appear in `inner_dims_pos`.
1895     - `inner_dims_pos[i]` specifies the source tensor dimension tiled by
1896    `inner_tiles[i]`.
1897
1898    `inner_tiles` (mandatory) specifies `k` tile sizes. These tile sizes
1899    correspond to the least significant ("inner") result tensor dimension sizes,
1900    in the same order. Tile sizes can be static or dynamic.
1901
1902    Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
1903    `...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
1904    by 16 and the 1st source dimension is tiled by 32. Other source dimensions
1905    (if any) are not tiled. If `inner_dims_pos = [1, 0]`, the 1st dimension is
1906    tiled by 16 and the 0th dimension is tiled by 32.
1907
1908    Example:
1909    ```mlir
1910    // NC to NCnc
1911    %0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
1912        into %dest : tensor<128x256xf32> -> tensor<16x8 x 8x32 xf32>
1913    //                                             \  /   \  /
1914    //                                       outer dims  inner dims
1915    ```
1916
1917    `outer_dims_perm` (optional) specifies a permutation for the outer
1918    dimensions. If specified, it must have `n` elements.
1919
1920    Example:
1921    ```mlir
1922    // CK to KCck
1923    %0 = tensor.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
1924        inner_tiles = [8, 32] into %dest
1925        : tensor<128x256xf32> -> tensor<8x16 x 8x32 xf32>
1926    //                                  \  /
1927    //            compare with "NC to NCnc": outer dims are transposed
1928    ```
1929
1930    `padding_value` specifies a padding value at the boundary on non-perfectly
1931    divisible dimensions. Padding is optional:
1932    - If absent, it is UB if the tile does not perfectly divide the dimension.
1933    - If present, it will pad along high dimensions (high-padding) to make the
1934      tile complete.
1935
1936    Example:
1937    ```mlir
1938    %0 = tensor.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0]
1939        inner_dims_pos = [1] inner_tiles = [2] into %arg1
1940        : tensor<200x127x256xf32> -> tensor<256x64x200x2xf32>
1941    //                 \
1942    //                padded and tiled dim
1943    //
1944    // Source dimension 1 is tiled. 64 does not divide 127 evenly, so 1 padded
1945    // element is added at the end.
1946    //
1947    // Note: Only tiled dimensions can be padded.
1948    ```
1949  }];
1950  let arguments = (ins AnyRankedTensor:$source,
1951                       AnyRankedTensor:$dest,
1952                       Optional<AnyType>:$padding_value,
1953                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
1954                       DenseI64ArrayAttr:$inner_dims_pos,
1955                       Variadic<Index>:$inner_tiles,
1956                       DenseI64ArrayAttr:$static_inner_tiles);
1957  let results = (outs AnyRankedTensor:$result);
1958  let assemblyFormat = [{
1959    $source
1960    (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
1961    (`outer_dims_perm` `=` $outer_dims_perm^)?
1962    `inner_dims_pos` `=` $inner_dims_pos
1963    `inner_tiles` `=`
1964    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
1965    `into` $dest attr-dict `:` type($source) `->` type($dest)
1966  }];
1967
1968  let builders = [
1969    OpBuilder<(ins "Value":$source, "Value":$dest,
1970      "ArrayRef<int64_t>":$innerDimsPos,
1971      "ArrayRef<OpFoldResult>":$innerTiles,
1972      CArg<"std::optional<Value>", "std::nullopt">:$paddingValue,
1973      CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
1974  ];
1975
1976  let extraClassDeclaration = commonExtraClassDeclaration # [{
1977    // Method to get the shape of the result as `SmallVector<OpFoldResult>`.
1978    // This is a static method to allow getting the shape of the destination
1979    // expected while creating a `pack` op.
1980    static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
1981        Location loc, ArrayRef<OpFoldResult> sourceDims,
1982        ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
1983        ArrayRef<int64_t> outerDimsPerm = {});
1984
1985    // Method to get the `RankedTensorType` of the result based on the inner
1986    // tiles, position of the inner tiles (innerDimsPos)  and interchange vector
1987    // of outer loops (outerDimsPerm).
1988    static RankedTensorType inferPackedType(RankedTensorType sourceType,
1989        ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
1990        ArrayRef<int64_t> outerDimsPerm = {});
1991
1992    // Returns true if we have enough static information to catch undefined
1993    // behavior when the tile size does not divide perfectly the dimension of
1994    // the input tensor. Detecting UB requires that the input size and either
1995    // corresponding tile or output size are static.
1996    static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
1997                                    ArrayRef<int64_t> innerDimsPos,
1998                                    ArrayRef<int64_t> outputShape,
1999                                    ArrayRef<int64_t> outerDimsPerm,
2000                                    ArrayRef<OpFoldResult> innerTiles);
2001
2002    static Value createDestinationTensor(OpBuilder &b, Location loc,
2003        Value source, ArrayRef<OpFoldResult> innerTileSizes,
2004        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
2005
2006    /// Build and return a new PackOp that is a clone of the current PackOp with
2007    /// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
2008    /// innerPermutation (resp. outerPermutation).
2009    /// A new `tensor.empty` of the proper shape is built in the process.
2010    /// Asserts that:
2011    ///   - At least one of innerPermutation or outerPermutation is non-empty.
2012    ///   - If not empty, innerPermutation is a valid permutation of size
2013    ///     matching innerDimPos.
2014    ///   - If not empty, outerPermutation is a valid permutation of size
2015    ///     matching outerDimsPerm.
2016    PackOp createTransposedClone(OpBuilder &b,
2017                                 Location loc,
2018                                 ArrayRef<int64_t> innerPermutation,
2019                                 ArrayRef<int64_t> outerPermutation);
2020
2021    /// Check if this PackOp is like a simple pad operation.
2022    /// In other words, this operation:
2023    /// 1. adds useless dimensions (dimension of size 1),
2024    /// 2. pads the other ones, and
2025    /// 3. doesn't shuffle the dimensions
2026    bool isLikePad();
2027  }];
2028
2029  let hasCanonicalizeMethod = 1;
2030
2031  let hasFolder = 1;
2032}
2033
2034//===----------------------------------------------------------------------===//
2035// UnPackOp
2036//===----------------------------------------------------------------------===//
2037
2038def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
2039  let summary = "tensor unpack operation";
2040  let description = [{
2041    The "unpack" operation converts a source tensor of rank `n` with a tiled and
2042    packed layout to a result tensor of rank `n - k`.
2043
2044    `inner_dims_pos` (mandatory) specifies `k` source tensor dimensions with
2045    which the last `k` source tensor dimensions are combined, where
2046    `0 < k <= n/2`. Each `inner_dims_pos` element must be `>= 0` and `< n - k`.
2047    The order of the dimensions in `inner_dims_pos` matters: dimension
2048    `inner_dims_pos[i]` is combined with dimension `n - k + i` (assuming that
2049    `outer_dims_perm` is not specified).
2050
2051    `inner_tiles` (mandatory) specifies `k` tile sizes. These tile sizes
2052    correspond to the least significant ("inner") source tensor dimension sizes.
2053    The behavior of this op is undefined if:
2054    - `inner_tiles` do not exactly match with the corresponding source tensor
2055      dimension sizes.
2056    - Or, `inner_tiles[i]` does not divide the size of dimension
2057      `inner_dims_pos[i]` (assuming that `outer_dims_perm` is not specified)
2058      evenly.
2059
2060    `outer_dims_perm` (optional) specifies a permutation for the outer
2061    dimensions. If specified, it must have `n - k` elements. If specified, this
2062    permutation is applied before combining any dimensions.
2063
2064    Example:
2065
2066    ```mlir
2067    // NCnc to NC:
2068    %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
2069        into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
2070
2071    // CK to KCck:
2072    %0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
2073        inner_tiles = [8, 32] into %dest
2074        : tensor<8x16x8x32xf32> -> tensor<128x256xf32>
2075    ```
2076  }];
2077  let arguments = (ins AnyRankedTensor:$source,
2078                       AnyRankedTensor:$dest,
2079                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
2080                       DenseI64ArrayAttr:$inner_dims_pos,
2081                       Variadic<Index>:$inner_tiles,
2082                       DenseI64ArrayAttr:$static_inner_tiles);
2083  let results = (outs AnyRankedTensor:$result);
2084  let assemblyFormat = [{
2085    $source
2086    (`outer_dims_perm` `=` $outer_dims_perm^)?
2087    `inner_dims_pos` `=` $inner_dims_pos
2088    `inner_tiles` `=`
2089    custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
2090    `into` $dest attr-dict `:` type($source) `->` type($dest)
2091  }];
2092
2093  let builders = [
2094    OpBuilder<(ins "Value":$source, "Value":$dest,
2095    "ArrayRef<int64_t>":$innerDimsPos,
2096    "ArrayRef<OpFoldResult>":$innerTiles,
2097    CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
2098  ];
2099
2100  let extraClassDeclaration = commonExtraClassDeclaration # [{
2101    static Value createDestinationTensor(OpBuilder &b, Location loc,
2102        Value source, ArrayRef<OpFoldResult> innerTileSizes,
2103        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
2104
2105    /// Build and return a new UnPackOp that is a clone of the current UnPackOp
2106    /// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
2107    /// innerPermutation (resp. outerPermutation).
2108    /// Asserts that:
2109    ///   - At least one of innerPermutation or outerPermutation is non-empty.
2110    ///   - If not empty, innerPermutation is a valid permutation of size
2111    ///     matching innerDimPos.
2112    ///   - If not empty, outerPermutation is a valid permutation of size
2113    ///     matching outerDimsPerm.
2114    UnPackOp createTransposedClone(OpBuilder &b,
2115                                   Location loc,
2116                                   Value transposedSource,
2117                                   ArrayRef<int64_t> innerPermutation,
2118                                   ArrayRef<int64_t> outerPermutation);
2119
2120    /// Check if this UnPackOp is like a simple unpad operation.
2121    /// In other words, this operation:
2122    /// 1. drops useless dimensions (dimension of size 1), and
2123    /// 2. reduces dimensions in place (i.e., no transpose.)
2124    bool isLikeUnPad();
2125  }];
2126
2127  let hasCanonicalizeMethod = 1;
2128
2129  let hasFolder = 1;
2130}
2131
2132//===----------------------------------------------------------------------===//
2133// YieldOp
2134//===----------------------------------------------------------------------===//
2135
2136def Tensor_YieldOp : Tensor_Op<"yield",
2137    [Pure, ReturnLike, Terminator,
2138     HasParent<"::mlir::tensor::GenerateOp, ::mlir::tensor::PadOp">]> {
2139  let summary = "Yield a value from a region";
2140  let description = [{
2141     This operation is used to yield a single value from a within a region. It
2142     is used to create dynamically sized tensors
2143     (see `tensor.generate` and `tensor.pad` ops).
2144  }];
2145
2146  let arguments = (ins AnyType:$value);
2147  let assemblyFormat = "$value attr-dict `:` type($value)";
2148
2149  // Dummy builder to appease code in templated ensureTerminator that
2150  // GenerateOp's auto-generated parser calls.
2151  let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
2152}
2153
2154#endif // TENSOR_OPS
2155