1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Arith/Utils/Utils.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 16 #include "mlir/Dialect/Utils/StaticValueUtils.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/BuiltinAttributeInterfaces.h" 19 #include "mlir/IR/BuiltinTypeInterfaces.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/IRMapping.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/OpDefinition.h" 24 #include "mlir/IR/TypeUtilities.h" 25 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Support/LLVM.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ADT/SmallBitVector.h" 31 #include "llvm/ADT/StringRef.h" 32 #include "llvm/Support/MathExtras.h" 33 #include <algorithm> 34 #include <optional> 35 36 using namespace mlir; 37 using namespace mlir::tensor; 38 39 using llvm::divideCeilSigned; 40 using llvm::divideFloorSigned; 41 using llvm::mod; 42 43 /// Materialize a single constant operation from a given attribute value with 44 /// the desired resultant type. 45 Operation *TensorDialect::materializeConstant(OpBuilder &builder, 46 Attribute value, Type type, 47 Location loc) { 48 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) 49 return op; 50 if (complex::ConstantOp::isBuildableWith(value, type)) 51 return builder.create<complex::ConstantOp>(loc, type, 52 llvm::cast<ArrayAttr>(value)); 53 return nullptr; 54 } 55 56 OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value, 57 int64_t dim) { 58 auto tensorType = llvm::cast<RankedTensorType>(value.getType()); 59 SmallVector<OpFoldResult> result; 60 if (tensorType.isDynamicDim(dim)) 61 return builder.createOrFold<tensor::DimOp>(loc, value, dim); 62 63 return builder.getIndexAttr(tensorType.getDimSize(dim)); 64 } 65 66 SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder, 67 Location loc, Value value) { 68 auto tensorType = llvm::cast<RankedTensorType>(value.getType()); 69 SmallVector<OpFoldResult> result; 70 for (int64_t i = 0; i < tensorType.getRank(); ++i) 71 result.push_back(getMixedSize(builder, loc, value, i)); 72 return result; 73 } 74 75 FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc, 76 OpResult opResult) { 77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType()); 78 assert(tensorType && "expected tensor type"); 79 80 // If the op has a destination, it implements DestinationStyleOpInterface and 81 // we can query the destination operand from that interface. 82 auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>(); 83 if (destOp) 84 return destOp.getTiedOpOperand(opResult)->get(); 85 86 // Otherwise, create a new destination tensor with the same shape. 87 OpBuilder::InsertionGuard g(b); 88 b.setInsertionPoint(opResult.getDefiningOp()); 89 90 // Compute sizes. 91 SmallVector<OpFoldResult> mixedSizes; 92 if (!tensorType.hasStaticShape()) { 93 // Dynamic shape: Query ReifyRankedShapedTypeOpInterface. 94 ReifiedRankedShapedTypeDims reifiedShapes; 95 if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes))) 96 return failure(); 97 mixedSizes = reifiedShapes[opResult.getResultNumber()]; 98 } else { 99 // Static shape: Take static sizes directly. 100 for (int64_t sz : tensorType.getShape()) 101 mixedSizes.push_back(b.getIndexAttr(sz)); 102 } 103 104 // Create empty tensor. 105 Value emptyTensor = 106 b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType()); 107 return emptyTensor; 108 } 109 110 LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc, 111 Operation *op, 112 SmallVector<Value> &result) { 113 for (OpResult opResult : op->getResults()) { 114 if (llvm::isa<TensorType>(opResult.getType())) { 115 FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult); 116 if (failed(destination)) 117 return failure(); 118 result.push_back(*destination); 119 } 120 } 121 return success(); 122 } 123 124 bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) { 125 if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) { 126 if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2)) 127 return rtp1.getShape() == rtp2.getShape() && 128 rtp1.getElementType() == rtp2.getElementType(); 129 return false; 130 } 131 return tp1 == tp2; // default implementation 132 } 133 134 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or 135 /// rank-extending tensor.insert_slice op. 136 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape, 137 ArrayRef<OpFoldResult> mixedSizes) { 138 llvm::SmallBitVector droppedDims(mixedSizes.size()); 139 int64_t shapePos = reducedShape.size() - 1; 140 141 for (const auto &size : enumerate(llvm::reverse(mixedSizes))) { 142 size_t idx = mixedSizes.size() - size.index() - 1; 143 // Rank-reduced dims must have a static unit dimension. 144 bool isStaticUnitSize = 145 isa<Attribute>(size.value()) && 146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1; 147 148 if (shapePos < 0) { 149 // There are no more dims in the reduced shape. All remaining sizes must 150 // be rank-reduced dims. 151 assert(isStaticUnitSize && "expected unit dim"); 152 droppedDims.set(idx); 153 continue; 154 } 155 156 // Dim is preserved if the size is not a static 1. 157 if (!isStaticUnitSize) { 158 --shapePos; 159 continue; 160 } 161 162 // Dim is preserved if the reduced shape dim is also 1. 163 if (reducedShape[shapePos] == 1) { 164 --shapePos; 165 continue; 166 } 167 168 // Otherwise: Dim is dropped. 169 droppedDims.set(idx); 170 } 171 172 assert(shapePos < 0 && "dimension mismatch"); 173 return droppedDims; 174 } 175 176 /// Given a ranked tensor type and a range of values that defines its dynamic 177 /// dimension sizes, turn all dynamic sizes that have a constant value into 178 /// static dimension sizes. 179 static RankedTensorType 180 foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, 181 SmallVector<Value> &foldedDynamicSizes) { 182 SmallVector<int64_t> staticShape(type.getShape()); 183 assert(type.getNumDynamicDims() == dynamicSizes.size() && 184 "incorrect number of dynamic sizes"); 185 186 // Compute new static and dynamic sizes. 187 unsigned ctr = 0; 188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) { 189 if (type.isDynamicDim(i)) { 190 Value dynamicSize = dynamicSizes[ctr++]; 191 std::optional<int64_t> cst = getConstantIntValue(dynamicSize); 192 if (cst.has_value()) { 193 // Dynamic size must be non-negative. 194 if (cst.value() < 0) { 195 foldedDynamicSizes.push_back(dynamicSize); 196 continue; 197 } 198 staticShape[i] = *cst; 199 } else { 200 foldedDynamicSizes.push_back(dynamicSize); 201 } 202 } 203 } 204 205 return RankedTensorType::get(staticShape, type.getElementType(), 206 type.getEncoding()); 207 } 208 209 //===----------------------------------------------------------------------===// 210 // BitcastOp 211 //===----------------------------------------------------------------------===// 212 213 bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 214 if (inputs.size() != 1 || outputs.size() != 1) 215 return false; 216 Type a = inputs.front(), b = outputs.front(); 217 auto aT = dyn_cast<TensorType>(a); 218 auto bT = dyn_cast<TensorType>(b); 219 if (!aT || !bT) 220 return false; 221 222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth()) 223 return false; 224 225 return succeeded(verifyCompatibleShape(aT, bT)); 226 } 227 228 namespace { 229 230 /// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast 231 /// operation. 232 struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> { 233 using OpRewritePattern<BitcastOp>::OpRewritePattern; 234 235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast, 236 PatternRewriter &rewriter) const final { 237 auto tensorBitcastOperand = 238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>(); 239 if (!tensorBitcastOperand) 240 return failure(); 241 242 auto resultType = cast<TensorType>(tensorBitcast.getType()); 243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType, 244 tensorBitcastOperand.getOperand()); 245 return success(); 246 } 247 }; 248 249 } // namespace 250 251 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, 252 MLIRContext *context) { 253 results.add<ChainedTensorBitcast>(context); 254 } 255 256 //===----------------------------------------------------------------------===// 257 // CastOp 258 //===----------------------------------------------------------------------===// 259 260 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 261 setNameFn(getResult(), "cast"); 262 } 263 264 /// Returns true if `target` is a ranked tensor type that preserves static 265 /// information available in the `source` ranked tensor type. 266 bool mlir::tensor::preservesStaticInformation(Type source, Type target) { 267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source); 268 auto targetType = llvm::dyn_cast<RankedTensorType>(target); 269 270 // Requires RankedTensorType. 271 if (!sourceType || !targetType) 272 return false; 273 274 // Requires same elemental type. 275 if (sourceType.getElementType() != targetType.getElementType()) 276 return false; 277 278 // Requires same rank. 279 if (sourceType.getRank() != targetType.getRank()) 280 return false; 281 282 // Requires same encoding. 283 if (sourceType.getEncoding() != targetType.getEncoding()) 284 return false; 285 286 // If cast is towards more static sizes along any dimension, don't fold. 287 for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) { 288 if (!ShapedType::isDynamic(std::get<0>(t)) && 289 ShapedType::isDynamic(std::get<1>(t))) 290 return false; 291 } 292 293 return true; 294 } 295 296 /// Determines whether tensor::CastOp casts to a more dynamic version of the 297 /// source tensor. This is useful to fold a tensor.cast into a consuming op and 298 /// implement canonicalization patterns for ops in different dialects that may 299 /// consume the results of tensor.cast operations. Such foldable tensor.cast 300 /// operations are typically inserted as `slice` ops and are canonicalized, 301 /// to preserve the type compatibility of their uses. 302 /// 303 /// Returns true when all conditions are met: 304 /// 1. source and result are ranked tensors with same element type and rank. 305 /// 2. the tensor type has more static information than the result 306 /// 307 /// Example: 308 /// ```mlir 309 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 310 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 311 /// ``` 312 /// 313 /// folds into: 314 /// 315 /// ```mlir 316 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 317 /// ``` 318 bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { 319 if (!castOp) 320 return false; 321 322 // Can fold if the source of cast has at least as much static information as 323 // its results. 324 return preservesStaticInformation(castOp.getType(), 325 castOp.getSource().getType()); 326 } 327 328 /// Determines whether the tensor::CastOp casts to a more static version of the 329 /// source tensor. This is useful to fold into a producing op and implement 330 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer 331 /// being from different dialects. Returns true when all conditions are met: 332 /// 1. source and result and ranked tensors with same element type and rank. 333 /// 2. the result type has more static information than the source. 334 /// 335 /// Example: 336 /// ```mlir 337 /// %1 = producer ... : tensor<?x?xf32> 338 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32> 339 /// ``` 340 /// 341 /// can be canonicalized to : 342 /// 343 /// ```mlir 344 /// %2 = producer ... : tensor<8x16xf32> 345 /// ``` 346 /// Not all ops might be canonicalizable this way, but for those that can be, 347 /// this method provides a check that it is worth doing the canonicalization. 348 bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) { 349 if (!castOp) 350 return false; 351 return preservesStaticInformation(castOp.getSource().getType(), 352 castOp.getType()); 353 } 354 355 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp 356 /// that can be folded. 357 LogicalResult mlir::tensor::foldTensorCast(Operation *op) { 358 bool folded = false; 359 for (OpOperand &operand : op->getOpOperands()) { 360 auto castOp = operand.get().getDefiningOp<tensor::CastOp>(); 361 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { 362 operand.set(castOp.getOperand()); 363 folded = true; 364 } 365 } 366 return success(folded); 367 } 368 369 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 370 if (inputs.size() != 1 || outputs.size() != 1) 371 return false; 372 Type a = inputs.front(), b = outputs.front(); 373 auto aT = llvm::dyn_cast<TensorType>(a); 374 auto bT = llvm::dyn_cast<TensorType>(b); 375 if (!aT || !bT) 376 return false; 377 378 if (aT.getElementType() != bT.getElementType()) 379 return false; 380 381 return succeeded(verifyCompatibleShape(aT, bT)); 382 } 383 384 /// Compute a TensorType that has the joined shape knowledge of the two 385 /// given TensorTypes. The element types need to match. 386 static TensorType joinShapes(TensorType one, TensorType two) { 387 assert(one.getElementType() == two.getElementType()); 388 389 if (!one.hasRank()) 390 return two; 391 if (!two.hasRank()) 392 return one; 393 394 int64_t rank = one.getRank(); 395 if (rank != two.getRank()) 396 return {}; 397 398 SmallVector<int64_t, 4> join; 399 join.reserve(rank); 400 for (int64_t i = 0; i < rank; ++i) { 401 if (one.isDynamicDim(i)) { 402 join.push_back(two.getDimSize(i)); 403 continue; 404 } 405 if (two.isDynamicDim(i)) { 406 join.push_back(one.getDimSize(i)); 407 continue; 408 } 409 if (one.getDimSize(i) != two.getDimSize(i)) 410 return {}; 411 join.push_back(one.getDimSize(i)); 412 } 413 return RankedTensorType::get(join, one.getElementType()); 414 } 415 416 namespace { 417 418 /// Replaces chains of two tensor.cast operations by a single tensor.cast 419 /// operation if doing so does not remove runtime constraints. 420 struct ChainedTensorCast : public OpRewritePattern<CastOp> { 421 using OpRewritePattern<CastOp>::OpRewritePattern; 422 423 LogicalResult matchAndRewrite(CastOp tensorCast, 424 PatternRewriter &rewriter) const final { 425 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>(); 426 427 if (!tensorCastOperand) 428 return failure(); 429 430 auto sourceType = 431 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType()); 432 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType()); 433 auto resultType = llvm::cast<TensorType>(tensorCast.getType()); 434 435 // We can remove the intermediate cast if joining all three produces the 436 // same result as just joining the source and result shapes. 437 auto firstJoin = 438 joinShapes(joinShapes(sourceType, intermediateType), resultType); 439 440 // The join might not exist if the cast sequence would fail at runtime. 441 if (!firstJoin) 442 return failure(); 443 444 // The newJoin always exists if the above join exists, it might just contain 445 // less information. If so, we cannot drop the intermediate cast, as doing 446 // so would remove runtime checks. 447 auto newJoin = joinShapes(sourceType, resultType); 448 if (firstJoin != newJoin) 449 return failure(); 450 451 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType, 452 tensorCastOperand.getOperand()); 453 return success(); 454 } 455 }; 456 457 /// Fold tensor.cast into tesor.extract_slice producer. 458 /// Example: 459 /// ``` 460 /// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : 461 /// tensor<128x512xf32> to tensor<?x512xf32> 462 /// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32> 463 /// ``` 464 /// -> 465 /// ``` 466 /// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] : 467 /// tensor<128x512xf32> to tensor<16x512xf32> 468 /// ``` 469 struct TensorCastExtractSlice : public OpRewritePattern<CastOp> { 470 using OpRewritePattern<CastOp>::OpRewritePattern; 471 472 LogicalResult matchAndRewrite(CastOp tensorCast, 473 PatternRewriter &rewriter) const final { 474 auto extractOperand = 475 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>(); 476 477 // Cannot fold cast to unranked tensor. 478 auto rankedResultType = 479 llvm::dyn_cast<RankedTensorType>(tensorCast.getType()); 480 if (!rankedResultType) 481 return failure(); 482 483 if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || 484 rankedResultType.getShape() == 485 llvm::cast<RankedTensorType>(tensorCast.getSource().getType()) 486 .getShape()) 487 return failure(); 488 489 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes(); 490 auto dimMask = computeRankReductionMask( 491 extractOperand.getStaticSizes(), extractOperand.getType().getShape()); 492 size_t dimIndex = 0; 493 for (size_t i = 0, e = sizes.size(); i < e; i++) { 494 if (dimMask && dimMask->count(i)) 495 continue; 496 int64_t dim = rankedResultType.getShape()[dimIndex++]; 497 if (ShapedType::isDynamic(dim)) 498 continue; 499 sizes[i] = rewriter.getIndexAttr(dim); 500 } 501 502 rewriter.replaceOpWithNewOp<ExtractSliceOp>( 503 tensorCast, rankedResultType, extractOperand.getSource(), 504 extractOperand.getMixedOffsets(), sizes, 505 extractOperand.getMixedStrides()); 506 return success(); 507 } 508 }; 509 510 } // namespace 511 512 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, 513 MLIRContext *context) { 514 results.add<ChainedTensorCast, TensorCastExtractSlice>(context); 515 } 516 517 //===----------------------------------------------------------------------===// 518 // ConcatOp 519 //===----------------------------------------------------------------------===// 520 521 RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) { 522 assert(!inputTypes.empty() && "cannot concatenate 0 tensors"); 523 auto tensorTypes = 524 llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) { 525 return llvm::cast<RankedTensorType>(type); 526 })); 527 int64_t concatRank = tensorTypes[0].getRank(); 528 529 // The concatenation dim must be in the range [0, rank). 530 assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim"); 531 532 SmallVector<int64_t> sizes(concatRank); 533 for (int64_t i = 0, e = concatRank; i < e; ++i) { 534 if (i == dim) 535 continue; 536 SaturatedInteger size; 537 for (auto tensorType : tensorTypes) 538 size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i))); 539 sizes[i] = size.asInteger(); 540 } 541 auto concatSize = SaturatedInteger::wrap(0); 542 for (auto tensorType : tensorTypes) 543 concatSize = 544 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim)); 545 sizes[dim] = concatSize.asInteger(); 546 return RankedTensorType::get(sizes, tensorTypes[0].getElementType()); 547 } 548 549 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim, 550 ValueRange inputs) { 551 FailureOr<RankedTensorType> resultType = 552 inferResultType(dim, inputs.getTypes()); 553 assert(succeeded(resultType) && "failed to infer concatenation result type"); 554 build(builder, result, *resultType, dim, inputs); 555 } 556 557 LogicalResult ConcatOp::verify() { 558 if (getInputs().size() < 1) 559 return emitOpError("requires at least one input"); 560 561 SmallVector<RankedTensorType> inputTypes; 562 for (auto input : getInputs()) 563 inputTypes.push_back(cast<RankedTensorType>(input.getType())); 564 565 RankedTensorType resultType = getResultType(); 566 int64_t resultRank = getRank(); 567 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) { 568 return type.getRank() != resultRank; 569 })) 570 return emitOpError("rank of concatenated inputs must match result rank"); 571 572 Type resultElementType = resultType.getElementType(); 573 if (llvm::any_of(inputTypes, [&](RankedTensorType type) { 574 return type.getElementType() != resultElementType; 575 })) 576 return emitOpError("inputs and result element type must match"); 577 578 int64_t dim = getDim(); 579 if (dim >= resultRank) 580 return emitOpError("concatenation dim must be less than the tensor rank"); 581 582 SmallVector<int64_t> sizes(resultRank); 583 for (int64_t i = 0, e = resultRank; i < e; ++i) { 584 if (i == dim) 585 continue; 586 SaturatedInteger size; 587 for (auto tensorType : inputTypes) { 588 FailureOr<SaturatedInteger> maybeSize = 589 size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i))); 590 if (failed(maybeSize)) 591 return emitOpError("static concatenation size mismatch along ") 592 << "non-concatenated dimension " << i; 593 size = *maybeSize; 594 } 595 sizes[i] = size.asInteger(); 596 } 597 auto concatSize = SaturatedInteger::wrap(0); 598 for (auto tensorType : inputTypes) 599 concatSize = 600 concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim)); 601 sizes[dim] = concatSize.asInteger(); 602 auto inferredResultType = 603 RankedTensorType::get(sizes, inputTypes[0].getElementType()); 604 605 for (auto [inferredSize, actualSize] : 606 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) { 607 bool hasDynamic = ShapedType::isDynamic(inferredSize) || 608 ShapedType::isDynamic(actualSize); 609 if (!hasDynamic && inferredSize != actualSize) 610 return emitOpError("result type ") 611 << resultType << "does not match inferred shape " 612 << inferredResultType << " static sizes"; 613 } 614 615 return success(); 616 } 617 618 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) { 619 size_t numInputs = getInputs().size(); 620 uint64_t concatDim = getDim(); 621 622 SmallVector<SmallVector<OpFoldResult>> inputShapes; 623 inputShapes.reserve(numInputs); 624 SmallVector<OpFoldResult> concatOffsets; 625 concatOffsets.reserve(numInputs); 626 SmallVector<OpFoldResult> outputShape; 627 628 AffineExpr addExpr = 629 builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1); 630 OpFoldResult zero = builder.getIndexAttr(0); 631 Location loc = getLoc(); 632 for (auto [index, input] : llvm::enumerate(getInputs())) { 633 SmallVector<OpFoldResult> inputShape = 634 tensor::getMixedSizes(builder, input.getLoc(), input); 635 if (index == 0) { 636 outputShape = inputShape; 637 concatOffsets.push_back(zero); 638 } else { 639 concatOffsets.push_back(outputShape[concatDim]); 640 outputShape[concatDim] = affine::makeComposedFoldedAffineApply( 641 builder, loc, addExpr, 642 {outputShape[concatDim], inputShape[concatDim]}); 643 } 644 inputShapes.emplace_back(std::move(inputShape)); 645 } 646 647 Value replacement = builder.create<tensor::EmptyOp>( 648 loc, outputShape, getType().getElementType()); 649 650 int64_t rank = getType().getRank(); 651 OpFoldResult one = builder.getIndexAttr(1); 652 SmallVector<OpFoldResult> strides(rank, one); 653 SmallVector<OpFoldResult> offsets(rank, zero); 654 for (auto [index, input] : llvm::enumerate(getInputs())) { 655 offsets[concatDim] = concatOffsets[index]; 656 auto insertSlice = builder.create<tensor::InsertSliceOp>( 657 loc, input, replacement, offsets, inputShapes[index], strides); 658 replacement = insertSlice.getResult(); 659 } 660 if (replacement.getType() != getType()) { 661 replacement = builder.create<tensor::CastOp>(loc, getType(), replacement); 662 } 663 return SmallVector<Value>{replacement}; 664 } 665 666 LogicalResult 667 ConcatOp::reifyResultShapes(OpBuilder &builder, 668 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 669 ValueRange inputs = getInputs(); 670 int64_t dim = getDim(); 671 RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes()); 672 673 Value init = inputs[0]; 674 int64_t rank = getType().getRank(); 675 676 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank)); 677 678 // Pre-populate the result sizes with as much static information as possible 679 // from the given result type, as well as the inferred result type, otherwise 680 // use the dim sizes from the first input. 681 for (int64_t i = 0; i < rank; ++i) { 682 if (i == dim) 683 continue; 684 if (!getType().isDynamicDim(i)) { 685 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i)); 686 } else if (!inferredResultType.isDynamicDim(i)) { 687 reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp( 688 builder, getLoc(), 689 builder.getIndexAttr(inferredResultType.getDimSize(i))); 690 } else { 691 reifiedReturnShapes[0][i] = 692 builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult(); 693 } 694 } 695 696 if (getType().isDynamicDim(dim)) { 697 // Take the sum of the input sizes along the concatenated dim. 698 AffineExpr sum = builder.getAffineDimExpr(0); 699 SmallVector<OpFoldResult> sizes = { 700 builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)}; 701 for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) { 702 sum = sum + builder.getAffineDimExpr(idx + 1); 703 sizes.push_back( 704 builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim)); 705 } 706 reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp( 707 builder, getLoc(), 708 affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes)); 709 } else { 710 // If the result shape is static along the concatenated dim, use the static 711 // shape. 712 reifiedReturnShapes[0][dim] = 713 builder.getIndexAttr(getType().getDimSize(dim)); 714 } 715 return success(); 716 } 717 718 void ConcatOp::getAsmResultNames( 719 function_ref<void(Value, StringRef)> setNameFn) { 720 setNameFn(getResult(), "concat"); 721 } 722 723 OpFoldResult ConcatOp::fold(FoldAdaptor) { 724 ValueRange inputs = getInputs(); 725 if (inputs.size() == 1 && inputs[0].getType() == getResultType()) 726 return inputs[0]; 727 return {}; 728 } 729 730 namespace { 731 /// Fold a concat op with a single input to a cast. 732 struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> { 733 using OpRewritePattern<ConcatOp>::OpRewritePattern; 734 735 LogicalResult matchAndRewrite(ConcatOp concatOp, 736 PatternRewriter &rewriter) const override { 737 if (concatOp.getInputs().size() != 1) 738 return failure(); 739 rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(), 740 concatOp.getInputs()[0]); 741 return success(); 742 } 743 }; 744 } // namespace 745 746 void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, 747 MLIRContext *context) { 748 results.add<SingleInputConcatOp>(context); 749 } 750 751 //===----------------------------------------------------------------------===// 752 // DimOp 753 //===----------------------------------------------------------------------===// 754 755 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 756 setNameFn(getResult(), "dim"); 757 } 758 759 void DimOp::build(OpBuilder &builder, OperationState &result, Value source, 760 int64_t index) { 761 auto loc = result.location; 762 Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); 763 build(builder, result, source, indexValue); 764 } 765 766 std::optional<int64_t> DimOp::getConstantIndex() { 767 return getConstantIntValue(getIndex()); 768 } 769 770 Speculation::Speculatability DimOp::getSpeculatability() { 771 auto constantIndex = getConstantIndex(); 772 if (!constantIndex) 773 return Speculation::NotSpeculatable; 774 775 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType()); 776 if (!rankedSourceType) 777 return Speculation::NotSpeculatable; 778 779 if (rankedSourceType.getRank() <= constantIndex) 780 return Speculation::NotSpeculatable; 781 782 return Speculation::Speculatable; 783 } 784 785 OpFoldResult DimOp::fold(FoldAdaptor adaptor) { 786 // All forms of folding require a known index. 787 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); 788 if (!index) 789 return {}; 790 791 // Folding for unranked types (UnrankedTensorType) is not supported. 792 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType()); 793 if (!tensorType) 794 return {}; 795 796 // Out of bound indices produce undefined behavior but are still valid IR. 797 // Don't choke on them. 798 int64_t indexVal = index.getInt(); 799 if (indexVal < 0 || indexVal >= tensorType.getRank()) 800 return {}; 801 802 // Fold if the shape extent along the given index is known. 803 if (!tensorType.isDynamicDim(index.getInt())) { 804 Builder builder(getContext()); 805 return builder.getIndexAttr(tensorType.getShape()[index.getInt()]); 806 } 807 808 Operation *definingOp = getSource().getDefiningOp(); 809 810 // Fold dim to the operand of tensor.generate. 811 if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) { 812 auto resultType = 813 llvm::cast<RankedTensorType>(fromElements.getResult().getType()); 814 // The case where the type encodes the size of the dimension is handled 815 // above. 816 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()])); 817 818 // Find the operand of the fromElements that corresponds to this index. 819 auto dynExtents = fromElements.getDynamicExtents().begin(); 820 for (auto dim : resultType.getShape().take_front(index.getInt())) 821 if (ShapedType::isDynamic(dim)) 822 dynExtents++; 823 824 return Value{*dynExtents}; 825 } 826 827 // The size at the given index is now known to be a dynamic size. 828 unsigned unsignedIndex = index.getValue().getZExtValue(); 829 830 if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) { 831 // Fold only for non-rank reduced ops. For the rank-reduced version, rely on 832 // `resolve-shaped-type-result-dims` pass. 833 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() && 834 sliceOp.isDynamicSize(unsignedIndex)) { 835 return {sliceOp.getDynamicSize(unsignedIndex)}; 836 } 837 } 838 839 // dim(cast) -> dim 840 if (succeeded(foldTensorCast(*this))) 841 return getResult(); 842 843 return {}; 844 } 845 846 namespace { 847 /// Fold dim of a cast into the dim of the source of the tensor cast. 848 struct DimOfCastOp : public OpRewritePattern<DimOp> { 849 using OpRewritePattern<DimOp>::OpRewritePattern; 850 851 LogicalResult matchAndRewrite(DimOp dimOp, 852 PatternRewriter &rewriter) const override { 853 auto castOp = dimOp.getSource().getDefiningOp<CastOp>(); 854 if (!castOp) 855 return failure(); 856 Value newSource = castOp.getOperand(); 857 rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex()); 858 return success(); 859 } 860 }; 861 862 /// Fold dim of a destination passing style op into the dim of the corresponding 863 /// init. 864 struct DimOfDestStyleOp : public OpRewritePattern<DimOp> { 865 using OpRewritePattern<DimOp>::OpRewritePattern; 866 867 LogicalResult matchAndRewrite(DimOp dimOp, 868 PatternRewriter &rewriter) const override { 869 auto source = dimOp.getSource(); 870 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>(); 871 if (!destOp) 872 return failure(); 873 874 auto resultIndex = cast<OpResult>(source).getResultNumber(); 875 auto *initOperand = destOp.getDpsInitOperand(resultIndex); 876 877 rewriter.modifyOpInPlace( 878 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); }); 879 return success(); 880 } 881 }; 882 883 /// Fold dim of a tensor reshape operation to a extract into the reshape's shape 884 /// operand. 885 struct DimOfReshapeOp : public OpRewritePattern<DimOp> { 886 using OpRewritePattern<DimOp>::OpRewritePattern; 887 888 LogicalResult matchAndRewrite(DimOp dim, 889 PatternRewriter &rewriter) const override { 890 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>(); 891 892 if (!reshape) 893 return failure(); 894 895 // Since tensors are immutable we don't need to worry about where to place 896 // the extract call 897 rewriter.setInsertionPointAfter(dim); 898 Location loc = dim.getLoc(); 899 Value extract = 900 rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex()); 901 if (extract.getType() != dim.getType()) 902 extract = 903 rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract); 904 rewriter.replaceOp(dim, extract); 905 return success(); 906 } 907 }; 908 } // namespace 909 910 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, 911 MLIRContext *context) { 912 results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context); 913 } 914 915 //===----------------------------------------------------------------------===// 916 // EmptyOp 917 //===----------------------------------------------------------------------===// 918 919 void EmptyOp::build(OpBuilder &builder, OperationState &result, 920 ArrayRef<int64_t> staticShape, Type elementType, 921 Attribute encoding) { 922 assert(all_of(staticShape, 923 [](int64_t sz) { return !ShapedType::isDynamic(sz); }) && 924 "expected only static sizes"); 925 build(builder, result, staticShape, elementType, ValueRange{}, encoding); 926 } 927 928 void EmptyOp::build(OpBuilder &builder, OperationState &result, 929 ArrayRef<int64_t> staticShape, Type elementType, 930 ValueRange dynamicSizes, Attribute encoding) { 931 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding); 932 build(builder, result, tensorType, dynamicSizes); 933 } 934 935 void EmptyOp::build(OpBuilder &builder, OperationState &result, 936 ArrayRef<OpFoldResult> sizes, Type elementType, 937 Attribute encoding) { 938 SmallVector<int64_t> staticShape; 939 SmallVector<Value> dynamicSizes; 940 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape); 941 build(builder, result, staticShape, elementType, dynamicSizes, encoding); 942 } 943 944 LogicalResult EmptyOp::verify() { 945 if (getType().getNumDynamicDims() != getDynamicSizes().size()) 946 return emitOpError("incorrect number of dynamic sizes, has ") 947 << getDynamicSizes().size() << ", expected " 948 << getType().getNumDynamicDims(); 949 return success(); 950 } 951 952 LogicalResult 953 EmptyOp::reifyResultShapes(OpBuilder &builder, 954 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 955 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank())); 956 unsigned ctr = 0; 957 for (int64_t i = 0; i < getType().getRank(); ++i) { 958 if (getType().isDynamicDim(i)) { 959 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++]; 960 } else { 961 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i)); 962 } 963 } 964 return success(); 965 } 966 967 Value EmptyOp::getDynamicSize(unsigned idx) { 968 assert(getType().isDynamicDim(idx) && "expected dynamic dim"); 969 unsigned ctr = 0; 970 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i) 971 if (getType().isDynamicDim(i)) 972 ++ctr; 973 return getDynamicSizes()[ctr]; 974 } 975 976 SmallVector<OpFoldResult> EmptyOp::getMixedSizes() { 977 SmallVector<OpFoldResult> result; 978 unsigned ctr = 0; 979 OpBuilder b(getContext()); 980 for (int64_t i = 0; i < getType().getRank(); ++i) { 981 if (getType().isDynamicDim(i)) { 982 result.push_back(getDynamicSizes()[ctr++]); 983 } else { 984 result.push_back(b.getIndexAttr(getType().getShape()[i])); 985 } 986 } 987 return result; 988 } 989 990 namespace { 991 /// Change the type of the result of a `tensor.empty` by making the result 992 /// type statically sized along dimensions that in the original operation were 993 /// defined as dynamic, but the size was defined using a `constant` op. For 994 /// example 995 /// 996 /// %c5 = arith.constant 5: index 997 /// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32> 998 /// 999 /// to 1000 /// 1001 /// %0 = tensor.empty(%arg0) : tensor<?x5xf32> 1002 struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> { 1003 using OpRewritePattern<EmptyOp>::OpRewritePattern; 1004 1005 LogicalResult matchAndRewrite(EmptyOp op, 1006 PatternRewriter &rewriter) const override { 1007 SmallVector<Value> foldedDynamicSizes; 1008 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes( 1009 op.getType(), op.getDynamicSizes(), foldedDynamicSizes); 1010 1011 // Stop here if no dynamic size was promoted to static. 1012 if (foldedTensorType == op.getType()) 1013 return failure(); 1014 1015 auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType, 1016 foldedDynamicSizes); 1017 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); 1018 return success(); 1019 } 1020 }; 1021 1022 struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> { 1023 using OpRewritePattern<DimOp>::OpRewritePattern; 1024 1025 LogicalResult matchAndRewrite(tensor::DimOp dimOp, 1026 PatternRewriter &rewriter) const override { 1027 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex(); 1028 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>(); 1029 if (!emptyTensorOp || !maybeConstantIndex) 1030 return failure(); 1031 auto emptyTensorType = emptyTensorOp.getType(); 1032 if (*maybeConstantIndex < 0 || 1033 *maybeConstantIndex >= emptyTensorType.getRank() || 1034 !emptyTensorType.isDynamicDim(*maybeConstantIndex)) 1035 return failure(); 1036 rewriter.replaceOp(dimOp, 1037 emptyTensorOp.getDynamicSize(*maybeConstantIndex)); 1038 return success(); 1039 } 1040 }; 1041 1042 /// Canonicalize 1043 /// 1044 /// ```mlir 1045 /// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32> 1046 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32> 1047 /// ``` 1048 /// 1049 /// into 1050 /// 1051 /// ```mlir 1052 /// %0 = tensor.empty(%d1) : tensor<4x?xf32> 1053 /// ``` 1054 /// 1055 /// This assumes the input program is correct in terms of its shape. So it is 1056 /// safe to assume that `%d0` is in fact 4. 1057 struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> { 1058 using OpRewritePattern<CastOp>::OpRewritePattern; 1059 1060 LogicalResult matchAndRewrite(CastOp castOp, 1061 PatternRewriter &rewriter) const override { 1062 if (!canFoldIntoProducerOp(castOp)) 1063 return failure(); 1064 auto producer = castOp.getSource().getDefiningOp<EmptyOp>(); 1065 if (!producer) 1066 return failure(); 1067 1068 auto resultType = 1069 llvm::cast<RankedTensorType>(castOp->getResult(0).getType()); 1070 ArrayRef<int64_t> resultShape = resultType.getShape(); 1071 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes(); 1072 SmallVector<OpFoldResult> newMixedSizes; 1073 newMixedSizes.reserve(currMixedSizes.size()); 1074 assert(resultShape.size() == currMixedSizes.size() && 1075 "mismatch in result shape and sizes of empty op"); 1076 for (auto it : llvm::zip(resultShape, currMixedSizes)) { 1077 int64_t newDim = std::get<0>(it); 1078 OpFoldResult currDim = std::get<1>(it); 1079 // Case 1: The empty tensor dim is static. Check that the tensor cast 1080 // result dim matches. 1081 if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) { 1082 if (ShapedType::isDynamic(newDim) || 1083 newDim != llvm::cast<IntegerAttr>(attr).getInt()) { 1084 // Something is off, the cast result shape cannot be more dynamic 1085 // than the empty tensor result shape (enforced by 1086 // `canFoldIntoProducer`). Abort for now. 1087 return rewriter.notifyMatchFailure( 1088 producer, "mismatch in static value of shape of empty tensor " 1089 "result and cast result"); 1090 } 1091 newMixedSizes.push_back(attr); 1092 continue; 1093 } 1094 1095 // Case 2 : The tensor cast shape is static, but empty tensor result 1096 // shape is dynamic. 1097 if (!ShapedType::isDynamic(newDim)) { 1098 newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); 1099 continue; 1100 } 1101 1102 // Case 3 : The tensor cast shape is dynamic and empty tensor result 1103 // shape is dynamic. Use the dynamic value from the empty tensor op. 1104 newMixedSizes.push_back(currDim); 1105 } 1106 1107 // TODO: Do not drop tensor encoding. 1108 rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes, 1109 resultType.getElementType()); 1110 return success(); 1111 } 1112 }; 1113 1114 } // namespace 1115 1116 void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results, 1117 MLIRContext *context) { 1118 results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp, 1119 ReplaceEmptyTensorStaticShapeDims>(context); 1120 } 1121 1122 /// Try to remove a tensor operation if it would only reshape a constant. 1123 /// Removes the op and replaces the constant with a new constant of the result 1124 /// shape. When an optional cst attribute is passed, it is reshaped only if the 1125 /// splat value matches the value in the attribute. 1126 static OpFoldResult 1127 reshapeConstantSource(DenseElementsAttr source, TensorType result, 1128 std::optional<Attribute> cst = std::nullopt) { 1129 if (source && source.isSplat() && result.hasStaticShape() && 1130 (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value())) 1131 return source.resizeSplat(result); 1132 1133 return {}; 1134 } 1135 1136 //===----------------------------------------------------------------------===// 1137 // ExtractOp 1138 //===----------------------------------------------------------------------===// 1139 1140 namespace { 1141 1142 /// Canonicalizes the pattern of the form 1143 /// 1144 /// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32> 1145 /// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32> 1146 /// 1147 /// to 1148 /// 1149 /// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32> 1150 struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> { 1151 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 1152 1153 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 1154 PatternRewriter &rewriter) const final { 1155 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>(); 1156 if (!tensorCast) 1157 return failure(); 1158 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType())) 1159 return failure(); 1160 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 1161 extract, tensorCast.getSource(), extract.getIndices()); 1162 return success(); 1163 } 1164 }; 1165 1166 } // namespace 1167 1168 void ExtractOp::getAsmResultNames( 1169 function_ref<void(Value, StringRef)> setNameFn) { 1170 setNameFn(getResult(), "extracted"); 1171 } 1172 1173 LogicalResult ExtractOp::verify() { 1174 // Verify the # indices match if we have a ranked type. 1175 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType()); 1176 if (tensorType.getRank() != static_cast<int64_t>(getIndices().size())) 1177 return emitOpError("incorrect number of indices for extract_element"); 1178 return success(); 1179 } 1180 1181 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { 1182 if (Attribute tensor = adaptor.getTensor()) { 1183 // If this is a splat elements attribute, simply return the value. 1184 // All of the elements of a splat attribute are the same. 1185 if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor)) 1186 return splatTensor.getSplatValue<Attribute>(); 1187 1188 // If this is a dense resource elements attribute, return. 1189 if (isa<DenseResourceElementsAttr>(tensor)) 1190 return {}; 1191 } 1192 1193 // Collect the constant indices into the tensor. 1194 SmallVector<uint64_t, 8> indices; 1195 for (Attribute indice : adaptor.getIndices()) { 1196 if (!indice || !llvm::isa<IntegerAttr>(indice)) 1197 return {}; 1198 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt()); 1199 } 1200 1201 // Fold extract(from_elements(...)). 1202 if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) { 1203 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType()); 1204 auto rank = tensorType.getRank(); 1205 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() && 1206 "rank mismatch"); 1207 int flatIndex = 0; 1208 int stride = 1; 1209 for (int i = rank - 1; i >= 0; --i) { 1210 flatIndex += indices[i] * stride; 1211 stride *= tensorType.getDimSize(i); 1212 } 1213 // Prevent out of bounds accesses. This can happen in invalid code that 1214 // will never execute. 1215 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex || 1216 flatIndex < 0) 1217 return {}; 1218 return fromElementsOp.getElements()[flatIndex]; 1219 } 1220 1221 // If this is an elements attribute, query the value at the given indices. 1222 if (Attribute tensor = adaptor.getTensor()) { 1223 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor); 1224 if (elementsAttr && elementsAttr.isValidIndex(indices)) 1225 return elementsAttr.getValues<Attribute>()[indices]; 1226 } 1227 1228 return {}; 1229 } 1230 1231 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, 1232 MLIRContext *context) { 1233 results.add<ExtractFromTensorCast>(context); 1234 } 1235 1236 //===----------------------------------------------------------------------===// 1237 // FromElementsOp 1238 //===----------------------------------------------------------------------===// 1239 1240 void FromElementsOp::getAsmResultNames( 1241 function_ref<void(Value, StringRef)> setNameFn) { 1242 setNameFn(getResult(), "from_elements"); 1243 } 1244 1245 void FromElementsOp::build(OpBuilder &builder, OperationState &result, 1246 ValueRange elements) { 1247 assert(!elements.empty() && "expected at least one element"); 1248 Type resultType = RankedTensorType::get( 1249 {static_cast<int64_t>(elements.size())}, elements.front().getType()); 1250 build(builder, result, resultType, elements); 1251 } 1252 1253 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { 1254 if (!llvm::is_contained(adaptor.getElements(), nullptr)) 1255 return DenseElementsAttr::get(getType(), adaptor.getElements()); 1256 return {}; 1257 } 1258 1259 namespace { 1260 1261 // Pushes the index_casts that occur before extractions to after the extract. 1262 // This minimizes type conversion in some cases and enables the extract 1263 // canonicalizer. This changes: 1264 // 1265 // %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex> 1266 // %extract = tensor.extract %cast[%index] : tensor<1xindex> 1267 // 1268 // to the following: 1269 // 1270 // %extract = tensor.extract %tensor[%index] : tensor<1xindex> 1271 // %cast = arith.index_cast %extract : i32 to index 1272 // 1273 // to just %element. 1274 // 1275 // Consider expanding this to a template and handle all tensor cast 1276 // operations. 1277 struct ExtractElementFromIndexCast 1278 : public OpRewritePattern<tensor::ExtractOp> { 1279 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 1280 1281 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 1282 PatternRewriter &rewriter) const final { 1283 Location loc = extract.getLoc(); 1284 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>(); 1285 if (!indexCast) 1286 return failure(); 1287 1288 Type elementTy = getElementTypeOrSelf(indexCast.getIn()); 1289 1290 auto newExtract = rewriter.create<tensor::ExtractOp>( 1291 loc, elementTy, indexCast.getIn(), extract.getIndices()); 1292 1293 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(), 1294 newExtract); 1295 1296 return success(); 1297 } 1298 }; 1299 1300 } // namespace 1301 1302 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, 1303 MLIRContext *context) { 1304 results.add<ExtractElementFromIndexCast>(context); 1305 } 1306 1307 //===----------------------------------------------------------------------===// 1308 // GatherOp 1309 //===----------------------------------------------------------------------===// 1310 1311 void GatherOp::getAsmResultNames( 1312 function_ref<void(Value, StringRef)> setNameFn) { 1313 setNameFn(getResult(), "gather"); 1314 } 1315 1316 /// Return the inferred result type for a gatherOp where: 1317 /// - sourceType is the type of the source tensor gathered from 1318 /// - indicesType is the type of the indices used to gather 1319 /// - gatherDims are the dims along which the gather occurs. 1320 /// Return a full rank or ranked-reduced variant of the type depending on 1321 /// the value of rankReduced. 1322 /// 1323 /// The leading dimensions of the index tensor give the result tensor its 1324 /// leading dimensions. 1325 /// The trailing dimensions of the result tensor are obtained from the source 1326 /// tensor by setting the dimensions specified in gather_dims to `1` (if 1327 /// rankedReduced is false), or skipping them (otherwise). 1328 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType, 1329 RankedTensorType indicesType, 1330 ArrayRef<int64_t> gatherDims, 1331 bool rankReduced) { 1332 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back()); 1333 resultShape.reserve(resultShape.size() + sourceType.getRank()); 1334 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) { 1335 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) { 1336 if (!rankReduced) 1337 resultShape.push_back(1); 1338 continue; 1339 } 1340 resultShape.push_back(sourceType.getDimSize(idx)); 1341 } 1342 return RankedTensorType::Builder(sourceType).setShape(resultShape); 1343 } 1344 1345 static LogicalResult 1346 verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, 1347 ArrayRef<int64_t> indices, int64_t rank, 1348 StringRef gatherOrScatter, StringRef sourceOrDest) { 1349 if (dims.empty()) 1350 return op->emitOpError(gatherOrScatter) << "_dims must be non-empty"; 1351 1352 int64_t numGatherDims = dims.size(); 1353 if (numGatherDims > rank) 1354 return op->emitOpError(gatherOrScatter) 1355 << "_dims overflow " << sourceOrDest << " rank"; 1356 if (indices.empty() || indices.back() != numGatherDims) 1357 return op->emitOpError(gatherOrScatter) 1358 << "_dims length must match the size of last dimension of indices"; 1359 for (int64_t val : dims) { 1360 if (val < 0) 1361 return op->emitOpError(gatherOrScatter) 1362 << "_dims value must be non-negative"; 1363 if (val >= rank) 1364 return op->emitOpError(gatherOrScatter) 1365 << "_dims value must be smaller than " << sourceOrDest << " rank"; 1366 } 1367 for (int64_t i = 1; i < numGatherDims; ++i) { 1368 if (dims[i - 1] >= dims[i]) 1369 return op->emitOpError(gatherOrScatter) 1370 << "_dims values must be strictly increasing"; 1371 } 1372 return success(); 1373 } 1374 1375 LogicalResult GatherOp::verify() { 1376 int64_t sourceRank = getSourceType().getRank(); 1377 ArrayRef<int64_t> gatherDims = getGatherDims(); 1378 if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, 1379 getIndicesType().getShape(), sourceRank, 1380 "gather", "source"))) 1381 return failure(); 1382 1383 RankedTensorType expectedResultType = GatherOp::inferResultType( 1384 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false); 1385 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType( 1386 getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true); 1387 if (getResultType() != expectedResultType && 1388 getResultType() != expectedRankReducedResultType) { 1389 return emitOpError("result type " 1390 "mismatch: " 1391 "expected ") 1392 << expectedResultType << " or its rank-reduced variant " 1393 << expectedRankReducedResultType << " (got: " << getResultType() 1394 << ")"; 1395 } 1396 1397 return success(); 1398 } 1399 1400 OpFoldResult GatherOp::fold(FoldAdaptor adaptor) { 1401 if (OpFoldResult reshapedSource = reshapeConstantSource( 1402 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()), 1403 getResult().getType())) 1404 return reshapedSource; 1405 return {}; 1406 } 1407 1408 //===----------------------------------------------------------------------===// 1409 // InsertOp 1410 //===----------------------------------------------------------------------===// 1411 1412 void InsertOp::getAsmResultNames( 1413 function_ref<void(Value, StringRef)> setNameFn) { 1414 setNameFn(getResult(), "inserted"); 1415 } 1416 1417 LogicalResult InsertOp::verify() { 1418 // Verify the # indices match if we have a ranked type. 1419 auto destType = llvm::cast<RankedTensorType>(getDest().getType()); 1420 if (destType.getRank() != static_cast<int64_t>(getIndices().size())) 1421 return emitOpError("incorrect number of indices"); 1422 return success(); 1423 } 1424 1425 OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { 1426 Attribute scalar = adaptor.getScalar(); 1427 Attribute dest = adaptor.getDest(); 1428 if (scalar && dest) 1429 if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest)) 1430 if (scalar == splatDest.getSplatValue<Attribute>()) 1431 return dest; 1432 return {}; 1433 } 1434 1435 //===----------------------------------------------------------------------===// 1436 // GenerateOp 1437 //===----------------------------------------------------------------------===// 1438 1439 void GenerateOp::getAsmResultNames( 1440 function_ref<void(Value, StringRef)> setNameFn) { 1441 setNameFn(getResult(), "generated"); 1442 } 1443 1444 LogicalResult GenerateOp::reifyResultShapes( 1445 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 1446 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank())); 1447 int idx = 0; 1448 for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) { 1449 if (getType().isDynamicDim(dim)) { 1450 reifiedReturnShapes[0][dim] = getOperand(idx++); 1451 } else { 1452 reifiedReturnShapes[0][dim] = 1453 builder.getIndexAttr(getType().getDimSize(dim)); 1454 } 1455 } 1456 return success(); 1457 } 1458 1459 LogicalResult GenerateOp::verify() { 1460 // Ensure that the tensor type has as many dynamic dimensions as are 1461 // specified by the operands. 1462 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType()); 1463 if (getNumOperands() != resultType.getNumDynamicDims()) 1464 return emitError("must have as many index operands as dynamic extents " 1465 "in the result type"); 1466 return success(); 1467 } 1468 1469 LogicalResult GenerateOp::verifyRegions() { 1470 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType()); 1471 // Ensure that region arguments span the index space. 1472 if (!llvm::all_of(getBody().getArgumentTypes(), 1473 [](Type ty) { return ty.isIndex(); })) 1474 return emitError("all body arguments must be index"); 1475 if (getBody().getNumArguments() != resultTy.getRank()) 1476 return emitError("must have one body argument per input dimension"); 1477 1478 // Ensure that the region yields an element of the right type. 1479 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator()); 1480 1481 if (yieldOp.getValue().getType() != resultTy.getElementType()) 1482 return emitOpError( 1483 "body must be terminated with a `yield` operation of the tensor " 1484 "element type"); 1485 1486 return success(); 1487 } 1488 1489 void GenerateOp::build( 1490 OpBuilder &b, OperationState &result, Type resultTy, 1491 ValueRange dynamicExtents, 1492 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { 1493 build(b, result, resultTy, dynamicExtents); 1494 1495 // Build and populate body. 1496 OpBuilder::InsertionGuard guard(b); 1497 Region *bodyRegion = result.regions.front().get(); 1498 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank(); 1499 SmallVector<Type, 2> argumentTypes(rank, b.getIndexType()); 1500 SmallVector<Location, 2> argumentLocs(rank, result.location); 1501 Block *bodyBlock = 1502 b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs); 1503 bodyBuilder(b, result.location, bodyBlock->getArguments()); 1504 } 1505 1506 namespace { 1507 1508 /// Canonicalizes tensor.generate operations with a constant 1509 /// operand into the equivalent operation with the operand expressed in the 1510 /// result type, instead. We also insert a type cast to make sure that the 1511 /// resulting IR is still well-typed. 1512 struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> { 1513 using OpRewritePattern<GenerateOp>::OpRewritePattern; 1514 1515 LogicalResult matchAndRewrite(GenerateOp generateOp, 1516 PatternRewriter &rewriter) const final { 1517 SmallVector<Value> foldedDynamicSizes; 1518 RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes( 1519 generateOp.getType(), generateOp.getDynamicExtents(), 1520 foldedDynamicSizes); 1521 1522 // Stop here if no dynamic size was promoted to static. 1523 if (foldedTensorType == generateOp.getType()) 1524 return failure(); 1525 1526 auto loc = generateOp.getLoc(); 1527 auto newOp = 1528 rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes); 1529 rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(), 1530 newOp.getBody().begin()); 1531 rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp, 1532 generateOp.getType(), newOp); 1533 return success(); 1534 } 1535 }; 1536 1537 /// Canonicalizes the pattern of the form 1538 /// 1539 /// %tensor = tensor.generate %x { 1540 /// ^bb0(%arg0: index): 1541 /// <computation> 1542 /// yield %1 : index 1543 /// } : tensor<?xindex> 1544 /// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32> 1545 /// 1546 /// to just <computation> with %arg0 replaced by %c0. We only do this if the 1547 /// tensor.generate operation has no side-effects. 1548 struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> { 1549 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; 1550 1551 LogicalResult matchAndRewrite(tensor::ExtractOp extract, 1552 PatternRewriter &rewriter) const final { 1553 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>(); 1554 if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) 1555 return failure(); 1556 1557 IRMapping mapping; 1558 Block *body = &tensorFromElements.getBody().front(); 1559 mapping.map(body->getArguments(), extract.getIndices()); 1560 for (auto &op : body->without_terminator()) 1561 rewriter.clone(op, mapping); 1562 1563 auto yield = cast<YieldOp>(body->getTerminator()); 1564 1565 rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue())); 1566 return success(); 1567 } 1568 }; 1569 1570 } // namespace 1571 1572 void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results, 1573 MLIRContext *context) { 1574 // TODO: Move extract pattern to tensor::ExtractOp. 1575 results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context); 1576 } 1577 1578 //===----------------------------------------------------------------------===// 1579 // RankOp 1580 //===----------------------------------------------------------------------===// 1581 1582 void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 1583 setNameFn(getResult(), "rank"); 1584 } 1585 1586 OpFoldResult RankOp::fold(FoldAdaptor adaptor) { 1587 // Constant fold rank when the rank of the operand is known. 1588 auto type = getOperand().getType(); 1589 auto shapedType = llvm::dyn_cast<ShapedType>(type); 1590 if (shapedType && shapedType.hasRank()) 1591 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); 1592 return IntegerAttr(); 1593 } 1594 1595 //===----------------------------------------------------------------------===// 1596 // ReshapeOp 1597 //===----------------------------------------------------------------------===// 1598 1599 void ReshapeOp::getAsmResultNames( 1600 function_ref<void(Value, StringRef)> setNameFn) { 1601 setNameFn(getResult(), "reshape"); 1602 } 1603 1604 static int64_t getNumElements(ShapedType type) { 1605 int64_t numElements = 1; 1606 for (auto dim : type.getShape()) 1607 numElements *= dim; 1608 return numElements; 1609 } 1610 1611 LogicalResult ReshapeOp::verify() { 1612 TensorType operandType = llvm::cast<TensorType>(getSource().getType()); 1613 TensorType resultType = llvm::cast<TensorType>(getResult().getType()); 1614 1615 if (operandType.getElementType() != resultType.getElementType()) 1616 return emitOpError("element types of source and destination tensor " 1617 "types should be the same"); 1618 1619 int64_t shapeSize = 1620 llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0); 1621 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType); 1622 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType); 1623 1624 if (resultRankedType) { 1625 if (operandRankedType && resultRankedType.hasStaticShape() && 1626 operandRankedType.hasStaticShape()) { 1627 if (getNumElements(operandRankedType) != getNumElements(resultRankedType)) 1628 return emitOpError("source and destination tensor should have the " 1629 "same number of elements"); 1630 } 1631 if (ShapedType::isDynamic(shapeSize)) 1632 return emitOpError("cannot use shape operand with dynamic length to " 1633 "reshape to statically-ranked tensor type"); 1634 if (shapeSize != resultRankedType.getRank()) 1635 return emitOpError( 1636 "length of shape operand differs from the result's tensor rank"); 1637 } 1638 return success(); 1639 } 1640 1641 OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { 1642 if (OpFoldResult reshapedSource = reshapeConstantSource( 1643 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()), 1644 getResult().getType())) 1645 return reshapedSource; 1646 1647 // If the producer of operand 'source' is another 'tensor.reshape' op, use the 1648 // producer's input instead as the original tensor to reshape. This could 1649 // render such producer dead code. 1650 if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) { 1651 getSourceMutable().assign(reshapeOpProducer.getSource()); 1652 return getResult(); 1653 } 1654 1655 auto source = getSource(); 1656 auto sourceTy = dyn_cast<RankedTensorType>(source.getType()); 1657 auto resultTy = dyn_cast<RankedTensorType>(getType()); 1658 if (!sourceTy || !resultTy || sourceTy != resultTy) 1659 return {}; 1660 1661 // If the source and result are both 1D tensors and have the same type, the 1662 // reshape has no effect, even if the tensor is dynamically shaped. 1663 if (sourceTy.getRank() == 1) 1664 return source; 1665 1666 if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) { 1667 auto elements = fromElements.getElements(); 1668 bool dynamicNoop = 1669 sourceTy.getRank() == static_cast<int64_t>(elements.size()); 1670 for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) { 1671 auto element = elements[id]; 1672 1673 if (auto cst = getConstantIntValue(element)) { 1674 dynamicNoop &= cst.value() == sourceTy.getDimSize(id); 1675 continue; 1676 } 1677 1678 if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) { 1679 dynamicNoop &= dimOp.getSource() == source; 1680 1681 APSInt dim; 1682 auto cst = getConstantIntValue(dimOp.getIndex()); 1683 dynamicNoop &= 1684 cst.has_value() && cst.value() == static_cast<int64_t>(id); 1685 continue; 1686 } 1687 1688 dynamicNoop = false; 1689 break; 1690 } 1691 1692 if (dynamicNoop) 1693 return source; 1694 } 1695 1696 return {}; 1697 } 1698 1699 //===----------------------------------------------------------------------===// 1700 // Reassociative reshape ops 1701 //===----------------------------------------------------------------------===// 1702 1703 void CollapseShapeOp::getAsmResultNames( 1704 function_ref<void(Value, StringRef)> setNameFn) { 1705 setNameFn(getResult(), "collapsed"); 1706 } 1707 1708 void ExpandShapeOp::getAsmResultNames( 1709 function_ref<void(Value, StringRef)> setNameFn) { 1710 setNameFn(getResult(), "expanded"); 1711 } 1712 1713 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) { 1714 assert(resultDim >= 0 && resultDim < getResultType().getRank() && 1715 "invalid resultDim"); 1716 for (const auto &it : llvm::enumerate(getReassociationIndices())) 1717 if (llvm::is_contained(it.value(), resultDim)) 1718 return it.index(); 1719 llvm_unreachable("could not find reassociation group"); 1720 } 1721 1722 FailureOr<SmallVector<OpFoldResult>> 1723 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc, 1724 RankedTensorType expandedType, 1725 ArrayRef<ReassociationIndices> reassociation, 1726 ArrayRef<OpFoldResult> inputShape) { 1727 std::optional<SmallVector<OpFoldResult>> outputShape = 1728 inferExpandShapeOutputShape(b, loc, expandedType, reassociation, 1729 inputShape); 1730 if (!outputShape) 1731 return failure(); 1732 return *outputShape; 1733 } 1734 1735 SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() { 1736 return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext()); 1737 } 1738 1739 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 1740 Type resultType, Value src, 1741 ArrayRef<ReassociationIndices> reassociation, 1742 ArrayRef<OpFoldResult> outputShape) { 1743 auto [staticOutputShape, dynamicOutputShape] = 1744 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape)); 1745 build(builder, result, cast<RankedTensorType>(resultType), src, 1746 getReassociationIndicesAttribute(builder, reassociation), 1747 dynamicOutputShape, staticOutputShape); 1748 } 1749 1750 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, 1751 Type resultType, Value src, 1752 ArrayRef<ReassociationIndices> reassociation) { 1753 SmallVector<OpFoldResult> inputShape = 1754 getMixedSizes(builder, result.location, src); 1755 auto tensorResultTy = cast<RankedTensorType>(resultType); 1756 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape( 1757 builder, result.location, tensorResultTy, reassociation, inputShape); 1758 SmallVector<OpFoldResult> outputShapeOrEmpty; 1759 if (succeeded(outputShape)) { 1760 outputShapeOrEmpty = *outputShape; 1761 } 1762 build(builder, result, tensorResultTy, src, reassociation, 1763 outputShapeOrEmpty); 1764 } 1765 1766 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() { 1767 return getSymbolLessAffineMaps(getReassociationExprs()); 1768 } 1769 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() { 1770 return convertReassociationIndicesToExprs(getContext(), 1771 getReassociationIndices()); 1772 } 1773 1774 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() { 1775 return getSymbolLessAffineMaps(getReassociationExprs()); 1776 } 1777 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() { 1778 return convertReassociationIndicesToExprs(getContext(), 1779 getReassociationIndices()); 1780 } 1781 1782 RankedTensorType CollapseShapeOp::inferCollapsedType( 1783 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) { 1784 return inferCollapsedType( 1785 type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( 1786 type.getContext(), reassociation))); 1787 } 1788 1789 /// Compute the RankedTensorType obtained by applying `reassociation` to 1790 /// `type`. 1791 RankedTensorType 1792 CollapseShapeOp::inferCollapsedType(RankedTensorType type, 1793 ArrayRef<AffineMap> reassociation) { 1794 auto shape = type.getShape(); 1795 SmallVector<int64_t, 4> newShape; 1796 newShape.reserve(reassociation.size()); 1797 1798 // Use the fact that reassociation is valid to simplify the logic: only use 1799 // each map's rank. 1800 assert(isReassociationValid(reassociation) && "invalid reassociation"); 1801 unsigned currentDim = 0; 1802 for (AffineMap m : reassociation) { 1803 unsigned dim = m.getNumResults(); 1804 auto band = shape.slice(currentDim, dim); 1805 int64_t size = 1; 1806 if (llvm::is_contained(band, ShapedType::kDynamic)) 1807 size = ShapedType::kDynamic; 1808 else 1809 for (unsigned d = 0; d < dim; ++d) 1810 size *= shape[currentDim + d]; 1811 newShape.push_back(size); 1812 currentDim += dim; 1813 } 1814 1815 return RankedTensorType::get(newShape, type.getElementType()); 1816 } 1817 1818 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, 1819 ArrayRef<ReassociationIndices> reassociation, 1820 ArrayRef<NamedAttribute> attrs) { 1821 auto resultType = inferCollapsedType( 1822 llvm::cast<RankedTensorType>(src.getType()), 1823 getSymbolLessAffineMaps( 1824 convertReassociationIndicesToExprs(b.getContext(), reassociation))); 1825 result.addAttribute(getReassociationAttrStrName(), 1826 getReassociationIndicesAttribute(b, reassociation)); 1827 build(b, result, resultType, src, attrs); 1828 } 1829 1830 template <typename TensorReshapeOp, bool isExpansion = std::is_same< 1831 TensorReshapeOp, ExpandShapeOp>::value> 1832 static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, 1833 RankedTensorType expandedType, 1834 RankedTensorType collapsedType) { 1835 if (failed( 1836 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) 1837 return failure(); 1838 1839 auto maps = op.getReassociationMaps(); 1840 RankedTensorType expectedType = 1841 CollapseShapeOp::inferCollapsedType(expandedType, maps); 1842 if (!isSameTypeWithoutEncoding(collapsedType, expectedType)) 1843 return op.emitOpError("expected collapsed type to be ") 1844 << expectedType << ", but got " << collapsedType; 1845 return success(); 1846 } 1847 1848 LogicalResult ExpandShapeOp::verify() { 1849 auto srcType = getSrcType(); 1850 auto resultType = getResultType(); 1851 1852 if ((int64_t)getStaticOutputShape().size() != resultType.getRank()) 1853 return emitOpError("expected number of static shape dims to be equal to " 1854 "the output rank (") 1855 << resultType.getRank() << ") but found " 1856 << getStaticOutputShape().size() << " inputs instead"; 1857 1858 if ((int64_t)getOutputShape().size() != 1859 llvm::count(getStaticOutputShape(), ShapedType::kDynamic)) 1860 return emitOpError("mismatch in dynamic dims in output_shape and " 1861 "static_output_shape: static_output_shape has ") 1862 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic) 1863 << " dynamic dims while output_shape has " << getOutputShape().size() 1864 << " values"; 1865 1866 return verifyTensorReshapeOp(*this, resultType, srcType); 1867 } 1868 1869 LogicalResult CollapseShapeOp::verify() { 1870 return verifyTensorReshapeOp(*this, getSrcType(), getResultType()); 1871 } 1872 1873 namespace { 1874 /// Reshape of a splat constant can be replaced with a constant of the result 1875 /// type. 1876 template <typename TensorReshapeOp> 1877 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> { 1878 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1879 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1880 PatternRewriter &rewriter) const override { 1881 DenseElementsAttr attr; 1882 if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr))) 1883 return failure(); 1884 if (!attr || !attr.isSplat()) 1885 return failure(); 1886 DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( 1887 reshapeOp.getResultType(), attr.getRawData()); 1888 rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr); 1889 return success(); 1890 } 1891 }; 1892 1893 // Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type. 1894 template <typename TensorReshapeOp> 1895 class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> { 1896 public: 1897 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1898 1899 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1900 PatternRewriter &rewriter) const override { 1901 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>(); 1902 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape()) 1903 return failure(); 1904 1905 rewriter.replaceOpWithNewOp<tensor::SplatOp>( 1906 reshapeOp, reshapeOp.getResultType(), splatOp.getInput()); 1907 return success(); 1908 } 1909 }; 1910 1911 /// Reshape of a FromElements can be replaced with a FromElements of the 1912 /// result type 1913 template <typename TensorReshapeOp> 1914 struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> { 1915 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; 1916 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, 1917 PatternRewriter &rewriter) const override { 1918 auto fromElements = 1919 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>(); 1920 if (!fromElements) 1921 return failure(); 1922 1923 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType()); 1924 1925 if (!shapedTy.hasStaticShape()) 1926 return failure(); 1927 1928 rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(), 1929 fromElements.getElements()); 1930 return success(); 1931 } 1932 }; 1933 1934 // Fold CastOp into CollapseShapeOp when adding static information. 1935 struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> { 1936 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern; 1937 1938 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp, 1939 PatternRewriter &rewriter) const override { 1940 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>(); 1941 if (!tensor::canFoldIntoConsumerOp(castOp)) 1942 return failure(); 1943 1944 RankedTensorType srcType = 1945 llvm::cast<RankedTensorType>(castOp.getSource().getType()); 1946 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType( 1947 srcType, collapseShapeOp.getReassociationMaps()); 1948 1949 if (newResultType == collapseShapeOp.getResultType()) { 1950 rewriter.modifyOpInPlace(collapseShapeOp, [&]() { 1951 collapseShapeOp.getSrcMutable().assign(castOp.getSource()); 1952 }); 1953 } else { 1954 auto newOp = rewriter.create<CollapseShapeOp>( 1955 collapseShapeOp.getLoc(), newResultType, castOp.getSource(), 1956 collapseShapeOp.getReassociation()); 1957 rewriter.replaceOpWithNewOp<tensor::CastOp>( 1958 collapseShapeOp, collapseShapeOp.getResultType(), newOp); 1959 } 1960 return success(); 1961 } 1962 }; 1963 1964 struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> { 1965 using OpRewritePattern<DimOp>::OpRewritePattern; 1966 1967 LogicalResult matchAndRewrite(DimOp dimOp, 1968 PatternRewriter &rewriter) const override { 1969 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>(); 1970 if (!expandShapeOp) 1971 return failure(); 1972 1973 // Only constant dimension values are supported. 1974 std::optional<int64_t> dim = dimOp.getConstantIndex(); 1975 if (!dim.has_value()) 1976 return failure(); 1977 1978 // Skip static dims. These are folded to constant ops. 1979 RankedTensorType resultType = expandShapeOp.getResultType(); 1980 if (!resultType.isDynamicDim(*dim)) 1981 return failure(); 1982 1983 // Find reassociation group that contains this result dimension. 1984 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim); 1985 1986 // `dim` is the only dynamic dimension in `group`. (Otherwise, the 1987 // ExpandShapeOp would be ambiguous.) 1988 int64_t product = 1; 1989 ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim]; 1990 for (int64_t d : grp) { 1991 if (d != dim) { 1992 assert(!resultType.isDynamicDim(d) && "expected static dim"); 1993 product *= resultType.getDimSize(d); 1994 } 1995 } 1996 1997 // result dim size = src dim size / (product(other dims in reassoc group)) 1998 Value srcDimSz = 1999 rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim); 2000 AffineExpr expr; 2001 bindSymbols(dimOp.getContext(), expr); 2002 rewriter.replaceOpWithNewOp<affine::AffineApplyOp>( 2003 dimOp, expr.floorDiv(product), srcDimSz); 2004 return success(); 2005 } 2006 }; 2007 2008 struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> { 2009 using OpRewritePattern<DimOp>::OpRewritePattern; 2010 2011 LogicalResult matchAndRewrite(DimOp dimOp, 2012 PatternRewriter &rewriter) const override { 2013 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>(); 2014 if (!collapseShapeOp) 2015 return failure(); 2016 2017 // Only constant dimension values are supported. 2018 std::optional<int64_t> dim = dimOp.getConstantIndex(); 2019 if (!dim.has_value() || 2020 dim.value() >= collapseShapeOp.getResultType().getRank()) 2021 return failure(); 2022 2023 // Skip static dims. These are folded to constant ops. 2024 RankedTensorType resultType = collapseShapeOp.getResultType(); 2025 if (!resultType.isDynamicDim(*dim)) 2026 return failure(); 2027 2028 // Get reassociation group of the result dimension. 2029 ReassociationIndices group = 2030 collapseShapeOp.getReassociationIndices()[*dim]; 2031 2032 // result dim size = product(dims in reassoc group) 2033 SmallVector<Value> srcDimSizes; 2034 SmallVector<AffineExpr> syms; 2035 AffineExpr product; 2036 for (const auto &it : llvm::enumerate(group)) { 2037 srcDimSizes.push_back(rewriter.create<DimOp>( 2038 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value())); 2039 syms.push_back(rewriter.getAffineSymbolExpr(it.index())); 2040 product = product ? product * syms.back() : syms.back(); 2041 } 2042 rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product, 2043 srcDimSizes); 2044 return success(); 2045 } 2046 }; 2047 2048 /// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by 2049 /// matching constant output_shape operands of the expand. This makes the 2050 /// `tensor.expand_shape` more static and creates a consumer cast that can be 2051 /// propagated further. 2052 struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> { 2053 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; 2054 2055 LogicalResult matchAndRewrite(ExpandShapeOp expandOp, 2056 PatternRewriter &rewriter) const override { 2057 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>(); 2058 if (!canFoldIntoConsumerOp(castOp)) 2059 return failure(); 2060 2061 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape(); 2062 SmallVector<ReassociationIndices, 4> reassoc = 2063 expandOp.getReassociationIndices(); 2064 2065 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape()); 2066 SmallVector<Value> dynamicOutputShape; 2067 auto outputIt = expandOp.getOutputShape().begin(); 2068 2069 for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) { 2070 for (uint64_t outDim : innerReassoc) { 2071 if (!ShapedType::isDynamic(newOutputShape[outDim])) 2072 continue; 2073 2074 // If the cast's src type is dynamic, don't infer any of the 2075 // corresponding expanded dimensions. `tensor.expand_shape` requires at 2076 // least one of the expanded dimensions to be dynamic if the input is 2077 // dynamic. 2078 Value val = *outputIt; 2079 ++outputIt; 2080 if (ShapedType::isDynamic(castSrcShape[inputDim])) { 2081 dynamicOutputShape.push_back(val); 2082 continue; 2083 } 2084 2085 APInt cst; 2086 if (matchPattern(val, m_ConstantInt(&cst))) { 2087 newOutputShape[outDim] = cst.getSExtValue(); 2088 } else { 2089 dynamicOutputShape.push_back(val); 2090 } 2091 } 2092 } 2093 2094 // Couldn't match any values, nothing to change 2095 if (expandOp.getOutputShape().size() == dynamicOutputShape.size()) 2096 return failure(); 2097 2098 // Calculate the input shape from the output 2099 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l); 2100 for (auto inDim : llvm::seq<int>(0, newInputShape.size())) { 2101 for (auto outDim : reassoc[inDim]) { 2102 auto ofr = newOutputShape[outDim]; 2103 if (ShapedType::isDynamic(ofr)) { 2104 newInputShape[inDim] = ShapedType::kDynamic; 2105 break; 2106 } 2107 newInputShape[inDim] *= ofr; 2108 } 2109 } 2110 2111 SmallVector<OpFoldResult> outputOfr = 2112 getMixedValues(newOutputShape, dynamicOutputShape, rewriter); 2113 auto inputType = RankedTensorType::get( 2114 newInputShape, expandOp.getSrcType().getElementType()); 2115 auto outputType = RankedTensorType::get( 2116 newOutputShape, expandOp.getSrcType().getElementType()); 2117 auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType, 2118 expandOp.getSrc()); 2119 auto newExpand = rewriter.create<ExpandShapeOp>( 2120 expandOp.getLoc(), outputType, inputCast.getResult(), 2121 expandOp.getReassociationIndices(), outputOfr); 2122 rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(), 2123 newExpand.getResult()); 2124 return success(); 2125 } 2126 }; 2127 } // namespace 2128 2129 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 2130 MLIRContext *context) { 2131 results.add< 2132 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>, 2133 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>, 2134 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>, 2135 FoldReshapeWithSplat<ExpandShapeOp>, 2136 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape, 2137 FoldDimOfCollapseShape>(context); 2138 } 2139 2140 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, 2141 MLIRContext *context) { 2142 results.add< 2143 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>, 2144 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp, 2145 tensor::DimOp, RankedTensorType>, 2146 FoldReshapeWithConstant<CollapseShapeOp>, 2147 FoldReshapeWithSplat<CollapseShapeOp>, 2148 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>( 2149 context); 2150 } 2151 2152 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { 2153 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, 2154 adaptor.getOperands()); 2155 } 2156 2157 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { 2158 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, 2159 adaptor.getOperands()); 2160 } 2161 2162 //===----------------------------------------------------------------------===// 2163 // ExtractSliceOp 2164 //===----------------------------------------------------------------------===// 2165 2166 void ExtractSliceOp::getAsmResultNames( 2167 function_ref<void(Value, StringRef)> setNameFn) { 2168 setNameFn(getResult(), "extracted_slice"); 2169 } 2170 2171 /// An extract_slice result type can be inferred, when it is not 2172 /// rank-reduced, from the source type and the static representation of 2173 /// offsets, sizes and strides. Special sentinels encode the dynamic case. 2174 RankedTensorType ExtractSliceOp::inferResultType( 2175 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets, 2176 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) { 2177 // An extract_slice op may specify only a leading subset of offset/sizes/ 2178 // strides in which case we complete with offset=0, sizes from memref type 2179 // and strides=1. 2180 assert(static_cast<int64_t>(staticSizes.size()) == 2181 sourceTensorType.getRank() && 2182 "unexpected staticSizes not equal to rank of source"); 2183 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(), 2184 sourceTensorType.getEncoding()); 2185 } 2186 2187 RankedTensorType ExtractSliceOp::inferResultType( 2188 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, 2189 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { 2190 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2191 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2192 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 2193 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 2194 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 2195 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets, 2196 staticSizes, staticStrides); 2197 } 2198 2199 /// If the rank is reduced (i.e. the desiredResultRank is smaller than the 2200 /// number of sizes), drop as many size 1 as needed to produce an inferred 2201 /// type with the desired rank. 2202 /// 2203 /// Note that there may be multiple ways to compute this rank-reduced type: 2204 /// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors. 2205 /// 2206 /// To disambiguate, this function always drops the first 1 sizes occurrences. 2207 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( 2208 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, 2209 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2210 ArrayRef<int64_t> strides) { 2211 // Type inferred in the absence of rank-reducing behavior. 2212 auto inferredType = llvm::cast<RankedTensorType>( 2213 inferResultType(sourceRankedTensorType, offsets, sizes, strides)); 2214 int rankDiff = inferredType.getRank() - desiredResultRank; 2215 if (rankDiff > 0) { 2216 auto shape = inferredType.getShape(); 2217 llvm::SmallBitVector dimsToProject = 2218 getPositionsOfShapeOne(rankDiff, shape); 2219 SmallVector<int64_t> projectedShape; 2220 // Best effort rank-reducing: drop 1s in order. 2221 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) 2222 if (!dimsToProject.test(pos)) 2223 projectedShape.push_back(shape[pos]); 2224 inferredType = 2225 RankedTensorType::get(projectedShape, inferredType.getElementType()); 2226 } 2227 return inferredType; 2228 } 2229 2230 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType( 2231 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType, 2232 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 2233 ArrayRef<OpFoldResult> strides) { 2234 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2235 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2236 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 2237 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 2238 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 2239 return ExtractSliceOp::inferCanonicalRankReducedResultType( 2240 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes, 2241 staticStrides); 2242 } 2243 2244 /// Build an ExtractSliceOp with mixed static and dynamic entries and custom 2245 /// result type. If the type passed is nullptr, it is inferred. 2246 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 2247 RankedTensorType resultType, Value source, 2248 ArrayRef<OpFoldResult> offsets, 2249 ArrayRef<OpFoldResult> sizes, 2250 ArrayRef<OpFoldResult> strides, 2251 ArrayRef<NamedAttribute> attrs) { 2252 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2253 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2254 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 2255 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 2256 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 2257 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType()); 2258 // Structuring implementation this way avoids duplication between builders. 2259 if (!resultType) { 2260 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType( 2261 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides)); 2262 } 2263 result.addAttributes(attrs); 2264 build(b, result, resultType, source, dynamicOffsets, dynamicSizes, 2265 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), 2266 b.getDenseI64ArrayAttr(staticSizes), 2267 b.getDenseI64ArrayAttr(staticStrides)); 2268 } 2269 2270 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred 2271 /// result type. 2272 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 2273 ArrayRef<OpFoldResult> offsets, 2274 ArrayRef<OpFoldResult> sizes, 2275 ArrayRef<OpFoldResult> strides, 2276 ArrayRef<NamedAttribute> attrs) { 2277 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 2278 } 2279 2280 /// Build an ExtractSliceOp with mixed static and dynamic entries packed into 2281 /// a Range vector. 2282 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 2283 ArrayRef<Range> ranges, 2284 ArrayRef<NamedAttribute> attrs) { 2285 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); 2286 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 2287 } 2288 2289 /// Build an ExtractSliceOp with dynamic entries and custom result type. If 2290 /// the type passed is nullptr, it is inferred. 2291 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, 2292 RankedTensorType resultType, Value source, 2293 ValueRange offsets, ValueRange sizes, 2294 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 2295 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2296 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2297 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2298 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2299 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2300 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2301 build(b, result, resultType, source, offsetValues, sizeValues, strideValues); 2302 } 2303 2304 /// Build an ExtractSliceOp with dynamic entries and inferred result type. 2305 void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, 2306 ValueRange offsets, ValueRange sizes, 2307 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 2308 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); 2309 } 2310 2311 static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, 2312 Operation *op, 2313 RankedTensorType expectedType) { 2314 switch (result) { 2315 case SliceVerificationResult::Success: 2316 return success(); 2317 case SliceVerificationResult::RankTooLarge: 2318 return op->emitError("expected rank to be smaller or equal to ") 2319 << "the other rank. "; 2320 case SliceVerificationResult::SizeMismatch: 2321 return op->emitError("expected type to be ") 2322 << expectedType << " or a rank-reduced version. (size mismatch) "; 2323 case SliceVerificationResult::ElemTypeMismatch: 2324 return op->emitError("expected element type to be ") 2325 << expectedType.getElementType(); 2326 default: 2327 llvm_unreachable("unexpected extract_slice op verification result"); 2328 } 2329 } 2330 2331 /// Verifier for ExtractSliceOp. 2332 LogicalResult ExtractSliceOp::verify() { 2333 // Verify result type against inferred type. 2334 RankedTensorType expectedType = ExtractSliceOp::inferResultType( 2335 getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides()); 2336 SliceVerificationResult result = isRankReducedType(expectedType, getType()); 2337 return produceSliceErrorMsg(result, *this, expectedType); 2338 } 2339 2340 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() { 2341 return ::getDroppedDims(getType().getShape(), getMixedSizes()); 2342 } 2343 2344 FailureOr<Value> 2345 ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value, 2346 ArrayRef<int64_t> desiredShape) { 2347 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType()); 2348 assert(sourceTensorType && "not a ranked tensor type"); 2349 auto sourceShape = sourceTensorType.getShape(); 2350 if (sourceShape.equals(desiredShape)) 2351 return value; 2352 auto maybeRankReductionMask = 2353 mlir::computeRankReductionMask(sourceShape, desiredShape); 2354 if (!maybeRankReductionMask) 2355 return failure(); 2356 return createCanonicalRankReducingExtractSliceOp( 2357 b, loc, value, 2358 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape)); 2359 } 2360 2361 LogicalResult ExtractSliceOp::reifyResultShapes( 2362 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 2363 reifiedReturnShapes.resize(1); 2364 reifiedReturnShapes[0].reserve(getType().getRank()); 2365 SmallVector<OpFoldResult> mixedSizes = getMixedSizes(); 2366 llvm::SmallBitVector droppedDims = getDroppedDims(); 2367 for (const auto &size : enumerate(mixedSizes)) { 2368 if (droppedDims.test(size.index())) 2369 continue; 2370 reifiedReturnShapes[0].push_back(size.value()); 2371 } 2372 return success(); 2373 } 2374 2375 namespace { 2376 /// Pattern to rewrite an extract_slice op with tensor::Cast arguments. 2377 /// This essentially pushes memref_cast past its consuming slice when 2378 /// `canFoldIntoConsumerOp` is true. 2379 /// 2380 /// Example: 2381 /// ``` 2382 /// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32> 2383 /// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to 2384 /// tensor<3x4xf32> 2385 /// ``` 2386 /// is rewritten into: 2387 /// ``` 2388 /// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to 2389 /// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32> 2390 /// ``` 2391 class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> { 2392 public: 2393 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 2394 2395 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 2396 PatternRewriter &rewriter) const override { 2397 // Any constant operand, just return to let the constant folder kick in. 2398 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) { 2399 return matchPattern(operand, matchConstantIndex()); 2400 })) 2401 return failure(); 2402 2403 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>(); 2404 if (!castOp) 2405 return failure(); 2406 2407 if (!canFoldIntoConsumerOp(castOp)) 2408 return failure(); 2409 2410 // Create folded extract. 2411 Location loc = sliceOp.getLoc(); 2412 Value newResult = rewriter.create<ExtractSliceOp>( 2413 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(), 2414 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), 2415 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 2416 if (newResult.getType() != sliceOp.getType()) 2417 newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult); 2418 rewriter.replaceOp(sliceOp, newResult); 2419 return success(); 2420 } 2421 }; 2422 2423 /// Slice elements from `values` into `outValues`. `counts` represents the 2424 /// numbers of elements to stride in the original values for each dimension. 2425 /// The output values can be used to construct a DenseElementsAttr. 2426 template <typename IterTy, typename ElemTy> 2427 static void sliceElements(IterTy values, ArrayRef<int64_t> counts, 2428 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes, 2429 ArrayRef<int64_t> strides, 2430 llvm::SmallVectorImpl<ElemTy> *outValues) { 2431 assert(offsets.size() == sizes.size()); 2432 assert(offsets.size() == strides.size()); 2433 if (offsets.empty()) 2434 return; 2435 2436 int64_t offset = offsets.front(); 2437 int64_t size = sizes.front(); 2438 int64_t stride = strides.front(); 2439 if (offsets.size() == 1) { 2440 for (int64_t i = 0; i < size; ++i, offset += stride) 2441 outValues->push_back(*(values + offset)); 2442 2443 return; 2444 } 2445 2446 for (int64_t i = 0; i < size; ++i, offset += stride) { 2447 auto begin = values + offset * counts.front(); 2448 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(), 2449 offsets.drop_front(), sizes.drop_front(), 2450 strides.drop_front(), outValues); 2451 } 2452 } 2453 2454 /// Fold arith.constant and tensor.extract_slice into arith.constant. The 2455 /// folded operation might introduce more constant data; Users can control 2456 /// their heuristics by the control function. 2457 class ConstantOpExtractSliceFolder final 2458 : public OpRewritePattern<ExtractSliceOp> { 2459 public: 2460 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 2461 2462 ConstantOpExtractSliceFolder(MLIRContext *context, 2463 ControlConstantExtractSliceFusionFn controlFn) 2464 : OpRewritePattern<ExtractSliceOp>(context), 2465 controlFn(std::move(controlFn)) {} 2466 2467 LogicalResult matchAndRewrite(ExtractSliceOp op, 2468 PatternRewriter &rewriter) const override { 2469 DenseElementsAttr attr; 2470 if (!matchPattern(op.getSource(), m_Constant(&attr))) 2471 return failure(); 2472 2473 // A constant splat is handled by fold(). 2474 if (attr.isSplat()) 2475 return failure(); 2476 2477 // Dynamic result shape is not supported. 2478 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType()); 2479 auto resultType = llvm::cast<ShapedType>(op.getResult().getType()); 2480 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) 2481 return failure(); 2482 2483 // Customized control over the folding. 2484 if (!controlFn(op)) 2485 return failure(); 2486 2487 int64_t count = sourceType.getNumElements(); 2488 if (count == 0) 2489 return failure(); 2490 2491 // Check if there are any dynamic parts, which are not supported. 2492 auto offsets = op.getStaticOffsets(); 2493 if (llvm::is_contained(offsets, ShapedType::kDynamic)) 2494 return failure(); 2495 auto sizes = op.getStaticSizes(); 2496 if (llvm::is_contained(sizes, ShapedType::kDynamic)) 2497 return failure(); 2498 auto strides = op.getStaticStrides(); 2499 if (llvm::is_contained(strides, ShapedType::kDynamic)) 2500 return failure(); 2501 2502 // Compute the stride for each dimension. 2503 SmallVector<int64_t> counts; 2504 ArrayRef<int64_t> shape = sourceType.getShape(); 2505 counts.reserve(shape.size()); 2506 for (int64_t v : shape) { 2507 count = count / v; 2508 counts.push_back(count); 2509 } 2510 2511 // New attribute constructed by the sliced values. 2512 DenseElementsAttr newAttr; 2513 2514 if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) { 2515 SmallVector<APInt> outValues; 2516 outValues.reserve(sourceType.getNumElements()); 2517 sliceElements<DenseElementsAttr::IntElementIterator, APInt>( 2518 elems.begin(), counts, offsets, sizes, strides, &outValues); 2519 newAttr = DenseElementsAttr::get(resultType, outValues); 2520 } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) { 2521 SmallVector<APFloat> outValues; 2522 outValues.reserve(sourceType.getNumElements()); 2523 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>( 2524 elems.begin(), counts, offsets, sizes, strides, &outValues); 2525 newAttr = DenseElementsAttr::get(resultType, outValues); 2526 } 2527 2528 if (newAttr) { 2529 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr); 2530 return success(); 2531 } 2532 2533 return failure(); 2534 } 2535 2536 private: 2537 /// This additionally controls whether the fold happens or not. Users can 2538 /// impose their heuristics in the function. 2539 ControlConstantExtractSliceFusionFn controlFn; 2540 }; 2541 2542 } // namespace 2543 2544 void mlir::tensor::populateFoldConstantExtractSlicePatterns( 2545 RewritePatternSet &patterns, 2546 const ControlConstantExtractSliceFusionFn &controlFn) { 2547 patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn); 2548 } 2549 2550 /// Return the canonical type of the result of an extract_slice op. 2551 struct SliceReturnTypeCanonicalizer { 2552 RankedTensorType operator()(ExtractSliceOp op, 2553 ArrayRef<OpFoldResult> mixedOffsets, 2554 ArrayRef<OpFoldResult> mixedSizes, 2555 ArrayRef<OpFoldResult> mixedStrides) { 2556 return ExtractSliceOp::inferCanonicalRankReducedResultType( 2557 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes, 2558 mixedStrides); 2559 } 2560 }; 2561 2562 /// A canonicalizer wrapper to replace ExtractSliceOps. 2563 struct SliceCanonicalizer { 2564 void operator()(PatternRewriter &rewriter, ExtractSliceOp op, 2565 ExtractSliceOp newOp) { 2566 Value replacement = newOp.getResult(); 2567 if (replacement.getType() != op.getType()) 2568 replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), 2569 replacement); 2570 rewriter.replaceOp(op, replacement); 2571 } 2572 }; 2573 2574 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 2575 MLIRContext *context) { 2576 results.add< 2577 OpWithOffsetSizesAndStridesConstantArgumentFolder< 2578 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, 2579 ExtractSliceOpCastFolder>(context); 2580 } 2581 2582 // 2583 static LogicalResult 2584 foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, 2585 ShapedType shapedType) { 2586 OpBuilder b(op.getContext()); 2587 for (OpFoldResult ofr : op.getMixedOffsets()) 2588 if (getConstantIntValue(ofr) != static_cast<int64_t>(0)) 2589 return failure(); 2590 // Rank-reducing noops only need to inspect the leading dimensions: 2591 // llvm::zip is appropriate. 2592 auto shape = shapedType.getShape(); 2593 for (auto it : llvm::zip(op.getMixedSizes(), shape)) 2594 if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it)) 2595 return failure(); 2596 for (OpFoldResult ofr : op.getMixedStrides()) 2597 if (getConstantIntValue(ofr) != static_cast<int64_t>(1)) 2598 return failure(); 2599 return success(); 2600 } 2601 2602 /// If we have an ExtractSliceOp consuming an InsertSliceOp with the same 2603 /// slice, we can return the InsertSliceOp's source directly. 2604 // TODO: This only checks the immediate producer; extend to go up the 2605 // insert/extract chain if the slices are disjoint. 2606 static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) { 2607 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>(); 2608 2609 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; 2610 if (insertOp && insertOp.getSource().getType() == extractOp.getType() && 2611 insertOp.isSameAs(extractOp, isSame)) 2612 return insertOp.getSource(); 2613 2614 return {}; 2615 } 2616 2617 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) { 2618 if (OpFoldResult reshapedSource = reshapeConstantSource( 2619 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()), 2620 getResult().getType())) 2621 return reshapedSource; 2622 if (getSourceType() == getType() && 2623 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 2624 return this->getSource(); 2625 if (Value slice = foldExtractAfterInsertSlice(*this)) 2626 return slice; 2627 2628 return OpFoldResult(); 2629 } 2630 2631 Value mlir::tensor::createCanonicalRankReducingExtractSliceOp( 2632 OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) { 2633 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType()); 2634 unsigned rank = rankedTensorType.getRank(); 2635 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 2636 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor); 2637 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 2638 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor, 2639 offsets, sizes, strides); 2640 } 2641 2642 //===----------------------------------------------------------------------===// 2643 // InsertSliceOp 2644 //===----------------------------------------------------------------------===// 2645 2646 void InsertSliceOp::getAsmResultNames( 2647 function_ref<void(Value, StringRef)> setNameFn) { 2648 setNameFn(getResult(), "inserted_slice"); 2649 } 2650 2651 // Build a InsertSliceOp with mixed static and dynamic entries. 2652 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 2653 Value dest, ArrayRef<OpFoldResult> offsets, 2654 ArrayRef<OpFoldResult> sizes, 2655 ArrayRef<OpFoldResult> strides, 2656 ArrayRef<NamedAttribute> attrs) { 2657 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 2658 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 2659 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 2660 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 2661 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 2662 result.addAttributes(attrs); 2663 build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, 2664 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), 2665 b.getDenseI64ArrayAttr(staticSizes), 2666 b.getDenseI64ArrayAttr(staticStrides)); 2667 } 2668 2669 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a 2670 /// Range vector. 2671 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 2672 Value dest, ArrayRef<Range> ranges, 2673 ArrayRef<NamedAttribute> attrs) { 2674 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); 2675 build(b, result, source, dest, offsets, sizes, strides, attrs); 2676 } 2677 2678 // Build a InsertSliceOp with dynamic entries. 2679 void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, 2680 Value dest, ValueRange offsets, ValueRange sizes, 2681 ValueRange strides, ArrayRef<NamedAttribute> attrs) { 2682 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 2683 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 2684 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 2685 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 2686 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 2687 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 2688 build(b, result, source, dest, offsetValues, sizeValues, strideValues); 2689 } 2690 2691 /// Rank-reducing type verification for both InsertSliceOp and 2692 /// ParallelInsertSliceOp. 2693 static SliceVerificationResult verifyInsertSliceOp( 2694 RankedTensorType srcType, RankedTensorType dstType, 2695 ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes, 2696 ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) { 2697 // insert_slice is the inverse of extract_slice, use the same type 2698 // inference. 2699 RankedTensorType expected = ExtractSliceOp::inferResultType( 2700 dstType, staticOffsets, staticSizes, staticStrides); 2701 if (expectedType) 2702 *expectedType = expected; 2703 return isRankReducedType(expected, srcType); 2704 } 2705 2706 /// Verifier for InsertSliceOp. 2707 LogicalResult InsertSliceOp::verify() { 2708 RankedTensorType expectedType; 2709 SliceVerificationResult result = 2710 verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(), 2711 getStaticSizes(), getStaticStrides(), &expectedType); 2712 return produceSliceErrorMsg(result, *this, expectedType); 2713 } 2714 2715 /// If we have two consecutive InsertSliceOp writing to the same slice, we 2716 /// can mutate the second InsertSliceOp's destination to the first one's. 2717 /// 2718 /// Example: 2719 /// 2720 /// ```mlir 2721 /// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1] 2722 /// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1] 2723 /// ``` 2724 /// 2725 /// folds into: 2726 /// 2727 /// ```mlir 2728 /// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1] 2729 /// ``` 2730 /// 2731 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. 2732 static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) { 2733 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>(); 2734 2735 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; 2736 if (!prevInsertOp || 2737 prevInsertOp.getSource().getType() != insertOp.getSource().getType() || 2738 !prevInsertOp.isSameAs(insertOp, isSame)) 2739 return failure(); 2740 2741 insertOp.getDestMutable().assign(prevInsertOp.getDest()); 2742 return success(); 2743 } 2744 2745 /// Folds round-trip extract/insert slice op pairs. 2746 /// Example: 2747 /// ```mlir 2748 /// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] 2749 /// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] 2750 /// ``` 2751 /// can be folded into %val. 2752 static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) { 2753 auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>(); 2754 2755 auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; }; 2756 if (!extractOp || extractOp.getSource() != insertOp.getDest() || 2757 !extractOp.isSameAs(insertOp, isSame)) 2758 return nullptr; 2759 2760 return extractOp.getSource(); 2761 } 2762 2763 OpFoldResult InsertSliceOp::fold(FoldAdaptor) { 2764 if (getSourceType().hasStaticShape() && getType().hasStaticShape() && 2765 getSourceType() == getType() && 2766 succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) 2767 return this->getSource(); 2768 if (succeeded(foldInsertAfterInsertSlice(*this))) 2769 return getResult(); 2770 if (auto result = foldInsertAfterExtractSlice(*this)) 2771 return result; 2772 if (llvm::any_of(getMixedSizes(), 2773 [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); })) 2774 return getDest(); 2775 return OpFoldResult(); 2776 } 2777 2778 LogicalResult InsertSliceOp::reifyResultShapes( 2779 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 2780 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank())); 2781 reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest()); 2782 return success(); 2783 } 2784 2785 namespace { 2786 /// Pattern to rewrite a insert_slice op with constant arguments. 2787 /// 2788 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. 2789 template <typename InsertOpTy> 2790 class InsertSliceOpConstantArgumentFolder final 2791 : public OpRewritePattern<InsertOpTy> { 2792 public: 2793 using OpRewritePattern<InsertOpTy>::OpRewritePattern; 2794 2795 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, 2796 PatternRewriter &rewriter) const override { 2797 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets()); 2798 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes()); 2799 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides()); 2800 2801 // No constant operands were folded, just return; 2802 if (failed(foldDynamicOffsetSizeList(mixedOffsets)) && 2803 failed(foldDynamicOffsetSizeList(mixedSizes)) && 2804 failed(foldDynamicStrideList(mixedStrides))) 2805 return failure(); 2806 2807 // Create the new op in canonical form. 2808 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( 2809 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), 2810 mixedOffsets, mixedSizes, mixedStrides); 2811 Value toInsert = insertSliceOp.getSource(); 2812 if (sourceType != insertSliceOp.getSourceType()) { 2813 OpBuilder::InsertionGuard g(rewriter); 2814 // The only difference between InsertSliceOp and ParallelInsertSliceOp 2815 // is that the insertion point is just before the ParallelCombiningOp in 2816 // the parallel case. 2817 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value) 2818 rewriter.setInsertionPoint(insertSliceOp->getParentOp()); 2819 toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(), 2820 sourceType, toInsert); 2821 } 2822 rewriter.replaceOpWithNewOp<InsertOpTy>( 2823 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets, 2824 mixedSizes, mixedStrides); 2825 return success(); 2826 } 2827 }; 2828 2829 /// Fold tensor_casts with insert_slice operations. If the source or 2830 /// destination tensor is a tensor_cast that removes static type information, 2831 /// the cast is folded into the insert_slice operation. E.g.: 2832 /// 2833 /// ```mlir 2834 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 2835 /// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ... 2836 /// ``` 2837 /// 2838 /// folds into: 2839 /// 2840 /// ```mlir 2841 /// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ... 2842 /// ``` 2843 /// 2844 /// Note: When folding a cast on the destination tensor, the result of the 2845 /// insert_slice operation is casted to ensure that the type of the result did 2846 /// not change. 2847 /// 2848 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp. 2849 template <typename InsertOpTy> 2850 struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> { 2851 using OpRewritePattern<InsertOpTy>::OpRewritePattern; 2852 2853 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, 2854 PatternRewriter &rewriter) const override { 2855 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) { 2856 return matchPattern(operand, matchConstantIndex()); 2857 })) 2858 return failure(); 2859 2860 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> { 2861 auto castOp = v.getDefiningOp<tensor::CastOp>(); 2862 if (!castOp || !canFoldIntoConsumerOp(castOp)) 2863 return std::nullopt; 2864 return castOp.getSource(); 2865 }; 2866 std::optional<Value> sourceCastSource = 2867 getSourceOfCastOp(insertSliceOp.getSource()); 2868 std::optional<Value> destCastSource = 2869 getSourceOfCastOp(insertSliceOp.getDest()); 2870 if (!sourceCastSource && !destCastSource) 2871 return failure(); 2872 2873 auto src = 2874 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource()); 2875 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest()); 2876 auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType()); 2877 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType()); 2878 if (!srcType || !dstType) 2879 return failure(); 2880 2881 // The tensor.cast source could have additional static information not seen 2882 // in the insert slice op static sizes, so we ignore dynamic dims when 2883 // computing the rank reduction mask. 2884 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes()); 2885 auto rankReductionMask = computeRankReductionMask( 2886 staticSizes, srcType.getShape(), /*matchDynamic=*/true); 2887 if (!rankReductionMask.has_value()) 2888 return failure(); 2889 // Replace dimensions in the insert slice op with corresponding static dims 2890 // from the cast source type. If the insert slice sizes have static dims 2891 // that are not static in the tensor.cast source (i.e., when the cast op 2892 // casts a dynamic dim to static), the dim should not be replaced, and the 2893 // pattern will fail later in `verifyInsertSliceOp`. 2894 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes()); 2895 int64_t rankReducedIdx = 0; 2896 for (auto [idx, size] : enumerate(staticSizes)) { 2897 if (!rankReductionMask.value().contains(idx) && 2898 !srcType.isDynamicDim(rankReducedIdx)) { 2899 mixedSizes[idx] = getAsIndexOpFoldResult( 2900 rewriter.getContext(), srcType.getDimSize(rankReducedIdx)); 2901 size = srcType.getDimSize(rankReducedIdx++); 2902 } 2903 } 2904 if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(), 2905 staticSizes, insertSliceOp.getStaticStrides()) != 2906 SliceVerificationResult::Success) 2907 return failure(); 2908 2909 Operation *replacement = rewriter.create<InsertOpTy>( 2910 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(), 2911 mixedSizes, insertSliceOp.getMixedStrides()); 2912 2913 // In the parallel case there is no result and so nothing to cast. 2914 bool isParallelInsert = 2915 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value; 2916 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) { 2917 replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(), 2918 insertSliceOp.getDestType(), 2919 replacement->getResult(0)); 2920 } 2921 rewriter.replaceOp(insertSliceOp, replacement->getResults()); 2922 return success(); 2923 } 2924 }; 2925 2926 /// If additional static type information can be deduced from a insert_slice's 2927 /// size operands, insert an explicit cast of the op's source operand. This 2928 /// enables other canonicalization patterns that are matching for tensor_cast 2929 /// ops such as `ForOpTensorCastFolder` in SCF. 2930 /// 2931 /// Example: 2932 /// 2933 /// ```mlir 2934 /// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1] 2935 /// : tensor<?x?xf32> into ... 2936 /// ``` 2937 /// 2938 /// folds into: 2939 /// 2940 /// ```mlir 2941 /// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32> 2942 /// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1] 2943 /// : tensor<64x64xf32> into ... 2944 /// ``` 2945 /// 2946 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. 2947 template <typename InsertOpTy> 2948 struct InsertSliceOpSourceCastInserter final 2949 : public OpRewritePattern<InsertOpTy> { 2950 using OpRewritePattern<InsertOpTy>::OpRewritePattern; 2951 2952 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, 2953 PatternRewriter &rewriter) const override { 2954 RankedTensorType srcType = insertSliceOp.getSourceType(); 2955 if (srcType.getRank() != insertSliceOp.getDestType().getRank()) 2956 return failure(); 2957 SmallVector<int64_t> newSrcShape(srcType.getShape()); 2958 for (int64_t i = 0; i < srcType.getRank(); ++i) { 2959 if (std::optional<int64_t> constInt = 2960 getConstantIntValue(insertSliceOp.getMixedSizes()[i])) { 2961 // Bail on invalid IR. 2962 if (*constInt < 0) 2963 return failure(); 2964 newSrcShape[i] = *constInt; 2965 } 2966 } 2967 if (!hasValidSizesOffsets(newSrcShape)) 2968 return failure(); 2969 2970 RankedTensorType newSrcType = RankedTensorType::get( 2971 newSrcShape, srcType.getElementType(), srcType.getEncoding()); 2972 if (srcType == newSrcType || 2973 !preservesStaticInformation(srcType, newSrcType) || 2974 !tensor::CastOp::areCastCompatible(srcType, newSrcType)) 2975 return failure(); 2976 2977 // newSrcType is: 2978 // 1) Different from srcType. 2979 // 2) "More static" than srcType. 2980 // 3) Cast-compatible with srcType. 2981 // Insert the cast. 2982 OpBuilder::InsertionGuard g(rewriter); 2983 // The only difference between InsertSliceOp and ParallelInsertSliceOp is 2984 // that the insertion point is just before the ParallelCombiningOp in the 2985 // parallel case. 2986 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value) 2987 rewriter.setInsertionPoint(insertSliceOp->getParentOp()); 2988 Value cast = rewriter.create<tensor::CastOp>( 2989 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource()); 2990 rewriter.replaceOpWithNewOp<InsertOpTy>( 2991 insertSliceOp, cast, insertSliceOp.getDest(), 2992 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 2993 insertSliceOp.getMixedStrides()); 2994 return success(); 2995 } 2996 }; 2997 } // namespace 2998 2999 llvm::SmallBitVector InsertSliceOp::getDroppedDims() { 3000 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes()); 3001 } 3002 3003 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, 3004 MLIRContext *context) { 3005 results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>, 3006 InsertSliceOpCastFolder<InsertSliceOp>, 3007 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context); 3008 } 3009 3010 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b, 3011 Location loc, 3012 Value tensor, 3013 Value dest) { 3014 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType()); 3015 unsigned rank = rankedTensorType.getRank(); 3016 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0)); 3017 SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest); 3018 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1)); 3019 return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets, 3020 sizes, strides); 3021 } 3022 3023 //===----------------------------------------------------------------------===// 3024 // PadOp 3025 //===----------------------------------------------------------------------===// 3026 3027 void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 3028 setNameFn(getResult(), "padded"); 3029 } 3030 3031 // TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it 3032 // supports optional types. 3033 void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, 3034 Type typeToInfer, Type typeToInferFrom) {} 3035 3036 ParseResult 3037 parseInferType(OpAsmParser &parser, 3038 std::optional<OpAsmParser::UnresolvedOperand> optOperand, 3039 Type &typeToInfer, Type typeToInferFrom) { 3040 if (optOperand) 3041 typeToInfer = typeToInferFrom; 3042 return success(); 3043 } 3044 3045 LogicalResult PadOp::verify() { 3046 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType()); 3047 auto resultType = llvm::cast<RankedTensorType>(getResult().getType()); 3048 auto expectedType = 3049 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh()); 3050 if (!expectedType) { 3051 return emitError("failed to infer expectedType from sourceType ") 3052 << sourceType << ", specified resultType is " << resultType; 3053 } 3054 if (resultType.getRank() != expectedType.getRank()) { 3055 return emitError("specified type ") 3056 << resultType << " does not match the inferred type " 3057 << expectedType; 3058 } 3059 for (int i = 0, e = sourceType.getRank(); i < e; ++i) { 3060 if (resultType.getDimSize(i) == expectedType.getDimSize(i)) 3061 continue; 3062 if (expectedType.isDynamicDim(i)) 3063 continue; 3064 return emitError("specified type ") 3065 << resultType << " does not match the inferred type " 3066 << expectedType; 3067 } 3068 3069 return success(); 3070 } 3071 3072 LogicalResult PadOp::verifyRegions() { 3073 auto ®ion = getRegion(); 3074 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank(); 3075 Block &block = region.front(); 3076 if (block.getNumArguments() != rank) 3077 return emitError("expected the block to have ") << rank << " arguments"; 3078 3079 // Note: the number and type of yield values are checked in the YieldOp. 3080 for (const auto &en : llvm::enumerate(block.getArgumentTypes())) { 3081 if (!en.value().isIndex()) 3082 return emitOpError("expected block argument ") 3083 << (en.index() + 1) << " to be an index"; 3084 } 3085 3086 // Ensure that the region yields an element of the right type. 3087 auto yieldOp = llvm::cast<YieldOp>(block.getTerminator()); 3088 if (yieldOp.getValue().getType() != 3089 llvm::cast<ShapedType>(getType()).getElementType()) 3090 return emitOpError("expected yield type to match shape element type"); 3091 3092 return success(); 3093 } 3094 3095 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType, 3096 ArrayRef<int64_t> staticLow, 3097 ArrayRef<int64_t> staticHigh, 3098 ArrayRef<int64_t> resultShape) { 3099 unsigned rank = sourceType.getRank(); 3100 if (staticLow.size() != rank) 3101 return RankedTensorType(); 3102 if (staticHigh.size() != rank) 3103 return RankedTensorType(); 3104 if (!resultShape.empty() && resultShape.size() != rank) 3105 return RankedTensorType(); 3106 3107 SmallVector<int64_t, 4> inferredShape; 3108 for (auto i : llvm::seq<unsigned>(0, rank)) { 3109 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic || 3110 staticHigh[i] == ShapedType::kDynamic) { 3111 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic 3112 : resultShape[i]); 3113 } else { 3114 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i]; 3115 assert((resultShape.empty() || size == resultShape[i] || 3116 resultShape[i] == ShapedType::kDynamic) && 3117 "mismatch between inferred shape and result shape"); 3118 inferredShape.push_back(size); 3119 } 3120 } 3121 3122 return RankedTensorType::get(inferredShape, sourceType.getElementType()); 3123 } 3124 3125 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, 3126 Value source, ArrayRef<int64_t> staticLow, 3127 ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high, 3128 bool nofold, ArrayRef<NamedAttribute> attrs) { 3129 auto sourceType = llvm::cast<RankedTensorType>(source.getType()); 3130 if (!resultType) 3131 resultType = inferResultType(sourceType, staticLow, staticHigh); 3132 result.addAttributes(attrs); 3133 build(b, result, resultType, source, low, high, 3134 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), 3135 nofold ? b.getUnitAttr() : UnitAttr()); 3136 } 3137 3138 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, 3139 Value source, ValueRange low, ValueRange high, bool nofold, 3140 ArrayRef<NamedAttribute> attrs) { 3141 auto sourceType = llvm::cast<RankedTensorType>(source.getType()); 3142 unsigned rank = sourceType.getRank(); 3143 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic); 3144 build(b, result, resultType, source, staticVector, staticVector, low, high, 3145 nofold, attrs); 3146 } 3147 3148 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, 3149 Value source, ArrayRef<OpFoldResult> low, 3150 ArrayRef<OpFoldResult> high, bool nofold, 3151 ArrayRef<NamedAttribute> attrs) { 3152 auto sourceType = llvm::cast<RankedTensorType>(source.getType()); 3153 SmallVector<Value, 4> dynamicLow, dynamicHigh; 3154 SmallVector<int64_t, 4> staticLow, staticHigh; 3155 // staticLow and staticHigh have full information of the padding config. 3156 // This will grow staticLow and staticHigh with 1 value. If the config is 3157 // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 3158 // value as well. 3159 dispatchIndexOpFoldResults(low, dynamicLow, staticLow); 3160 dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh); 3161 if (!resultType) { 3162 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh); 3163 } 3164 assert(llvm::isa<RankedTensorType>(resultType)); 3165 result.addAttributes(attrs); 3166 build(b, result, resultType, source, dynamicLow, dynamicHigh, 3167 b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), 3168 nofold ? b.getUnitAttr() : UnitAttr()); 3169 } 3170 3171 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType, 3172 Value source, ArrayRef<OpFoldResult> low, 3173 ArrayRef<OpFoldResult> high, Value constantPadValue, 3174 bool nofold, ArrayRef<NamedAttribute> attrs) { 3175 build(b, result, resultType, source, low, high, nofold, attrs); 3176 3177 // Add a region and a block to yield the pad value. 3178 Region *region = result.regions[0].get(); 3179 int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank(); 3180 SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType()); 3181 SmallVector<Location> blockArgLocs(sourceRank, result.location); 3182 3183 // `builder.createBlock` changes the insertion point within the block. Create 3184 // a guard to reset the insertion point of the builder after it is destroyed. 3185 OpBuilder::InsertionGuard guard(b); 3186 b.createBlock(region, region->end(), blockArgTypes, blockArgLocs); 3187 b.create<tensor::YieldOp>(result.location, constantPadValue); 3188 } 3189 3190 llvm::SmallBitVector PadOp::getPaddedDims() { 3191 llvm::SmallBitVector paddedDims(getSourceType().getRank()); 3192 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) { 3193 for (const auto &en : enumerate(paddingWidths)) 3194 if (getConstantIntValue(en.value()) != static_cast<int64_t>(0)) 3195 paddedDims.set(en.index()); 3196 }; 3197 extractPaddedDims(getMixedLowPad()); 3198 extractPaddedDims(getMixedHighPad()); 3199 return paddedDims; 3200 } 3201 3202 namespace { 3203 // Folds tensor.pad when padding is static zeros and the attribute 3204 // doesn't request otherwise. 3205 struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> { 3206 using OpRewritePattern<PadOp>::OpRewritePattern; 3207 3208 LogicalResult matchAndRewrite(PadOp padTensorOp, 3209 PatternRewriter &rewriter) const override { 3210 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad()) 3211 return failure(); 3212 if (padTensorOp.getNofold()) 3213 return failure(); 3214 rewriter.replaceOpWithNewOp<tensor::CastOp>( 3215 padTensorOp, padTensorOp.getResult().getType(), 3216 padTensorOp.getSource()); 3217 return success(); 3218 } 3219 }; 3220 3221 // Fold CastOp into PadOp when adding static information. 3222 struct FoldSourceTensorCast : public OpRewritePattern<PadOp> { 3223 using OpRewritePattern<PadOp>::OpRewritePattern; 3224 3225 LogicalResult matchAndRewrite(PadOp padTensorOp, 3226 PatternRewriter &rewriter) const override { 3227 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>(); 3228 if (!tensor::canFoldIntoConsumerOp(castOp)) 3229 return failure(); 3230 3231 auto newResultType = PadOp::inferResultType( 3232 llvm::cast<RankedTensorType>(castOp.getSource().getType()), 3233 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), 3234 padTensorOp.getResultType().getShape()); 3235 3236 if (newResultType == padTensorOp.getResultType()) { 3237 rewriter.modifyOpInPlace(padTensorOp, [&]() { 3238 padTensorOp.getSourceMutable().assign(castOp.getSource()); 3239 }); 3240 } else { 3241 auto newOp = rewriter.create<PadOp>( 3242 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(), 3243 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), 3244 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(), 3245 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); 3246 IRMapping mapper; 3247 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); 3248 3249 rewriter.replaceOpWithNewOp<tensor::CastOp>( 3250 padTensorOp, padTensorOp.getResultType(), newOp); 3251 } 3252 return success(); 3253 } 3254 }; 3255 3256 // Fold CastOp using the result of PadOp back into the latter if it adds 3257 // static information. 3258 struct FoldTargetTensorCast : public OpRewritePattern<PadOp> { 3259 using OpRewritePattern<PadOp>::OpRewritePattern; 3260 3261 LogicalResult matchAndRewrite(PadOp padTensorOp, 3262 PatternRewriter &rewriter) const override { 3263 if (!padTensorOp.getResult().hasOneUse()) 3264 return failure(); 3265 auto tensorCastOp = 3266 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin()); 3267 if (!tensorCastOp) 3268 return failure(); 3269 if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(), 3270 tensorCastOp.getDest().getType())) 3271 return failure(); 3272 3273 auto replacementOp = rewriter.create<PadOp>( 3274 padTensorOp.getLoc(), tensorCastOp.getDest().getType(), 3275 padTensorOp.getSource(), padTensorOp.getStaticLow(), 3276 padTensorOp.getStaticHigh(), padTensorOp.getLow(), 3277 padTensorOp.getHigh(), padTensorOp.getNofold(), 3278 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); 3279 replacementOp.getRegion().takeBody(padTensorOp.getRegion()); 3280 3281 rewriter.replaceOp(padTensorOp, replacementOp.getResult()); 3282 rewriter.replaceOp(tensorCastOp, replacementOp.getResult()); 3283 return success(); 3284 } 3285 }; 3286 3287 /// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad 3288 /// different dimensions. The pattern applies if the following preconditions 3289 /// hold: 3290 /// 1) the tensor::ExtractSliceOps are not rank-reducing, 3291 /// 2) the tensor::ExtractSliceOps have only unit-strides, 3292 /// 3) the tensor::PadOps perform only high-padding, 3293 /// 4) the tensor::PadOps have the same constant padding value, 3294 /// 5) the tensor::PadOps do not have common padding dimensions, 3295 /// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and 3296 /// zero-offset for every dimension. 3297 /// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for 3298 /// the 3299 /// padded source dimensions. 3300 /// 3301 /// Example: 3302 /// 3303 /// ```mlir 3304 /// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1] 3305 /// : tensor<64x64xf32> to tensor<?x64xf32> 3306 /// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ... 3307 /// } : tensor<?x64xf32> to tensor<8x64xf32> 3308 /// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] 3309 /// : tensor<8x64xf32> to tensor<8x?xf32> 3310 /// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ... 3311 /// } : tensor<8x?xf32> to tensor<8x4xf32> 3312 /// ``` 3313 /// 3314 /// folds into: 3315 /// 3316 /// ```mlir 3317 /// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1] 3318 /// : tensor<64x64xf32> to tensor<?x?xf32> 3319 /// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ... 3320 /// } : tensor<?x?xf32> to tensor<8x4xf32> 3321 /// ``` 3322 struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> { 3323 using OpRewritePattern<PadOp>::OpRewritePattern; 3324 3325 LogicalResult matchAndRewrite(PadOp padOp, 3326 PatternRewriter &rewriter) const override { 3327 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>(); 3328 if (!innerSliceOp) 3329 return failure(); 3330 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>(); 3331 if (!outerPadOp || outerPadOp.getNofold()) 3332 return failure(); 3333 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>(); 3334 if (!outerSliceOp) 3335 return failure(); 3336 3337 // 1) Fail if the chain is rank-reducing. 3338 int64_t rank = padOp.getSourceType().getRank(); 3339 if (outerSliceOp.getSourceType().getRank() != rank) { 3340 return rewriter.notifyMatchFailure(padOp, 3341 "cannot fold rank-reducing chain"); 3342 } 3343 3344 // 2) Fail if the tensor::ExtractSliceOps have non-unit strides. 3345 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) { 3346 return rewriter.notifyMatchFailure( 3347 padOp, "cannot fold non-unit stride ExtractSliceOps"); 3348 } 3349 3350 // 3) Fail if the tensor::PadOps have non-zero low padding. 3351 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) { 3352 return rewriter.notifyMatchFailure(padOp, 3353 "cannot fold PadOps with low padding"); 3354 } 3355 3356 // 4) Fail if the tensor::PadOps padding values do not match. 3357 Attribute innerAttr, outerAttr; 3358 Value innerValue = padOp.getConstantPaddingValue(); 3359 Value outerValue = outerPadOp.getConstantPaddingValue(); 3360 if (!innerValue || !outerValue || 3361 !matchPattern(innerValue, m_Constant(&innerAttr)) || 3362 !matchPattern(outerValue, m_Constant(&outerAttr)) || 3363 innerAttr != outerAttr) { 3364 return rewriter.notifyMatchFailure( 3365 padOp, "cannot fold PadOps with different padding values"); 3366 } 3367 3368 // 5) Fail if a dimension is padded by both tensor::PadOps. 3369 llvm::SmallBitVector innerDims = padOp.getPaddedDims(); 3370 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims(); 3371 if (innerDims.anyCommon(outerDims)) { 3372 return rewriter.notifyMatchFailure( 3373 padOp, "cannot fold PadOps with common padding dimensions"); 3374 } 3375 3376 // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the 3377 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair 3378 // for every dimension, and use the offset the other pair. Fail if no 3379 // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair 3380 // exists. 3381 SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0)); 3382 for (auto en : enumerate(newOffsets)) { 3383 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()]; 3384 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()]; 3385 if (!innerDims.test(en.index()) && 3386 (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) { 3387 en.value() = outerOffset; 3388 continue; 3389 } 3390 if (!outerDims.test(en.index()) && 3391 (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) { 3392 en.value() = innerOffset; 3393 continue; 3394 } 3395 return rewriter.notifyMatchFailure( 3396 padOp, "cannot find zero-offset and zero-padding pair"); 3397 } 3398 3399 // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size 3400 // of the outer tensor::ExtractSliceOp for the dimensions padded by the 3401 // outer tensor::PadOp and fail if the size of the inner 3402 // tensor::ExtractSliceOp does not match the size of the padded dimension. 3403 // Otherwise, take the size of the inner tensor::ExtractSliceOp. 3404 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes(); 3405 for (auto en : enumerate(newSizes)) { 3406 if (!outerDims.test(en.index())) 3407 continue; 3408 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()]; 3409 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()]; 3410 assert(!ShapedType::isDynamic(sourceSize) && 3411 "expected padded dimension to have a static size"); 3412 if (getConstantIntValue(sliceSize) != sourceSize) { 3413 return rewriter.notifyMatchFailure( 3414 padOp, "cannot fold since the inner ExtractSliceOp size does not " 3415 "match the size of the outer padding"); 3416 } 3417 en.value() = outerSliceOp.getMixedSizes()[en.index()]; 3418 } 3419 3420 // Combine the high paddings of the two tensor::PadOps. 3421 SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0)); 3422 for (auto en : enumerate(newHighPad)) { 3423 if (innerDims.test(en.index())) 3424 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()]; 3425 if (outerDims.test(en.index())) 3426 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()]; 3427 } 3428 3429 // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs 3430 // the two paddings in one step. 3431 auto newSliceOp = rewriter.create<ExtractSliceOp>( 3432 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes, 3433 innerSliceOp.getMixedStrides()); 3434 auto newPadOp = rewriter.create<PadOp>( 3435 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(), 3436 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(), 3437 getPrunedAttributeList(padOp, PadOp::getAttributeNames())); 3438 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), 3439 newPadOp.getRegion().begin()); 3440 rewriter.replaceOp(padOp, newPadOp.getResult()); 3441 return success(); 3442 } 3443 }; 3444 3445 struct FoldStaticPadding : public OpRewritePattern<PadOp> { 3446 using OpRewritePattern<PadOp>::OpRewritePattern; 3447 3448 LogicalResult matchAndRewrite(PadOp padTensorOp, 3449 PatternRewriter &rewriter) const override { 3450 Value input = padTensorOp.getSource(); 3451 if (!llvm::isa<RankedTensorType>(input.getType())) 3452 return failure(); 3453 auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape(); 3454 auto inputRank = inputDims.size(); 3455 3456 auto oldResultType = 3457 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType()); 3458 if (!oldResultType) 3459 return failure(); 3460 3461 auto outputDims = oldResultType.getShape(); 3462 3463 // Extract the static info from the high and low operands. 3464 SmallVector<int64_t> constOperandsLow; 3465 SmallVector<Value> newLows; 3466 for (auto operand : padTensorOp.getLow()) { 3467 APSInt intOp; 3468 if (!matchPattern(operand, m_ConstantInt(&intOp))) { 3469 constOperandsLow.push_back(ShapedType::kDynamic); 3470 newLows.push_back(operand); 3471 continue; 3472 } 3473 constOperandsLow.push_back(intOp.getExtValue()); 3474 } 3475 SmallVector<int64_t> constOperandsHigh; 3476 SmallVector<Value> newHighs; 3477 for (auto operand : padTensorOp.getHigh()) { 3478 APSInt intOp; 3479 if (!matchPattern(operand, m_ConstantInt(&intOp))) { 3480 constOperandsHigh.push_back(ShapedType::kDynamic); 3481 newHighs.push_back(operand); 3482 continue; 3483 } 3484 constOperandsHigh.push_back(intOp.getExtValue()); 3485 } 3486 3487 SmallVector<int64_t> constLow(padTensorOp.getStaticLow()); 3488 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh()); 3489 3490 // Verify the op is well-formed. 3491 if (inputDims.size() != outputDims.size() || 3492 inputDims.size() != constLow.size() || 3493 inputDims.size() != constHigh.size()) 3494 return failure(); 3495 3496 auto lowCount = 0; 3497 auto highCount = 0; 3498 for (size_t i = 0; i < inputRank; i++) { 3499 if (constLow[i] == ShapedType::kDynamic) 3500 constLow[i] = constOperandsLow[lowCount++]; 3501 if (constHigh[i] == ShapedType::kDynamic) 3502 constHigh[i] = constOperandsHigh[highCount++]; 3503 } 3504 3505 auto staticLow = ArrayRef<int64_t>(constLow); 3506 auto staticHigh = ArrayRef<int64_t>(constHigh); 3507 3508 // Calculate the output sizes with the static information. 3509 SmallVector<int64_t> newOutDims; 3510 for (size_t i = 0; i < inputRank; i++) { 3511 if (outputDims[i] == ShapedType::kDynamic) { 3512 newOutDims.push_back( 3513 (staticLow[i] == ShapedType::kDynamic || 3514 staticHigh[i] == ShapedType::kDynamic || 3515 inputDims[i] == ShapedType::kDynamic 3516 ? ShapedType::kDynamic 3517 : inputDims[i] + staticLow[i] + staticHigh[i])); 3518 } else { 3519 newOutDims.push_back(outputDims[i]); 3520 } 3521 } 3522 3523 if (SmallVector<int64_t>(outputDims) == newOutDims || 3524 llvm::all_of(newOutDims, 3525 [&](int64_t x) { return x == ShapedType::kDynamic; })) 3526 return failure(); 3527 3528 // Rewrite the op using the new static type. 3529 auto newResultType = RankedTensorType::get( 3530 newOutDims, padTensorOp.getType().getElementType()); 3531 auto newOp = rewriter.create<PadOp>( 3532 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh, 3533 newLows, newHighs, padTensorOp.getNofold(), 3534 getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames())); 3535 3536 IRMapping mapper; 3537 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper); 3538 rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType, 3539 newOp); 3540 3541 return success(); 3542 } 3543 }; 3544 3545 /// Folds a chain of `tensor.pad` ops with the same constant padding value. 3546 /// 3547 /// Example: 3548 /// 3549 /// ```mlir 3550 /// %1 = tensor.pad %0 low[0, 1] high[0, 2] { 3551 /// tensor.yield %val 3552 /// } : tensor<1x2xf32> to tensor<2x5xf32> 3553 /// %res = tensor.pad %1 low[0, 2] high[3, 0] { 3554 /// tensor.yield %val 3555 /// } : tensor<1x5xf32> to tensor<5x7xf32> 3556 /// ``` 3557 /// 3558 /// folds into: 3559 /// 3560 /// ```mlir 3561 /// %res = tensor.pad %0 low[0, 3] high[3, 2] { 3562 /// tensor.yield %val 3563 /// } : tensor<1x2xf32> to tensor<5x7xf32> 3564 /// ``` 3565 struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> { 3566 using OpRewritePattern<tensor::PadOp>::OpRewritePattern; 3567 3568 LogicalResult matchAndRewrite(tensor::PadOp padOp, 3569 PatternRewriter &rewriter) const override { 3570 if (padOp.getNofold()) { 3571 return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad"); 3572 } 3573 3574 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>(); 3575 if (!producerPad || producerPad.getNofold()) { 3576 return rewriter.notifyMatchFailure( 3577 padOp, "producer is not a foldable tensor.pad op"); 3578 } 3579 3580 // Fail if the tensor::PadOps padding values do not match. 3581 Value consumerPadValue = padOp.getConstantPaddingValue(); 3582 Value producerPadValue = producerPad.getConstantPaddingValue(); 3583 if (!consumerPadValue || !producerPadValue || 3584 consumerPadValue != producerPadValue) { 3585 return rewriter.notifyMatchFailure( 3586 padOp, 3587 "cannot fold PadOps with different or non-constant padding values"); 3588 } 3589 3590 Location loc = padOp.getLoc(); 3591 AffineExpr d0, d1; 3592 bindDims(rewriter.getContext(), d0, d1); 3593 3594 // Combine the low/high paddings of the two tensor::PadOps. 3595 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings, 3596 ArrayRef<OpFoldResult> producerPaddings) { 3597 SmallVector<OpFoldResult> sumPaddings; 3598 for (auto [consumerIndex, producerIndex] : 3599 llvm::zip_equal(consumerPaddings, producerPaddings)) { 3600 sumPaddings.push_back(affine::makeComposedFoldedAffineApply( 3601 rewriter, loc, d0 + d1, {consumerIndex, producerIndex})); 3602 } 3603 return sumPaddings; 3604 }; 3605 3606 SmallVector<OpFoldResult> newHighPad = 3607 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad()); 3608 SmallVector<OpFoldResult> newLowPad = 3609 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad()); 3610 3611 auto newPadOp = rewriter.create<tensor::PadOp>( 3612 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(), 3613 newLowPad, newHighPad, padOp.getNofold(), 3614 getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames())); 3615 rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), 3616 newPadOp.getRegion().begin()); 3617 rewriter.replaceOp(padOp, newPadOp.getResult()); 3618 return success(); 3619 } 3620 }; 3621 3622 } // namespace 3623 3624 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, 3625 MLIRContext *context) { 3626 results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast, 3627 FoldOrthogonalPaddings, FoldStaticPadding, 3628 FoldConsecutiveConstantPadding>(context); 3629 } 3630 3631 /// Return the padding value of the PadOp if it constant. In this context, 3632 /// "constant" means an actual constant or "defined outside of the block". 3633 /// 3634 /// Values are considered constant in three cases: 3635 /// - A ConstantLike value. 3636 /// - A basic block argument from a different block. 3637 /// - A value defined outside of the block. 3638 /// 3639 /// If the padding value is not constant, an empty Value is returned. 3640 Value PadOp::getConstantPaddingValue() { 3641 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator()); 3642 if (!yieldOp) 3643 return {}; 3644 Value padValue = yieldOp.getValue(); 3645 // Check if yield value is a constant. 3646 if (matchPattern(padValue, m_Constant())) 3647 return padValue; 3648 // Check if yield value is defined inside the PadOp block. 3649 if (padValue.getParentBlock() == &getRegion().front()) 3650 return {}; 3651 // Else: Yield value defined outside of the PadOp block. 3652 return padValue; 3653 } 3654 3655 OpFoldResult PadOp::fold(FoldAdaptor) { 3656 if (getResultType().hasStaticShape() && getResultType() == getSourceType() && 3657 !getNofold()) 3658 return getSource(); 3659 return {}; 3660 } 3661 3662 //===----------------------------------------------------------------------===// 3663 // ParallelInsertSliceOp 3664 //===----------------------------------------------------------------------===// 3665 3666 OpResult ParallelInsertSliceOp::getTiedOpResult() { 3667 ParallelCombiningOpInterface parallelCombiningParent = 3668 getParallelCombiningParent(); 3669 for (const auto &it : 3670 llvm::enumerate(parallelCombiningParent.getYieldingOps())) { 3671 Operation &nextOp = it.value(); 3672 if (&nextOp == getOperation()) 3673 return parallelCombiningParent.getParentResult(it.index()); 3674 } 3675 llvm_unreachable("ParallelInsertSliceOp no tied OpResult found"); 3676 } 3677 3678 // Build a ParallelInsertSliceOp with mixed static and dynamic entries. 3679 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, 3680 Value source, Value dest, 3681 ArrayRef<OpFoldResult> offsets, 3682 ArrayRef<OpFoldResult> sizes, 3683 ArrayRef<OpFoldResult> strides, 3684 ArrayRef<NamedAttribute> attrs) { 3685 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; 3686 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; 3687 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 3688 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); 3689 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 3690 result.addAttributes(attrs); 3691 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes, 3692 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), 3693 b.getDenseI64ArrayAttr(staticSizes), 3694 b.getDenseI64ArrayAttr(staticStrides)); 3695 } 3696 3697 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries 3698 /// packed into a Range vector. 3699 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, 3700 Value source, Value dest, 3701 ArrayRef<Range> ranges, 3702 ArrayRef<NamedAttribute> attrs) { 3703 auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); 3704 build(b, result, source, dest, offsets, sizes, strides, attrs); 3705 } 3706 3707 // Build a ParallelInsertSliceOp with dynamic entries. 3708 void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, 3709 Value source, Value dest, ValueRange offsets, 3710 ValueRange sizes, ValueRange strides, 3711 ArrayRef<NamedAttribute> attrs) { 3712 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>( 3713 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; })); 3714 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>( 3715 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; })); 3716 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>( 3717 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; })); 3718 build(b, result, source, dest, offsetValues, sizeValues, strideValues); 3719 } 3720 3721 LogicalResult ParallelInsertSliceOp::verify() { 3722 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp())) 3723 return this->emitError("expected ParallelCombiningOpInterface parent, got:") 3724 << *(getOperation()->getParentOp()); 3725 3726 RankedTensorType expectedType; 3727 SliceVerificationResult result = 3728 verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(), 3729 getStaticSizes(), getStaticStrides(), &expectedType); 3730 return produceSliceErrorMsg(result, *this, expectedType); 3731 } 3732 3733 void ParallelInsertSliceOp::getCanonicalizationPatterns( 3734 RewritePatternSet &results, MLIRContext *context) { 3735 results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>, 3736 InsertSliceOpCastFolder<ParallelInsertSliceOp>, 3737 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context); 3738 } 3739 3740 llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() { 3741 return ::getDroppedDims(getSourceType().getShape(), getMixedSizes()); 3742 } 3743 3744 //===----------------------------------------------------------------------===// 3745 // ScatterOp 3746 //===----------------------------------------------------------------------===// 3747 3748 void ScatterOp::getAsmResultNames( 3749 function_ref<void(Value, StringRef)> setNameFn) { 3750 setNameFn(getResult(), "scatter"); 3751 } 3752 3753 LogicalResult ScatterOp::verify() { 3754 int64_t destRank = getDestType().getRank(); 3755 ArrayRef<int64_t> scatterDims = getScatterDims(); 3756 if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, 3757 getIndicesType().getShape(), destRank, 3758 "scatter", "dest"))) 3759 return failure(); 3760 3761 if (!getUnique()) 3762 return emitOpError("requires 'unique' attribute to be set"); 3763 // TODO: we could also check statically that there are fewer leading index 3764 // tensor dims than the dest dims. If this is not the case, the unique 3765 // attribute cannot be true. 3766 3767 // Use the GatherOp::inferResultType on the `dest` type and verify the 3768 // expected type matches the source type. 3769 RankedTensorType expectedSourceType = GatherOp::inferResultType( 3770 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false); 3771 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType( 3772 getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true); 3773 if (getSourceType() != expectedSourceType && 3774 getSourceType() != expectedRankReducedSourceType) { 3775 return emitOpError("source type " 3776 "mismatch: " 3777 "expected ") 3778 << expectedSourceType << " or its rank-reduced variant " 3779 << expectedRankReducedSourceType << " (got: " << getSourceType() 3780 << ")"; 3781 } 3782 3783 return success(); 3784 } 3785 3786 //===----------------------------------------------------------------------===// 3787 // SplatOp 3788 //===----------------------------------------------------------------------===// 3789 3790 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, 3791 Type aggregateType, ValueRange dynamicSizes) { 3792 build(builder, result, aggregateType, element, dynamicSizes); 3793 } 3794 3795 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, 3796 ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) { 3797 auto aggregateType = RankedTensorType::get(staticShape, element.getType()); 3798 build(builder, result, aggregateType, element, dynamicSizes); 3799 } 3800 3801 void SplatOp::build(OpBuilder &builder, OperationState &result, Value element, 3802 ArrayRef<OpFoldResult> sizes) { 3803 SmallVector<int64_t> staticShape; 3804 SmallVector<Value> dynamicSizes; 3805 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape); 3806 build(builder, result, element, staticShape, dynamicSizes); 3807 } 3808 3809 void SplatOp::getAsmResultNames( 3810 function_ref<void(Value, StringRef)> setNameFn) { 3811 setNameFn(getResult(), "splat"); 3812 } 3813 3814 LogicalResult SplatOp::verify() { 3815 if (getType().getNumDynamicDims() != getDynamicSizes().size()) 3816 return emitOpError("incorrect number of dynamic sizes, has ") 3817 << getDynamicSizes().size() << ", expected " 3818 << getType().getNumDynamicDims(); 3819 return success(); 3820 } 3821 3822 LogicalResult 3823 SplatOp::reifyResultShapes(OpBuilder &builder, 3824 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 3825 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank())); 3826 unsigned ctr = 0; 3827 for (int64_t i = 0; i < getType().getRank(); ++i) { 3828 if (getType().isDynamicDim(i)) { 3829 reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++]; 3830 } else { 3831 reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i)); 3832 } 3833 } 3834 return success(); 3835 } 3836 3837 OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { 3838 auto constOperand = adaptor.getInput(); 3839 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand)) 3840 return {}; 3841 3842 // Do not fold if the splat is not statically shaped 3843 if (!getType().hasStaticShape()) 3844 return {}; 3845 3846 // SplatElementsAttr::get treats single value for second arg as being a 3847 // splat. 3848 return SplatElementsAttr::get(getType(), {constOperand}); 3849 } 3850 3851 //===----------------------------------------------------------------------===// 3852 // PackOp/UnPackOp Common 3853 //===----------------------------------------------------------------------===// 3854 3855 template <typename OpTy> 3856 static LogicalResult 3857 reifyResultShapesImpl(OpTy op, OpBuilder &builder, 3858 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 3859 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 3860 "applies to only pack or unpack operations"); 3861 int64_t destRank = op.getDestRank(); 3862 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank)); 3863 reifiedReturnShapes[0] = 3864 tensor::getMixedSizes(builder, op.getLoc(), op.getDest()); 3865 return success(); 3866 } 3867 3868 template <typename OpTy> 3869 static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) { 3870 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 3871 "applies to only pack or unpack operations"); 3872 DenseMap<int64_t, OpFoldResult> dimAndTileMapping; 3873 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos(); 3874 SmallVector<OpFoldResult> tiles = op.getMixedTiles(); 3875 assert(tiles.size() == dimsToTile.size() && 3876 "tiles must match indices of dimension to block"); 3877 // bind the dimension `i` with the tile factor. 3878 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size())) 3879 dimAndTileMapping[dimsToTile[i]] = tiles[i]; 3880 return dimAndTileMapping; 3881 } 3882 3883 template <typename OpTy> 3884 static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) { 3885 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 3886 "applies to only pack or unpack operations"); 3887 Builder builder(op); 3888 SmallVector<OpFoldResult> mixedInnerTiles; 3889 unsigned dynamicValIndex = 0; 3890 for (int64_t staticTile : op.getStaticInnerTiles()) { 3891 if (!ShapedType::isDynamic(staticTile)) 3892 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile)); 3893 else 3894 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); 3895 } 3896 return mixedInnerTiles; 3897 } 3898 3899 template <typename OpTy> 3900 static SmallVector<int64_t> getStaticTilesImpl(OpTy op) { 3901 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 3902 "applies to only pack or unpack operations"); 3903 SmallVector<Value> dynamicTiles; 3904 SmallVector<int64_t> staticTiles; 3905 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles); 3906 return staticTiles; 3907 } 3908 3909 /// Returns true if `dimsPos` is invalid. It is invalid when: 3910 /// a) It contains duplicate. 3911 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank). 3912 /// c) The number of elements in `dimsPos` is > than `rank`. 3913 static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos, 3914 size_t rank) { 3915 size_t dimsPosSize = dimsPos.size(); 3916 if (dimsPosSize > rank) 3917 return true; 3918 DenseSet<int64_t> uniqued; 3919 for (int64_t dim : dimsPos) 3920 uniqued.insert(dim); 3921 if (dimsPosSize != uniqued.size()) 3922 return true; 3923 return llvm::any_of(dimsPos, [rank](int64_t dimPos) { 3924 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank); 3925 }); 3926 } 3927 3928 /// Returns true if the dimension of `sourceShape` is smaller than the dimension 3929 /// of the `limitShape`. 3930 static bool areAllInBound(ArrayRef<int64_t> sourceShape, 3931 ArrayRef<int64_t> limitShape) { 3932 assert( 3933 sourceShape.size() == limitShape.size() && 3934 "expected source shape rank, and limit of the shape to have same rank"); 3935 return llvm::all_of( 3936 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) { 3937 int64_t sourceExtent = std::get<0>(it); 3938 int64_t limit = std::get<1>(it); 3939 return ShapedType::isDynamic(sourceExtent) || 3940 ShapedType::isDynamic(limit) || sourceExtent <= limit; 3941 }); 3942 } 3943 3944 template <typename OpTy> 3945 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { 3946 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 3947 "applies to only pack or unpack operations"); 3948 Operation *op = packOrUnPack.getOperation(); 3949 3950 // Return true if we have a zero-value tile. 3951 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) { 3952 return llvm::any_of( 3953 tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); 3954 }; 3955 3956 // Verify tiles. Do not allow zero tiles. 3957 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles(); 3958 if (hasZeros(mixedTiles)) 3959 return op->emitError("invalid zero tile factor"); 3960 3961 // Verify inner_dims_pos and outer_dims_perm. 3962 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value) 3963 ? packOrUnPack.getSourceType() 3964 : packOrUnPack.getDestType(); 3965 size_t unpackedRank = unpackedType.getRank(); 3966 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos(); 3967 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm(); 3968 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank)) 3969 return op->emitError("invalid inner_dims_pos vector"); 3970 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank)) 3971 return op->emitError("invalid outer_dims_perm vector"); 3972 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank) 3973 return op->emitError("outer_dims_perm must be a permutation or empty"); 3974 3975 // Tiling factors must be less than or equal to the input rank for pack (or 3976 // output rank for unpack), and must match the number of `inner_dims_pos`. 3977 if (mixedTiles.size() > unpackedRank) { 3978 return op->emitError("tiling factors must be less than or equal to the " 3979 "input rank for pack or output rank for unpack"); 3980 } 3981 if (mixedTiles.size() != innerDimsPos.size()) { 3982 return op->emitError( 3983 "tiling factors must equal the number of dimensions to tile"); 3984 } 3985 3986 ShapedType packedType = (std::is_same<OpTy, PackOp>::value) 3987 ? packOrUnPack.getDestType() 3988 : packOrUnPack.getSourceType(); 3989 size_t packedRank = packedType.getRank(); 3990 // Require output rank to match input rank + number of blocking factors. 3991 size_t expectedPackedRank = unpackedRank + mixedTiles.size(); 3992 if (expectedPackedRank != packedRank) { 3993 return op->emitError( 3994 "packed rank != (unpacked rank + num tiling factors), got ") 3995 << packedRank << " != " << expectedPackedRank; 3996 } 3997 3998 // Verify result shape is greater than the minimum expected 3999 // by the pack operation, and that the output shape 4000 // represents full tiles. 4001 RankedTensorType expectedPackedType = PackOp::inferPackedType( 4002 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); 4003 if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { 4004 return op->emitError("the shape of output is not large enough to hold the " 4005 "packed data. Expected at least ") 4006 << expectedPackedType << ", got " << packedType; 4007 } 4008 if (!llvm::all_of( 4009 llvm::zip(packedType.getShape().take_back(mixedTiles.size()), 4010 mixedTiles), 4011 [](std::tuple<int64_t, OpFoldResult> it) { 4012 int64_t shape = std::get<0>(it); 4013 if (Attribute attr = 4014 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) { 4015 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr); 4016 int64_t staticTileSize = intAttr.getValue().getSExtValue(); 4017 return shape == staticTileSize; 4018 } 4019 return ShapedType::isDynamic(shape); 4020 })) { 4021 return op->emitError("mismatch in inner tile sizes specified and shaped of " 4022 "tiled dimension in the packed type"); 4023 } 4024 return success(); 4025 } 4026 4027 namespace { 4028 /// Subset of PackOp/UnPackOp fields used to compute the result of applying 4029 /// various permutations to the op. 4030 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse 4031 // these. These may or may not become true foldings / canonicalizations 4032 // depending on how aggressive we want to be in automatically folding 4033 // transposes. 4034 struct PackOrUnPackTransposeResult { 4035 SmallVector<int64_t> innerDimsPos; 4036 SmallVector<OpFoldResult> innerTiles; 4037 SmallVector<int64_t> outerDimsPerm; 4038 }; 4039 } // namespace 4040 4041 template <typename OpTy> 4042 static PackOrUnPackTransposeResult 4043 commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, 4044 ArrayRef<int64_t> innerPermutation, 4045 ArrayRef<int64_t> outerPermutation) { 4046 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 4047 "applies to only pack or unpack operations"); 4048 assert((!innerPermutation.empty() || !outerPermutation.empty()) && 4049 "some permutation must be non-empty"); 4050 PackOrUnPackTransposeResult metadata; 4051 metadata.innerDimsPos = 4052 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos()); 4053 metadata.innerTiles = 4054 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles()); 4055 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value 4056 ? packOrUnPackOp.getSourceRank() 4057 : packOrUnPackOp.getDestRank(); 4058 metadata.outerDimsPerm = 4059 packOrUnPackOp.getOuterDimsPerm().empty() 4060 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims)) 4061 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm()); 4062 if (!innerPermutation.empty()) { 4063 assert(innerPermutation.size() == metadata.innerDimsPos.size() && 4064 isPermutationVector(innerPermutation) && 4065 "invalid inner permutation"); 4066 applyPermutationToVector(metadata.innerDimsPos, innerPermutation); 4067 applyPermutationToVector(metadata.innerTiles, innerPermutation); 4068 } 4069 if (!outerPermutation.empty()) { 4070 assert(outerPermutation.size() == metadata.outerDimsPerm.size() && 4071 isPermutationVector(outerPermutation) && 4072 "invalid outer permutation"); 4073 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation); 4074 } 4075 return metadata; 4076 } 4077 4078 //===----------------------------------------------------------------------===// 4079 // PackOp 4080 //===----------------------------------------------------------------------===// 4081 4082 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { 4083 setNameFn(getResult(), "pack"); 4084 } 4085 4086 void PackOp::build(OpBuilder &builder, OperationState &state, Value source, 4087 Value dest, ArrayRef<int64_t> innerDimsPos, 4088 ArrayRef<OpFoldResult> innerTiles, 4089 std::optional<Value> paddingValue, 4090 ArrayRef<int64_t> outerDimsPerm) { 4091 assert(innerDimsPos.size() == innerTiles.size() && 4092 "number of tile sizes specified must match the specified number of " 4093 "original dimensions to be tiled"); 4094 SmallVector<int64_t> staticTileSizes; 4095 SmallVector<Value> dynamicTileSizes; 4096 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes); 4097 build(builder, state, dest.getType(), source, dest, 4098 paddingValue ? *paddingValue : nullptr, 4099 outerDimsPerm.empty() ? nullptr 4100 : builder.getDenseI64ArrayAttr(outerDimsPerm), 4101 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes, 4102 builder.getDenseI64ArrayAttr(staticTileSizes)); 4103 } 4104 4105 LogicalResult 4106 PackOp::reifyResultShapes(OpBuilder &builder, 4107 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 4108 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); 4109 } 4110 4111 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() { 4112 return getDimAndTileMappingImpl(*this); 4113 } 4114 4115 SmallVector<OpFoldResult> PackOp::getMixedTiles() { 4116 return getMixedTilesImpl(*this); 4117 } 4118 4119 SmallVector<int64_t> PackOp::getStaticTiles() { 4120 return getStaticTilesImpl(*this); 4121 } 4122 4123 ArrayRef<int64_t> PackOp::getAllOuterDims() { 4124 ShapedType inputType = getSourceType(); 4125 int64_t inputRank = inputType.getRank(); 4126 return getDestType().getShape().take_front(inputRank); 4127 } 4128 4129 SmallVector<int64_t> PackOp::getTiledOuterDims() { 4130 auto innerDimsPos = getInnerDimsPos(); 4131 auto packedShape = getDestType().getShape(); 4132 SmallVector<int64_t> res; 4133 4134 for (auto index : innerDimsPos) 4135 res.push_back(packedShape[index]); 4136 4137 return res; 4138 } 4139 4140 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape, 4141 ArrayRef<int64_t> innerDimsPos, 4142 ArrayRef<int64_t> outputShape, 4143 ArrayRef<int64_t> outerDimsPerm, 4144 ArrayRef<OpFoldResult> innerTiles) { 4145 SmallVector<int64_t> outputTileSizes( 4146 outputShape.take_front(inputShape.size())); 4147 if (!outerDimsPerm.empty()) { 4148 assert(outerDimsPerm.size() == outputTileSizes.size() && 4149 "expected output and outer_dims_perm to have same size"); 4150 applyPermutationToVector(outputTileSizes, 4151 invertPermutationVector(outerDimsPerm)); 4152 } 4153 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { 4154 if (ShapedType::isDynamic(inputShape[pos])) 4155 continue; 4156 std::optional<int64_t> constantTile = getConstantIntValue(tileSize); 4157 4158 if (!constantTile) { 4159 if (!ShapedType::isDynamic(outputTileSizes[pos]) && 4160 (inputShape[pos] % outputTileSizes[pos] != 0)) 4161 return true; 4162 } else if (inputShape[pos] % (*constantTile) != 0) { 4163 return true; 4164 } 4165 } 4166 return false; 4167 } 4168 4169 LogicalResult PackOp::verify() { 4170 if (failed(commonVerifierPackAndUnPackOp(*this))) 4171 return failure(); 4172 4173 // Verify padding value, and bail out if the tile does not divide the 4174 // dimension fully. In the case of dynamic tile factors or dimensions, having 4175 // a partial tile is undefined behavior. 4176 auto paddingValue = getPaddingValue(); 4177 if (paddingValue && 4178 paddingValue.getType() != getSourceType().getElementType()) { 4179 return emitOpError("expected padding_value has ") 4180 << getSourceType().getElementType() 4181 << " but got: " << paddingValue.getType(); 4182 } 4183 4184 if (!paddingValue && 4185 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(), 4186 getDestType().getShape(), getOuterDimsPerm(), 4187 getMixedTiles())) { 4188 return emitOpError( 4189 "invalid tile factor or output size provided. Only full tiles are " 4190 "supported when padding_value is not set"); 4191 } 4192 return success(); 4193 } 4194 4195 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all 4196 /// Value's to kDynamic, even if they are arith.constant values. 4197 static SmallVector<int64_t> 4198 asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) { 4199 SmallVector<int64_t> result; 4200 for (auto o : ofrs) { 4201 // Have to do this first, as getConstantIntValue special-cases constants. 4202 if (llvm::dyn_cast_if_present<Value>(o)) 4203 result.push_back(ShapedType::kDynamic); 4204 else 4205 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic)); 4206 } 4207 return result; 4208 } 4209 4210 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of 4211 /// the packed type. Having a shared helper helps implement these two methods in 4212 /// a way that ensures that they agree on which dimensions are dynamic. 4213 static SmallVector<int64_t> getPackOpResultTypeShape( 4214 ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes, 4215 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) { 4216 SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape); 4217 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) { 4218 if (ShapedType::isDynamic(resultShape[tiledDim.value()])) 4219 continue; 4220 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) { 4221 resultShape[tiledDim.value()] = ShapedType::kDynamic; 4222 continue; 4223 } 4224 resultShape[tiledDim.value()] = divideCeilSigned( 4225 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]); 4226 } 4227 4228 // Swap tile loops if outer_dims_perm is available. 4229 if (!outerDimsPerm.empty()) 4230 applyPermutationToVector(resultShape, outerDimsPerm); 4231 4232 // Append the inner tile dimensions. 4233 resultShape.append(innerTileSizes.begin(), innerTileSizes.end()); 4234 return resultShape; 4235 } 4236 4237 SmallVector<OpFoldResult> PackOp::getResultShape( 4238 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims, 4239 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos, 4240 ArrayRef<int64_t> outerDimsPerm) { 4241 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims); 4242 4243 AffineExpr s0, s1; 4244 bindSymbols(builder.getContext(), s0, s1); 4245 AffineExpr ceilDivExpr = s0.ceilDiv(s1); 4246 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) { 4247 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply( 4248 builder, loc, ceilDivExpr, 4249 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]}); 4250 } 4251 if (!outerDimsPerm.empty()) 4252 applyPermutationToVector(resultDims, outerDimsPerm); 4253 resultDims.append(innerTileSizes.begin(), innerTileSizes.end()); 4254 4255 SmallVector<int64_t> resultTypeShape = 4256 getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims), 4257 asShapeWithAnyValueAsDynamic(innerTileSizes), 4258 innerDimsPos, outerDimsPerm); 4259 4260 // Fix-up `resultDims` to ensure that they are Value's if and only if the 4261 // result type shape says it's a dynamic dim. This is needed as callers may 4262 // use dispatchIndexOpFoldResults on the result, and rely on exact number of 4263 // dynamic dims returned by that. 4264 for (unsigned i = 0; i < resultDims.size(); ++i) { 4265 if (!ShapedType::isDynamic(resultTypeShape[i])) 4266 continue; 4267 resultDims[i] = 4268 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]); 4269 } 4270 4271 return resultDims; 4272 } 4273 4274 /// Get the expected packed type based on source type, tile factors, position of 4275 /// the inner tiles and permutation of the outer tiled loop. 4276 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType, 4277 ArrayRef<int64_t> innerTileSizes, 4278 ArrayRef<int64_t> innerDimsPos, 4279 ArrayRef<int64_t> outerDimsPerm) { 4280 SmallVector<int64_t> resultShape = getPackOpResultTypeShape( 4281 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); 4282 return RankedTensorType::get(resultShape, sourceType.getElementType()); 4283 } 4284 4285 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, 4286 ArrayRef<OpFoldResult> innerTileSizes, 4287 ArrayRef<int64_t> innerDimsPos, 4288 ArrayRef<int64_t> outerDimsPerm) { 4289 AffineExpr dim0, dim1; 4290 bindDims(b.getContext(), dim0, dim1); 4291 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { 4292 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1), 4293 {v1, v2}); 4294 }; 4295 4296 SmallVector<OpFoldResult> mixedSizes; 4297 for (auto [index, value] : llvm::enumerate( 4298 llvm::cast<RankedTensorType>(source.getType()).getShape())) { 4299 if (ShapedType::isDynamic(value)) 4300 mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult()); 4301 else 4302 mixedSizes.push_back(b.getIndexAttr(value)); 4303 } 4304 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) { 4305 int64_t dimPos = std::get<0>(it); 4306 OpFoldResult tileSize = std::get<1>(it); 4307 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize); 4308 } 4309 if (!outerDimsPerm.empty()) 4310 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm); 4311 4312 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end()); 4313 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType(); 4314 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType); 4315 } 4316 4317 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, 4318 ArrayRef<int64_t> innerPermutation, 4319 ArrayRef<int64_t> outerPermutation) { 4320 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( 4321 *this, innerPermutation, outerPermutation); 4322 Value transposedDest = 4323 createDestinationTensor(b, loc, getSource(), metadata.innerTiles, 4324 metadata.innerDimsPos, metadata.outerDimsPerm); 4325 return b.create<PackOp>(loc, getSource(), transposedDest, 4326 metadata.innerDimsPos, metadata.innerTiles, 4327 getPaddingValue(), metadata.outerDimsPerm); 4328 } 4329 4330 /// Returns true if the tiles and the tiled dims are constant. 4331 template <typename OpTy> 4332 bool areTilesAndTiledDimsAllConstant(OpTy op) { 4333 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 4334 "applies to only pack or unpack operations"); 4335 ShapedType packedType = (std::is_same<OpTy, PackOp>::value) 4336 ? op.getDestType() 4337 : op.getSourceType(); 4338 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles(); 4339 for (auto [dimDest, tile] : llvm::zip( 4340 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) { 4341 std::optional<int64_t> constTileSize = getConstantIntValue(tile); 4342 if (!constTileSize || ShapedType::isDynamic(dimDest)) 4343 return false; 4344 } 4345 return true; 4346 } 4347 4348 Speculation::Speculatability PackOp::getSpeculatability() { 4349 if (getPaddingValue()) 4350 return Speculation::Speculatable; 4351 4352 // The verifier rejects already operations if we can statically prove that the 4353 // sizes of the tiles do not divide perfectly the dimension; thus, check only 4354 // to have constant tiles and tiled inner dimensions. 4355 if (!areTilesAndTiledDimsAllConstant(*this)) 4356 return Speculation::NotSpeculatable; 4357 4358 return Speculation::Speculatable; 4359 } 4360 4361 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same 4362 // dimensions for pack and unpack. 4363 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) { 4364 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos()) 4365 return false; 4366 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm()) 4367 return true; 4368 // Outer dims permutation is optional. 4369 // To compare unbalanced pack-unpack pair, treat no permutation as equal to 4370 // identity permutation. 4371 return isIdentityPermutation(packOp.getOuterDimsPerm()) && 4372 isIdentityPermutation(unPackOp.getOuterDimsPerm()); 4373 } 4374 4375 // Return true if pack and unpack have the same tiles. 4376 // Same SSA values or same integer constants. 4377 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) { 4378 auto packTiles = packOp.getMixedTiles(); 4379 auto unPackTiles = unPackOp.getMixedTiles(); 4380 if (packTiles.size() != unPackTiles.size()) 4381 return false; 4382 for (size_t i = 0, e = packTiles.size(); i < e; i++) { 4383 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i])) 4384 return false; 4385 } 4386 return true; 4387 } 4388 4389 /// Returns true if the pack op does not need a padding value. 4390 static bool paddingIsNotNeeded(PackOp op) { 4391 auto srcType = op.getSourceType(); 4392 if (llvm::any_of(op.getInnerDimsPos(), 4393 [&](int64_t pos) { return srcType.isDynamicDim(pos); })) 4394 return false; 4395 if (ShapedType::isDynamicShape(op.getStaticInnerTiles())) 4396 return false; 4397 return !PackOp::requirePaddingValue( 4398 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(), 4399 op.getOuterDimsPerm(), op.getMixedTiles()); 4400 } 4401 4402 /// Returns true if the `srcShape` or `destShape` is different from the one in 4403 /// `packOp` and populates each with the inferred static shape. 4404 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape, 4405 SmallVectorImpl<int64_t> &destShape) { 4406 bool changeNeeded = false; 4407 srcShape.assign(packOp.getSourceType().getShape().begin(), 4408 packOp.getSourceType().getShape().end()); 4409 destShape.assign(packOp.getDestType().getShape().begin(), 4410 packOp.getDestType().getShape().end()); 4411 llvm::SmallSetVector<int64_t, 4> innerDims; 4412 innerDims.insert(packOp.getInnerDimsPos().begin(), 4413 packOp.getInnerDimsPos().end()); 4414 SmallVector<int64_t> inverseOuterDimsPerm; 4415 if (!packOp.getOuterDimsPerm().empty()) 4416 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm()); 4417 int srcRank = packOp.getSourceRank(); 4418 for (auto i : llvm::seq<int64_t>(0, srcRank)) { 4419 if (innerDims.contains(i)) 4420 continue; 4421 int64_t srcPos = i; 4422 int64_t destPos = i; 4423 if (!inverseOuterDimsPerm.empty()) 4424 destPos = inverseOuterDimsPerm[srcPos]; 4425 if (ShapedType::isDynamic(srcShape[srcPos]) == 4426 ShapedType::isDynamic(destShape[destPos])) { 4427 continue; 4428 } 4429 int64_t size = srcShape[srcPos]; 4430 if (ShapedType::isDynamic(size)) 4431 size = destShape[destPos]; 4432 srcShape[srcPos] = size; 4433 destShape[destPos] = size; 4434 changeNeeded = true; 4435 } 4436 return changeNeeded; 4437 } 4438 4439 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { 4440 // Fold an pack(unpack(x)) to x. 4441 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) { 4442 if (unPackOp.getSourceType() != packOp.getDestType()) 4443 return failure(); 4444 if (packOp.getPaddingValue() || 4445 !hasSameInnerOuterAttribute(packOp, unPackOp) || 4446 !haveSameTiles(packOp, unPackOp)) 4447 return failure(); 4448 rewriter.replaceOp(packOp, unPackOp.getSource()); 4449 return success(); 4450 } 4451 4452 // Fold optional PaddingValue operand away if padding is not needed. 4453 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) { 4454 rewriter.startOpModification(packOp); 4455 packOp.getPaddingValueMutable().clear(); 4456 rewriter.finalizeOpModification(packOp); 4457 return success(); 4458 } 4459 4460 // Insert tensor.cast ops if static shape inference is available.. 4461 SmallVector<int64_t> srcShape, destShape; 4462 if (inferStaticShape(packOp, srcShape, destShape)) { 4463 Location loc = packOp.getLoc(); 4464 Value source = packOp.getSource(); 4465 if (srcShape != packOp.getSourceType().getShape()) { 4466 auto newSrcType = packOp.getSourceType().clone(srcShape); 4467 source = 4468 rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource()); 4469 } 4470 Value dest = packOp.getDest(); 4471 RankedTensorType originalResultType = packOp.getDestType(); 4472 bool needUpdateDestType = (destShape != originalResultType.getShape()); 4473 if (needUpdateDestType) { 4474 auto newDestType = packOp.getDestType().clone(destShape); 4475 dest = 4476 rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest()); 4477 } 4478 rewriter.modifyOpInPlace(packOp, [&] { 4479 packOp.getSourceMutable().assign(source); 4480 packOp.getDestMutable().assign(dest); 4481 packOp.getResult().setType(cast<RankedTensorType>(dest.getType())); 4482 }); 4483 // Insert a cast if needed 4484 if (needUpdateDestType) { 4485 rewriter.setInsertionPointAfter(packOp); 4486 auto castOp = 4487 rewriter.create<tensor::CastOp>(loc, originalResultType, packOp); 4488 rewriter.replaceAllUsesExcept(packOp, castOp, castOp); 4489 } 4490 return success(); 4491 } 4492 4493 return failure(); 4494 } 4495 4496 template <typename PackOrUnpackOp> 4497 static bool isLikePadUnPad(PackOrUnpackOp packOp, 4498 RankedTensorType packedTensorType) { 4499 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value || 4500 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value, 4501 "Function meant for pack/unpack"); 4502 // This is a pad if packing only adds ones and we don't transpose dimensions. 4503 4504 // Check that we are not transposing any dimensions. 4505 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 4506 int64_t numPackedDims = innerDimsPos.size(); 4507 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims)); 4508 if (orderedDims != innerDimsPos) { 4509 // Dimensions don't happen in order. 4510 return false; 4511 } 4512 4513 ArrayRef<int64_t> packedShape = packedTensorType.getShape(); 4514 int64_t packedRank = packedTensorType.getRank(); 4515 // At this point we know that we are taking numPackedDims outer 4516 // dimensions and pushing them all the way as the inner most dimensions. 4517 // What's left on the outer most dimensions is, in this order: 4518 // - the factor of the packed dimensions, then 4519 // - the untouched dimensions 4520 // This shifting inward of dimensions is a no-op (as opposed to a transpose) 4521 // if all the dimensions that bubble outerward are ones. 4522 // Therefore check that all the dimensions but the numPackedDims inner most 4523 // ones are ones. 4524 return llvm::all_of( 4525 llvm::seq<int64_t>(0, packedRank - numPackedDims), 4526 [&packedShape](int64_t i) { return packedShape[i] == 1; }); 4527 } 4528 4529 bool PackOp::isLikePad() { 4530 auto packedTensorType = 4531 llvm::cast<RankedTensorType>((*this)->getResultTypes().front()); 4532 return isLikePadUnPad(*this, packedTensorType); 4533 } 4534 4535 OpFoldResult PackOp::fold(FoldAdaptor adaptor) { 4536 std::optional<Attribute> paddingValue; 4537 if (auto pad = adaptor.getPaddingValue()) 4538 paddingValue = pad; 4539 if (OpFoldResult reshapedSource = reshapeConstantSource( 4540 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()), 4541 getDestType(), paddingValue)) 4542 return reshapedSource; 4543 return {}; 4544 } 4545 4546 //===----------------------------------------------------------------------===// 4547 // UnPackOp 4548 //===----------------------------------------------------------------------===// 4549 4550 void UnPackOp::getAsmResultNames( 4551 function_ref<void(Value, StringRef)> setNameFn) { 4552 setNameFn(getResult(), "unpack"); 4553 } 4554 4555 LogicalResult 4556 UnPackOp::reifyResultShapes(OpBuilder &builder, 4557 ReifiedRankedShapedTypeDims &reifiedReturnShapes) { 4558 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes); 4559 } 4560 4561 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() { 4562 return getDimAndTileMappingImpl(*this); 4563 } 4564 4565 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() { 4566 return getMixedTilesImpl(*this); 4567 } 4568 4569 SmallVector<int64_t> UnPackOp::getStaticTiles() { 4570 return getStaticTilesImpl(*this); 4571 } 4572 4573 ArrayRef<int64_t> UnPackOp::getAllOuterDims() { 4574 ShapedType destType = getDestType(); 4575 int64_t destRank = destType.getRank(); 4576 return getSourceType().getShape().take_front(destRank); 4577 } 4578 4579 SmallVector<int64_t> UnPackOp::getTiledOuterDims() { 4580 auto innerDimsPos = getInnerDimsPos(); 4581 auto packedShape = getSourceType().getShape(); 4582 SmallVector<int64_t> res; 4583 4584 for (auto index : innerDimsPos) 4585 res.push_back(packedShape[index]); 4586 4587 return res; 4588 } 4589 4590 LogicalResult UnPackOp::verify() { 4591 return commonVerifierPackAndUnPackOp(*this); 4592 } 4593 4594 Speculation::Speculatability UnPackOp::getSpeculatability() { 4595 // See PackOp::getSpeculatability. 4596 if (!areTilesAndTiledDimsAllConstant(*this)) 4597 return Speculation::NotSpeculatable; 4598 4599 return Speculation::Speculatable; 4600 } 4601 4602 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source, 4603 Value dest, ArrayRef<int64_t> innerDimsPos, 4604 ArrayRef<OpFoldResult> innerTiles, 4605 ArrayRef<int64_t> outerDimsPerm) { 4606 assert(innerDimsPos.size() == innerTiles.size() && 4607 "number of tile sizes specified must match the specified number of " 4608 "original dimensions to be tiled"); 4609 SmallVector<int64_t> staticTileSizes; 4610 SmallVector<Value> dynamicTileSizes; 4611 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes); 4612 build(builder, state, dest.getType(), source, dest, 4613 outerDimsPerm.empty() ? nullptr 4614 : builder.getDenseI64ArrayAttr(outerDimsPerm), 4615 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes, 4616 builder.getDenseI64ArrayAttr(staticTileSizes)); 4617 } 4618 4619 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc, 4620 Value source, 4621 ArrayRef<OpFoldResult> innerTileSizes, 4622 ArrayRef<int64_t> innerDimsPos, 4623 ArrayRef<int64_t> outerDimsPerm) { 4624 AffineExpr sym0, sym1; 4625 bindSymbols(b.getContext(), sym0, sym1); 4626 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { 4627 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2}); 4628 }; 4629 4630 SmallVector<OpFoldResult> mixedSizes; 4631 auto srcType = llvm::cast<RankedTensorType>(source.getType()); 4632 for (auto i : 4633 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) { 4634 if (srcType.isDynamicDim(i)) 4635 mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult()); 4636 else 4637 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i))); 4638 } 4639 if (!outerDimsPerm.empty()) { 4640 applyPermutationToVector<OpFoldResult>( 4641 mixedSizes, invertPermutationVector(outerDimsPerm)); 4642 } 4643 4644 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes)) 4645 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize); 4646 4647 auto elemType = srcType.getElementType(); 4648 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType); 4649 } 4650 4651 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, 4652 Value transposedSource, 4653 ArrayRef<int64_t> innerPermutation, 4654 ArrayRef<int64_t> outerPermutation) { 4655 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( 4656 *this, innerPermutation, outerPermutation); 4657 return b.create<UnPackOp>(loc, transposedSource, getDest(), 4658 metadata.innerDimsPos, metadata.innerTiles, 4659 metadata.outerDimsPerm); 4660 } 4661 4662 /// Returns true if the `srcShape` or `destShape` is different from the one in 4663 /// `op` and populates each with the inferred static shape. 4664 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape, 4665 SmallVectorImpl<int64_t> &destShape) { 4666 bool changeNeeded = false; 4667 srcShape.assign(op.getSourceType().getShape().begin(), 4668 op.getSourceType().getShape().end()); 4669 destShape.assign(op.getDestType().getShape().begin(), 4670 op.getDestType().getShape().end()); 4671 llvm::SmallSetVector<int64_t, 4> innerDims; 4672 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end()); 4673 SmallVector<int64_t> inverseOuterDimsPerm; 4674 if (!op.getOuterDimsPerm().empty()) 4675 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm()); 4676 int destRank = op.getDestRank(); 4677 for (auto i : llvm::seq<int64_t>(0, destRank)) { 4678 if (innerDims.contains(i)) 4679 continue; 4680 int64_t srcPos = i; 4681 int64_t destPos = i; 4682 if (!inverseOuterDimsPerm.empty()) 4683 srcPos = inverseOuterDimsPerm[destPos]; 4684 if (ShapedType::isDynamic(srcShape[srcPos]) == 4685 ShapedType::isDynamic(destShape[destPos])) { 4686 continue; 4687 } 4688 int64_t size = srcShape[srcPos]; 4689 if (ShapedType::isDynamic(size)) 4690 size = destShape[destPos]; 4691 srcShape[srcPos] = size; 4692 destShape[destPos] = size; 4693 changeNeeded = true; 4694 } 4695 return changeNeeded; 4696 } 4697 4698 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, 4699 PatternRewriter &rewriter) { 4700 /// unpack(pack(x)) -> x 4701 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) { 4702 if (packOp.getSourceType() != unPackOp.getDestType()) 4703 return failure(); 4704 if (packOp.getPaddingValue() || 4705 !hasSameInnerOuterAttribute(packOp, unPackOp) || 4706 !haveSameTiles(packOp, unPackOp)) 4707 return failure(); 4708 rewriter.replaceOp(unPackOp, packOp.getSource()); 4709 return success(); 4710 } 4711 /// unpack(destinationStyleOp(x)) -> unpack(x) 4712 if (auto dstStyleOp = 4713 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) { 4714 auto destValue = cast<OpResult>(unPackOp.getDest()); 4715 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()]; 4716 rewriter.modifyOpInPlace(unPackOp, 4717 [&]() { unPackOp.setDpsInitOperand(0, newDest); }); 4718 return success(); 4719 } 4720 4721 // Insert tensor.cast ops if static shape inference is available.. 4722 SmallVector<int64_t> srcShape, destShape; 4723 if (inferStaticShape(unPackOp, srcShape, destShape)) { 4724 Location loc = unPackOp.getLoc(); 4725 Value source = unPackOp.getSource(); 4726 if (srcShape != unPackOp.getSourceType().getShape()) { 4727 auto newSrcType = unPackOp.getSourceType().clone(srcShape); 4728 source = rewriter.create<tensor::CastOp>(loc, newSrcType, 4729 unPackOp.getSource()); 4730 } 4731 Value dest = unPackOp.getDest(); 4732 if (destShape != unPackOp.getDestType().getShape()) { 4733 auto newDestType = unPackOp.getDestType().clone(destShape); 4734 dest = 4735 rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest()); 4736 } 4737 Value newOp = rewriter.create<tensor::UnPackOp>( 4738 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(), 4739 unPackOp.getOuterDimsPerm()); 4740 rewriter.replaceOpWithNewOp<tensor::CastOp>( 4741 unPackOp, unPackOp.getResult().getType(), newOp); 4742 return success(); 4743 } 4744 4745 return failure(); 4746 } 4747 4748 bool UnPackOp::isLikeUnPad() { 4749 RankedTensorType packedTensorType = getSourceType(); 4750 return isLikePadUnPad(*this, packedTensorType); 4751 } 4752 4753 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { 4754 if (OpFoldResult reshapedSource = reshapeConstantSource( 4755 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()), 4756 getResult().getType())) 4757 return reshapedSource; 4758 return {}; 4759 } 4760 4761 //===----------------------------------------------------------------------===// 4762 // Common Canonicalizers and Folders. 4763 //===----------------------------------------------------------------------===// 4764 bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { 4765 // 1. InsertSliceOp has its own logic about folding tensor.cast ops. 4766 // 2. Exclude DPS ops that are also LoopLike from this interface as they 4767 // might need special handling of attached regions. 4768 if (isa<InsertSliceOp>(op.getOperation()) || 4769 isa<LoopLikeOpInterface>(op.getOperation())) 4770 return false; 4771 4772 // If no operand comes from a tensor::CastOp and can be folded then fail. 4773 bool hasTensorCastOperand = 4774 llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { 4775 if (llvm::isa<BlockArgument>(opOperand.get())) 4776 return false; 4777 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>(); 4778 return castOp && canFoldIntoConsumerOp(castOp); 4779 }); 4780 4781 return hasTensorCastOperand; 4782 } 4783 4784 static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op, 4785 SmallVector<Type> &newResTy) { 4786 SmallVector<Value> newOperands; 4787 newOperands.reserve(op->getNumOperands()); 4788 4789 // Assumes that the result has dpsInits followed by nonDpsInits. 4790 int64_t dpsInitIdx = 0; 4791 for (OpOperand &opOperand : op->getOpOperands()) { 4792 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>(); 4793 bool fold = canFoldIntoConsumerOp(tensorCastOp); 4794 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); 4795 if (op.isDpsInit(&opOperand) && 4796 !llvm::isa<MemRefType>(newOperands.back().getType())) 4797 newResTy[dpsInitIdx++] = newOperands.back().getType(); 4798 } 4799 return newOperands; 4800 } 4801 4802 // Given the (potentially) updated packed type, `newPackedTy`, generates an 4803 // updated mixed-tile-sizes attribute. A tile size is updated only 4804 // when: 4805 // * a dim from newPackedTy is static, and 4806 // * the corresponding size from mixedTiles is still dynamic. 4807 // Otherwise, the original tile size is preserved. 4808 // Note - packed-type-dim and mixed-tile-size should always match! 4809 static SmallVector<OpFoldResult> 4810 getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, 4811 SmallVector<OpFoldResult> mixedTiles) { 4812 SmallVector<OpFoldResult> newMixedTileSizes; 4813 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy) 4814 .getShape() 4815 .take_back(mixedTiles.size()), 4816 mixedTiles)) { 4817 int64_t shape = std::get<0>(it); 4818 if (shape == ShapedType::kDynamic) { 4819 newMixedTileSizes.push_back(std::get<1>(it)); 4820 continue; 4821 } 4822 4823 // If the current result dim is static, update the dynamic mixed-size 4824 // (provided the original value is dynamic). 4825 OpFoldResult tile = std::get<1>(it); 4826 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) { 4827 // Already a constant 4828 newMixedTileSizes.push_back(tile); 4829 } else { 4830 assert(getConstantIntValue(tile).value() == shape && 4831 "tile size and dim size don't match!"); 4832 newMixedTileSizes.push_back( 4833 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); 4834 } 4835 } 4836 4837 return newMixedTileSizes; 4838 } 4839 4840 /// Folds a tensor.cast op into a consuming tensor::PackOp op if the 4841 /// `tensor.cast` has source that is more static than the consuming op. 4842 /// 4843 /// Example: 4844 /// ```mlir 4845 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 4846 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ... 4847 /// ``` 4848 /// 4849 /// folds into: 4850 /// 4851 /// ```mlir 4852 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ... 4853 /// ``` 4854 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { 4855 using OpRewritePattern<PackOp>::OpRewritePattern; 4856 4857 LogicalResult matchAndRewrite(PackOp op, 4858 PatternRewriter &rewriter) const override { 4859 if (!foldTensorCastPrecondition(op)) 4860 return failure(); 4861 4862 SmallVector<Type> newResultTypes(op->getResultTypes()); 4863 SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); 4864 4865 // Get the updated mixed-tile-sizes attribute. 4866 SmallVector<OpFoldResult> newMixedTileSizes = 4867 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles()); 4868 4869 // Clone op. 4870 // TODO: Strictly speaking, discardable attributes should be _discarded_ at 4871 // this point. However, in practice, we use them for things that we'd like 4872 // to preserve. Implement a better abstraction. 4873 PackOp newOp = rewriter.create<PackOp>( 4874 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), 4875 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); 4876 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); 4877 4878 // Replace op. 4879 Value oldResult = op.getResult(); 4880 Value newResult = newOp.getResult(); 4881 Value replacement = (newResult.getType() != oldResult.getType()) 4882 ? rewriter.create<tensor::CastOp>( 4883 op->getLoc(), oldResult.getType(), newResult) 4884 : newResult; 4885 4886 rewriter.replaceOp(op, {replacement}); 4887 4888 return success(); 4889 } 4890 }; 4891 4892 /// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the 4893 /// `tensor.cast` has source that is more static than the consuming op. 4894 /// 4895 /// Example: 4896 /// ```mlir 4897 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> 4898 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32> 4899 /// ``` 4900 /// 4901 /// folds into: 4902 /// 4903 /// ```mlir 4904 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32> 4905 /// ``` 4906 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> { 4907 using OpRewritePattern<UnPackOp>::OpRewritePattern; 4908 4909 LogicalResult matchAndRewrite(UnPackOp op, 4910 PatternRewriter &rewriter) const override { 4911 if (!foldTensorCastPrecondition(op)) 4912 return failure(); 4913 4914 SmallVector<Type> newResultTypes(op->getResultTypes()); 4915 SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); 4916 Value sourceTensor = newOperands[0]; 4917 4918 // Get the updated mixed-tile-sizes attribute. 4919 SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes( 4920 rewriter, sourceTensor.getType(), op.getMixedTiles()); 4921 4922 // Clone op. 4923 // TODO: Strictly speaking, discardable attributes should be _discarded_ at 4924 // this point. However, in practice, we use them for things that we'd like 4925 // to preserve. Implement a better abstraction. 4926 UnPackOp newOp = rewriter.create<UnPackOp>( 4927 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), 4928 newMixedTileSizes, op.getOuterDimsPerm()); 4929 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); 4930 4931 // Replace op. 4932 Value oldResult = op.getResult(); 4933 Value newResult = newOp.getResult(); 4934 Value replacement = (newResult.getType() != oldResult.getType()) 4935 ? rewriter.create<tensor::CastOp>( 4936 op->getLoc(), oldResult.getType(), newResult) 4937 : newResult; 4938 4939 rewriter.replaceOp(op, {replacement}); 4940 4941 return success(); 4942 } 4943 }; 4944 4945 /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if 4946 /// the `tensor.cast` has source that is more static than the consuming op. 4947 /// 4948 /// Example: 4949 /// ```mlir 4950 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> 4951 /// %2 = consumer %1 ... : tensor<?x?xf32> ... 4952 /// ``` 4953 /// 4954 /// folds into: 4955 /// 4956 /// ```mlir 4957 /// %2 = consumer %0 ... : tensor<8x16xf32> ... 4958 /// ``` 4959 /// TODO: Move the pattern to a proper place, so all other DestinationStyleOp 4960 /// can add the pattern to their canonicalizers. 4961 struct FoldTensorCastProducerOp 4962 : public OpInterfaceRewritePattern<DestinationStyleOpInterface> { 4963 using OpInterfaceRewritePattern< 4964 DestinationStyleOpInterface>::OpInterfaceRewritePattern; 4965 4966 LogicalResult matchAndRewrite(DestinationStyleOpInterface op, 4967 PatternRewriter &rewriter) const override { 4968 4969 // Reject tensor::PackOp - there's dedicated pattern for that instead. 4970 if (!foldTensorCastPrecondition(op) || 4971 isa<tensor::PackOp, tensor::UnPackOp>(*op)) 4972 return failure(); 4973 4974 SmallVector<Type> newResultTypes(op->getResultTypes()); 4975 SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); 4976 4977 // Clone op 4978 auto newOp = clone(rewriter, op, newResultTypes, newOperands); 4979 4980 SmallVector<Value, 4> replacements; 4981 replacements.reserve(newOp->getNumResults()); 4982 for (auto [oldResult, newResult] : 4983 llvm::zip(op->getResults(), newOp->getResults())) { 4984 if (newResult.getType() != oldResult.getType()) { 4985 replacements.push_back(rewriter.create<tensor::CastOp>( 4986 op->getLoc(), oldResult.getType(), newResult)); 4987 } else { 4988 replacements.push_back(newResult); 4989 } 4990 } 4991 rewriter.replaceOp(op, replacements); 4992 4993 return success(); 4994 } 4995 }; 4996 4997 //===----------------------------------------------------------------------===// 4998 // TensorDialect 4999 //===----------------------------------------------------------------------===// 5000 5001 void TensorDialect::getCanonicalizationPatterns( 5002 RewritePatternSet &results) const { 5003 results.add<FoldTensorCastPackOp>(getContext()); 5004 results.add<FoldTensorCastUnPackOp>(getContext()); 5005 results.add<FoldTensorCastProducerOp>(getContext()); 5006 } 5007 5008 //===----------------------------------------------------------------------===// 5009 // TableGen'd op method definitions 5010 //===----------------------------------------------------------------------===// 5011 5012 #define GET_OP_CLASSES 5013 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 5014