1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Affine/Utils.h" 12 #include "mlir/Dialect/Arith/Utils/Utils.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Utils/Utils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 21 22 using namespace mlir; 23 using namespace mlir::tensor; 24 25 namespace { 26 27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 28 29 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 30 auto padOp = cast<PadOp>(op); 31 SmallVector<utils::IteratorType> iteratorTypes( 32 padOp.getResultType().getRank(), utils::IteratorType::parallel); 33 return iteratorTypes; 34 } 35 36 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 37 ReifiedRankedShapedTypeDims reifiedShapes; 38 (void)reifyResultShapes(b, op, reifiedShapes); 39 OpFoldResult zero = b.getIndexAttr(0); 40 OpFoldResult one = b.getIndexAttr(1); 41 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 42 // overwritten. 43 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 44 for (const auto &ub : enumerate(reifiedShapes[0])) 45 loopRanges[ub.index()].size = ub.value(); 46 return loopRanges; 47 } 48 49 FailureOr<TilingResult> 50 getTiledImplementation(Operation *op, OpBuilder &b, 51 ArrayRef<OpFoldResult> offsets, 52 ArrayRef<OpFoldResult> sizes) const { 53 FailureOr<TilingResult> result = 54 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); 55 if (failed(result)) 56 return failure(); 57 return result.value(); 58 } 59 60 LogicalResult 61 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 62 ArrayRef<OpFoldResult> offsets, 63 ArrayRef<OpFoldResult> sizes, 64 SmallVector<OpFoldResult> &resultOffsets, 65 SmallVector<OpFoldResult> &resultSizes) const { 66 resultOffsets.assign(offsets.begin(), offsets.end()); 67 resultSizes.assign(sizes.begin(), sizes.end()); 68 return success(); 69 } 70 }; 71 72 template <typename OpTy> 73 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, 74 OpBuilder &builder) { 75 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 76 "applies to only pack or unpack operations"); 77 OpBuilder::InsertionGuard g(builder); 78 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() 79 : op.getDestRank(); 80 OpFoldResult zero = builder.getIndexAttr(0); 81 OpFoldResult one = builder.getIndexAttr(1); 82 ReifiedRankedShapedTypeDims resultShape; 83 (void)reifyResultShapes(builder, op, resultShape); 84 SmallVector<Range> loopBounds(rank); 85 for (auto dim : llvm::seq<int64_t>(0, rank)) { 86 loopBounds[dim].offset = zero; 87 loopBounds[dim].stride = one; 88 loopBounds[dim].size = resultShape[0][dim]; 89 } 90 return loopBounds; 91 } 92 93 static void applyPermToRange(SmallVector<OpFoldResult> &offsets, 94 SmallVector<OpFoldResult> &sizes, 95 ArrayRef<int64_t> permutation) { 96 if (permutation.empty()) 97 return; 98 applyPermutationToVector<OpFoldResult>(offsets, permutation); 99 applyPermutationToVector<OpFoldResult>(sizes, permutation); 100 } 101 102 struct PackOpTiling 103 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { 104 105 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 106 // Note that here we only consider untiled dimensions and outer tiled data 107 // dimensions, the inner tiled data dimensions are materialized when 108 // building the body of the operation. 109 auto packOp = cast<PackOp>(op); 110 SmallVector<utils::IteratorType> iteratorTypes( 111 packOp.getSourceRank(), utils::IteratorType::parallel); 112 return iteratorTypes; 113 } 114 115 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 116 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); 117 } 118 119 FailureOr<TilingResult> 120 getTiledImplementation(Operation *op, OpBuilder &b, 121 ArrayRef<OpFoldResult> offsets, 122 ArrayRef<OpFoldResult> sizes) const { 123 auto packOp = cast<PackOp>(op); 124 Location loc = packOp.getLoc(); 125 126 // The tiling is applied on interchanged dimensions. We have to undo the 127 // interchange to map sizes and offsets to the original input. 128 int64_t inputRank = packOp.getSourceRank(); 129 SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end()); 130 SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end()); 131 applyPermToRange(origOffsets, origSizes, 132 invertPermutationVector(packOp.getOuterDimsPerm())); 133 134 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 135 packOp.getDimAndTileMapping(); 136 SmallVector<OpFoldResult> srcDimValues = 137 tensor::createDimValues(b, loc, packOp.getSource()); 138 SmallVector<OpFoldResult> inputIndices, inputSizes; 139 for (auto dim : llvm::seq<int64_t>(0, inputRank)) { 140 using AV = affine::AffineValueExpr; 141 affine::AffineBuilder ab(b, loc); 142 AffineExpr dim0, dim1, sym; 143 bindDims(b.getContext(), dim0, dim1); 144 bindSymbols(b.getContext(), sym); 145 if (dimAndTileMapping.count(dim)) { 146 // If the data dimension is tiled, the i-th index is the product of 147 // offset_i and tile_i, and the i-th size is the product of sizes_i and 148 // tile_i. 149 auto avOffset = AV(dim0).bind(origOffsets[dim]); 150 auto avSize = AV(dim0).bind(origSizes[dim]); 151 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 152 inputIndices.push_back(ab.mul(avOffset, avTileSize)); 153 inputSizes.push_back(ab.mul(avSize, avTileSize)); 154 } else { 155 inputIndices.push_back(origOffsets[dim]); 156 inputSizes.push_back(origSizes[dim]); 157 } 158 159 // Limit the size of the input operand for incomplete tiles. 160 if (packOp.getPaddingValue()) { 161 OpFoldResult dimSize = srcDimValues[dim]; 162 auto avDimSize = AV(dim0).bind(dimSize); 163 auto avInputIdx = AV(dim1).bind(inputIndices.back()); 164 inputSizes.back() = 165 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); 166 } 167 } 168 169 auto oneAttr = b.getI64IntegerAttr(1); 170 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 171 172 SmallVector<Value> tiledOperands; 173 tiledOperands.push_back(b.create<ExtractSliceOp>( 174 loc, packOp.getSource(), inputIndices, inputSizes, strides)); 175 176 SmallVector<OpFoldResult> outputOffsets, outputSizes; 177 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, 178 outputSizes))) 179 return {}; 180 181 strides.append(packOp.getDestRank() - inputRank, oneAttr); 182 auto extractSlice = b.create<ExtractSliceOp>( 183 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 184 tiledOperands.push_back(extractSlice); 185 186 if (auto val = packOp.getPaddingValue()) 187 tiledOperands.push_back(val); 188 for (auto tile : packOp.getInnerTiles()) 189 tiledOperands.push_back(tile); 190 191 Operation *tiledPackOp = b.create<PackOp>( 192 loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); 193 194 return TilingResult{{tiledPackOp}, 195 SmallVector<Value>(tiledPackOp->getResults())}; 196 } 197 198 LogicalResult 199 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 200 ArrayRef<OpFoldResult> offsets, 201 ArrayRef<OpFoldResult> sizes, 202 SmallVector<OpFoldResult> &resultOffsets, 203 SmallVector<OpFoldResult> &resultSizes) const { 204 // The iteration domain is over outer dimensions of packed layout. In this 205 // context, the outer dimensions of `resultOffsets` are `offsets`. The 206 // inner dimensions of `resultOffsets` are zeros because tiling is not 207 // applied to them. 208 auto packOp = cast<PackOp>(op); 209 int64_t inputRank = packOp.getSourceRank(); 210 int64_t outputRank = packOp.getDestRank(); 211 auto zeroAttr = b.getI64IntegerAttr(0); 212 resultOffsets.assign(offsets.begin(), offsets.end()); 213 resultOffsets.append(outputRank - inputRank, zeroAttr); 214 215 ReifiedRankedShapedTypeDims outputShape; 216 (void)reifyResultShapes(b, packOp, outputShape); 217 resultSizes.assign(sizes.begin(), sizes.end()); 218 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) 219 resultSizes.push_back(outputShape[0][dataTileDim]); 220 221 return success(); 222 } 223 }; 224 225 struct UnpackTileDimInfo { 226 bool isAlignedToInnerTileSize; 227 OpFoldResult sourceOffset; 228 OpFoldResult sourceSize; 229 OpFoldResult resultOffset; 230 OpFoldResult destExpandedSize; 231 }; 232 233 /// Returns the needed information for tiling unpack op on `tileDim` with given 234 /// `tileOffset` and `tileSize`. For more details, see the comment of the 235 /// `getTiledImplementation`. 236 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, 237 int64_t tileDim, 238 OpFoldResult tileOffset, 239 OpFoldResult tileSize) { 240 UnpackTileDimInfo info; 241 Attribute zeroAttr = b.getIndexAttr(0); 242 Attribute oneAttr = b.getIndexAttr(1); 243 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 244 unpackOp.getDimAndTileMapping(); 245 // The dimension is not one of packed data dimension. 246 if (!dimAndTileMapping.count(tileDim)) { 247 info.isAlignedToInnerTileSize = true; 248 info.sourceOffset = tileOffset; 249 info.sourceSize = tileSize; 250 info.resultOffset = zeroAttr; 251 info.destExpandedSize = tileSize; 252 return info; 253 } 254 255 Location loc = unpackOp.getLoc(); 256 using AV = affine::AffineValueExpr; 257 affine::AffineBuilder ab(b, loc); 258 AffineExpr dim0, dim1, sym0; 259 bindDims(b.getContext(), dim0, dim1); 260 bindSymbols(b.getContext(), sym0); 261 262 OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; 263 264 info.isAlignedToInnerTileSize = false; 265 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( 266 presburger::BoundType::UB, 267 getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt, 268 /*stopCondition=*/nullptr, /*closedUB=*/true); 269 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); 270 if (!failed(cstSize) && cstInnerSize) { 271 if (*cstSize % *cstInnerSize == 0) 272 info.isAlignedToInnerTileSize = true; 273 274 // If the tiling size equals to the inner tiling size, the outer dims are 275 // always 1. 276 if (*cstInnerSize == *cstSize) { 277 auto lhs = AV(dim0).bind(tileOffset); 278 auto rhs = AV(dim1).bind(innerTileSize); 279 info.sourceOffset = ab.floor(lhs, rhs); 280 info.sourceSize = oneAttr; 281 info.resultOffset = zeroAttr; 282 info.destExpandedSize = tileSize; 283 return info; 284 } 285 } 286 287 if (info.isAlignedToInnerTileSize) { 288 info.sourceOffset = 289 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); 290 info.resultOffset = zeroAttr; 291 info.destExpandedSize = tileSize; 292 293 // The ceilDiv is needed here because there could be incomplete tile even 294 // it is perfect tiling cases. E.g., 295 // %0 = unpack tensor<33x2xf32> into tensor<64xf32> 296 // If the tiling size is 32, there will be 3 tiles. Two of them have 297 // size=32; one of them have size=2. The size is represented using 298 // affine_min op; we need ceilDiv. 299 info.sourceSize = 300 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); 301 return info; 302 } 303 304 affine::DivModValue firstCoord = affine::getDivMod( 305 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), 306 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 307 OpFoldResult tileExclusiveBound = 308 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); 309 affine::DivModValue lastCoord = affine::getDivMod( 310 b, loc, 311 getValueOrCreateConstantIndexOp( 312 b, loc, 313 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), 314 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 315 316 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), 317 AV(dim1).bind(firstCoord.quotient)); 318 info.sourceSize = 319 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); 320 info.sourceOffset = firstCoord.quotient; 321 info.resultOffset = firstCoord.remainder; 322 // Do not create an Affine ops for expanded size because the affine op is too 323 // complicated which would trigger an issue in affine ops simplification. 324 info.destExpandedSize = b.createOrFold<arith::MulIOp>( 325 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), 326 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 327 return info; 328 } 329 330 struct UnPackOpTiling 331 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { 332 333 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 334 auto unpackOp = cast<UnPackOp>(op); 335 SmallVector<utils::IteratorType> iteratorTypes( 336 unpackOp.getDestRank(), utils::IteratorType::parallel); 337 return iteratorTypes; 338 } 339 340 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 341 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); 342 } 343 344 /// There are two cases in tiling unpack ops. If the tiling size is aligned to 345 /// the inner tile size, the corresponding tiles of source are all complete. 346 /// Otherwise, there are in-complete tiles. We will need to expand the slice 347 /// of source for getting complete tiles. The tiled unpack op unpacks more 348 /// data from source, so We'll need an extract_slice op to shift and truncate 349 /// the output. 350 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The 351 /// coordinates of second tile (i.e., result[15..31]) are 352 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last 353 /// row are incomplete tiles. To represent the unpack op, we have to complete 354 /// the rows. I.e., the input coordinates would start with (1, 0); end with 355 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements 356 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we 357 /// can get the actual result. 358 FailureOr<TilingResult> 359 getTiledImplementation(Operation *op, OpBuilder &b, 360 ArrayRef<OpFoldResult> offsets, 361 ArrayRef<OpFoldResult> sizes) const { 362 auto unpackOp = cast<UnPackOp>(op); 363 int64_t srcRank = unpackOp.getSourceRank(); 364 int64_t destRank = unpackOp.getDestRank(); 365 int64_t numInnerTiles = srcRank - destRank; 366 Location loc = unpackOp.getLoc(); 367 368 // The perfect tiling case indicates that the tiling sizes are multiple of 369 // inner_tile_size. In this context, no extra data is needed when 370 // representing the tiled unpack op. 371 bool isPerfectTilingCase = true; 372 Attribute oneAttr = b.getIndexAttr(1); 373 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); 374 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; 375 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; 376 for (auto dim : llvm::seq<int64_t>(0, destRank)) { 377 UnpackTileDimInfo info = 378 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); 379 if (!info.isAlignedToInnerTileSize) 380 isPerfectTilingCase = false; 381 sliceSrcIndices.push_back(info.sourceOffset); 382 sliceSrcSizes.push_back(info.sourceSize); 383 destExpandedSizes.push_back(info.destExpandedSize); 384 resultOffsetsFromDest.push_back(info.resultOffset); 385 } 386 387 // The tiling is applied on destination dimensions. We have to apply the 388 // interchange on source dimensions if outer_dims_perm is set. 389 applyPermToRange(sliceSrcIndices, sliceSrcSizes, 390 unpackOp.getOuterDimsPerm()); 391 Attribute zeroAttr = b.getIndexAttr(0); 392 sliceSrcIndices.append(numInnerTiles, zeroAttr); 393 sliceSrcSizes.append(unpackOp.getMixedTiles()); 394 sliceSrcStrides.append(numInnerTiles, oneAttr); 395 Value sliceSource = 396 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, 397 sliceSrcSizes, sliceSrcStrides); 398 399 SmallVector<OpFoldResult> destStrides(destRank, oneAttr); 400 Value sliceDest; 401 if (isPerfectTilingCase) { 402 sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets, 403 sizes, destStrides); 404 } else { 405 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, 406 unpackOp.getDestType().getElementType()); 407 } 408 409 SmallVector<Value> tiledOperands = {sliceSource, sliceDest}; 410 for (auto tile : unpackOp.getInnerTiles()) 411 tiledOperands.push_back(tile); 412 413 Operation *tiledUnpackOp = b.create<UnPackOp>( 414 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); 415 416 if (isPerfectTilingCase) 417 return TilingResult{{tiledUnpackOp}, 418 SmallVector<Value>(tiledUnpackOp->getResults())}; 419 420 auto extractSlice = 421 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), 422 resultOffsetsFromDest, sizes, destStrides); 423 return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; 424 } 425 426 LogicalResult 427 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 428 ArrayRef<OpFoldResult> offsets, 429 ArrayRef<OpFoldResult> sizes, 430 SmallVector<OpFoldResult> &resultOffsets, 431 SmallVector<OpFoldResult> &resultSizes) const { 432 resultOffsets = llvm::to_vector(offsets); 433 resultSizes = llvm::to_vector(sizes); 434 return success(); 435 } 436 437 FailureOr<TilingResult> 438 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 439 ArrayRef<OpFoldResult> offsets, 440 ArrayRef<OpFoldResult> sizes) const { 441 FailureOr<TilingResult> tilingResult = 442 getTiledImplementation(op, b, offsets, sizes); 443 if (failed(tilingResult)) 444 return failure(); 445 return tilingResult.value(); 446 } 447 }; 448 449 } // namespace 450 451 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, 452 tensor::PadOp padOp, 453 ArrayRef<OpFoldResult> offsets, 454 ArrayRef<OpFoldResult> sizes, 455 bool generateZeroSliceGuard) { 456 // Only constant padding value supported. 457 Value padValue = padOp.getConstantPaddingValue(); 458 if (!padValue) 459 return failure(); 460 461 // Helper variables and functions for various arithmetic operations. These 462 // are used extensively for computing new offset/length and padding values. 463 Location loc = padOp->getLoc(); 464 AffineExpr dim0, dim1; 465 bindDims(b.getContext(), dim0, dim1); 466 // Add two integers. 467 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 468 auto add = [&](OpFoldResult v1, OpFoldResult v2) { 469 return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2}); 470 }; 471 // Subtract two integers. 472 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 473 auto sub = [&](OpFoldResult v1, OpFoldResult v2) { 474 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); 475 }; 476 // Take the minimum of two integers. 477 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 478 auto min = [&](OpFoldResult v1, OpFoldResult v2) { 479 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); 480 }; 481 // Take the maximum of two integers. 482 auto max = [&](OpFoldResult v1, OpFoldResult v2) { 483 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); 484 }; 485 // Zero index-typed integer. 486 OpFoldResult zero = b.getIndexAttr(0); 487 488 // Compute new offsets, lengths, low padding, high padding. 489 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 490 SmallVector<OpFoldResult> newLows, newHighs; 491 // Set to true if the original data source is not read at all. 492 bool hasZeroLen = false; 493 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 494 // is true if the original data source turns out to be unused at runtime. 495 Value dynHasZeroLenCond; 496 497 int64_t rank = padOp.getSourceType().getRank(); 498 for (unsigned dim = 0; dim < rank; ++dim) { 499 auto low = padOp.getMixedLowPad()[dim]; 500 bool hasLowPad = !isConstantIntValue(low, 0); 501 auto high = padOp.getMixedHighPad()[dim]; 502 bool hasHighPad = !isConstantIntValue(high, 0); 503 auto offset = offsets[dim]; 504 auto length = sizes[dim]; 505 auto srcSize = 506 tensor::createDimValue(b, loc, padOp.getSource(), dim).value(); 507 508 // The new amount of low padding is `low - offset`. Except for the case 509 // where none of the low padding is read. In that case, the new amount of 510 // low padding is zero. 511 // 512 // Optimization: If low = 0, then newLow = 0. 513 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 514 newLows.push_back(newLow); 515 516 // Start reading the data from position `offset - low`. Since the original 517 // read may have started in the low padding zone, this value could be 518 // negative. Therefore, start reading from: 519 // 520 // max(offset - low, 0) 521 // 522 // The original read could also have started in the high padding zone. 523 // In that case, set the offset to the end of source tensor. The new 524 // ExtractSliceOp length will be zero in that case. (Effectively reading 525 // no data from the source.) 526 // 527 // Optimization: If low = 0, then the formula can be simplified. 528 OpFoldResult newOffset = hasLowPad 529 ? min(max(sub(offset, low), zero), srcSize) 530 : min(offset, srcSize); 531 newOffsets.push_back(newOffset); 532 533 // The original ExtractSliceOp was reading until position `offset + 534 // length`. Therefore, the corresponding position within the source tensor 535 // is: 536 // 537 // offset + length - low 538 // 539 // In case the original ExtractSliceOp stopped reading within the low 540 // padding zone, this value can be negative. In that case, the end 541 // position of the read should be zero. (Similar to newOffset.) 542 // 543 // The original read could also have stopped in the high padding zone. 544 // In that case, set the end positition of the read should be the end of 545 // the source tensor. (Similar to newOffset.) 546 // 547 // endLoc = min(max(offset - low + length, 0), srcSize) 548 // 549 // The new ExtractSliceOp length is `endLoc - newOffset`. 550 // 551 // Optimization: If low = 0, then the formula can be simplified. 552 OpFoldResult endLoc = 553 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) 554 : min(add(offset, length), srcSize); 555 OpFoldResult newLength = sub(endLoc, newOffset); 556 newLengths.push_back(newLength); 557 558 // Check if newLength is zero. In that case, no SubTensorOp should be 559 // executed. 560 if (isConstantIntValue(newLength, 0)) { 561 hasZeroLen = true; 562 } else if (!hasZeroLen) { 563 Value check = b.create<arith::CmpIOp>( 564 loc, arith::CmpIPredicate::eq, 565 getValueOrCreateConstantIndexOp(b, loc, newLength), 566 getValueOrCreateConstantIndexOp(b, loc, zero)); 567 dynHasZeroLenCond = 568 dynHasZeroLenCond 569 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 570 : check; 571 } 572 573 // The amount of high padding is simply the number of elements remaining, 574 // so that the result has the same length as the original ExtractSliceOp. 575 // As an optimization, if the original high padding is zero, then the new 576 // high padding must also be zero. 577 OpFoldResult newHigh = 578 hasHighPad ? sub(sub(length, newLength), newLow) : zero; 579 newHighs.push_back(newHigh); 580 581 // Only unit stride supported. 582 newStrides.push_back(b.getIndexAttr(1)); 583 } 584 585 // The shape of the result can be obtained from the sizes passed in. 586 SmallVector<Value> dynDims; 587 SmallVector<int64_t> shape; 588 dispatchIndexOpFoldResults(sizes, dynDims, shape); 589 RankedTensorType resultType = 590 RankedTensorType::get(shape, padOp.getResultType().getElementType()); 591 592 // Insert cast to ensure that types match. (May be folded away.) 593 auto castResult = [&](Value val) -> Value { 594 if (resultType == val.getType()) 595 return val; 596 return b.create<tensor::CastOp>(loc, resultType, val); 597 }; 598 599 // In cases where the original data source is unused: Emit a GenerateOp and 600 // do not generate a SliceOp. (The result shape of the SliceOp would 601 // have a dimension of size 0, the semantics of which is unclear.) 602 auto createGenerateOp = [&]() { 603 // Create GenerateOp. 604 auto generateOp = b.create<tensor::GenerateOp>( 605 loc, resultType, dynDims, 606 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 607 builder.create<tensor::YieldOp>(gLoc, padValue); 608 }); 609 return generateOp; 610 }; 611 612 // Emit a SliceOp and a PadOp. Should not be used in cases where 613 // the result shape of the new SliceOp has a zero dimension. 614 auto createPadOfExtractSlice = [&]() { 615 // Create pad(extract_slice(x)). 616 Value newSliceOp = b.create<tensor::ExtractSliceOp>( 617 loc, padOp.getSource(), newOffsets, newLengths, newStrides); 618 auto newPadOp = b.create<PadOp>( 619 loc, Type(), newSliceOp, newLows, newHighs, 620 /*nofold=*/padOp.getNofold(), 621 getPrunedAttributeList(padOp, PadOp::getAttributeNames())); 622 623 // Copy region to new PadOp. 624 IRMapping bvm; 625 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); 626 627 // Cast result and return. 628 return newPadOp; 629 }; 630 631 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that 632 // the original data source x is not used. 633 if (hasZeroLen) { 634 Operation *generateOp = createGenerateOp(); 635 return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; 636 } 637 638 // If there are dynamic dimensions: Generate an scf.if check to avoid 639 // creating SliceOps with result dimensions of size 0 at runtime. 640 if (generateZeroSliceGuard && dynHasZeroLenCond) { 641 Operation *thenOp; 642 Operation *elseOp; 643 auto result = b.create<scf::IfOp>( 644 loc, dynHasZeroLenCond, 645 /*thenBuilder=*/ 646 [&](OpBuilder &b, Location loc) { 647 thenOp = createGenerateOp(); 648 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); 649 }, 650 /*elseBuilder=*/ 651 [&](OpBuilder &b, Location loc) { 652 elseOp = createPadOfExtractSlice(); 653 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); 654 }); 655 return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())}; 656 } 657 658 Operation *newPadOp = createPadOfExtractSlice(); 659 return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; 660 } 661 662 void mlir::tensor::registerTilingInterfaceExternalModels( 663 DialectRegistry ®istry) { 664 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 665 tensor::PadOp::attachInterface<PadOpTiling>(*ctx); 666 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 667 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 668 }); 669 } 670 671 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( 672 DialectRegistry ®istry) { 673 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 674 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 675 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 676 }); 677 } 678