1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Affine/Utils.h" 12 #include "mlir/Dialect/Arith/Utils/Utils.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Utils/Utils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 21 22 using namespace mlir; 23 using namespace mlir::tensor; 24 25 namespace { 26 27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 28 29 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 30 auto padOp = cast<PadOp>(op); 31 SmallVector<utils::IteratorType> iteratorTypes( 32 padOp.getResultType().getRank(), utils::IteratorType::parallel); 33 return iteratorTypes; 34 } 35 36 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 37 ReifiedRankedShapedTypeDims reifiedShapes; 38 (void)reifyResultShapes(b, op, reifiedShapes); 39 OpFoldResult zero = b.getIndexAttr(0); 40 OpFoldResult one = b.getIndexAttr(1); 41 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 42 // overwritten. 43 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 44 for (const auto &ub : enumerate(reifiedShapes[0])) 45 loopRanges[ub.index()].size = ub.value(); 46 return loopRanges; 47 } 48 49 FailureOr<TilingResult> 50 getTiledImplementation(Operation *op, OpBuilder &b, 51 ArrayRef<OpFoldResult> offsets, 52 ArrayRef<OpFoldResult> sizes) const { 53 FailureOr<TilingResult> result = 54 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); 55 if (failed(result)) 56 return failure(); 57 return result.value(); 58 } 59 60 LogicalResult 61 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 62 ArrayRef<OpFoldResult> offsets, 63 ArrayRef<OpFoldResult> sizes, 64 SmallVectorImpl<OpFoldResult> &resultOffsets, 65 SmallVectorImpl<OpFoldResult> &resultSizes) const { 66 resultOffsets.assign(offsets.begin(), offsets.end()); 67 resultSizes.assign(sizes.begin(), sizes.end()); 68 return success(); 69 } 70 }; 71 72 template <typename OpTy> 73 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, 74 OpBuilder &builder) { 75 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 76 "applies to only pack or unpack operations"); 77 OpBuilder::InsertionGuard g(builder); 78 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() 79 : op.getDestRank(); 80 OpFoldResult zero = builder.getIndexAttr(0); 81 OpFoldResult one = builder.getIndexAttr(1); 82 ReifiedRankedShapedTypeDims resultShape; 83 (void)reifyResultShapes(builder, op, resultShape); 84 SmallVector<Range> loopBounds(rank); 85 for (auto dim : llvm::seq<int64_t>(0, rank)) { 86 loopBounds[dim].offset = zero; 87 loopBounds[dim].stride = one; 88 loopBounds[dim].size = resultShape[0][dim]; 89 } 90 return loopBounds; 91 } 92 93 static void applyPermToRange(SmallVector<OpFoldResult> &offsets, 94 SmallVector<OpFoldResult> &sizes, 95 ArrayRef<int64_t> permutation) { 96 if (permutation.empty()) 97 return; 98 applyPermutationToVector<OpFoldResult>(offsets, permutation); 99 applyPermutationToVector<OpFoldResult>(sizes, permutation); 100 } 101 102 struct PackOpTiling 103 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { 104 105 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 106 // Note that here we only consider untiled dimensions and outer tiled data 107 // dimensions, the inner tiled data dimensions are materialized when 108 // building the body of the operation. 109 auto packOp = cast<PackOp>(op); 110 SmallVector<utils::IteratorType> iteratorTypes( 111 packOp.getSourceRank(), utils::IteratorType::parallel); 112 return iteratorTypes; 113 } 114 115 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 116 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); 117 } 118 119 FailureOr<TilingResult> 120 getTiledImplementation(Operation *op, OpBuilder &b, 121 ArrayRef<OpFoldResult> offsets, 122 ArrayRef<OpFoldResult> sizes) const { 123 auto packOp = cast<PackOp>(op); 124 Location loc = packOp.getLoc(); 125 126 // The tiling is applied on interchanged dimensions. We have to undo the 127 // interchange to map sizes and offsets to the original input. 128 int64_t inputRank = packOp.getSourceRank(); 129 SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end()); 130 SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end()); 131 applyPermToRange(origOffsets, origSizes, 132 invertPermutationVector(packOp.getOuterDimsPerm())); 133 134 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 135 packOp.getDimAndTileMapping(); 136 SmallVector<OpFoldResult> srcDimValues = 137 tensor::getMixedSizes(b, loc, packOp.getSource()); 138 SmallVector<OpFoldResult> inputIndices, inputSizes; 139 for (auto dim : llvm::seq<int64_t>(0, inputRank)) { 140 using AV = affine::AffineValueExpr; 141 affine::AffineBuilder ab(b, loc); 142 AffineExpr dim0, dim1, sym; 143 bindDims(b.getContext(), dim0, dim1); 144 bindSymbols(b.getContext(), sym); 145 if (dimAndTileMapping.count(dim)) { 146 // If the data dimension is tiled, the i-th index is the product of 147 // offset_i and tile_i, and the i-th size is the product of sizes_i and 148 // tile_i. 149 auto avOffset = AV(dim0).bind(origOffsets[dim]); 150 auto avSize = AV(dim0).bind(origSizes[dim]); 151 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 152 inputIndices.push_back(ab.mul(avOffset, avTileSize)); 153 inputSizes.push_back(ab.mul(avSize, avTileSize)); 154 } else { 155 inputIndices.push_back(origOffsets[dim]); 156 inputSizes.push_back(origSizes[dim]); 157 } 158 159 // Limit the size of the input operand for incomplete tiles. 160 if (packOp.getPaddingValue()) { 161 OpFoldResult dimSize = srcDimValues[dim]; 162 auto avDimSize = AV(dim0).bind(dimSize); 163 auto avInputIdx = AV(dim1).bind(inputIndices.back()); 164 inputSizes.back() = 165 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); 166 } 167 } 168 169 auto oneAttr = b.getI64IntegerAttr(1); 170 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 171 172 SmallVector<Value> tiledOperands; 173 tiledOperands.push_back(b.create<ExtractSliceOp>( 174 loc, packOp.getSource(), inputIndices, inputSizes, strides)); 175 176 SmallVector<OpFoldResult> outputOffsets, outputSizes; 177 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, 178 outputSizes))) 179 return {}; 180 181 strides.append(packOp.getDestRank() - inputRank, oneAttr); 182 auto extractSlice = b.create<ExtractSliceOp>( 183 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 184 tiledOperands.push_back(extractSlice); 185 186 if (auto val = packOp.getPaddingValue()) 187 tiledOperands.push_back(val); 188 for (auto tile : packOp.getInnerTiles()) 189 tiledOperands.push_back(tile); 190 191 Operation *tiledPackOp = b.create<PackOp>( 192 loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); 193 194 return TilingResult{{tiledPackOp}, 195 SmallVector<Value>(tiledPackOp->getResults())}; 196 } 197 198 LogicalResult 199 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 200 ArrayRef<OpFoldResult> offsets, 201 ArrayRef<OpFoldResult> sizes, 202 SmallVectorImpl<OpFoldResult> &resultOffsets, 203 SmallVectorImpl<OpFoldResult> &resultSizes) const { 204 // The iteration domain is over outer dimensions of packed layout. In this 205 // context, the outer dimensions of `resultOffsets` are `offsets`. The 206 // inner dimensions of `resultOffsets` are zeros because tiling is not 207 // applied to them. 208 auto packOp = cast<PackOp>(op); 209 int64_t inputRank = packOp.getSourceRank(); 210 int64_t outputRank = packOp.getDestRank(); 211 auto zeroAttr = b.getI64IntegerAttr(0); 212 resultOffsets.assign(offsets.begin(), offsets.end()); 213 resultOffsets.append(outputRank - inputRank, zeroAttr); 214 215 ReifiedRankedShapedTypeDims outputShape; 216 (void)reifyResultShapes(b, packOp, outputShape); 217 resultSizes.assign(sizes.begin(), sizes.end()); 218 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) 219 resultSizes.push_back(outputShape[0][dataTileDim]); 220 221 return success(); 222 } 223 224 FailureOr<TilingResult> 225 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 226 ArrayRef<OpFoldResult> offsets, 227 ArrayRef<OpFoldResult> sizes) const { 228 auto packOp = cast<PackOp>(op); 229 int64_t numTiles = packOp.getInnerDimsPos().size(); 230 231 // tensor.pack op is fusible (as a producer) only if full inner tiles are 232 // iterated or inner dims are not tiled. Otherwise, it will generate a 233 // sequence of non-trivial ops (for partial tiles). 234 for (auto offset : offsets.take_back(numTiles)) 235 if (!isConstantIntValue(offset, 0)) 236 return failure(); 237 238 for (auto iter : 239 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) 240 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 241 return failure(); 242 243 FailureOr<TilingResult> tilingResult = getTiledImplementation( 244 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles)); 245 if (failed(tilingResult)) 246 return failure(); 247 return tilingResult.value(); 248 } 249 }; 250 251 struct UnpackTileDimInfo { 252 bool isAlignedToInnerTileSize; 253 OpFoldResult sourceOffset; 254 OpFoldResult sourceSize; 255 OpFoldResult resultOffset; 256 OpFoldResult destExpandedSize; 257 }; 258 259 /// Returns the needed information for tiling unpack op on `tileDim` with given 260 /// `tileOffset` and `tileSize`. For more details, see the comment of the 261 /// `getTiledImplementation`. 262 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, 263 int64_t tileDim, 264 OpFoldResult tileOffset, 265 OpFoldResult tileSize) { 266 UnpackTileDimInfo info; 267 Attribute zeroAttr = b.getIndexAttr(0); 268 Attribute oneAttr = b.getIndexAttr(1); 269 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 270 unpackOp.getDimAndTileMapping(); 271 // The dimension is not one of packed data dimension. 272 if (!dimAndTileMapping.count(tileDim)) { 273 info.isAlignedToInnerTileSize = true; 274 info.sourceOffset = tileOffset; 275 info.sourceSize = tileSize; 276 info.resultOffset = zeroAttr; 277 info.destExpandedSize = tileSize; 278 return info; 279 } 280 281 Location loc = unpackOp.getLoc(); 282 using AV = affine::AffineValueExpr; 283 affine::AffineBuilder ab(b, loc); 284 AffineExpr dim0, dim1, sym0; 285 bindDims(b.getContext(), dim0, dim1); 286 bindSymbols(b.getContext(), sym0); 287 288 OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; 289 290 info.isAlignedToInnerTileSize = false; 291 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( 292 presburger::BoundType::UB, tileSize, 293 /*stopCondition=*/nullptr, /*closedUB=*/true); 294 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); 295 if (!failed(cstSize) && cstInnerSize) { 296 if (*cstSize % *cstInnerSize == 0) 297 info.isAlignedToInnerTileSize = true; 298 299 // If the tiling size equals to the inner tiling size, the outer dims are 300 // always 1. 301 if (*cstInnerSize == *cstSize) { 302 auto lhs = AV(dim0).bind(tileOffset); 303 auto rhs = AV(dim1).bind(innerTileSize); 304 info.sourceOffset = ab.floor(lhs, rhs); 305 info.sourceSize = oneAttr; 306 info.resultOffset = zeroAttr; 307 info.destExpandedSize = tileSize; 308 return info; 309 } 310 } 311 312 if (info.isAlignedToInnerTileSize) { 313 info.sourceOffset = 314 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); 315 info.resultOffset = zeroAttr; 316 info.destExpandedSize = tileSize; 317 318 // The ceilDiv is needed here because there could be incomplete tile even 319 // it is perfect tiling cases. E.g., 320 // %0 = unpack tensor<33x2xf32> into tensor<64xf32> 321 // If the tiling size is 32, there will be 3 tiles. Two of them have 322 // size=32; one of them have size=2. The size is represented using 323 // affine_min op; we need ceilDiv. 324 info.sourceSize = 325 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); 326 return info; 327 } 328 329 affine::DivModValue firstCoord = affine::getDivMod( 330 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), 331 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 332 OpFoldResult tileExclusiveBound = 333 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); 334 affine::DivModValue lastCoord = affine::getDivMod( 335 b, loc, 336 getValueOrCreateConstantIndexOp( 337 b, loc, 338 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), 339 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 340 341 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), 342 AV(dim1).bind(firstCoord.quotient)); 343 info.sourceSize = 344 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); 345 info.sourceOffset = firstCoord.quotient; 346 info.resultOffset = firstCoord.remainder; 347 // Do not create an Affine ops for expanded size because the affine op is too 348 // complicated which would trigger an issue in affine ops simplification. 349 info.destExpandedSize = b.createOrFold<arith::MulIOp>( 350 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), 351 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 352 return info; 353 } 354 355 struct UnPackOpTiling 356 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { 357 358 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 359 auto unpackOp = cast<UnPackOp>(op); 360 SmallVector<utils::IteratorType> iteratorTypes( 361 unpackOp.getDestRank(), utils::IteratorType::parallel); 362 return iteratorTypes; 363 } 364 365 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 366 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); 367 } 368 369 /// There are two cases in tiling unpack ops. If the tiling size is aligned to 370 /// the inner tile size, the corresponding tiles of source are all complete. 371 /// Otherwise, there are in-complete tiles. We will need to expand the slice 372 /// of source for getting complete tiles. The tiled unpack op unpacks more 373 /// data from source, so We'll need an extract_slice op to shift and truncate 374 /// the output. 375 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The 376 /// coordinates of second tile (i.e., result[15..31]) are 377 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last 378 /// row are incomplete tiles. To represent the unpack op, we have to complete 379 /// the rows. I.e., the input coordinates would start with (1, 0); end with 380 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements 381 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we 382 /// can get the actual result. 383 FailureOr<TilingResult> 384 getTiledImplementation(Operation *op, OpBuilder &b, 385 ArrayRef<OpFoldResult> offsets, 386 ArrayRef<OpFoldResult> sizes) const { 387 auto unpackOp = cast<UnPackOp>(op); 388 int64_t srcRank = unpackOp.getSourceRank(); 389 int64_t destRank = unpackOp.getDestRank(); 390 int64_t numInnerTiles = srcRank - destRank; 391 Location loc = unpackOp.getLoc(); 392 393 // The perfect tiling case indicates that the tiling sizes are multiple of 394 // inner_tile_size. In this context, no extra data is needed when 395 // representing the tiled unpack op. 396 bool isPerfectTilingCase = true; 397 Attribute oneAttr = b.getIndexAttr(1); 398 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); 399 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; 400 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; 401 for (auto dim : llvm::seq<int64_t>(0, destRank)) { 402 UnpackTileDimInfo info = 403 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); 404 if (!info.isAlignedToInnerTileSize) 405 isPerfectTilingCase = false; 406 sliceSrcIndices.push_back(info.sourceOffset); 407 sliceSrcSizes.push_back(info.sourceSize); 408 destExpandedSizes.push_back(info.destExpandedSize); 409 resultOffsetsFromDest.push_back(info.resultOffset); 410 } 411 412 // The tiling is applied on destination dimensions. We have to apply the 413 // interchange on source dimensions if outer_dims_perm is set. 414 applyPermToRange(sliceSrcIndices, sliceSrcSizes, 415 unpackOp.getOuterDimsPerm()); 416 Attribute zeroAttr = b.getIndexAttr(0); 417 sliceSrcIndices.append(numInnerTiles, zeroAttr); 418 sliceSrcSizes.append(unpackOp.getMixedTiles()); 419 sliceSrcStrides.append(numInnerTiles, oneAttr); 420 Value sliceSource = 421 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, 422 sliceSrcSizes, sliceSrcStrides); 423 424 SmallVector<OpFoldResult> destStrides(destRank, oneAttr); 425 Value sliceDest; 426 if (isPerfectTilingCase) { 427 sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets, 428 sizes, destStrides); 429 } else { 430 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, 431 unpackOp.getDestType().getElementType()); 432 } 433 434 SmallVector<Value> tiledOperands = {sliceSource, sliceDest}; 435 for (auto tile : unpackOp.getInnerTiles()) 436 tiledOperands.push_back(tile); 437 438 Operation *tiledUnpackOp = b.create<UnPackOp>( 439 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); 440 441 if (isPerfectTilingCase) 442 return TilingResult{{tiledUnpackOp}, 443 SmallVector<Value>(tiledUnpackOp->getResults())}; 444 445 auto extractSlice = 446 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), 447 resultOffsetsFromDest, sizes, destStrides); 448 return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; 449 } 450 451 LogicalResult 452 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 453 ArrayRef<OpFoldResult> offsets, 454 ArrayRef<OpFoldResult> sizes, 455 SmallVectorImpl<OpFoldResult> &resultOffsets, 456 SmallVectorImpl<OpFoldResult> &resultSizes) const { 457 resultOffsets = llvm::to_vector(offsets); 458 resultSizes = llvm::to_vector(sizes); 459 return success(); 460 } 461 462 FailureOr<TilingResult> 463 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 464 ArrayRef<OpFoldResult> offsets, 465 ArrayRef<OpFoldResult> sizes) const { 466 FailureOr<TilingResult> tilingResult = 467 getTiledImplementation(op, b, offsets, sizes); 468 if (failed(tilingResult)) 469 return failure(); 470 return tilingResult.value(); 471 } 472 }; 473 474 } // namespace 475 476 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, 477 tensor::PadOp padOp, 478 ArrayRef<OpFoldResult> offsets, 479 ArrayRef<OpFoldResult> sizes, 480 bool generateZeroSliceGuard) { 481 // Only constant padding value supported. 482 Value padValue = padOp.getConstantPaddingValue(); 483 if (!padValue) 484 return failure(); 485 486 // Helper variables and functions for various arithmetic operations. These 487 // are used extensively for computing new offset/length and padding values. 488 Location loc = padOp->getLoc(); 489 AffineExpr dim0, dim1; 490 bindDims(b.getContext(), dim0, dim1); 491 // Add two integers. 492 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 493 auto add = [&](OpFoldResult v1, OpFoldResult v2) { 494 return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2}); 495 }; 496 // Subtract two integers. 497 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 498 auto sub = [&](OpFoldResult v1, OpFoldResult v2) { 499 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); 500 }; 501 // Take the minimum of two integers. 502 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 503 auto min = [&](OpFoldResult v1, OpFoldResult v2) { 504 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); 505 }; 506 // Take the maximum of two integers. 507 auto max = [&](OpFoldResult v1, OpFoldResult v2) { 508 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); 509 }; 510 // Zero index-typed integer. 511 OpFoldResult zero = b.getIndexAttr(0); 512 513 // Compute new offsets, lengths, low padding, high padding. 514 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 515 SmallVector<OpFoldResult> newLows, newHighs; 516 // Set to true if the original data source is not read at all. 517 bool hasZeroLen = false; 518 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 519 // is true if the original data source turns out to be unused at runtime. 520 Value dynHasZeroLenCond; 521 522 int64_t rank = padOp.getSourceType().getRank(); 523 for (unsigned dim = 0; dim < rank; ++dim) { 524 auto low = padOp.getMixedLowPad()[dim]; 525 bool hasLowPad = !isConstantIntValue(low, 0); 526 auto high = padOp.getMixedHighPad()[dim]; 527 bool hasHighPad = !isConstantIntValue(high, 0); 528 auto offset = offsets[dim]; 529 auto length = sizes[dim]; 530 auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim); 531 532 // The new amount of low padding is `low - offset`. Except for the case 533 // where none of the low padding is read. In that case, the new amount of 534 // low padding is zero. 535 // 536 // Optimization: If low = 0, then newLow = 0. 537 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 538 newLows.push_back(newLow); 539 540 // Start reading the data from position `offset - low`. Since the original 541 // read may have started in the low padding zone, this value could be 542 // negative. Therefore, start reading from: 543 // 544 // max(offset - low, 0) 545 // 546 // The original read could also have started in the high padding zone. 547 // In that case, set the offset to the end of source tensor. The new 548 // ExtractSliceOp length will be zero in that case. (Effectively reading 549 // no data from the source.) 550 // 551 // Optimization: If low = 0, then the formula can be simplified. 552 OpFoldResult newOffset = hasLowPad 553 ? min(max(sub(offset, low), zero), srcSize) 554 : min(offset, srcSize); 555 newOffsets.push_back(newOffset); 556 557 // The original ExtractSliceOp was reading until position `offset + 558 // length`. Therefore, the corresponding position within the source tensor 559 // is: 560 // 561 // offset + length - low 562 // 563 // In case the original ExtractSliceOp stopped reading within the low 564 // padding zone, this value can be negative. In that case, the end 565 // position of the read should be zero. (Similar to newOffset.) 566 // 567 // The original read could also have stopped in the high padding zone. 568 // In that case, set the end positition of the read should be the end of 569 // the source tensor. (Similar to newOffset.) 570 // 571 // endLoc = min(max(offset - low + length, 0), srcSize) 572 // 573 // The new ExtractSliceOp length is `endLoc - newOffset`. 574 // 575 // Optimization: If low = 0, then the formula can be simplified. 576 OpFoldResult endLoc = 577 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) 578 : min(add(offset, length), srcSize); 579 OpFoldResult newLength = sub(endLoc, newOffset); 580 newLengths.push_back(newLength); 581 582 // Check if newLength is zero. In that case, no SubTensorOp should be 583 // executed. 584 if (isConstantIntValue(newLength, 0)) { 585 hasZeroLen = true; 586 } else if (!hasZeroLen) { 587 Value check = b.create<arith::CmpIOp>( 588 loc, arith::CmpIPredicate::eq, 589 getValueOrCreateConstantIndexOp(b, loc, newLength), 590 getValueOrCreateConstantIndexOp(b, loc, zero)); 591 dynHasZeroLenCond = 592 dynHasZeroLenCond 593 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 594 : check; 595 } 596 597 // The amount of high padding is simply the number of elements remaining, 598 // so that the result has the same length as the original ExtractSliceOp. 599 // As an optimization, if the original high padding is zero, then the new 600 // high padding must also be zero. 601 OpFoldResult newHigh = 602 hasHighPad ? sub(sub(length, newLength), newLow) : zero; 603 newHighs.push_back(newHigh); 604 605 // Only unit stride supported. 606 newStrides.push_back(b.getIndexAttr(1)); 607 } 608 609 // The shape of the result can be obtained from the sizes passed in. 610 SmallVector<Value> dynDims; 611 SmallVector<int64_t> shape; 612 dispatchIndexOpFoldResults(sizes, dynDims, shape); 613 RankedTensorType resultType = 614 RankedTensorType::get(shape, padOp.getResultType().getElementType()); 615 616 // Insert cast to ensure that types match. (May be folded away.) 617 auto castResult = [&](Value val) -> Value { 618 if (resultType == val.getType()) 619 return val; 620 return b.create<tensor::CastOp>(loc, resultType, val); 621 }; 622 623 // In cases where the original data source is unused: Emit a GenerateOp and 624 // do not generate a SliceOp. (The result shape of the SliceOp would 625 // have a dimension of size 0, the semantics of which is unclear.) 626 auto createGenerateOp = [&]() { 627 // Create GenerateOp. 628 auto generateOp = b.create<tensor::GenerateOp>( 629 loc, resultType, dynDims, 630 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 631 builder.create<tensor::YieldOp>(gLoc, padValue); 632 }); 633 return generateOp; 634 }; 635 636 // Emit a SliceOp and a PadOp. Should not be used in cases where 637 // the result shape of the new SliceOp has a zero dimension. 638 auto createPadOfExtractSlice = [&]() { 639 // Create pad(extract_slice(x)). 640 Value newSliceOp = b.create<tensor::ExtractSliceOp>( 641 loc, padOp.getSource(), newOffsets, newLengths, newStrides); 642 auto newPadOp = b.create<PadOp>( 643 loc, Type(), newSliceOp, newLows, newHighs, 644 /*nofold=*/padOp.getNofold(), 645 getPrunedAttributeList(padOp, PadOp::getAttributeNames())); 646 647 // Copy region to new PadOp. 648 IRMapping bvm; 649 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); 650 651 // Cast result and return. 652 return newPadOp; 653 }; 654 655 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that 656 // the original data source x is not used. 657 if (hasZeroLen) { 658 Operation *generateOp = createGenerateOp(); 659 return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; 660 } 661 662 // If there are dynamic dimensions: Generate an scf.if check to avoid 663 // creating SliceOps with result dimensions of size 0 at runtime. 664 if (generateZeroSliceGuard && dynHasZeroLenCond) { 665 Operation *thenOp; 666 Operation *elseOp; 667 auto result = b.create<scf::IfOp>( 668 loc, dynHasZeroLenCond, 669 /*thenBuilder=*/ 670 [&](OpBuilder &b, Location loc) { 671 thenOp = createGenerateOp(); 672 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); 673 }, 674 /*elseBuilder=*/ 675 [&](OpBuilder &b, Location loc) { 676 elseOp = createPadOfExtractSlice(); 677 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); 678 }); 679 return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())}; 680 } 681 682 Operation *newPadOp = createPadOfExtractSlice(); 683 return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; 684 } 685 686 void mlir::tensor::registerTilingInterfaceExternalModels( 687 DialectRegistry ®istry) { 688 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 689 tensor::PadOp::attachInterface<PadOpTiling>(*ctx); 690 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 691 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 692 }); 693 } 694 695 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( 696 DialectRegistry ®istry) { 697 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 698 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 699 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 700 }); 701 } 702