1//===- LinalgInterfaces.td - Linalg Interfaces Declaration -*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This is the definition file for the structured interface sfor Linalg ops. 10// 11//===----------------------------------------------------------------------===// 12 13#ifndef LINALG_IR_LINALGINTERFACES 14#define LINALG_IR_LINALGINTERFACES 15 16include "mlir/Interfaces/DestinationStyleOpInterface.td" 17include "mlir/IR/OpBase.td" 18 19// The 'LinalgContractionOpInterface' provides access to the 20// 'ContractionOpInterface'. 21def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> { 22 let description = [{ 23 A Linalg contraction is defined in general terms: 24 1. Has 2 input and 1 output shapes. 25 2. Has at least one reduction dimension. 26 3. Has only projected permutation indexing maps. 27 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field 28 (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary 29 operations that may change the type (e.g. for mixed-precision). 30 As a consequence, when vectorization of such an op occurs, the only special 31 behavior is that the (unique) MulOpType is vectorized into a 32 `vector.contract`. All other ops are handled in a generic fashion. 33 In the future, we may wish to allow more input arguments and elementwise and 34 constant operations that do not involve the reduction dimension(s). 35 }]; 36 let cppNamespace = "::mlir::linalg"; 37 let verify = [{ return detail::verifyContractionInterface($_op); }]; 38 let verifyWithRegions = 1; 39 let methods = [ 40 InterfaceMethod< 41 /*desc=*/"Returns the left-hand side operand.", 42 /*retTy=*/"Value", 43 /*methodName=*/"lhs", 44 /*args=*/(ins), 45 /*methodBody=*/[{ 46 return $_op.getOperation()->getOperand(0); 47 }]>, 48 InterfaceMethod< 49 /*desc=*/"Returns the right-hand side operand.", 50 /*retTy=*/"Value", 51 /*methodName=*/"rhs", 52 /*args=*/(ins), 53 /*methodBody=*/[{ 54 return $_op.getOperation()->getOperand(1); 55 }]>, 56 InterfaceMethod< 57 /*desc=*/[{ 58 Returns whether the given op has indexing maps that correspond to a 59 row-major matmul operation. 60 }], 61 /*retTy=*/"bool", 62 /*methodName=*/"isRowMajorMatmul", 63 /*args=*/(ins), 64 /*methodBody=*/[{ 65 return mlir::isRowMajorMatmul($_op.getIndexingMaps()); 66 }]>, 67 InterfaceMethod< 68 /*desc=*/[{ 69 Returns whether the given op has indexing maps that correspond to a 70 column-major matmul operation. 71 }], 72 /*retTy=*/"bool", 73 /*methodName=*/"isColumnMajorMatmul", 74 /*args=*/(ins), 75 /*methodBody=*/[{ 76 return mlir::isColumnMajorMatmul($_op.getIndexingMaps()); 77 }]>, 78 InterfaceMethod< 79 /*desc=*/[{ 80 Returns whether the given op has indexing maps that correspond to a 81 row-major batch matmul operation. 82 }], 83 /*retTy=*/"bool", 84 /*methodName=*/"isRowMajorBatchMatmul", 85 /*args=*/(ins), 86 /*methodBody=*/[{ 87 return mlir::isRowMajorBatchMatmul($_op.getIndexingMaps()); 88 }]>, 89 InterfaceMethod< 90 /*desc=*/[{ 91 Returns whether the given op has indexing maps that correspond to a 92 vector-matrix multiplication. 93 }], 94 /*retTy=*/"bool", 95 /*methodName=*/"isVecmat", 96 /*args=*/(ins), 97 /*methodBody=*/[{ 98 return mlir::isVecmat($_op.getIndexingMaps()); 99 }]>, 100 InterfaceMethod< 101 /*desc=*/[{ 102 Returns whether the given op has indexing maps that correspond to a 103 batched vector-matrix multiplication. 104 }], 105 /*retTy=*/"bool", 106 /*methodName=*/"isBatchVecmat", 107 /*args=*/(ins), 108 /*methodBody=*/[{ 109 return mlir::isBatchVecmat($_op.getIndexingMaps()); 110 }]>, 111 InterfaceMethod< 112 /*desc=*/[{ 113 Returns whether the given op has indexing maps that correspond to a 114 matrix-vector multiplication. 115 }], 116 /*retTy=*/"bool", 117 /*methodName=*/"isMatvec", 118 /*args=*/(ins), 119 /*methodBody=*/[{ 120 return mlir::isMatvec($_op.getIndexingMaps()); 121 }]>, 122 InterfaceMethod< 123 /*desc=*/[{ 124 Returns whether the given op has indexing maps that correspond to a 125 batched matrix-vector multiplication. 126 }], 127 /*retTy=*/"bool", 128 /*methodName=*/"isBatchMatvec", 129 /*args=*/(ins), 130 /*methodBody=*/[{ 131 return mlir::isBatchMatvec($_op.getIndexingMaps()); 132 }]>, 133 ]; 134} 135 136def LinalgConvolutionOpInterface : OpInterface<"ConvolutionOpInterface"> { 137 let description = [{ 138 A convolution is defined in general terms: 139 1. Has an `image` and a `filter` operand. 140 2. Has one `output` operand. 141 3. The indexing maps of the input have expressions that satisfy 142 ``` 143 AffineExpr ::== AffineDimExpr | ConvolvedExpr 144 ConvolvedExpr ::== MulExpr (`+` MulExpr)+ 145 MulExpr ::== AffineDimExpr (`*` (AffineConstantExpr | AffineSymbolExpr))? 146 ``` 147 4. The filter and the output have projected permutation maps. 148 5. Each of the loops can be qualified as one of, 149 - Loop over batch dimension, 150 - Loop over output image dimensions, 151 - Loop over output channel dimensions, 152 - Loop over convolved filter dimensions, 153 - Loop over input channel dimension. 154 }]; 155 let cppNamespace = "::mlir::linalg"; 156 let verify = [{ return detail::verifyConvolutionInterface($_op); }]; 157 let methods = [ 158 InterfaceMethod< 159 /*desc=*/"Return the image operand.", 160 /*retTy=*/"Value", 161 /*methodName=*/"image", 162 /*args=*/(ins), 163 /*methodBody=*/"", 164 /*defaultImplementation=*/[{ 165 return $_op.getOperation()->getOperand(0); 166 }] 167 >, 168 InterfaceMethod< 169 /*desc=*/"Return the filter operand.", 170 /*retTy=*/"Value", 171 /*methodName=*/"filter", 172 /*args=*/(ins), 173 /*methodBody=*/"", 174 /*defaultImplementation=*/[{ 175 return $_op.getOperation()->getOperand(1); 176 }] 177 >, 178 ]; 179} 180 181def LinalgFillOpInterface : OpInterface<"FillOpInterface"> { 182 let description = [{ 183 A fill operation is defined in general terms: 184 1. Has a scalar `value` operand. 185 2. Has one `output` operand. 186 }]; 187 let cppNamespace = "::mlir::linalg"; 188 let verify = [{ return detail::verifyFillInterface($_op); }]; 189 let methods = [ 190 InterfaceMethod< 191 /*desc=*/"Return the fill value.", 192 /*retTy=*/"Value", 193 /*methodName=*/"value", 194 /*args=*/(ins), 195 /*methodBody=*/"", 196 /*defaultImplementation=*/[{ 197 return $_op.getOperation()->getOperand(0); 198 }] 199 >, 200 InterfaceMethod< 201 /*desc=*/"Return the output operand.", 202 /*retTy=*/"Value", 203 /*methodName=*/"output", 204 /*args=*/(ins), 205 /*methodBody=*/"", 206 /*defaultImplementation=*/[{ 207 return $_op.getOperation()->getOperand(1); 208 }] 209 >, 210 InterfaceMethod< 211 /*desc=*/"Return the result.", 212 /*retTy=*/"Value", 213 /*methodName=*/"result", 214 /*args=*/(ins), 215 /*methodBody=*/"", 216 /*defaultImplementation=*/[{ 217 if ($_op.getOperation()->getResults().empty()) 218 return nullptr; 219 return $_op.getOperation()->getResults().front(); 220 }] 221 >, 222 ]; 223} 224 225// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. 226def LinalgStructuredInterface 227 : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> { 228 let cppNamespace = "::mlir::linalg"; 229 let methods = [ 230 //===------------------------------------------------------------------===// 231 // Loop types handling. 232 //===------------------------------------------------------------------===// 233 InterfaceMethod< 234 /*desc=*/[{ 235 Return the number of parallel loops. 236 }], 237 /*retTy=*/"unsigned", 238 /*methodName=*/"getNumParallelLoops", 239 /*args=*/(ins), 240 /*methodBody=*/"", 241 /*defaultImplementation=*/[{ 242 return llvm::count($_op.getIteratorTypesArray(), 243 utils::IteratorType::parallel); 244 }] 245 >, 246 InterfaceMethod< 247 /*desc=*/[{ 248 Return true if all loops are parallel. 249 }], 250 /*retTy=*/"bool", 251 /*methodName=*/"isAllParallelLoops", 252 /*args=*/(ins), 253 /*methodBody=*/"", 254 /*defaultImplementation=*/[{ 255 return getNumParallelLoops() == getNumLoops(); 256 }] 257 >, 258 InterfaceMethod< 259 /*desc=*/[{ 260 Return the dims that are parallel loops. 261 }], 262 /*retTy=*/"void", 263 /*methodName=*/"getParallelDims", 264 /*args=*/(ins "SmallVectorImpl<unsigned> &":$res), 265 /*methodBody=*/"", 266 /*defaultImplementation=*/[{ 267 return findPositionsOfType($_op.getIteratorTypesArray(), 268 utils::IteratorType::parallel, res); 269 }] 270 >, 271 InterfaceMethod< 272 /*desc=*/[{ 273 Return the number of reduction loops. 274 }], 275 /*retTy=*/"unsigned", 276 /*methodName=*/"getNumReductionLoops", 277 /*args=*/(ins), 278 /*methodBody=*/"", 279 /*defaultImplementation=*/[{ 280 return llvm::count($_op.getIteratorTypesArray(), 281 utils::IteratorType::reduction); 282 }] 283 >, 284 InterfaceMethod< 285 /*desc=*/[{ 286 Return the dims that are reduction loops. 287 }], 288 /*retTy=*/"void", 289 /*methodName=*/"getReductionDims", 290 /*args=*/(ins "SmallVectorImpl<unsigned> &":$res), 291 /*methodBody=*/"", 292 /*defaultImplementation=*/[{ 293 return findPositionsOfType($_op.getIteratorTypesArray(), 294 utils::IteratorType::reduction, res); 295 }] 296 >, 297 InterfaceMethod< 298 /*desc=*/[{ 299 Return the total number of loops within the current operation. 300 }], 301 /*retTy=*/"unsigned", 302 /*methodName=*/"getNumLoops", 303 /*args=*/(ins), 304 /*methodBody=*/"", 305 /*defaultImplementation=*/[{ 306 return $_op.getIteratorTypesArray().size(); 307 }] 308 >, 309 InterfaceMethod< 310 /*desc=*/[{ 311 Returns true if the current operation has only one loop and it's a 312 reduction loop. 313 }], 314 /*retTy=*/"bool", 315 /*methodName=*/"hasSingleReductionLoop", 316 /*args=*/(ins), 317 /*methodBody=*/"", 318 /*defaultImplementation=*/[{ 319 auto iters = $_op.getIteratorTypesArray(); 320 return iters.size() == 1 && 321 llvm::count(iters, utils::IteratorType::reduction) == 1; 322 }]>, 323 //===------------------------------------------------------------------===// 324 // Input and Init arguments handling. 325 //===------------------------------------------------------------------===// 326 InterfaceMethod< 327 /*desc=*/[{ 328 Return true if the payload uses the value loaded from `opOperand`. This 329 is useful to avoid loading from "write-only" memory that may be 330 uninitialized, as well as properly cloning "read-write" operands. 331 }], 332 /*retTy=*/"bool", 333 /*methodName=*/"payloadUsesValueFromOperand", 334 /*args=*/(ins "OpOperand *":$opOperand), 335 /*methodBody=*/"", 336 /*defaultImplementation=*/[{ 337 unsigned bbArgNumber = opOperand->getOperandNumber(); 338 // Init tensors have uses. 339 return !getBlock()->getArgument(bbArgNumber).use_empty(); 340 }] 341 >, 342 InterfaceMethod< 343 /*desc=*/[{ 344 Returns true only if linalgOp takes one input and produces one result. 345 }], 346 /*retTy=*/"bool", 347 /*methodName=*/"isSingleInputOutput", 348 /*args=*/(ins), 349 /*methodBody=*/"", 350 /*defaultImplementation=*/[{ 351 return $_op.getNumDpsInputs() == 1 && $_op.getNumDpsInits() == 1; 352 }] 353 >, 354 InterfaceMethod< 355 /*desc=*/[{ 356 Return true if `opOperand` is an init tensor. This is true when it is 357 an output tensor operand whose value is used in the payload region. 358 }], 359 /*retTy=*/"bool", 360 /*methodName=*/"isInitTensor", 361 /*args=*/(ins "OpOperand *":$opOperand), 362 /*methodBody=*/"", 363 /*defaultImplementation=*/[{ 364 if (!$_op.isDpsInit(opOperand)) 365 return false; 366 return payloadUsesValueFromOperand(opOperand); 367 }] 368 >, 369 InterfaceMethod< 370 /*desc=*/[{ 371 Return the `opOperand` rank or zero for scalars or vectors not wrapped within a tensor or a memref. 372 }], 373 /*retTy=*/"int64_t", 374 /*methodName=*/"getRank", 375 /*args=*/(ins "OpOperand*":$opOperand), 376 /*methodBody=*/"", 377 /*defaultImplementation=*/[{ 378 assert(opOperand->getOwner() == this->getOperation()); 379 Type t = opOperand->get().getType(); 380 // A VectorType is an elemental type, do not consider its rank for the operand. 381 if (isa<VectorType>(t)) 382 return 0; 383 // Tensor and Memref container types have a rank. 384 if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) { 385 // Failsafe. 386 assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) && 387 "expected a ranked tensor or memref in LinalgInterface::getRank"); 388 return shapedType.getRank(); 389 } 390 return 0; 391 }] 392 >, 393 InterfaceMethod< 394 /*desc=*/[{ 395 Return the input block arguments of the region. 396 }], 397 /*retTy=*/"Block::BlockArgListType", 398 /*methodName=*/"getRegionInputArgs", 399 /*args=*/(ins), 400 /*methodBody=*/"", 401 /*defaultImplementation=*/[{ 402 return getBlock()->getArguments().take_front($_op.getNumDpsInputs()); 403 }] 404 >, 405 InterfaceMethod< 406 /*desc=*/[{ 407 Return the output block arguments of the region. 408 }], 409 /*retTy=*/"Block::BlockArgListType", 410 /*methodName=*/"getRegionOutputArgs", 411 /*args=*/(ins), 412 /*methodBody=*/"", 413 /*defaultImplementation=*/[{ 414 return getBlock()->getArguments().take_back($_op.getNumDpsInits()); 415 }] 416 >, 417 InterfaceMethod< 418 /*desc=*/[{ 419 Return the `opOperand` shape or an empty vector for scalars or vectors 420 not wrapped within a tensor or a memref. 421 }], 422 /*retTy=*/"ArrayRef<int64_t>", 423 /*methodName=*/"getShape", 424 /*args=*/(ins "OpOperand*":$opOperand), 425 /*methodBody=*/"", 426 /*defaultImplementation=*/[{ 427 assert(opOperand->getOwner() == this->getOperation()); 428 Type t = opOperand->get().getType(); 429 // A VectorType is an elemental type, do not consider its rank for the operand. 430 if (isa<VectorType>(t)) 431 return {}; 432 if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) { 433 // Failsafe. 434 assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) && 435 "expected a ranked tensor or memref in LinalgInterface::getRank"); 436 return shapedType.getShape(); 437 } 438 return {}; 439 }] 440 >, 441 InterfaceMethod< 442 /*desc=*/[{ 443 Return the block argument for an `opOperand`. 444 }], 445 /*retTy=*/"BlockArgument", 446 /*methodName=*/"getMatchingBlockArgument", 447 /*args=*/(ins "OpOperand *":$opOperand), 448 /*methodBody=*/"", 449 /*defaultImplementation=*/[{ 450 assert(opOperand->getOwner() == this->getOperation()); 451 return getBlock()->getArgument(opOperand->getOperandNumber()); 452 }] 453 >, 454 InterfaceMethod< 455 /*desc=*/[{ 456 Return the operand for a `blockArgument`. 457 }], 458 /*retTy=*/"OpOperand *", 459 /*methodName=*/"getMatchingOpOperand", 460 /*args=*/(ins "BlockArgument":$blockArgument), 461 /*methodBody=*/"", 462 /*defaultImplementation=*/[{ 463 assert(blockArgument.getOwner() == getBlock()); 464 return &this->getOperation()->getOpOperand( 465 blockArgument.getArgNumber()); 466 }] 467 >, 468 InterfaceMethod< 469 /*desc=*/[{ 470 Return the input or output indexing map for `opOperand`. 471 }], 472 /*retTy=*/"AffineMap", 473 /*methodName=*/"getMatchingIndexingMap", 474 /*args=*/(ins "OpOperand*":$opOperand), 475 /*methodBody=*/"", 476 /*defaultImplementation=*/[{ 477 assert(opOperand->getOwner() == this->getOperation()); 478 auto indexingMaps = 479 $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>(); 480 return *(indexingMaps.begin() + opOperand->getOperandNumber()); 481 }] 482 >, 483 InterfaceMethod< 484 /*desc=*/[{ 485 Return the indexing map for a `result`. 486 }], 487 /*retTy=*/"AffineMap", 488 /*methodName=*/"getIndexingMapMatchingResult", 489 /*args=*/(ins "OpResult":$result), 490 /*methodBody=*/"", 491 /*defaultImplementation=*/[{ 492 assert(result.getOwner() == this->getOperation()); 493 auto indexingMaps = 494 $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>(); 495 return *(indexingMaps.begin() + $_op.getNumDpsInputs() + 496 result.getResultNumber()); 497 }] 498 >, 499 InterfaceMethod< 500 /*desc=*/[{ 501 Return the value yielded by the region corresponding to an output 502 `opOperand`. 503 }], 504 /*retTy=*/"OpOperand *", 505 /*methodName=*/"getMatchingYieldValue", 506 /*args=*/(ins "OpOperand*":$opOperand), 507 /*methodBody=*/"", 508 /*defaultImplementation=*/[{ 509 assert(opOperand->getOwner() == this->getOperation()); 510 int64_t resultIndex = 511 opOperand->getOperandNumber() - $_op.getNumDpsInputs(); 512 assert(resultIndex >= 0 && 513 resultIndex < this->getOperation()->getNumResults()); 514 Operation *yieldOp = getBlock()->getTerminator(); 515 return &yieldOp->getOpOperand(resultIndex); 516 }] 517 >, 518 //===------------------------------------------------------------------===// 519 // Other interface methods. 520 //===------------------------------------------------------------------===// 521 InterfaceMethod< 522 /*desc=*/[{ 523 Return the single block constituting the body of the operation by 524 calling the getBody method on the concrete operation. 525 }], 526 /*retTy=*/"Block*", 527 /*methodName=*/"getBlock", 528 /*args=*/(ins), 529 /*methodBody=*/"", 530 /*defaultImplementation=*/[{ 531 // Assume the concrete operation implements the 532 // SingleBlockImplicitTerminator trait. 533 return $_op.getBody(); 534 }] 535 >, 536 InterfaceMethod< 537 /*desc=*/[{ 538 Return iterator types in the current operation. 539 540 Default implementation assumes that the operation has an attribute 541 `iterator_types`, but it's not always the case. Sometimes iterator types 542 can be infered from other parameters and in such cases default 543 getIteratorTypesArray should be overriden. 544 }], 545 /*retTy=*/"SmallVector<utils::IteratorType>", 546 /*methodName=*/"getIteratorTypesArray", 547 /*args=*/(ins), 548 /*methodBody=*/"", 549 /*defaultImplementation=*/[{ 550 auto range = $_op.getIteratorTypes() 551 .template getAsValueRange<IteratorTypeAttr, 552 utils::IteratorType>(); 553 return {range.begin(), range.end()}; 554 }] 555 >, 556 InterfaceMethod< 557 /*desc=*/[{ 558 Return true if the indexing map is depending on the current op instance. 559 This means that the indexing map is dynamically synthesized by using the 560 op instance's concrete attributes, instead of being static for all 561 instances of the same op kind. 562 }], 563 /*retTy=*/"bool", 564 /*methodName=*/"hasDynamicIndexingMaps", 565 /*args=*/(ins), 566 /*methodBody=*/"", 567 /*defaultImplementation=*/[{ return false; }] 568 >, 569 InterfaceMethod< 570 /*desc=*/[{ 571 Verify all attributes used by indexing maps are valid. 572 }], 573 /*retTy=*/"LogicalResult", 574 /*methodName=*/"verifyIndexingMapRequiredAttributes", 575 /*args=*/(ins), 576 /*methodBody=*/"", 577 /*defaultImplementation=*/[{ return success(); }] 578 >, 579 InterfaceMethod< 580 /*desc=*/[{ 581 Return the indexing maps attribute within the current operation. 582 }], 583 /*retTy=*/"ArrayAttr", 584 /*methodName=*/"getIndexingMaps" 585 >, 586 InterfaceMethod< 587 /*desc=*/[{ 588 Return the indexing maps within the current operation. 589 }], 590 /*retTy=*/"SmallVector<AffineMap>", 591 /*methodName=*/"getIndexingMapsArray", 592 /*args=*/(ins), 593 /*methodBody=*/"", 594 /*defaultImplementation=*/[{ 595 auto range = $_op.getIndexingMaps() 596 .template getAsValueRange<AffineMapAttr>(); 597 return {range.begin(), range.end()}; 598 }] 599 >, 600 InterfaceMethod< 601 /*desc=*/[{ 602 Return true if any of the operands has a dynamic shape. 603 }], 604 /*retTy=*/"bool", 605 /*methodName=*/"hasDynamicShape", 606 /*args=*/(ins), 607 /*methodBody=*/"", 608 /*defaultImplementation=*/[{ 609 return llvm::any_of(getStaticShape(), ShapedType::isDynamic); 610 }] 611 >, 612 InterfaceMethod< 613 /*desc=*/[{ 614 Return the name registered for this op when lowering to an external 615 library call. 616 }], 617 /*retTy=*/"std::string", 618 /*methodName=*/"getLibraryCallName", 619 /*args=*/(ins), 620 /*methodBody=*/"", 621 /*defaultImplementation=*/[{ 622 return $_op.getLibraryCallName(); 623 }] 624 >, 625 InterfaceMethod< 626 /*desc=*/[{ 627 Return whether the op accesses the iteration indices. 628 }], 629 /*retTy=*/"bool", 630 /*methodName=*/"hasIndexSemantics", 631 /*args=*/(ins), 632 /*methodBody=*/"", 633 /*defaultImplementation=*/"" 634 >, 635 InterfaceMethod< 636 /*desc=*/[{ 637 Return op operands that have a corresponding argument in the basic block. 638 By default, the block should have an argument for each operand, but there 639 are expection. For example, in `map` output operand isn't used in 640 the block. 641 }], 642 /*retTy=*/"::llvm::SmallVector<OpOperand *>", 643 /*methodName=*/"getOpOperandsMatchingBBargs", 644 /*args=*/(ins), 645 /*methodBody=*/"", 646 /*defaultImplementation=*/[{ 647 ::llvm::SmallVector<OpOperand *> result; 648 result.reserve($_op->getNumOperands()); 649 llvm::transform( 650 this->getOperation()->getOpOperands(), 651 std::back_inserter(result), 652 [](OpOperand &opOperand) { return &opOperand; }); 653 return result; 654 }] 655 >, 656 InterfaceMethod< 657 /*desc=*/[{ 658 Given a dimension of the iteration space of a Linalg operation, finds an 659 operand in the operation that is defined on such dimension. Returns 660 whether such operand was found or not. If found, also returns the 661 operand value and the dimension position within the operand. 662 }], 663 /*retTy=*/"LogicalResult", 664 /*methodName=*/"mapIterationSpaceDimToOperandDim", 665 /*args=*/(ins "unsigned":$dimPos, 666 "::mlir::Value &":$operand, 667 "unsigned &":$operandDimPos), 668 /*methodBody=*/"", 669 /*defaultImplementation=*/[{ 670 // Retrieve the operand and its dimension position from the first 671 // operand with a permutation map that is defined on such dimension. 672 for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) { 673 if (idxMap.isProjectedPermutation()) { 674 if (auto mayOperandDim = idxMap.getResultPosition( 675 getAffineDimExpr(dimPos, idxMap.getContext()))) { 676 operand = $_op->getOperand(i); 677 operandDimPos = *mayOperandDim; 678 return success(); 679 } 680 } 681 } 682 683 return failure(); 684 }] 685 >, 686 InterfaceMethod< 687 /*desc=*/[{ 688 Given a dimension of the iteration space of a Linalg operation, finds 689 all the operands in the operation that are defined on such dimension. 690 Returns all the operand values found and their dimension positions in 691 `operandDimPairs`. 692 }], 693 /*retTy=*/"void", 694 /*methodName=*/"mapIterationSpaceDimToAllOperandDims", 695 /*args=*/(ins "unsigned":$dimPos, 696 "mlir::SmallVectorImpl<std::pair<Value, unsigned>>&":$operandDimPairs), 697 /*methodBody=*/"", 698 /*defaultImplementation=*/[{ 699 for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) { 700 if (idxMap.isProjectedPermutation()) { 701 if (auto mayOperandDim = idxMap.getResultPosition( 702 getAffineDimExpr(dimPos, idxMap.getContext()))) { 703 operandDimPairs.push_back({$_op->getOperand(i), *mayOperandDim}); 704 } 705 } 706 } 707 708 return; 709 }] 710 >, 711 InterfaceMethod< 712 /*desc=*/[{ 713 Return true if the user has supplied an explicit indexing maps for this op. 714 }], 715 /*retTy=*/"bool", 716 /*methodName=*/"hasUserDefinedMaps", 717 /*args=*/(ins), 718 /*methodBody=*/"", 719 /*defaultImplementation=*/[{ return false; }] 720 >, 721 //===------------------------------------------------------------------===// 722 // Linalg generalization hooks. 723 //===------------------------------------------------------------------===// 724 InterfaceMethod< 725 /*desc=*/[{ 726 Hook to provide a custom AffineMap used to compute all the operand 727 subshapes given loop bounds. This is used to answer the question: "given 728 an iteration space over the codomain, what are the subshapes of the 729 operands involved in the computation". 730 The default behavior is to just concatenate all the indexing maps. 731 A custom AffineMap allows providing a map that can be used to 732 compute subshapes even in cases where the concatenation of indexing maps 733 (i.e. the data traversal order) is not a simple permutation of the loop 734 traversal order. It is then possible to define ops with skewed data 735 traversal order for which we can still easily compute hyperrectangular 736 loop bounds and subviews. 737 }], 738 /*retTy=*/"AffineMap", 739 /*methodName=*/"getLoopsToShapesMap", 740 /*args=*/(ins), 741 /*methodBody=*/"", 742 /*defaultImplementation=*/[{ 743 auto maps = $_op.getIndexingMapsArray(); 744 return concatAffineMaps(maps, $_op.getContext()); 745 }] 746 >, 747 InterfaceMethod< 748 /*desc=*/[{ 749 Hook to provide a custom AffineMap used to construct the 750 hyperrectangular loop iteration space given all the operand subshapes. 751 This is used to answer the question: 752 "Given a list of operand ranges, what is the subportion of the iteration 753 space involved in the computation". 754 This is the inverse problem of `getLoopsToShapesMap`. 755 Return the empty AffineMap when such an AffineMap cannot be constructed. 756 The default behavior is based on a very simple inference procedure that 757 only works with permutation affine maps. 758 A more advanced Tensor-Comprehension like inference is possible but has 759 proven to be ambiguous in unfavorable case. 760 A safer and more robust alternative is to allow each op to define 761 its own AffineMap. 762 }], 763 /*retTy=*/"AffineMap", 764 /*methodName=*/"getShapesToLoopsMap", 765 /*args=*/(ins), 766 /*methodBody=*/"", 767 /*defaultImplementation=*/[{ 768 return inversePermutation(getLoopsToShapesMap()); 769 }] 770 >, 771 InterfaceMethod< 772 /*desc=*/[{ 773 Checks if the given operands can be dropped, and the remaining 774 operands can still compute the bounds of the op. 775 }], 776 /*retTy=*/"bool", 777 /*methodName=*/"canOpOperandsBeDropped", 778 /*args=*/(ins "ArrayRef<OpOperand *>":$droppedOperands), 779 /*methodBody=*/"", 780 /*defaultImplementation=*/[{ 781 return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); 782 }] 783 >, 784 InterfaceMethod< 785 /*desc=*/[{ 786 Like `getShape`, but only returns statically-known information, without 787 generating any new IR. For each shape dimension, returns >=0 if that 788 dimension is statically known, or ShapedType::kDynamic otherwise. 789 }], 790 /*retTy=*/"SmallVector<int64_t>", 791 /*methodName=*/"getStaticShape", 792 /*args=*/(ins), 793 /*methodBody=*/"", 794 /*defaultImplementation=*/[{ 795 SmallVector<int64_t> res; 796 for (OpOperand &opOperand : this->getOperation()->getOpOperands()) 797 llvm::append_range(res, getShape(&opOperand)); 798 return res; 799 }] 800 >, 801 InterfaceMethod< 802 /*desc=*/[{ 803 Returns the statically-known loop ranges. Composes 804 `getShapesToLoopsMap()` with the result of `getStaticShape`. 805 Returns ShapedType::kDynamic for non-statically-known loop ranges. 806 This is expected to be called by a valid Linalg op 807 }], 808 /*retTy=*/"SmallVector<int64_t, 4>", 809 /*methodName=*/"getStaticLoopRanges", 810 /*args=*/(ins), 811 /*methodBody=*/"", 812 /*defaultImplementation=*/[{ 813 SmallVector<int64_t> viewSizes = getStaticShape(); 814 AffineMap invertedMap = getShapesToLoopsMap(); 815 assert(invertedMap && "expected a valid Linalg op to call the method"); 816 return invertedMap.compose(viewSizes); 817 }] 818 >, 819 //===------------------------------------------------------------------===// 820 // Other static interface methods. 821 //===------------------------------------------------------------------===// 822 StaticInterfaceMethod< 823 /*desc=*/[{ 824 Returns the region builder for constructing the body for linalg.generic. 825 Returns a null function if this named op does not define a region 826 builder. 827 }], 828 /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>", 829 /*methodName=*/"getRegionBuilder", 830 (ins), 831 [{ return ConcreteOp::getRegionBuilder(); }] 832 >, 833 InterfaceMethod< 834 /*desc=*/[{ 835 Return true if all the indexing maps are projected permutations. 836 Otherwise return false. 837 }], 838 /*retTy=*/"bool", 839 /*methodName=*/"hasOnlyProjectedPermutations", 840 (ins), 841 [{ 842 return llvm::all_of($_op.getIndexingMapsArray(), 843 [](AffineMap map) { return map.isProjectedPermutation(); }); 844 }] 845 > 846 ]; 847 848 let extraClassDeclaration = [{ 849 /// Return the flat list of all operand dimension sizes in the order they 850 /// appear in the operands. 851 SmallVector<OpFoldResult> createFlatListOfOperandDims(OpBuilder &, Location); 852 853 /// Return the flat list of all operands' static dimension sizes in the 854 /// order they appear in the operands. All operand dimension sizes have to 855 /// be statically known. 856 SmallVector<int64_t, 4> createFlatListOfOperandStaticDims(); 857 858 /// Create the loop ranges to materialize the computation over the current 859 /// operands. This is done by applying `getShapesToLoopsMap` to 860 /// `createFlatListOfOperandDims`. 861 SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc); 862 863 /// Compute the static loop sizes necessary to vectorize the computation. 864 /// This is done by applying `getShapesToLoopsMap` to 865 /// `createFlatListOfOperandStaticDims`. 866 SmallVector<int64_t, 4> computeStaticLoopSizes(); 867 868 /// Returns the value that expresses the shape of the output in terms of 869 /// shape of the input operands where possible 870 LogicalResult reifyResultShapes(OpBuilder &b, 871 ReifiedRankedShapedTypeDims &reifiedReturnShapes); 872 873 /// Return the index in the indexingMaps vector that corresponds to this `opOperand` 874 int64_t getIndexingMapIndex(OpOperand *opOperand); 875 }]; 876 877 let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; 878 let verifyWithRegions = 1; 879} 880 881def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> { 882 let description = [{ 883 Interface for decomposing aggregated operations into a sequence of simpler 884 ops. 885 }]; 886 let cppNamespace = "::mlir::linalg"; 887 let methods = [ 888 InterfaceMethod< 889 /*desc=*/[{ 890 Method to decompose the operation into simpler operations. 891 892 On success, this method returns one `Value` per result in the 893 original operation. 894 The order of the returned values must match the order of the 895 original values. 896 In other words, the returned vector can be used directly with 897 `RewriterBase::replaceOp(this, returnedValues)`. 898 }], 899 /*retType=*/"FailureOr<SmallVector<Value>>", 900 /*methodName=*/"decomposeOperation", 901 /*args=*/(ins 902 "OpBuilder &":$b), 903 /*methodBody=*/"", 904 /*defaultImplementation=*/[{ 905 return {}; 906 }] 907 > 908 ]; 909} 910 911#endif // LINALG_IR_LINALGINTERFACES 912