1//===- ShapeOps.td - Shape operations definition -----------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This is the operation definition file for Shape dialect operations. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef SHAPE_OPS 14#define SHAPE_OPS 15 16include "mlir/Dialect/Shape/IR/ShapeBase.td" 17include "mlir/Interfaces/CallInterfaces.td" 18include "mlir/Interfaces/CastInterfaces.td" 19include "mlir/Interfaces/ControlFlowInterfaces.td" 20include "mlir/Interfaces/InferTypeOpInterface.td" 21include "mlir/Interfaces/SideEffectInterfaces.td" 22include "mlir/IR/OpAsmInterface.td" 23include "mlir/Interfaces/FunctionInterfaces.td" 24include "mlir/IR/SymbolInterfaces.td" 25 26//===----------------------------------------------------------------------===// 27// Shape op definitions 28//===----------------------------------------------------------------------===// 29 30// Base class for the operation in this dialect 31class Shape_Op<string mnemonic, list<Trait> traits = []> : 32 Op<ShapeDialect, mnemonic, traits>; 33 34def Shape_AddOp : Shape_Op<"add", 35 [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { 36 let summary = "Addition of sizes and indices"; 37 let description = [{ 38 Adds two sizes or indices. If either operand is an error it will be 39 propagated to the result. The operands can be of type `size` or `index`. If 40 at least one of the operands can hold an error, i.e. if it is of type 41 `size`, the result must be of type `size`. If error propagation is not 42 possible because both operands are of type `index` then the result may be 43 of type `size` or `index`. 44 }]; 45 46 let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); 47 let results = (outs Shape_SizeOrIndexType:$result); 48 49 let assemblyFormat = [{ 50 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 51 }]; 52 53 let hasFolder = 1; 54 let hasVerifier = 1; 55} 56 57def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, Pure]> { 58 let summary = "Returns the broadcasted output shape of two or more inputs"; 59 let description = [{ 60 Returns the broadcasted shape for input shapes or extent tensors. The rest 61 of this description is simplified for the 2 input case but can be extended 62 to more inputs. Both operands can be of type `shape.shape` or 63 `tensor<?xindex>`. The result is of type `shape.shape` and, if both 64 operands are tensors, may be of type `tensor<?xindex>`. 65 66 If the two operand shapes are of different rank the smaller one is padded 67 with 1's from the left. The resulting broadcasted shape is then defined as 68 69 result[i] = lhs[i] if lhs[i] == rhs[i] 70 = lhs[i] if rhs[i] == 1 71 = rhs[i] if lhs[i] == 1. 72 73 In case the resulting shape is undefined, i.e. if corresponding extents are 74 different from each other but none is 1, the result is an error shape. 75 Likewise error values are propagated if any of the operands holds an error 76 value. If the result type is an extent tensor (and can therefore not hold 77 the error value) the behavior may be undefined. The optional string 78 attribute can be used to describe the error case. 79 }]; 80 81 let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes, 82 OptionalAttr<StrAttr>:$error); 83 let results = (outs Shape_ShapeOrExtentTensorType:$result); 84 85 let builders = [OpBuilder<(ins "Value":$shape)>]; 86 87 let assemblyFormat = [{ 88 $shapes attr-dict `:` type($shapes) `->` type($result) 89 }]; 90 91 let builders = [OpBuilder<(ins "::mlir::Type":$result, 92 "::mlir::Value":$lhs, "::mlir::Value":$rhs, 93 "/*optional*/ ::mlir::StringAttr":$error), [{ 94 build($_builder, $_state, result, ::llvm::ArrayRef({lhs, rhs}), 95 error); 96 }]> 97 ]; 98 99 let hasFolder = 1; 100 let hasCanonicalizer = 1; 101 let hasVerifier = 1; 102} 103 104def Shape_ConstShapeOp : Shape_Op<"const_shape", 105 [ConstantLike, Pure, InferTypeOpAdaptorWithIsCompatible]> { 106 let summary = "Creates a constant shape or extent tensor"; 107 let description = [{ 108 Creates a constant shape or extent tensor. The individual extents are given 109 as the `shape` attribute. The number of these values equals the shape's 110 rank. 111 112 ```mlir 113 %0 = shape.const_shape [] : !shape.shape 114 %1 = shape.const_shape [1, 2, 3] : !shape.shape 115 %2 = shape.const_shape [4, 5, 6] : tensor<3xindex> 116 ``` 117 }]; 118 let arguments = (ins IndexElementsAttr:$shape); 119 let results = (outs Shape_ShapeOrExtentTensorType:$result); 120 121 let hasCustomAssemblyFormat = 1; 122 let hasFolder = 1; 123 let hasCanonicalizer = 1; 124} 125 126def Shape_ConstSizeOp : Shape_Op<"const_size", [ 127 ConstantLike, 128 Pure, 129 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]> 130 ]> { 131 let summary = "Creates a constant of type `shape.size`"; 132 let description = [{ 133 Creates a `shape.size` type representing the constant size given by `value`. 134 135 ```mlir 136 %x = shape.const_size 10 137 ``` 138 }]; 139 140 let arguments = (ins IndexAttr:$value); 141 let results = (outs Shape_SizeType:$result); 142 143 let builders = [OpBuilder<(ins "int64_t":$value)>]; 144 145 let assemblyFormat = "$value attr-dict"; 146 let hasFolder = 1; 147} 148 149def Shape_DivOp : Shape_Op<"div", [Pure, InferTypeOpAdaptorWithIsCompatible]> { 150 let summary = "Division of sizes and indices"; 151 let description = [{ 152 Divides two sizes or indices. If either operand is an error it will be 153 propagated to the result. The operands can be of type `size` or `index`. 154 If at least one of the operands can hold an error, i.e. if it is of type 155 `size`, the result must be of type `size`. If error propagation is not 156 possible because both operands are of type `index` then the result may be 157 of type `size` or `index`. If both operands and result are of type 158 `index`, their runtime values could be negative. The result is rounded 159 toward negative infinity, i.e. floor(lhs / rhs), such that 160 161 div(lhs, rhs) * rhs + mod(lhs, rhs) = lhs 162 163 always holds. If any of the values is of type `size`, the behavior for 164 negative value is undefined. 165 }]; 166 167 let arguments = (ins Shape_SizeOrIndexType:$lhs, 168 Shape_SizeOrIndexType:$rhs); 169 let results = (outs Shape_SizeOrIndexType:$result); 170 171 let assemblyFormat = [{ 172 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 173 }]; 174 175 let hasFolder = 1; 176 let hasVerifier = 1; 177} 178 179def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Pure, Commutative]> { 180 let summary = "Returns whether the input shapes or extent tensors are equal"; 181 let description = [{ 182 Takes one or more shape or extent tensor operands and determines whether 183 they are equal. When extent tensors are compared to shapes they are 184 regarded as their equivalent non-error shapes. Error shapes can be tested 185 for equality like any other shape value, meaning that the error value is 186 equal to itself. 187 }]; 188 189 let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes); 190 let results = (outs I1:$result); 191 192 // Convenience builder alias for the binary version. 193 let builders = [ 194 OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), 195 [{ build($_builder, $_state, ::llvm::ArrayRef({lhs, rhs})); }]>, 196 ]; 197 198 let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; 199 let hasFolder = 1; 200} 201 202def Shape_FromExtentsOp : Shape_Op<"from_extents", [Pure]> { 203 let summary = "Creates a shape from extents"; 204 let description = [{ 205 Creates a shape from multiple SSA values representing the extents of 206 the shape. 207 208 ```mlir 209 // Rank 2 shape. 210 %s0 = shape.from_extents %a, %b 211 // Rank 0 shape. 212 %s1 = shape.from_extents 213 ``` 214 }]; 215 let arguments = (ins Variadic<Shape_SizeOrIndexType>:$extents); 216 let results = (outs Shape_ShapeType:$shape); 217 218 let assemblyFormat = "$extents attr-dict `:` type($extents)"; 219 220 let hasFolder = 1; 221} 222 223def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [Pure]> { 224 let summary = "Creates a shape from a tensor of extents"; 225 let description = [{ 226 Creates a shape from a 1D integral tensor of extents. The rank of the 227 resulting shape equals the number of elements in the tensor, and the 228 extents match the values of the elements. 229 }]; 230 231 let arguments = (ins 1DTensorOf<[Index]>:$input); 232 let results = (outs Shape_ShapeType:$result); 233 234 let assemblyFormat = "$input attr-dict `:` type($input)"; 235} 236 237def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> { 238 let summary = "Determines if 2+ shapes can be successfully broadcasted"; 239 let description = [{ 240 Given multiple input shapes or extent tensors, return a predicate 241 specifying if they are broadcastable. This broadcastable follows the same 242 logic as what shape.broadcast documents. 243 244 Concretely, shape.is_broadcastable returning true implies that 245 shape.broadcast will not give an error, and shape.cstr_broadcastable will 246 not result in an assertion failure. Similarly, false implies an error or 247 assertion failure. 248 249 Example: 250 ```mlir 251 %true = shape.is_broadcastable [2,2], [3,1,2] 252 %false = shape.is_broadcastable [2,2], [3,2] 253 ``` 254 }]; 255 256 let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes); 257 let results = (outs I1:$result); 258 259 let builders = [ 260 OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), 261 [{ build($_builder, $_state, ::llvm::ArrayRef({lhs, rhs})); }]>, 262 ]; 263 264 let hasFolder = 1; 265 let hasCanonicalizer = 1; 266 267 let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; 268} 269 270def Shape_RankOp : Shape_Op<"rank", 271 [Pure, InferTypeOpAdaptorWithIsCompatible]> { 272 let summary = "Gets the rank of a shape"; 273 let description = [{ 274 Returns the rank of the shape or extent tensor, i.e. the number of extents. 275 }]; 276 277 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); 278 let results = (outs Shape_SizeOrIndexType:$rank); 279 280 let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($rank)"; 281 282 let hasFolder = 1; 283 let hasCanonicalizer = 1; 284 let hasVerifier = 1; 285} 286 287def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [ 288 DeclareOpInterfaceMethods<CastOpInterface>, Pure 289 ]> { 290 let summary = "Creates a dimension tensor from a shape"; 291 let description = [{ 292 Converts a shape to a 1D integral tensor of extents. The number of elements 293 in the tensor equals the rank of the shape, and the elements equal the 294 extents of the shape. 295 296 If the shape represents an error, this op's behavior is undefined. 297 }]; 298 299 let arguments = (ins Shape_ShapeOrExtentTensorType:$input); 300 let results = (outs IndexTensor:$result); 301 302 let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)"; 303 304 let hasFolder = 1; 305} 306 307def Shape_DimOp : Shape_Op<"dim", 308 [Pure, InferTypeOpAdaptorWithIsCompatible]> { 309 let summary = "Gets the specified extent from the shape of a shaped input"; 310 let description = [{ 311 Gets the extent indexed by `dim` from the shape of the `value` operand. If 312 the index is error or out-of-bound then it returns an invalid size if the 313 return type carries error information else the behavior is undefined. 314 315 This is a convenience op that performs the equivalent of getting the extent 316 of a shape (e.g., `dim(x, i) == get_extent(shape_of(x), i)`). 317 }]; 318 let arguments = (ins AnyShaped:$value, 319 Shape_SizeOrIndexType:$index); 320 let results = (outs Shape_SizeOrIndexType:$extent); 321 let assemblyFormat = "$value `,` $index attr-dict `:` type($value) `,`" 322 "type($index) `->` type($extent)"; 323 324 let extraClassDeclaration = [{ 325 /// Get the `index` value as integer if it is constant. 326 std::optional<int64_t> getConstantIndex(); 327 }]; 328 329 let hasFolder = 1; 330} 331 332def Shape_GetExtentOp : Shape_Op<"get_extent", 333 [Pure, InferTypeOpAdaptorWithIsCompatible]> { 334 let summary = "Gets the specified extent from a shape or extent tensor"; 335 let description = [{ 336 Gets the extent indexed by `dim` from the `shape` operand. If the shape is 337 an error then it returns an invalid size. 338 }]; 339 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, 340 Shape_SizeOrIndexType:$dim); 341 let results = (outs Shape_SizeOrIndexType:$extent); 342 let assemblyFormat = "$shape `,` $dim attr-dict `:` type($shape) `,` " 343 "type($dim) `->` type($extent)"; 344 345 let builders = [ 346 // Builder that allows passing a constant dimension as a simple integer. 347 OpBuilder<(ins "Value":$shape, "int64_t":$dim)> 348 ]; 349 350 let extraClassDeclaration = [{ 351 /// Get the `dim` value as integer if it is constant. 352 std::optional<int64_t> getConstantDim(); 353 }]; 354 355 let hasFolder = 1; 356 let hasVerifier = 1; 357} 358 359def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [Pure]> { 360 let summary = "Converts a standard index to a shape size"; 361 let description = [{ 362 Converts a standard index to a `shape.size`. This operation and its 363 inverse, `size_to_index`, facilitate index conversion between the standard 364 and the shape dialect. 365 366 The behavior is undefined for negative indices. 367 }]; 368 369 let arguments = (ins Index:$arg); 370 let results = (outs Shape_SizeType:$result); 371 372 let assemblyFormat = "$arg attr-dict"; 373 374 let hasFolder = 1; 375 let hasCanonicalizer = 1; 376} 377 378def Shape_MaxOp : Shape_Op<"max", 379 [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { 380 let summary = "Elementwise maximum"; 381 let description = [{ 382 Computes the elementwise maximum of two sizes or shapes with equal ranks. 383 If either operand is an error, then an error will be propagated to the 384 result. If the input types mismatch or the ranks do not match, then the 385 result is an error. 386 }]; 387 388 let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); 389 let results = (outs Shape_ShapeOrSizeType:$result); 390 391 let assemblyFormat = [{ 392 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 393 }]; 394 395 let hasFolder = 1; 396} 397 398def Shape_MeetOp : Shape_Op<"meet", 399 [Commutative, InferTypeOpAdaptorWithIsCompatible]> { 400 let summary = "Returns the least general shape or size of its operands"; 401 let description = [{ 402 An operation that computes the least general shape or dim of input operands. 403 This effectively asserts that corresponding static dimensions are equal. 404 The behavior is to match each element of the shape/size and propagate the 405 most restrictive information, returning an invalid shape if there are 406 contradictory requirements. E.g., using pseudo code 407 408 ``` 409 shape.meet([*], [*]) -> [*] 410 shape.meet([*], [1, ?]) -> [1, ?] 411 shape.meet([1, 2], [1, ?]) -> [1, 2] 412 shape.meet([*], [1, 2]) -> [1, 2] 413 shape.meet([], []) -> [] 414 shape.meet([], [*]) -> [] 415 shape.meet([], [?, ?]) -> [invalid] 416 shape.meet([1, ?], [2, ?, ?]) -> [invalid] 417 ``` 418 419 `shape.meet` also allows specifying an optional error string, that may be 420 used to return an error to the user upon mismatch of dimensions. 421 422 ```mlir 423 %c = shape.meet %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape 424 ``` 425 }]; 426 427 let arguments = (ins 428 Shape_AnyShapeOrSizeType:$arg0, 429 Shape_AnyShapeOrSizeType:$arg1, 430 OptionalAttr<StrAttr>:$error); 431 let results = (outs Shape_AnyShapeOrSizeType:$result); 432 433 let assemblyFormat = [{ 434 $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` 435 type($arg0) `,` type($arg1) `->` type($result) 436 }]; 437} 438 439def Shape_MinOp : Shape_Op<"min", 440 [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { 441 let summary = "Elementwise minimum"; 442 let description = [{ 443 Computes the elementwise minimum of two sizes or shapes with equal ranks. 444 If either operand is an error, then an error will be propagated to the 445 result. If the input types mismatch or the ranks do not match, then the 446 result is an error. 447 }]; 448 449 let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs); 450 let results = (outs Shape_ShapeOrSizeType:$result); 451 452 let assemblyFormat = [{ 453 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 454 }]; 455 456 let hasFolder = 1; 457} 458 459def Shape_MulOp : Shape_Op<"mul", 460 [Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> { 461 let summary = "Multiplication of sizes and indices"; 462 let description = [{ 463 Multiplies two sizes or indices. If either operand is an error it will be 464 propagated to the result. The operands can be of type `size` or `index`. If 465 at least one of the operands can hold an error, i.e. if it is of type 466 `size`, the result must be of type `size`. If error propagation is not 467 possible because both operands are of type `index` then the result may be 468 of type `size` or `index`. 469 }]; 470 471 let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); 472 let results = (outs Shape_SizeOrIndexType:$result); 473 474 let assemblyFormat = [{ 475 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 476 }]; 477 478 let hasFolder = 1; 479 let hasVerifier = 1; 480} 481 482def Shape_NumElementsOp : Shape_Op<"num_elements", 483 [Pure, InferTypeOpAdaptorWithIsCompatible]> { 484 let summary = "Returns the number of elements for a given shape"; 485 let description = [{ 486 Returns the number of elements for a given shape which is the product of 487 its extents. If the argument is of type `shape` then the result will be of 488 type `size` and potential errors will be propagated. Otherwise, if the 489 argument is and extent tensor `tensor<?xindex>` then the result will be of 490 type `index`. 491 }]; 492 493 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); 494 let results = (outs Shape_SizeOrIndexType:$result); 495 496 let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; 497 498 let hasFolder = 1; 499 let hasVerifier = 1; 500} 501 502def Shape_ReduceOp : Shape_Op<"reduce", 503 [SingleBlockImplicitTerminator<"YieldOp">]> { 504 let summary = "Returns an expression reduced over a shape or extent tensor"; 505 let description = [{ 506 An operation that takes as input a shape or extent tensor, and a number of 507 initial values. This operation has a region that is applied repeatedly for 508 every extent of the input. Starting with the initial values, the individual 509 extents are then aggregated as defined by the associated region. 510 511 Conceptually this op performs the following reduction: 512 513 ``` 514 res[] = init; 515 for (int i = 0, i < shape.rank(); i++) { 516 res = reduce(i, shape[i], res[0], ..., res[n]); 517 } 518 ``` 519 520 Where `reduce` represents the region attached and the result of the reduce 521 op is the last computed output of the reduce region. As an example, the 522 number of elements can be computed as follows: 523 524 ```mlir 525 func.func @reduce(%shape : !shape.shape, %init : !shape.size) -> 526 !shape.size { 527 %num_elements = shape.reduce(%shape, %init) -> !shape.size { 528 ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): 529 %updated_acc = "shape.mul"(%acc, %dim) : 530 (!shape.size, !shape.size) -> !shape.size 531 shape.yield %updated_acc : !shape.size 532 } 533 return %num_elements : !shape.size 534 } 535 ``` 536 }]; 537 538 let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, 539 Variadic<AnyType>:$initVals); 540 let results = (outs Variadic<AnyType>:$result); 541 let regions = (region SizedRegion<1>:$region); 542 543 let builders = [OpBuilder<(ins "Value":$shape, "ValueRange":$initVals)>]; 544 545 let hasCustomAssemblyFormat = 1; 546 let hasVerifier = 1; 547} 548 549def Shape_ShapeOfOp : Shape_Op<"shape_of", 550 [Pure, InferTypeOpAdaptorWithIsCompatible]> { 551 let summary = "Returns shape of a value or shaped type operand"; 552 553 let description = [{ 554 The operation takes a value or a shaped operand as an argument and it 555 returns a shape or extent tensor. 556 }]; 557 558 let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg); 559 let results = (outs Shape_ShapeOrExtentTensorType:$result); 560 561 let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; 562 563 let hasCanonicalizer = 1; 564 let hasVerifier = 1; 565} 566 567def Shape_ValueOfOp : Shape_Op<"value_of", [Pure]> { 568 let summary = "Returns value of a !shape.value_shape operand"; 569 570 let description = [{ 571 The operation takes !shape.value_shape, a.k.a. (value, shape) tuple as an 572 argument, and returns its value. The behavior is undefined for unknown and 573 invalid arguments. 574 }]; 575 576 let arguments = (ins Shape_ValueShapeType:$arg); 577 let results = (outs AnyShaped:$result); 578 579 let assemblyFormat = "$arg attr-dict `:` type($result)"; 580} 581 582def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [ 583 DeclareOpInterfaceMethods<CastOpInterface>, Pure 584 ]> { 585 let summary = "Casts between index types of the shape and standard dialect"; 586 let description = [{ 587 Converts a `shape.size` to a standard index. This operation and its 588 inverse, `index_to_size`, facilitate index conversion between the standard 589 and the shape dialect. The behavior is undefined for unknown and invalid 590 arguments. 591 }]; 592 593 let arguments = (ins Shape_SizeOrIndexType:$arg); 594 let results = (outs Index:$result); 595 596 let assemblyFormat = "$arg attr-dict `:` type($arg)"; 597 598 let hasFolder = 1; 599 let hasCanonicalizer = 1; 600} 601 602def Shape_ValueAsShapeOp : Shape_Op<"value_as_shape", [Pure]> { 603 let summary = "Returns value as a shape"; 604 605 let description = [{ 606 The operations takes a ValueShape and returns a Shape corresponding to the 607 value. If the input value cannot be shape (e.g., not a 1D tensor of 608 integral value representing sizes) then this propagages the error shape. 609 E.g., 610 611 ```mlir 612 // The following 613 %0 = arith.constant dense<[1,2]> : tensor<2xi32> 614 %shape = shape.value_as_shape %0 : tensor<2xi32> -> !shape.shape 615 // is equivalent to 616 %shape' = shape.const_shape [1, 2] : !shape.shape 617 ``` 618 619 This operation is the complement of `shape_of` wrt ValueShape values. 620 }]; 621 622 let arguments = (ins AnyTypeOf<[1DTensorOf<[AnyInteger, Index]>, 623 Shape_ValueShapeType]>:$arg); 624 let results = (outs Shape_ShapeOrExtentTensorType:$result); 625 626 let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; 627} 628 629def Shape_WithOp : Shape_Op<"with_shape", [Pure]> { 630 let summary = "Returns ValueShape with given shape"; 631 let description = [{ 632 Returns ValueShape with the shape updated to match the shape operand. That 633 is a new ValueShape tuple is created with value equal to `operand`'s 634 value and shape equal to `shape`. If the ValueShape and given `shape` are 635 non-conformant, then the returned ValueShape will represent an error of 636 this mismatch. Similarly if either inputs are in an error state, then an 637 error is propagated. 638 639 Usage: 640 %0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape 641 642 This is used, for example, where one combines shape function calculations 643 and/or call one shape function from another. E.g., 644 645 ```mlir 646 func.func @shape_foobah(%a: !shape.value_shape, 647 %b: !shape.value_shape, 648 %c: !shape.value_shape) -> !shape.shape { 649 %0 = call @shape_foo(%a, %b) : 650 (!shape.value_shape, !shape.value_shape) -> !shape.shape 651 %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape 652 %2 = call @shape_bah(%c, %1) : 653 (!shape.value_shape, !shape.value_shape) -> !shape.shape 654 return %2 : !shape.shape 655 } 656 ``` 657 658 This op need not be a refinement of the shape. In non-error cases the input 659 ValueShape's value and shape are conformant and so too for the output, but 660 the result may be less specified than `operand`'s shape as `shape` is 661 merely used to construct the new ValueShape. If join behavior is desired 662 then a join op should be used. 663 }]; 664 665 let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand, 666 Shape_ShapeOrExtentTensorType:$shape); 667 let results = (outs Shape_ValueShapeType:$result); 668 669 let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)"; 670} 671 672def Shape_YieldOp : Shape_Op<"yield", 673 [HasParent<"ReduceOp, FunctionLibraryOp">, 674 Pure, 675 ReturnLike, 676 Terminator]> { 677 let summary = "Returns the value to parent op"; 678 679 let arguments = (ins Variadic<AnyType>:$operands); 680 681 let builders = [OpBuilder<(ins), 682 [{ build($_builder, $_state, std::nullopt); }]> 683 ]; 684 685 let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; 686 let hasVerifier = 1; 687} 688 689// TODO: Add Ops: if_static, if_ranked 690 691// For testing usage. 692def Shape_DebugPrintOp : Shape_Op<"debug_print", []> { 693 let summary = "Prints the input shape or size"; 694 let description = [{ 695 Prints the input dim or shape and passes through input. 696 697 Note: This is intended for testing and debugging only. 698 }]; 699 700 let arguments = (ins Shape_ShapeOrSizeType:$input); 701 let results = (outs Shape_ShapeOrSizeType:$output); 702} 703 704def Shape_SplitAtOp : Shape_Op<"split_at", [Pure]> { 705 let summary = "Splits a shape at a given index"; 706 let description = [{ 707 Splits a shape at a given dimension `index`, returning two shapes. If 708 `index` is negative, it is treated as indexing from the back of the shape. 709 This negative-handling behavior is important when handling unranked shapes, 710 where the positive index is not necessarily knowable due to a dynamic 711 number of leading dimensions. If the result is in extent tensor form out of 712 bounds indices result in undefined behavior. 713 714 Examples: 715 - split_at([4,5,6], index=0) -> [], [4,5,6] 716 - split_at([4,5,6], index=1) -> [4], [5,6] 717 - split_at([4,5,6], index=2) -> [4,5], [6] 718 - split_at([4,5,6], index=3) -> [4,5,6], [] 719 - split_at([4,5,6], index=4) -> error 720 - split_at([4,5,6], index=-1) -> [4,5], [6] 721 - split_at([4,5,6], index=-2) -> [4], [5,6] 722 - split_at([4,5,6], index=-3) -> [], [4,5,6] 723 - split_at([4,5,6], index=-4) -> error 724 725 Requires: 726 - `index` is in the range [-rank(operand),rank(operand)] 727 }]; 728 729 let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, 730 Shape_SizeOrIndexType:$index); 731 let results = (outs Shape_ShapeOrExtentTensorType:$head, 732 Shape_ShapeOrExtentTensorType:$tail); 733 let hasFolder = 1; 734} 735 736def Shape_ConcatOp : Shape_Op<"concat", [Pure]> { 737 let summary = "Concatenates two shapes"; 738 let description = [{ 739 Creates a shape whose dimensions consist of first the dimensions from `lhs` 740 followed by the dimensions of `rhs`. 741 742 Example: 743 concat([2,3], [4,5]) -> [2,3,4,5] 744 concat([], []) -> [] 745 concat([], [4,5,6]) -> [4,5,6] 746 }]; 747 748 let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs, 749 Shape_ShapeOrExtentTensorType:$rhs); 750 let results = (outs Shape_ShapeOrExtentTensorType:$result); 751 752 let assemblyFormat = [{ 753 $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) 754 }]; 755 756 let hasFolder = 1; 757} 758 759//===----------------------------------------------------------------------===// 760// Shape constraint related ops. 761//===----------------------------------------------------------------------===// 762 763// TODO: Move the code below and witnesses to a different file. 764def Shape_AnyOp : Shape_Op<"any", [Commutative, 765 Pure]> { 766 let summary = "Return any combination of the input shapes"; 767 let description = [{ 768 This operation takes multiple input shapes or extent tensors and returns 769 some combination of their dimensions. This can be best seen with examples 770 below. 771 772 The result is undefined, but still side-effect free, in cases where the 773 inputs have differing ranks or differ in extents of shared dimensions. 774 775 Example: 776 ```mlir 777 %s0 = shape.any [2,?], [?,3] // [2,3] 778 %s1 = shape.any [?,?], [1,2] // [1,2] 779 ``` 780 }]; 781 782 let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs); 783 let results = (outs Shape_ShapeOrExtentTensorType:$result); 784 785 let assemblyFormat = "$inputs attr-dict `:` type($inputs) `->` type($result)"; 786 787 let hasFolder = 1; 788} 789 790def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, Pure]> { 791 let summary = "Return a logical AND of all witnesses"; 792 let description = [{ 793 Used to simplify constraints as any single failing precondition is enough 794 to prevent execution. 795 796 "assuming" operations represent an execution order restriction to the 797 compiler, information for dependent code to rely on (by assuming), and 798 nothing else. They should not exist after a program is fully lowered and 799 ready to execute. 800 801 Example: 802 ```mlir 803 %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing 804 %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure 805 %w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing 806 %wf = shape.assuming_all %w0, %w1 // Failure 807 %wt = shape.assuming_all %w0, %w2 // Passing 808 ``` 809 }]; 810 811 let arguments = (ins Variadic<Shape_WitnessType>:$inputs); 812 let results = (outs Shape_WitnessType:$result); 813 814 let assemblyFormat = "$inputs attr-dict"; 815 816 let hasFolder = 1; 817 let hasCanonicalizer = 1; 818 let hasVerifier = 1; 819} 820 821def Shape_AssumingOp : Shape_Op<"assuming", [ 822 SingleBlockImplicitTerminator<"AssumingYieldOp">, 823 DeclareOpInterfaceMethods<RegionBranchOpInterface>, 824 RecursiveMemoryEffects]> { 825 let summary = "Execute the region"; 826 let description = [{ 827 Executes the region assuming all witnesses are true. 828 829 "assuming" operations represent an execution order restriction to the 830 compiler, information for dependent code to rely on (by assuming), and 831 nothing else. They should not exist after a program is fully lowered and 832 ready to execute. 833 }]; 834 let arguments = (ins Shape_WitnessType:$witness); 835 let regions = (region SizedRegion<1>:$doRegion); 836 let results = (outs Variadic<AnyType>:$results); 837 838 let extraClassDeclaration = [{ 839 // Inline the region into the region containing the AssumingOp and delete 840 // the AssumingOp. 841 // 842 // This does no checks on the inputs to the AssumingOp. 843 static void inlineRegionIntoParent(AssumingOp &op, 844 PatternRewriter &rewriter); 845 }]; 846 847 let builders = [ 848 OpBuilder<(ins "Value":$witness, 849 CArg<"function_ref<SmallVector<Value, 2>(OpBuilder &, Location)>">)> 850 ]; 851 852 let hasCanonicalizer = 1; 853 let hasCustomAssemblyFormat = 1; 854} 855 856def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", 857 [Pure, ReturnLike, Terminator, HasParent<"AssumingOp">]> { 858 let summary = "Yield operation"; 859 let description = [{ 860 This yield operation represents a return operation within the 861 `shape.assuming` operation region. The operation takes variable number of 862 operands and produces no results. The operand number and types must match 863 the number and types of parent `shape.assuming` results. 864 }]; 865 866 let arguments = (ins Variadic<AnyType>:$operands); 867 868 let builders = [ 869 OpBuilder<(ins), [{ /* nothing to do */ }]>, 870 ]; 871 872 let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; 873} 874 875def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> { 876 let summary = "Determines if 2+ shapes can be successfully broadcasted"; 877 let description = [{ 878 Given input shapes or extent tensors, return a witness specifying if they 879 are broadcastable. This broadcastable follows the same logic as what 880 shape.broadcast documents. 881 882 "cstr" operations represent runtime assertions. 883 884 Example: 885 ```mlir 886 %w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing 887 %w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure 888 ``` 889 }]; 890 891 let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes); 892 let results = (outs Shape_WitnessType:$result); 893 894 let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; 895 896 let builders = [ 897 OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs), 898 [{ build($_builder, $_state, ::llvm::ArrayRef({lhs, rhs})); }]>, 899 ]; 900 901 let hasCanonicalizer = 1; 902 let hasFolder = 1; 903 let hasVerifier = 1; 904} 905 906def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> { 907 let summary = "Determines if all input shapes are equal"; 908 let description = [{ 909 Given 1 or more input shapes, determine if all shapes are the exact same. 910 911 "cstr" operations represent runtime assertions. 912 913 Example: 914 ```mlir 915 %w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing 916 %w1 = shape.cstr_eq [2,2], [1,2] // Failure 917 ``` 918 }]; 919 let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes); 920 let results = (outs Shape_WitnessType:$result); 921 922 let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; 923 924 let hasCanonicalizer = 1; 925 let hasFolder = 1; 926} 927 928def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, Pure]> { 929 let summary = "An operation that returns a statically known witness value"; 930 let description = [{ 931 This operation represents a statically known witness result. This can be 932 often used to canonicalize/fold constraint and assuming code that will always 933 pass. 934 935 ```mlir 936 %0 = shape.const_shape [1,2,3] 937 %1 = shape.const_shape [1,2,3] 938 %w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true" 939 %w1 = shape.const_witness true 940 %w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true" 941 ``` 942 }]; 943 let arguments = (ins BoolAttr:$passing); 944 let results = (outs Shape_WitnessType:$result); 945 946 let assemblyFormat = "$passing attr-dict"; 947 948 let hasFolder = 1; 949} 950 951def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> { 952 let summary = "Represents a runtime assertion that an i1 is `true`"; 953 let description = [{ 954 Represents a runtime assertion that an i1 is true. It returns a 955 !shape.witness to order this assertion. 956 957 For simplicity, prefer using other cstr_* ops if they are available for a 958 given constraint. 959 960 Example: 961 ```mlir 962 %bool = ... 963 %w0 = shape.cstr_require %bool, "msg" // Passing if `%bool` is true. 964 ``` 965 966 Since this op can be used to express many different possible assertions 967 (depending on whatever computation calculated `pred`), the `msg` 968 should clarify the nature of the assertion for users. 969 }]; 970 let arguments = (ins I1:$pred, StrAttr:$msg); 971 let results = (outs Shape_WitnessType:$result); 972 973 let assemblyFormat = "$pred `,` $msg attr-dict"; 974 975 let hasFolder = 1; 976} 977 978//===----------------------------------------------------------------------===// 979// Shape collection ops. 980//===----------------------------------------------------------------------===// 981 982def Shape_FunctionLibraryOp : Shape_Op<"function_library", 983 [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, 984 NoTerminator, OpAsmOpInterface, SingleBlock]> { 985 let summary = "Represents shape functions and corresponding ops"; 986 let description = [{ 987 Represents a list of shape functions and the ops whose shape transfer 988 functions they represent. 989 990 Example: 991 992 ```mlir 993 shape.function_library { 994 func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape { 995 %0 = shape_of %arg : !shape.value_shape -> !shape.shape 996 return %0 : !shape.shape 997 } 998 } mapping { 999 std.atan = @same_result_shape 1000 } 1001 ``` 1002 }]; 1003 1004 let arguments = (ins SymbolNameAttr:$sym_name, 1005 OptionalAttr<StrAttr>:$sym_visibility, 1006 DictionaryAttr:$mapping); 1007 let regions = (region AnyRegion:$body); 1008 1009 let extraClassDeclaration = [{ 1010 /// Returns an associated shape function for an operation if defined. 1011 FuncOp getShapeFunction(Operation *op); 1012 1013 //===------------------------------------------------------------------===// 1014 // OpAsmOpInterface 1015 //===------------------------------------------------------------------===// 1016 1017 // This will filter the `shape.` prefix in front of operations inside the 1018 // func body. 1019 static StringRef getDefaultDialect() { return "shape";} 1020 }]; 1021 1022 let builders = [OpBuilder<(ins "StringRef":$name)>]; 1023 let skipDefaultBuilders = 1; 1024 let hasCustomAssemblyFormat = 1; 1025} 1026 1027def Shape_FuncOp : Shape_Op<"func", 1028 [AffineScope, AutomaticAllocationScope, 1029 FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> { 1030 let summary = "Shape function"; 1031 let description = [{ 1032 An operation with a name containing a single `SSACFG` region which 1033 represents a shape transfer function or helper function for shape transfer 1034 function. 1035 }]; 1036 1037 let arguments = (ins SymbolNameAttr:$sym_name, 1038 TypeAttrOf<FunctionType>:$function_type, 1039 OptionalAttr<DictArrayAttr>:$arg_attrs, 1040 OptionalAttr<DictArrayAttr>:$res_attrs, 1041 OptionalAttr<StrAttr>:$sym_visibility); 1042 let regions = (region AnyRegion:$body); 1043 1044 let builders = [OpBuilder<(ins 1045 "StringRef":$name, "FunctionType":$type, 1046 CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs, 1047 CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs) 1048 >]; 1049 1050 let extraClassDeclaration = [{ 1051 static FuncOp create(Location location, StringRef name, FunctionType type, 1052 ArrayRef<NamedAttribute> attrs = {}); 1053 static FuncOp create(Location location, StringRef name, FunctionType type, 1054 Operation::dialect_attr_range attrs); 1055 static FuncOp create(Location location, StringRef name, FunctionType type, 1056 ArrayRef<NamedAttribute> attrs, 1057 ArrayRef<DictionaryAttr> argAttrs); 1058 //===------------------------------------------------------------------===// 1059 // FunctionOpInterface Methods 1060 //===------------------------------------------------------------------===// 1061 1062 /// Returns the region on the current operation that is callable. This may 1063 /// return null in the case of an external callable object, e.g. an external 1064 /// function. 1065 ::mlir::Region *getCallableRegion() { 1066 return isExternal() ? nullptr : &getBody(); 1067 } 1068 1069 /// Returns the argument types of this function. 1070 ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); } 1071 1072 /// Returns the result types of this function. 1073 ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); } 1074 1075 //===------------------------------------------------------------------===// 1076 // OpAsmOpInterface 1077 //===------------------------------------------------------------------===// 1078 1079 // This will filter the `shape.` prefix in front of operations inside the 1080 // func body. 1081 static StringRef getDefaultDialect() { return "shape";} 1082 1083 //===------------------------------------------------------------------===// 1084 // SymbolOpInterface Methods 1085 //===------------------------------------------------------------------===// 1086 1087 bool isDeclaration() { return isExternal(); } 1088 }]; 1089 let hasCustomAssemblyFormat = 1; 1090} 1091 1092def Shape_ReturnOp : Shape_Op<"return", 1093 [Pure, HasParent<"FuncOp">, ReturnLike, Terminator]> { 1094 let summary = "Shape function return operation"; 1095 let description = [{ 1096 The `shape.return` operation represents a return operation within a 1097 function. The operation takes variable number of operands and produces no 1098 results. 1099 }]; 1100 1101 let arguments = (ins Variadic<AnyType>:$operands); 1102 1103 let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; 1104 1105 // TODO: Tighten verification. 1106} 1107 1108#endif // SHAPE_OPS 1109