xref: /llvm-project/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td (revision 8487d2460e8cf80c7c3b240cf46969eeeb4ed18d)
1//===- ShapeOps.td - Shape operations definition -----------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the operation definition file for Shape dialect operations.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef SHAPE_OPS
14#define SHAPE_OPS
15
16include "mlir/Dialect/Shape/IR/ShapeBase.td"
17include "mlir/Interfaces/CallInterfaces.td"
18include "mlir/Interfaces/CastInterfaces.td"
19include "mlir/Interfaces/ControlFlowInterfaces.td"
20include "mlir/Interfaces/InferTypeOpInterface.td"
21include "mlir/Interfaces/SideEffectInterfaces.td"
22include "mlir/IR/OpAsmInterface.td"
23include "mlir/Interfaces/FunctionInterfaces.td"
24include "mlir/IR/SymbolInterfaces.td"
25
26//===----------------------------------------------------------------------===//
27// Shape op definitions
28//===----------------------------------------------------------------------===//
29
30// Base class for the operation in this dialect
31class Shape_Op<string mnemonic, list<Trait> traits = []> :
32    Op<ShapeDialect, mnemonic, traits>;
33
34def Shape_AddOp : Shape_Op<"add",
35    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
36  let summary = "Addition of sizes and indices";
37  let description = [{
38    Adds two sizes or indices. If either operand is an error it will be
39    propagated to the result. The operands can be of type `size` or `index`. If
40    at least one of the operands can hold an error, i.e. if it is of type
41    `size`, the result must be of type `size`. If error propagation is not
42    possible because both operands are of type `index` then the result may be
43    of type `size` or `index`.
44  }];
45
46  let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
47  let results = (outs Shape_SizeOrIndexType:$result);
48
49  let assemblyFormat = [{
50    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
51  }];
52
53  let hasFolder = 1;
54  let hasVerifier = 1;
55}
56
57def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, Pure]> {
58  let summary = "Returns the broadcasted output shape of two or more inputs";
59  let description = [{
60    Returns the broadcasted shape for input shapes or extent tensors. The rest
61    of this description is simplified for the 2 input case but can be extended
62    to more inputs. Both operands can be of type `shape.shape` or
63    `tensor<?xindex>`. The result is of type `shape.shape` and, if both
64    operands are tensors, may be of type `tensor<?xindex>`.
65
66    If the two operand shapes are of different rank the smaller one is padded
67    with 1's from the left. The resulting broadcasted shape is then defined as
68
69        result[i] = lhs[i] if lhs[i] == rhs[i]
70                  = lhs[i] if rhs[i] == 1
71                  = rhs[i] if lhs[i] == 1.
72
73    In case the resulting shape is undefined, i.e. if corresponding extents are
74    different from each other but none is 1, the result is an error shape.
75    Likewise error values are propagated if any of the operands holds an error
76    value. If the result type is an extent tensor (and can therefore not hold
77    the error value) the behavior may be undefined. The optional string
78    attribute can be used to describe the error case.
79  }];
80
81  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes,
82                       OptionalAttr<StrAttr>:$error);
83  let results = (outs Shape_ShapeOrExtentTensorType:$result);
84
85  let builders = [OpBuilder<(ins "Value":$shape)>];
86
87  let assemblyFormat = [{
88    $shapes attr-dict `:` type($shapes) `->` type($result)
89  }];
90
91  let builders = [OpBuilder<(ins "::mlir::Type":$result,
92                                "::mlir::Value":$lhs, "::mlir::Value":$rhs,
93                                "/*optional*/ ::mlir::StringAttr":$error), [{
94      build($_builder, $_state, result, ::llvm::ArrayRef({lhs, rhs}),
95        error);
96    }]>
97  ];
98
99  let hasFolder = 1;
100  let hasCanonicalizer = 1;
101  let hasVerifier = 1;
102}
103
104def Shape_ConstShapeOp : Shape_Op<"const_shape",
105    [ConstantLike, Pure, InferTypeOpAdaptorWithIsCompatible]> {
106  let summary = "Creates a constant shape or extent tensor";
107  let description = [{
108    Creates a constant shape or extent tensor. The individual extents are given
109    as the `shape` attribute. The number of these values equals the shape's
110    rank.
111
112    ```mlir
113    %0 = shape.const_shape [] : !shape.shape
114    %1 = shape.const_shape [1, 2, 3] : !shape.shape
115    %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
116    ```
117  }];
118  let arguments = (ins IndexElementsAttr:$shape);
119  let results = (outs Shape_ShapeOrExtentTensorType:$result);
120
121  let hasCustomAssemblyFormat = 1;
122  let hasFolder = 1;
123  let hasCanonicalizer = 1;
124}
125
126def Shape_ConstSizeOp : Shape_Op<"const_size", [
127    ConstantLike,
128    Pure,
129    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
130  ]> {
131  let summary = "Creates a constant of type `shape.size`";
132  let description = [{
133    Creates a `shape.size` type representing the constant size given by `value`.
134
135    ```mlir
136    %x = shape.const_size 10
137    ```
138  }];
139
140  let arguments = (ins IndexAttr:$value);
141  let results = (outs Shape_SizeType:$result);
142
143  let builders = [OpBuilder<(ins "int64_t":$value)>];
144
145  let assemblyFormat = "$value attr-dict";
146  let hasFolder = 1;
147}
148
149def Shape_DivOp : Shape_Op<"div", [Pure, InferTypeOpAdaptorWithIsCompatible]> {
150  let summary = "Division of sizes and indices";
151  let description = [{
152    Divides two sizes or indices. If either operand is an error it will be
153    propagated to the result. The operands can be of type `size` or `index`.
154    If at least one of the operands can hold an error, i.e. if it is of type
155    `size`, the result must be of type `size`. If error propagation is not
156    possible because both operands are of type `index` then the result may be
157    of type  `size` or `index`. If both operands and result are of type
158    `index`, their runtime values could be negative. The result is rounded
159    toward negative infinity, i.e. floor(lhs / rhs), such that
160
161        div(lhs, rhs) * rhs + mod(lhs, rhs) = lhs
162
163    always holds. If any of the values is of type `size`, the behavior for
164    negative value is undefined.
165  }];
166
167  let arguments = (ins Shape_SizeOrIndexType:$lhs,
168                       Shape_SizeOrIndexType:$rhs);
169  let results = (outs Shape_SizeOrIndexType:$result);
170
171  let assemblyFormat = [{
172    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
173  }];
174
175  let hasFolder = 1;
176  let hasVerifier = 1;
177}
178
179def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Pure, Commutative]> {
180  let summary = "Returns whether the input shapes or extent tensors are equal";
181  let description = [{
182    Takes one or more shape or extent tensor operands and determines whether
183    they are equal. When extent tensors are compared to shapes they are
184    regarded as their equivalent non-error shapes. Error shapes can be tested
185    for equality like any other shape value, meaning that the error value is
186    equal to itself.
187  }];
188
189  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
190  let results = (outs I1:$result);
191
192  // Convenience builder alias for the binary version.
193  let builders = [
194  OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
195    [{ build($_builder, $_state, ::llvm::ArrayRef({lhs, rhs})); }]>,
196  ];
197
198  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
199  let hasFolder = 1;
200}
201
202def Shape_FromExtentsOp : Shape_Op<"from_extents", [Pure]> {
203  let summary = "Creates a shape from extents";
204  let description = [{
205    Creates a shape from multiple SSA values representing the extents of
206    the shape.
207
208    ```mlir
209    // Rank 2 shape.
210    %s0 = shape.from_extents %a, %b
211    // Rank 0 shape.
212    %s1 = shape.from_extents
213    ```
214  }];
215  let arguments = (ins Variadic<Shape_SizeOrIndexType>:$extents);
216  let results = (outs Shape_ShapeType:$shape);
217
218  let assemblyFormat = "$extents attr-dict `:` type($extents)";
219
220  let hasFolder = 1;
221}
222
223def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [Pure]> {
224  let summary = "Creates a shape from a tensor of extents";
225  let description = [{
226    Creates a shape from a 1D integral tensor of extents. The rank of the
227    resulting shape equals the number of elements in the tensor, and the
228    extents match the values of the elements.
229  }];
230
231  let arguments = (ins 1DTensorOf<[Index]>:$input);
232  let results = (outs Shape_ShapeType:$result);
233
234  let assemblyFormat = "$input attr-dict `:` type($input)";
235}
236
237def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
238  let summary = "Determines if 2+ shapes can be successfully broadcasted";
239  let description = [{
240    Given multiple input shapes or extent tensors, return a predicate
241    specifying if they are broadcastable. This broadcastable follows the same
242    logic as what shape.broadcast documents.
243
244    Concretely, shape.is_broadcastable returning true implies that
245    shape.broadcast will not give an error, and shape.cstr_broadcastable will
246    not result in an assertion failure. Similarly, false implies an error or
247    assertion failure.
248
249    Example:
250    ```mlir
251    %true = shape.is_broadcastable [2,2], [3,1,2]
252    %false = shape.is_broadcastable [2,2], [3,2]
253    ```
254  }];
255
256  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
257  let results = (outs I1:$result);
258
259  let builders = [
260  OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
261    [{ build($_builder, $_state, ::llvm::ArrayRef({lhs, rhs})); }]>,
262  ];
263
264  let hasFolder = 1;
265  let hasCanonicalizer = 1;
266
267  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
268}
269
270def Shape_RankOp : Shape_Op<"rank",
271    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
272  let summary = "Gets the rank of a shape";
273  let description = [{
274    Returns the rank of the shape or extent tensor, i.e. the number of extents.
275  }];
276
277  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
278  let results = (outs Shape_SizeOrIndexType:$rank);
279
280  let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($rank)";
281
282  let hasFolder = 1;
283  let hasCanonicalizer = 1;
284  let hasVerifier = 1;
285}
286
287def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
288    DeclareOpInterfaceMethods<CastOpInterface>, Pure
289  ]> {
290  let summary = "Creates a dimension tensor from a shape";
291  let description = [{
292    Converts a shape to a 1D integral tensor of extents. The number of elements
293    in the tensor equals the rank of the shape, and the elements equal the
294    extents of the shape.
295
296    If the shape represents an error, this op's behavior is undefined.
297  }];
298
299  let arguments = (ins Shape_ShapeOrExtentTensorType:$input);
300  let results = (outs IndexTensor:$result);
301
302  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
303
304  let hasFolder = 1;
305}
306
307def Shape_DimOp : Shape_Op<"dim",
308    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
309  let summary = "Gets the specified extent from the shape of a shaped input";
310  let description = [{
311    Gets the extent indexed by `dim` from the shape of the `value` operand. If
312    the index is error or out-of-bound then it returns an invalid size if the
313    return type carries error information else the behavior is undefined.
314
315    This is a convenience op that performs the equivalent of getting the extent
316    of a shape (e.g., `dim(x, i) == get_extent(shape_of(x), i)`).
317  }];
318  let arguments = (ins AnyShaped:$value,
319                       Shape_SizeOrIndexType:$index);
320  let results = (outs Shape_SizeOrIndexType:$extent);
321  let assemblyFormat = "$value `,` $index attr-dict `:` type($value) `,`"
322                       "type($index) `->` type($extent)";
323
324  let extraClassDeclaration = [{
325    /// Get the `index` value as integer if it is constant.
326    std::optional<int64_t> getConstantIndex();
327  }];
328
329  let hasFolder = 1;
330}
331
332def Shape_GetExtentOp : Shape_Op<"get_extent",
333    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
334  let summary = "Gets the specified extent from a shape or extent tensor";
335  let description = [{
336    Gets the extent indexed by `dim` from the `shape` operand. If the shape is
337    an error then it returns an invalid size.
338  }];
339  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
340                       Shape_SizeOrIndexType:$dim);
341  let results = (outs Shape_SizeOrIndexType:$extent);
342  let assemblyFormat = "$shape `,` $dim attr-dict `:` type($shape) `,` "
343                       "type($dim) `->` type($extent)";
344
345  let builders = [
346    // Builder that allows passing a constant dimension as a simple integer.
347    OpBuilder<(ins "Value":$shape, "int64_t":$dim)>
348  ];
349
350  let extraClassDeclaration = [{
351    /// Get the `dim` value as integer if it is constant.
352    std::optional<int64_t> getConstantDim();
353  }];
354
355  let hasFolder = 1;
356  let hasVerifier = 1;
357}
358
359def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [Pure]> {
360  let summary = "Converts a standard index to a shape size";
361  let description = [{
362    Converts a standard index to a `shape.size`. This operation and its
363    inverse, `size_to_index`, facilitate index conversion between the standard
364    and the shape dialect.
365
366    The behavior is undefined for negative indices.
367  }];
368
369  let arguments = (ins Index:$arg);
370  let results = (outs Shape_SizeType:$result);
371
372  let assemblyFormat = "$arg attr-dict";
373
374  let hasFolder = 1;
375  let hasCanonicalizer = 1;
376}
377
378def Shape_MaxOp : Shape_Op<"max",
379    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
380  let summary = "Elementwise maximum";
381  let description = [{
382    Computes the elementwise maximum of two sizes or shapes with equal ranks.
383    If either operand is an error, then an error will be propagated to the
384    result. If the input types mismatch or the ranks do not match, then the
385    result is an error.
386  }];
387
388  let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
389  let results = (outs Shape_ShapeOrSizeType:$result);
390
391  let assemblyFormat = [{
392    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
393  }];
394
395  let hasFolder = 1;
396}
397
398def Shape_MeetOp : Shape_Op<"meet",
399    [Commutative, InferTypeOpAdaptorWithIsCompatible]> {
400  let summary = "Returns the least general shape or size of its operands";
401  let description = [{
402    An operation that computes the least general shape or dim of input operands.
403    This effectively asserts that corresponding static dimensions are equal.
404    The behavior is to match each element of the shape/size and propagate the
405    most restrictive information, returning an invalid shape if there are
406    contradictory requirements. E.g., using pseudo code
407
408    ```
409    shape.meet([*], [*]) -> [*]
410    shape.meet([*], [1, ?]) -> [1, ?]
411    shape.meet([1, 2], [1, ?]) -> [1, 2]
412    shape.meet([*], [1, 2]) -> [1, 2]
413    shape.meet([], []) -> []
414    shape.meet([], [*]) -> []
415    shape.meet([], [?, ?]) -> [invalid]
416    shape.meet([1, ?], [2, ?, ?]) -> [invalid]
417    ```
418
419    `shape.meet` also allows specifying an optional error string, that may be
420    used to return an error to the user upon mismatch of dimensions.
421
422    ```mlir
423    %c = shape.meet %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
424    ```
425  }];
426
427  let arguments = (ins
428    Shape_AnyShapeOrSizeType:$arg0,
429    Shape_AnyShapeOrSizeType:$arg1,
430    OptionalAttr<StrAttr>:$error);
431  let results = (outs Shape_AnyShapeOrSizeType:$result);
432
433  let assemblyFormat = [{
434    $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
435      type($arg0) `,` type($arg1) `->` type($result)
436  }];
437}
438
439def Shape_MinOp : Shape_Op<"min",
440    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
441  let summary = "Elementwise minimum";
442  let description = [{
443    Computes the elementwise minimum of two sizes or shapes with equal ranks.
444    If either operand is an error, then an error will be propagated to the
445    result. If the input types mismatch or the ranks do not match, then the
446    result is an error.
447  }];
448
449  let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
450  let results = (outs Shape_ShapeOrSizeType:$result);
451
452  let assemblyFormat = [{
453    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
454  }];
455
456  let hasFolder = 1;
457}
458
459def Shape_MulOp : Shape_Op<"mul",
460    [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
461  let summary = "Multiplication of sizes and indices";
462  let description = [{
463    Multiplies two sizes or indices. If either operand is an error it will be
464    propagated to the result. The operands can be of type `size` or `index`. If
465    at least one of the operands can hold an error, i.e. if it is of type
466    `size`, the result must be of type `size`. If error propagation is not
467    possible because both operands are of type `index` then the result may be
468    of type `size` or `index`.
469  }];
470
471  let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs);
472  let results = (outs Shape_SizeOrIndexType:$result);
473
474  let assemblyFormat = [{
475    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
476  }];
477
478  let hasFolder = 1;
479  let hasVerifier = 1;
480}
481
482def Shape_NumElementsOp : Shape_Op<"num_elements",
483    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
484  let summary = "Returns the number of elements for a given shape";
485  let description = [{
486    Returns the number of elements for a given shape which is the product of
487    its extents. If the argument is of type `shape` then the result will be of
488    type `size` and potential errors will be propagated. Otherwise, if the
489    argument is and extent tensor `tensor<?xindex>` then the result will be of
490    type `index`.
491  }];
492
493  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
494  let results = (outs Shape_SizeOrIndexType:$result);
495
496  let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
497
498  let hasFolder = 1;
499  let hasVerifier = 1;
500}
501
502def Shape_ReduceOp : Shape_Op<"reduce",
503    [SingleBlockImplicitTerminator<"YieldOp">]> {
504  let summary = "Returns an expression reduced over a shape or extent tensor";
505  let description = [{
506    An operation that takes as input a shape or extent tensor, and a number of
507    initial values. This operation has a region that is applied repeatedly for
508    every extent of the input. Starting with the initial values, the individual
509    extents are then aggregated as defined by the associated region.
510
511    Conceptually this op performs the following reduction:
512
513    ```
514    res[] = init;
515    for (int i = 0, i < shape.rank(); i++) {
516      res = reduce(i, shape[i], res[0], ..., res[n]);
517    }
518    ```
519
520    Where `reduce` represents the region attached and the result of the reduce
521    op is the last computed output of the reduce region. As an example, the
522    number of elements can be computed as follows:
523
524    ```mlir
525    func.func @reduce(%shape : !shape.shape, %init : !shape.size) ->
526        !shape.size {
527      %num_elements = shape.reduce(%shape, %init) -> !shape.size  {
528        ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
529          %updated_acc = "shape.mul"(%acc, %dim) :
530            (!shape.size, !shape.size) -> !shape.size
531          shape.yield %updated_acc : !shape.size
532      }
533      return %num_elements : !shape.size
534    }
535    ```
536  }];
537
538  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
539                       Variadic<AnyType>:$initVals);
540  let results = (outs Variadic<AnyType>:$result);
541  let regions = (region SizedRegion<1>:$region);
542
543  let builders = [OpBuilder<(ins "Value":$shape, "ValueRange":$initVals)>];
544
545  let hasCustomAssemblyFormat = 1;
546  let hasVerifier = 1;
547}
548
549def Shape_ShapeOfOp : Shape_Op<"shape_of",
550    [Pure, InferTypeOpAdaptorWithIsCompatible]> {
551  let summary = "Returns shape of a value or shaped type operand";
552
553  let description = [{
554    The operation takes a value or a shaped operand as an argument and it
555    returns a shape or extent tensor.
556  }];
557
558  let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
559  let results = (outs Shape_ShapeOrExtentTensorType:$result);
560
561  let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
562
563  let hasCanonicalizer = 1;
564  let hasVerifier = 1;
565}
566
567def Shape_ValueOfOp : Shape_Op<"value_of", [Pure]> {
568  let summary = "Returns value of a !shape.value_shape operand";
569
570   let description = [{
571    The operation takes !shape.value_shape, a.k.a. (value, shape) tuple as an
572    argument, and returns its value. The behavior is undefined for unknown and
573    invalid arguments.
574  }];
575
576  let arguments = (ins Shape_ValueShapeType:$arg);
577  let results = (outs AnyShaped:$result);
578
579  let assemblyFormat = "$arg attr-dict `:` type($result)";
580}
581
582def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
583    DeclareOpInterfaceMethods<CastOpInterface>, Pure
584  ]> {
585  let summary = "Casts between index types of the shape and standard dialect";
586  let description = [{
587    Converts a `shape.size` to a standard index. This operation and its
588    inverse, `index_to_size`, facilitate index conversion between the standard
589    and the shape dialect. The behavior is undefined for unknown and invalid
590    arguments.
591  }];
592
593  let arguments = (ins Shape_SizeOrIndexType:$arg);
594  let results = (outs Index:$result);
595
596  let assemblyFormat = "$arg attr-dict `:` type($arg)";
597
598  let hasFolder = 1;
599  let hasCanonicalizer = 1;
600}
601
602def Shape_ValueAsShapeOp : Shape_Op<"value_as_shape", [Pure]> {
603  let summary = "Returns value as a shape";
604
605  let description = [{
606    The operations takes a ValueShape and returns a Shape corresponding to the
607    value.  If the input value cannot be shape (e.g., not a 1D tensor of
608    integral value representing sizes) then this propagages the error shape.
609    E.g.,
610
611    ```mlir
612    // The following
613    %0 = arith.constant dense<[1,2]> : tensor<2xi32>
614    %shape = shape.value_as_shape %0 : tensor<2xi32> -> !shape.shape
615    // is equivalent to
616    %shape' = shape.const_shape [1, 2] : !shape.shape
617    ```
618
619    This operation is the complement of `shape_of` wrt ValueShape values.
620  }];
621
622  let arguments = (ins AnyTypeOf<[1DTensorOf<[AnyInteger, Index]>,
623                       Shape_ValueShapeType]>:$arg);
624  let results = (outs Shape_ShapeOrExtentTensorType:$result);
625
626  let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
627}
628
629def Shape_WithOp : Shape_Op<"with_shape", [Pure]> {
630  let summary = "Returns ValueShape with given shape";
631  let description = [{
632    Returns ValueShape with the shape updated to match the shape operand. That
633    is a new ValueShape tuple is created with value equal to `operand`'s
634    value and shape equal to `shape`. If the ValueShape and given `shape` are
635    non-conformant, then the returned ValueShape will represent an error of
636    this mismatch. Similarly if either inputs are in an error state, then an
637    error is propagated.
638
639    Usage:
640      %0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape
641
642    This is used, for example, where one combines shape function calculations
643    and/or call one shape function from another. E.g.,
644
645    ```mlir
646    func.func @shape_foobah(%a: !shape.value_shape,
647                       %b: !shape.value_shape,
648                       %c: !shape.value_shape) -> !shape.shape {
649      %0 = call @shape_foo(%a, %b) :
650        (!shape.value_shape, !shape.value_shape) -> !shape.shape
651      %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
652      %2 = call @shape_bah(%c, %1) :
653        (!shape.value_shape, !shape.value_shape) -> !shape.shape
654      return %2 : !shape.shape
655    }
656    ```
657
658    This op need not be a refinement of the shape. In non-error cases the input
659    ValueShape's value and shape are conformant and so too for the output, but
660    the result may be less specified than `operand`'s shape as `shape` is
661    merely used to construct the new ValueShape. If join behavior is desired
662    then a join op should be used.
663  }];
664
665  let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand,
666                       Shape_ShapeOrExtentTensorType:$shape);
667  let results = (outs Shape_ValueShapeType:$result);
668
669  let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)";
670}
671
672def Shape_YieldOp : Shape_Op<"yield",
673    [HasParent<"ReduceOp, FunctionLibraryOp">,
674     Pure,
675     ReturnLike,
676     Terminator]> {
677  let summary = "Returns the value to parent op";
678
679  let arguments = (ins Variadic<AnyType>:$operands);
680
681  let builders = [OpBuilder<(ins),
682    [{ build($_builder, $_state, std::nullopt); }]>
683  ];
684
685  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
686  let hasVerifier = 1;
687}
688
689// TODO: Add Ops: if_static, if_ranked
690
691// For testing usage.
692def Shape_DebugPrintOp : Shape_Op<"debug_print", []> {
693  let summary = "Prints the input shape or size";
694  let description = [{
695    Prints the input dim or shape and passes through input.
696
697    Note: This is intended for testing and debugging only.
698  }];
699
700  let arguments = (ins Shape_ShapeOrSizeType:$input);
701  let results =  (outs Shape_ShapeOrSizeType:$output);
702}
703
704def Shape_SplitAtOp : Shape_Op<"split_at", [Pure]> {
705  let summary = "Splits a shape at a given index";
706  let description = [{
707    Splits a shape at a given dimension `index`, returning two shapes. If
708    `index` is negative, it is treated as indexing from the back of the shape.
709    This negative-handling behavior is important when handling unranked shapes,
710    where the positive index is not necessarily knowable due to a dynamic
711    number of leading dimensions. If the result is in extent tensor form out of
712    bounds indices result in undefined behavior.
713
714    Examples:
715    - split_at([4,5,6], index=0) -> [], [4,5,6]
716    - split_at([4,5,6], index=1) -> [4], [5,6]
717    - split_at([4,5,6], index=2) -> [4,5], [6]
718    - split_at([4,5,6], index=3) -> [4,5,6], []
719    - split_at([4,5,6], index=4) -> error
720    - split_at([4,5,6], index=-1) -> [4,5], [6]
721    - split_at([4,5,6], index=-2) -> [4], [5,6]
722    - split_at([4,5,6], index=-3) -> [], [4,5,6]
723    - split_at([4,5,6], index=-4) -> error
724
725    Requires:
726    - `index` is in the range [-rank(operand),rank(operand)]
727  }];
728
729  let arguments = (ins Shape_ShapeOrExtentTensorType:$operand,
730                       Shape_SizeOrIndexType:$index);
731  let results = (outs Shape_ShapeOrExtentTensorType:$head,
732                      Shape_ShapeOrExtentTensorType:$tail);
733  let hasFolder = 1;
734}
735
736def Shape_ConcatOp : Shape_Op<"concat", [Pure]> {
737  let summary = "Concatenates two shapes";
738  let description = [{
739    Creates a shape whose dimensions consist of first the dimensions from `lhs`
740    followed by the dimensions of `rhs`.
741
742    Example:
743    concat([2,3], [4,5]) -> [2,3,4,5]
744    concat([], []) -> []
745    concat([], [4,5,6]) -> [4,5,6]
746  }];
747
748  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
749                       Shape_ShapeOrExtentTensorType:$rhs);
750  let results = (outs Shape_ShapeOrExtentTensorType:$result);
751
752  let assemblyFormat = [{
753    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
754  }];
755
756  let hasFolder = 1;
757}
758
759//===----------------------------------------------------------------------===//
760// Shape constraint related ops.
761//===----------------------------------------------------------------------===//
762
763// TODO: Move the code below and witnesses to a different file.
764def Shape_AnyOp : Shape_Op<"any", [Commutative,
765                                   Pure]> {
766  let summary = "Return any combination of the input shapes";
767  let description = [{
768    This operation takes multiple input shapes or extent tensors and returns
769    some combination of their dimensions. This can be best seen with examples
770    below.
771
772    The result is undefined, but still side-effect free, in cases where the
773    inputs have differing ranks or differ in extents of shared dimensions.
774
775    Example:
776    ```mlir
777    %s0 = shape.any [2,?], [?,3] // [2,3]
778    %s1 = shape.any [?,?], [1,2] // [1,2]
779    ```
780  }];
781
782  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
783  let results = (outs Shape_ShapeOrExtentTensorType:$result);
784
785  let assemblyFormat = "$inputs attr-dict `:` type($inputs) `->` type($result)";
786
787  let hasFolder = 1;
788}
789
790def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, Pure]> {
791  let summary = "Return a logical AND of all witnesses";
792  let description = [{
793    Used to simplify constraints as any single failing precondition is enough
794    to prevent execution.
795
796    "assuming" operations represent an execution order restriction to the
797    compiler, information for dependent code to rely on (by assuming), and
798    nothing else. They should not exist after a program is fully lowered and
799    ready to execute.
800
801    Example:
802    ```mlir
803    %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
804    %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
805    %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
806    %wf = shape.assuming_all %w0, %w1 // Failure
807    %wt = shape.assuming_all %w0, %w2 // Passing
808    ```
809  }];
810
811  let arguments = (ins Variadic<Shape_WitnessType>:$inputs);
812  let results = (outs Shape_WitnessType:$result);
813
814  let assemblyFormat = "$inputs attr-dict";
815
816  let hasFolder = 1;
817  let hasCanonicalizer = 1;
818  let hasVerifier = 1;
819}
820
821def Shape_AssumingOp : Shape_Op<"assuming", [
822    SingleBlockImplicitTerminator<"AssumingYieldOp">,
823    DeclareOpInterfaceMethods<RegionBranchOpInterface>,
824    RecursiveMemoryEffects]> {
825  let summary = "Execute the region";
826  let description = [{
827    Executes the region assuming all witnesses are true.
828
829    "assuming" operations represent an execution order restriction to the
830    compiler, information for dependent code to rely on (by assuming), and
831    nothing else. They should not exist after a program is fully lowered and
832    ready to execute.
833  }];
834  let arguments = (ins Shape_WitnessType:$witness);
835  let regions = (region SizedRegion<1>:$doRegion);
836  let results = (outs Variadic<AnyType>:$results);
837
838  let extraClassDeclaration = [{
839    // Inline the region into the region containing the AssumingOp and delete
840    // the AssumingOp.
841    //
842    // This does no checks on the inputs to the AssumingOp.
843    static void inlineRegionIntoParent(AssumingOp &op,
844      PatternRewriter &rewriter);
845  }];
846
847  let builders = [
848    OpBuilder<(ins "Value":$witness,
849        CArg<"function_ref<SmallVector<Value, 2>(OpBuilder &, Location)>">)>
850  ];
851
852  let hasCanonicalizer = 1;
853  let hasCustomAssemblyFormat = 1;
854}
855
856def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
857       [Pure, ReturnLike, Terminator, HasParent<"AssumingOp">]> {
858  let summary = "Yield operation";
859  let description = [{
860    This yield operation represents a return operation within the
861    `shape.assuming` operation region. The operation takes variable number of
862    operands and produces no results. The operand number and types must match
863    the number and types of parent `shape.assuming` results.
864  }];
865
866  let arguments = (ins Variadic<AnyType>:$operands);
867
868  let builders = [
869    OpBuilder<(ins), [{ /* nothing to do */ }]>,
870  ];
871
872  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
873}
874
875def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
876  let summary = "Determines if 2+ shapes can be successfully broadcasted";
877  let description = [{
878    Given input shapes or extent tensors, return a witness specifying if they
879    are broadcastable. This broadcastable follows the same logic as what
880    shape.broadcast documents.
881
882    "cstr" operations represent runtime assertions.
883
884    Example:
885    ```mlir
886    %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
887    %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
888    ```
889  }];
890
891  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
892  let results = (outs Shape_WitnessType:$result);
893
894  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
895
896  let builders = [
897  OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
898    [{ build($_builder, $_state, ::llvm::ArrayRef({lhs, rhs})); }]>,
899  ];
900
901  let hasCanonicalizer = 1;
902  let hasFolder = 1;
903  let hasVerifier = 1;
904}
905
906def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
907  let summary = "Determines if all input shapes are equal";
908  let description = [{
909    Given 1 or more input shapes, determine if all shapes are the exact same.
910
911    "cstr" operations represent runtime assertions.
912
913    Example:
914    ```mlir
915    %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
916    %w1 = shape.cstr_eq [2,2], [1,2] // Failure
917    ```
918  }];
919  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
920  let results = (outs Shape_WitnessType:$result);
921
922  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
923
924  let hasCanonicalizer = 1;
925  let hasFolder = 1;
926}
927
928def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, Pure]> {
929  let summary = "An operation that returns a statically known witness value";
930  let description = [{
931  This operation represents a statically known witness result. This can be
932  often used to canonicalize/fold constraint and assuming code that will always
933  pass.
934
935  ```mlir
936  %0 = shape.const_shape [1,2,3]
937  %1 = shape.const_shape [1,2,3]
938  %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true"
939  %w1 = shape.const_witness true
940  %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true"
941  ```
942  }];
943  let arguments = (ins BoolAttr:$passing);
944  let results = (outs Shape_WitnessType:$result);
945
946  let assemblyFormat = "$passing attr-dict";
947
948  let hasFolder = 1;
949}
950
951def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
952  let summary = "Represents a runtime assertion that an i1 is `true`";
953  let description = [{
954    Represents a runtime assertion that an i1 is true. It returns a
955    !shape.witness to order this assertion.
956
957    For simplicity, prefer using other cstr_* ops if they are available for a
958    given constraint.
959
960    Example:
961    ```mlir
962    %bool = ...
963    %w0 = shape.cstr_require %bool, "msg" // Passing if `%bool` is true.
964    ```
965
966    Since this op can be used to express many different possible assertions
967    (depending on whatever computation calculated `pred`), the `msg`
968    should clarify the nature of the assertion for users.
969  }];
970  let arguments = (ins I1:$pred, StrAttr:$msg);
971  let results = (outs Shape_WitnessType:$result);
972
973  let assemblyFormat = "$pred `,` $msg attr-dict";
974
975  let hasFolder = 1;
976}
977
978//===----------------------------------------------------------------------===//
979// Shape collection ops.
980//===----------------------------------------------------------------------===//
981
982def Shape_FunctionLibraryOp : Shape_Op<"function_library",
983    [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
984     NoTerminator, OpAsmOpInterface, SingleBlock]> {
985  let summary = "Represents shape functions and corresponding ops";
986  let description = [{
987    Represents a list of shape functions and the ops whose shape transfer
988    functions they represent.
989
990    Example:
991
992    ```mlir
993    shape.function_library {
994      func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
995        %0 = shape_of %arg : !shape.value_shape -> !shape.shape
996        return %0 : !shape.shape
997      }
998    } mapping {
999      std.atan = @same_result_shape
1000    }
1001    ```
1002  }];
1003
1004  let arguments = (ins SymbolNameAttr:$sym_name,
1005                       OptionalAttr<StrAttr>:$sym_visibility,
1006                       DictionaryAttr:$mapping);
1007  let regions = (region AnyRegion:$body);
1008
1009  let extraClassDeclaration = [{
1010    /// Returns an associated shape function for an operation if defined.
1011    FuncOp getShapeFunction(Operation *op);
1012
1013    //===------------------------------------------------------------------===//
1014    // OpAsmOpInterface
1015    //===------------------------------------------------------------------===//
1016
1017    // This will filter the `shape.` prefix in front of operations inside the
1018    // func body.
1019    static StringRef getDefaultDialect() { return "shape";}
1020  }];
1021
1022  let builders = [OpBuilder<(ins "StringRef":$name)>];
1023  let skipDefaultBuilders = 1;
1024  let hasCustomAssemblyFormat = 1;
1025}
1026
1027def Shape_FuncOp : Shape_Op<"func",
1028    [AffineScope, AutomaticAllocationScope,
1029     FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> {
1030  let summary = "Shape function";
1031  let description = [{
1032    An operation with a name containing a single `SSACFG` region which
1033    represents a shape transfer function or helper function for shape transfer
1034    function.
1035  }];
1036
1037  let arguments = (ins SymbolNameAttr:$sym_name,
1038                       TypeAttrOf<FunctionType>:$function_type,
1039                       OptionalAttr<DictArrayAttr>:$arg_attrs,
1040                       OptionalAttr<DictArrayAttr>:$res_attrs,
1041                       OptionalAttr<StrAttr>:$sym_visibility);
1042  let regions = (region AnyRegion:$body);
1043
1044  let builders = [OpBuilder<(ins
1045    "StringRef":$name, "FunctionType":$type,
1046    CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
1047    CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
1048  >];
1049
1050  let extraClassDeclaration = [{
1051    static FuncOp create(Location location, StringRef name, FunctionType type,
1052                         ArrayRef<NamedAttribute> attrs = {});
1053    static FuncOp create(Location location, StringRef name, FunctionType type,
1054                         Operation::dialect_attr_range attrs);
1055    static FuncOp create(Location location, StringRef name, FunctionType type,
1056                         ArrayRef<NamedAttribute> attrs,
1057                         ArrayRef<DictionaryAttr> argAttrs);
1058    //===------------------------------------------------------------------===//
1059    // FunctionOpInterface Methods
1060    //===------------------------------------------------------------------===//
1061
1062    /// Returns the region on the current operation that is callable. This may
1063    /// return null in the case of an external callable object, e.g. an external
1064    /// function.
1065    ::mlir::Region *getCallableRegion() {
1066      return isExternal() ? nullptr : &getBody();
1067    }
1068
1069    /// Returns the argument types of this function.
1070    ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
1071
1072    /// Returns the result types of this function.
1073    ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
1074
1075    //===------------------------------------------------------------------===//
1076    // OpAsmOpInterface
1077    //===------------------------------------------------------------------===//
1078
1079    // This will filter the `shape.` prefix in front of operations inside the
1080    // func body.
1081    static StringRef getDefaultDialect() { return "shape";}
1082
1083    //===------------------------------------------------------------------===//
1084    // SymbolOpInterface Methods
1085    //===------------------------------------------------------------------===//
1086
1087    bool isDeclaration() { return isExternal(); }
1088  }];
1089  let hasCustomAssemblyFormat = 1;
1090}
1091
1092def Shape_ReturnOp : Shape_Op<"return",
1093    [Pure, HasParent<"FuncOp">, ReturnLike, Terminator]> {
1094  let summary = "Shape function return operation";
1095  let description = [{
1096    The `shape.return` operation represents a return operation within a
1097    function.  The operation takes variable number of operands and produces no
1098    results.
1099  }];
1100
1101  let arguments = (ins Variadic<AnyType>:$operands);
1102
1103  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
1104
1105  // TODO: Tighten verification.
1106}
1107
1108#endif // SHAPE_OPS
1109