1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Affine/Utils.h" 12 #include "mlir/Dialect/Arith/Utils/Utils.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Utils/Utils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 21 22 using namespace mlir; 23 using namespace mlir::tensor; 24 25 namespace { 26 27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 28 29 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 30 auto padOp = cast<PadOp>(op); 31 SmallVector<utils::IteratorType> iteratorTypes( 32 padOp.getResultType().getRank(), utils::IteratorType::parallel); 33 return iteratorTypes; 34 } 35 36 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 37 ReifiedRankedShapedTypeDims reifiedShapes; 38 (void)reifyResultShapes(b, op, reifiedShapes); 39 Location loc = op->getLoc(); 40 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 41 Value one = b.create<arith::ConstantIndexOp>(loc, 1); 42 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 43 // overwritten. 44 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 45 for (const auto &ub : enumerate(reifiedShapes[0])) 46 loopRanges[ub.index()].size = ub.value(); 47 return loopRanges; 48 } 49 50 FailureOr<TilingResult> 51 getTiledImplementation(Operation *op, OpBuilder &b, 52 ArrayRef<OpFoldResult> offsets, 53 ArrayRef<OpFoldResult> sizes) const { 54 FailureOr<TilingResult> result = 55 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); 56 if (failed(result)) 57 return failure(); 58 return result.value(); 59 } 60 61 LogicalResult 62 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 63 ArrayRef<OpFoldResult> offsets, 64 ArrayRef<OpFoldResult> sizes, 65 SmallVector<OpFoldResult> &resultOffsets, 66 SmallVector<OpFoldResult> &resultSizes) const { 67 resultOffsets.assign(offsets.begin(), offsets.end()); 68 resultSizes.assign(sizes.begin(), sizes.end()); 69 return success(); 70 } 71 }; 72 73 template <typename OpTy> 74 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, 75 OpBuilder &builder) { 76 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 77 "applies to only pack or unpack operations"); 78 OpBuilder::InsertionGuard g(builder); 79 Location loc = op.getLoc(); 80 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() 81 : op.getDestRank(); 82 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); 83 Value one = builder.create<arith::ConstantIndexOp>(loc, 1); 84 ReifiedRankedShapedTypeDims resultShape; 85 (void)reifyResultShapes(builder, op, resultShape); 86 SmallVector<Range> loopBounds(rank); 87 for (auto dim : llvm::seq<int64_t>(0, rank)) { 88 loopBounds[dim].offset = zero; 89 loopBounds[dim].stride = one; 90 loopBounds[dim].size = resultShape[0][dim]; 91 } 92 return loopBounds; 93 } 94 95 static void applyPermToRange(SmallVector<OpFoldResult> &offsets, 96 SmallVector<OpFoldResult> &sizes, 97 ArrayRef<int64_t> permutation) { 98 if (permutation.empty()) 99 return; 100 applyPermutationToVector<OpFoldResult>(offsets, permutation); 101 applyPermutationToVector<OpFoldResult>(sizes, permutation); 102 } 103 104 struct PackOpTiling 105 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { 106 107 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 108 // Note that here we only consider untiled dimensions and outer tiled data 109 // dimensions, the inner tiled data dimensions are materialized when 110 // building the body of the operation. 111 auto packOp = cast<PackOp>(op); 112 SmallVector<utils::IteratorType> iteratorTypes( 113 packOp.getSourceRank(), utils::IteratorType::parallel); 114 return iteratorTypes; 115 } 116 117 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 118 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); 119 } 120 121 FailureOr<TilingResult> 122 getTiledImplementation(Operation *op, OpBuilder &b, 123 ArrayRef<OpFoldResult> offsets, 124 ArrayRef<OpFoldResult> sizes) const { 125 auto packOp = cast<PackOp>(op); 126 Location loc = packOp.getLoc(); 127 128 // The tiling is applied on interchanged dimensions. We have to undo the 129 // interchange to map sizes and offsets to the original input. 130 int64_t inputRank = packOp.getSourceRank(); 131 SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end()); 132 SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end()); 133 applyPermToRange(origOffsets, origSizes, 134 invertPermutationVector(packOp.getOuterDimsPerm())); 135 136 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 137 packOp.getDimAndTileMapping(); 138 SmallVector<OpFoldResult> srcDimValues = 139 tensor::createDimValues(b, loc, packOp.getSource()); 140 SmallVector<OpFoldResult> inputIndices, inputSizes; 141 for (auto dim : llvm::seq<int64_t>(0, inputRank)) { 142 using AV = affine::AffineValueExpr; 143 affine::AffineBuilder ab(b, loc); 144 AffineExpr dim0, dim1, sym; 145 bindDims(b.getContext(), dim0, dim1); 146 bindSymbols(b.getContext(), sym); 147 if (dimAndTileMapping.count(dim)) { 148 // If the data dimension is tiled, the i-th index is the product of 149 // offset_i and tile_i, and the i-th size is the product of sizes_i and 150 // tile_i. 151 auto avOffset = AV(dim0).bind(origOffsets[dim]); 152 auto avSize = AV(dim0).bind(origSizes[dim]); 153 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 154 inputIndices.push_back(ab.mul(avOffset, avTileSize)); 155 inputSizes.push_back(ab.mul(avSize, avTileSize)); 156 } else { 157 inputIndices.push_back(origOffsets[dim]); 158 inputSizes.push_back(origSizes[dim]); 159 } 160 161 // Limit the size of the input operand for incomplete tiles. 162 if (packOp.getPaddingValue()) { 163 OpFoldResult dimSize = srcDimValues[dim]; 164 auto avDimSize = AV(dim0).bind(dimSize); 165 auto avInputIdx = AV(dim1).bind(inputIndices.back()); 166 inputSizes.back() = 167 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); 168 } 169 } 170 171 auto oneAttr = b.getI64IntegerAttr(1); 172 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 173 174 SmallVector<Value> tiledOperands; 175 tiledOperands.push_back(b.create<ExtractSliceOp>( 176 loc, packOp.getSource(), inputIndices, inputSizes, strides)); 177 178 SmallVector<OpFoldResult> outputOffsets, outputSizes; 179 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, 180 outputSizes))) 181 return {}; 182 183 strides.append(packOp.getDestRank() - inputRank, oneAttr); 184 auto extractSlice = b.create<ExtractSliceOp>( 185 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 186 tiledOperands.push_back(extractSlice); 187 188 if (auto val = packOp.getPaddingValue()) 189 tiledOperands.push_back(val); 190 for (auto tile : packOp.getInnerTiles()) 191 tiledOperands.push_back(tile); 192 193 Operation *tiledPackOp = b.create<PackOp>( 194 loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); 195 196 return TilingResult{{tiledPackOp}, 197 SmallVector<Value>(tiledPackOp->getResults())}; 198 } 199 200 LogicalResult 201 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 202 ArrayRef<OpFoldResult> offsets, 203 ArrayRef<OpFoldResult> sizes, 204 SmallVector<OpFoldResult> &resultOffsets, 205 SmallVector<OpFoldResult> &resultSizes) const { 206 // The iteration domain is over outer dimensions of packed layout. In this 207 // context, the outer dimensions of `resultOffsets` are `offsets`. The 208 // inner dimensions of `resultOffsets` are zeros because tiling is not 209 // applied to them. 210 auto packOp = cast<PackOp>(op); 211 int64_t inputRank = packOp.getSourceRank(); 212 int64_t outputRank = packOp.getDestRank(); 213 auto zeroAttr = b.getI64IntegerAttr(0); 214 resultOffsets.assign(offsets.begin(), offsets.end()); 215 resultOffsets.append(outputRank - inputRank, zeroAttr); 216 217 ReifiedRankedShapedTypeDims outputShape; 218 (void)reifyResultShapes(b, packOp, outputShape); 219 resultSizes.assign(sizes.begin(), sizes.end()); 220 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) 221 resultSizes.push_back(outputShape[0][dataTileDim]); 222 223 return success(); 224 } 225 }; 226 227 struct UnpackTileDimInfo { 228 bool isAlignedToInnerTileSize; 229 OpFoldResult sourceOffset; 230 OpFoldResult sourceSize; 231 OpFoldResult resultOffset; 232 OpFoldResult destExpandedSize; 233 }; 234 235 /// Returns the needed information for tiling unpack op on `tileDim` with given 236 /// `tileOffset` and `tileSize`. For more details, see the comment of the 237 /// `getTiledImplementation`. 238 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, 239 int64_t tileDim, 240 OpFoldResult tileOffset, 241 OpFoldResult tileSize) { 242 UnpackTileDimInfo info; 243 Attribute zeroAttr = b.getIndexAttr(0); 244 Attribute oneAttr = b.getIndexAttr(1); 245 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 246 unpackOp.getDimAndTileMapping(); 247 // The dimension is not one of packed data dimension. 248 if (!dimAndTileMapping.count(tileDim)) { 249 info.isAlignedToInnerTileSize = true; 250 info.sourceOffset = tileOffset; 251 info.sourceSize = tileSize; 252 info.resultOffset = zeroAttr; 253 info.destExpandedSize = tileSize; 254 return info; 255 } 256 257 Location loc = unpackOp.getLoc(); 258 using AV = affine::AffineValueExpr; 259 affine::AffineBuilder ab(b, loc); 260 AffineExpr dim0, dim1, sym0; 261 bindDims(b.getContext(), dim0, dim1); 262 bindSymbols(b.getContext(), sym0); 263 264 OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; 265 266 info.isAlignedToInnerTileSize = false; 267 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( 268 presburger::BoundType::UB, 269 getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt, 270 /*stopCondition=*/nullptr, /*closedUB=*/true); 271 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); 272 if (!failed(cstSize) && cstInnerSize) { 273 if (*cstSize % *cstInnerSize == 0) 274 info.isAlignedToInnerTileSize = true; 275 276 // If the tiling size equals to the inner tiling size, the outer dims are 277 // always 1. 278 if (*cstInnerSize == *cstSize) { 279 auto lhs = AV(dim0).bind(tileOffset); 280 auto rhs = AV(dim1).bind(innerTileSize); 281 info.sourceOffset = ab.floor(lhs, rhs); 282 info.sourceSize = oneAttr; 283 info.resultOffset = zeroAttr; 284 info.destExpandedSize = tileSize; 285 return info; 286 } 287 } 288 289 if (info.isAlignedToInnerTileSize) { 290 info.sourceOffset = 291 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); 292 info.resultOffset = zeroAttr; 293 info.destExpandedSize = tileSize; 294 295 // The ceilDiv is needed here because there could be incomplete tile even 296 // it is perfect tiling cases. E.g., 297 // %0 = unpack tensor<33x2xf32> into tensor<64xf32> 298 // If the tiling size is 32, there will be 3 tiles. Two of them have 299 // size=32; one of them have size=2. The size is represented using 300 // affine_min op; we need ceilDiv. 301 info.sourceSize = 302 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); 303 return info; 304 } 305 306 affine::DivModValue firstCoord = affine::getDivMod( 307 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), 308 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 309 OpFoldResult tileExclusiveBound = 310 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); 311 affine::DivModValue lastCoord = affine::getDivMod( 312 b, loc, 313 getValueOrCreateConstantIndexOp( 314 b, loc, 315 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), 316 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 317 318 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), 319 AV(dim1).bind(firstCoord.quotient)); 320 info.sourceSize = 321 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); 322 info.sourceOffset = firstCoord.quotient; 323 info.resultOffset = firstCoord.remainder; 324 // Do not create an Affine ops for expanded size because the affine op is too 325 // complicated which would trigger an issue in affine ops simplification. 326 info.destExpandedSize = b.createOrFold<arith::MulIOp>( 327 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), 328 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 329 return info; 330 } 331 332 struct UnPackOpTiling 333 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { 334 335 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 336 auto unpackOp = cast<UnPackOp>(op); 337 SmallVector<utils::IteratorType> iteratorTypes( 338 unpackOp.getDestRank(), utils::IteratorType::parallel); 339 return iteratorTypes; 340 } 341 342 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 343 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); 344 } 345 346 /// There are two cases in tiling unpack ops. If the tiling size is aligned to 347 /// the inner tile size, the corresponding tiles of source are all complete. 348 /// Otherwise, there are in-complete tiles. We will need to expand the slice 349 /// of source for getting complete tiles. The tiled unpack op unpacks more 350 /// data from source, so We'll need an extract_slice op to shift and truncate 351 /// the output. 352 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The 353 /// coordinates of second tile (i.e., result[15..31]) are 354 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last 355 /// row are incomplete tiles. To represent the unpack op, we have to complete 356 /// the rows. I.e., the input coordinates would start with (1, 0); end with 357 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements 358 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we 359 /// can get the actual result. 360 FailureOr<TilingResult> 361 getTiledImplementation(Operation *op, OpBuilder &b, 362 ArrayRef<OpFoldResult> offsets, 363 ArrayRef<OpFoldResult> sizes) const { 364 auto unpackOp = cast<UnPackOp>(op); 365 int64_t srcRank = unpackOp.getSourceRank(); 366 int64_t destRank = unpackOp.getDestRank(); 367 int64_t numInnerTiles = srcRank - destRank; 368 Location loc = unpackOp.getLoc(); 369 370 // The perfect tiling case indicates that the tiling sizes are multiple of 371 // inner_tile_size. In this context, no extra data is needed when 372 // representing the tiled unpack op. 373 bool isPerfectTilingCase = true; 374 Attribute oneAttr = b.getIndexAttr(1); 375 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); 376 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; 377 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; 378 for (auto dim : llvm::seq<int64_t>(0, destRank)) { 379 UnpackTileDimInfo info = 380 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); 381 if (!info.isAlignedToInnerTileSize) 382 isPerfectTilingCase = false; 383 sliceSrcIndices.push_back(info.sourceOffset); 384 sliceSrcSizes.push_back(info.sourceSize); 385 destExpandedSizes.push_back(info.destExpandedSize); 386 resultOffsetsFromDest.push_back(info.resultOffset); 387 } 388 389 // The tiling is applied on destination dimensions. We have to apply the 390 // interchange on source dimensions if outer_dims_perm is set. 391 applyPermToRange(sliceSrcIndices, sliceSrcSizes, 392 unpackOp.getOuterDimsPerm()); 393 Attribute zeroAttr = b.getIndexAttr(0); 394 sliceSrcIndices.append(numInnerTiles, zeroAttr); 395 sliceSrcSizes.append(unpackOp.getMixedTiles()); 396 sliceSrcStrides.append(numInnerTiles, oneAttr); 397 Value sliceSource = 398 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, 399 sliceSrcSizes, sliceSrcStrides); 400 401 SmallVector<OpFoldResult> destStrides(destRank, oneAttr); 402 Value sliceDest; 403 if (isPerfectTilingCase) { 404 sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets, 405 sizes, destStrides); 406 } else { 407 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, 408 unpackOp.getDestType().getElementType()); 409 } 410 411 SmallVector<Value> tiledOperands = {sliceSource, sliceDest}; 412 for (auto tile : unpackOp.getInnerTiles()) 413 tiledOperands.push_back(tile); 414 415 Operation *tiledUnpackOp = b.create<UnPackOp>( 416 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); 417 418 if (isPerfectTilingCase) 419 return TilingResult{{tiledUnpackOp}, 420 SmallVector<Value>(tiledUnpackOp->getResults())}; 421 422 auto extractSlice = 423 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), 424 resultOffsetsFromDest, sizes, destStrides); 425 return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; 426 } 427 428 LogicalResult 429 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 430 ArrayRef<OpFoldResult> offsets, 431 ArrayRef<OpFoldResult> sizes, 432 SmallVector<OpFoldResult> &resultOffsets, 433 SmallVector<OpFoldResult> &resultSizes) const { 434 resultOffsets = llvm::to_vector(offsets); 435 resultSizes = llvm::to_vector(sizes); 436 return success(); 437 } 438 439 FailureOr<TilingResult> 440 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 441 ArrayRef<OpFoldResult> offsets, 442 ArrayRef<OpFoldResult> sizes) const { 443 FailureOr<TilingResult> tilingResult = 444 getTiledImplementation(op, b, offsets, sizes); 445 if (failed(tilingResult)) 446 return failure(); 447 return tilingResult.value(); 448 } 449 }; 450 451 } // namespace 452 453 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, 454 tensor::PadOp padOp, 455 ArrayRef<OpFoldResult> offsets, 456 ArrayRef<OpFoldResult> sizes, 457 bool generateZeroSliceGuard) { 458 // Only constant padding value supported. 459 Value padValue = padOp.getConstantPaddingValue(); 460 if (!padValue) 461 return failure(); 462 463 // Helper variables and functions for various arithmetic operations. These 464 // are used extensively for computing new offset/length and padding values. 465 Location loc = padOp->getLoc(); 466 AffineExpr dim0, dim1; 467 bindDims(b.getContext(), dim0, dim1); 468 // Add two integers. 469 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 470 auto add = [&](OpFoldResult v1, OpFoldResult v2) { 471 return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2}); 472 }; 473 // Subtract two integers. 474 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 475 auto sub = [&](OpFoldResult v1, OpFoldResult v2) { 476 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); 477 }; 478 // Take the minimum of two integers. 479 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 480 auto min = [&](OpFoldResult v1, OpFoldResult v2) { 481 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); 482 }; 483 // Take the maximum of two integers. 484 auto max = [&](OpFoldResult v1, OpFoldResult v2) { 485 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); 486 }; 487 // Zero index-typed integer. 488 OpFoldResult zero = b.getIndexAttr(0); 489 490 // Compute new offsets, lengths, low padding, high padding. 491 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 492 SmallVector<OpFoldResult> newLows, newHighs; 493 // Set to true if the original data source is not read at all. 494 bool hasZeroLen = false; 495 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 496 // is true if the original data source turns out to be unused at runtime. 497 Value dynHasZeroLenCond; 498 499 int64_t rank = padOp.getSourceType().getRank(); 500 for (unsigned dim = 0; dim < rank; ++dim) { 501 auto low = padOp.getMixedLowPad()[dim]; 502 bool hasLowPad = !isConstantIntValue(low, 0); 503 auto high = padOp.getMixedHighPad()[dim]; 504 bool hasHighPad = !isConstantIntValue(high, 0); 505 auto offset = offsets[dim]; 506 auto length = sizes[dim]; 507 auto srcSize = 508 tensor::createDimValue(b, loc, padOp.getSource(), dim).value(); 509 510 // The new amount of low padding is `low - offset`. Except for the case 511 // where none of the low padding is read. In that case, the new amount of 512 // low padding is zero. 513 // 514 // Optimization: If low = 0, then newLow = 0. 515 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 516 newLows.push_back(newLow); 517 518 // Start reading the data from position `offset - low`. Since the original 519 // read may have started in the low padding zone, this value could be 520 // negative. Therefore, start reading from: 521 // 522 // max(offset - low, 0) 523 // 524 // The original read could also have started in the high padding zone. 525 // In that case, set the offset to the end of source tensor. The new 526 // ExtractSliceOp length will be zero in that case. (Effectively reading 527 // no data from the source.) 528 // 529 // Optimization: If low = 0, then the formula can be simplified. 530 OpFoldResult newOffset = hasLowPad 531 ? min(max(sub(offset, low), zero), srcSize) 532 : min(offset, srcSize); 533 newOffsets.push_back(newOffset); 534 535 // The original ExtractSliceOp was reading until position `offset + 536 // length`. Therefore, the corresponding position within the source tensor 537 // is: 538 // 539 // offset + length - low 540 // 541 // In case the original ExtractSliceOp stopped reading within the low 542 // padding zone, this value can be negative. In that case, the end 543 // position of the read should be zero. (Similar to newOffset.) 544 // 545 // The original read could also have stopped in the high padding zone. 546 // In that case, set the end positition of the read should be the end of 547 // the source tensor. (Similar to newOffset.) 548 // 549 // endLoc = min(max(offset - low + length, 0), srcSize) 550 // 551 // The new ExtractSliceOp length is `endLoc - newOffset`. 552 // 553 // Optimization: If low = 0, then the formula can be simplified. 554 OpFoldResult endLoc = 555 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) 556 : min(add(offset, length), srcSize); 557 OpFoldResult newLength = sub(endLoc, newOffset); 558 newLengths.push_back(newLength); 559 560 // Check if newLength is zero. In that case, no SubTensorOp should be 561 // executed. 562 if (isConstantIntValue(newLength, 0)) { 563 hasZeroLen = true; 564 } else if (!hasZeroLen) { 565 Value check = b.create<arith::CmpIOp>( 566 loc, arith::CmpIPredicate::eq, 567 getValueOrCreateConstantIndexOp(b, loc, newLength), 568 getValueOrCreateConstantIndexOp(b, loc, zero)); 569 dynHasZeroLenCond = 570 dynHasZeroLenCond 571 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 572 : check; 573 } 574 575 // The amount of high padding is simply the number of elements remaining, 576 // so that the result has the same length as the original ExtractSliceOp. 577 // As an optimization, if the original high padding is zero, then the new 578 // high padding must also be zero. 579 OpFoldResult newHigh = 580 hasHighPad ? sub(sub(length, newLength), newLow) : zero; 581 newHighs.push_back(newHigh); 582 583 // Only unit stride supported. 584 newStrides.push_back(b.getIndexAttr(1)); 585 } 586 587 // The shape of the result can be obtained from the sizes passed in. 588 SmallVector<Value> dynDims; 589 SmallVector<int64_t> shape; 590 dispatchIndexOpFoldResults(sizes, dynDims, shape); 591 RankedTensorType resultType = 592 RankedTensorType::get(shape, padOp.getResultType().getElementType()); 593 594 // Insert cast to ensure that types match. (May be folded away.) 595 auto castResult = [&](Value val) -> Value { 596 if (resultType == val.getType()) 597 return val; 598 return b.create<tensor::CastOp>(loc, resultType, val); 599 }; 600 601 // In cases where the original data source is unused: Emit a GenerateOp and 602 // do not generate a SliceOp. (The result shape of the SliceOp would 603 // have a dimension of size 0, the semantics of which is unclear.) 604 auto createGenerateOp = [&]() { 605 // Create GenerateOp. 606 auto generateOp = b.create<tensor::GenerateOp>( 607 loc, resultType, dynDims, 608 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 609 builder.create<tensor::YieldOp>(gLoc, padValue); 610 }); 611 return generateOp; 612 }; 613 614 // Emit a SliceOp and a PadOp. Should not be used in cases where 615 // the result shape of the new SliceOp has a zero dimension. 616 auto createPadOfExtractSlice = [&]() { 617 // Create pad(extract_slice(x)). 618 Value newSliceOp = b.create<tensor::ExtractSliceOp>( 619 loc, padOp.getSource(), newOffsets, newLengths, newStrides); 620 auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs, 621 /*nofold=*/padOp.getNofold()); 622 623 // Copy region to new PadOp. 624 IRMapping bvm; 625 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); 626 627 // Cast result and return. 628 return newPadOp; 629 }; 630 631 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that 632 // the original data source x is not used. 633 if (hasZeroLen) { 634 Operation *generateOp = createGenerateOp(); 635 return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; 636 } 637 638 // If there are dynamic dimensions: Generate an scf.if check to avoid 639 // creating SliceOps with result dimensions of size 0 at runtime. 640 if (generateZeroSliceGuard && dynHasZeroLenCond) { 641 Operation *thenOp; 642 Operation *elseOp; 643 auto result = b.create<scf::IfOp>( 644 loc, dynHasZeroLenCond, 645 /*thenBuilder=*/ 646 [&](OpBuilder &b, Location loc) { 647 thenOp = createGenerateOp(); 648 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); 649 }, 650 /*elseBuilder=*/ 651 [&](OpBuilder &b, Location loc) { 652 elseOp = createPadOfExtractSlice(); 653 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); 654 }); 655 return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())}; 656 } 657 658 Operation *newPadOp = createPadOfExtractSlice(); 659 return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; 660 } 661 662 void mlir::tensor::registerTilingInterfaceExternalModels( 663 DialectRegistry ®istry) { 664 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 665 tensor::PadOp::attachInterface<PadOpTiling>(*ctx); 666 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 667 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 668 }); 669 } 670 671 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( 672 DialectRegistry ®istry) { 673 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 674 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 675 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 676 }); 677 } 678