1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Affine/Utils.h" 12 #include "mlir/Dialect/Arith/Utils/Utils.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Utils/Utils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 21 22 using namespace mlir; 23 using namespace mlir::tensor; 24 25 namespace { 26 27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 28 29 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 30 auto padOp = cast<PadOp>(op); 31 SmallVector<utils::IteratorType> iteratorTypes( 32 padOp.getResultType().getRank(), utils::IteratorType::parallel); 33 return iteratorTypes; 34 } 35 36 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 37 ReifiedRankedShapedTypeDims reifiedShapes; 38 (void)reifyResultShapes(b, op, reifiedShapes); 39 OpFoldResult zero = b.getIndexAttr(0); 40 OpFoldResult one = b.getIndexAttr(1); 41 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 42 // overwritten. 43 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 44 for (const auto &ub : enumerate(reifiedShapes[0])) 45 loopRanges[ub.index()].size = ub.value(); 46 return loopRanges; 47 } 48 49 FailureOr<TilingResult> 50 getTiledImplementation(Operation *op, OpBuilder &b, 51 ArrayRef<OpFoldResult> offsets, 52 ArrayRef<OpFoldResult> sizes) const { 53 FailureOr<TilingResult> result = 54 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); 55 if (failed(result)) 56 return failure(); 57 return result.value(); 58 } 59 60 LogicalResult 61 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 62 ArrayRef<OpFoldResult> offsets, 63 ArrayRef<OpFoldResult> sizes, 64 SmallVector<OpFoldResult> &resultOffsets, 65 SmallVector<OpFoldResult> &resultSizes) const { 66 resultOffsets.assign(offsets.begin(), offsets.end()); 67 resultSizes.assign(sizes.begin(), sizes.end()); 68 return success(); 69 } 70 71 LogicalResult getIterationDomainTileFromResultTile( 72 Operation *op, OpBuilder &b, unsigned resultNumber, 73 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 74 SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 75 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 76 iterDomainOffsets.assign(offsets.begin(), offsets.end()); 77 iterDomainSizes.assign(sizes.begin(), sizes.end()); 78 return success(); 79 } 80 81 FailureOr<TilingResult> 82 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 83 ArrayRef<OpFoldResult> offsets, 84 ArrayRef<OpFoldResult> sizes) const { 85 return getTiledImplementation(op, b, offsets, sizes); 86 } 87 }; 88 89 template <typename OpTy> 90 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, 91 OpBuilder &builder) { 92 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 93 "applies to only pack or unpack operations"); 94 OpBuilder::InsertionGuard g(builder); 95 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() 96 : op.getDestRank(); 97 OpFoldResult zero = builder.getIndexAttr(0); 98 OpFoldResult one = builder.getIndexAttr(1); 99 ReifiedRankedShapedTypeDims resultShape; 100 (void)reifyResultShapes(builder, op, resultShape); 101 SmallVector<Range> loopBounds(rank); 102 for (auto dim : llvm::seq<int64_t>(0, rank)) { 103 loopBounds[dim].offset = zero; 104 loopBounds[dim].stride = one; 105 loopBounds[dim].size = resultShape[0][dim]; 106 } 107 return loopBounds; 108 } 109 110 static void applyPermToRange(SmallVector<OpFoldResult> &offsets, 111 SmallVector<OpFoldResult> &sizes, 112 ArrayRef<int64_t> permutation) { 113 if (permutation.empty()) 114 return; 115 applyPermutationToVector<OpFoldResult>(offsets, permutation); 116 applyPermutationToVector<OpFoldResult>(sizes, permutation); 117 } 118 119 struct PackOpTiling 120 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { 121 122 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 123 // Note that here we only consider untiled dimensions and outer tiled data 124 // dimensions, the inner tiled data dimensions are materialized when 125 // building the body of the operation. 126 auto packOp = cast<PackOp>(op); 127 SmallVector<utils::IteratorType> iteratorTypes( 128 packOp.getSourceRank(), utils::IteratorType::parallel); 129 return iteratorTypes; 130 } 131 132 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 133 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); 134 } 135 136 FailureOr<TilingResult> 137 getTiledImplementation(Operation *op, OpBuilder &b, 138 ArrayRef<OpFoldResult> offsets, 139 ArrayRef<OpFoldResult> sizes) const { 140 auto packOp = cast<PackOp>(op); 141 Location loc = packOp.getLoc(); 142 143 // The tiling is applied on interchanged dimensions. We have to undo the 144 // interchange to map sizes and offsets to the original input. 145 int64_t inputRank = packOp.getSourceRank(); 146 SmallVector<OpFoldResult> origOffsets(offsets); 147 SmallVector<OpFoldResult> origSizes(sizes); 148 applyPermToRange(origOffsets, origSizes, 149 invertPermutationVector(packOp.getOuterDimsPerm())); 150 151 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 152 packOp.getDimAndTileMapping(); 153 SmallVector<OpFoldResult> srcDimValues = 154 tensor::getMixedSizes(b, loc, packOp.getSource()); 155 SmallVector<OpFoldResult> inputIndices, inputSizes; 156 for (auto dim : llvm::seq<int64_t>(0, inputRank)) { 157 using AV = affine::AffineValueExpr; 158 affine::AffineBuilder ab(b, loc); 159 AffineExpr dim0, dim1, sym; 160 bindDims(b.getContext(), dim0, dim1); 161 bindSymbols(b.getContext(), sym); 162 if (dimAndTileMapping.count(dim)) { 163 // If the data dimension is tiled, the i-th index is the product of 164 // offset_i and tile_i, and the i-th size is the product of sizes_i and 165 // tile_i. 166 auto avOffset = AV(dim0).bind(origOffsets[dim]); 167 auto avSize = AV(dim0).bind(origSizes[dim]); 168 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 169 inputIndices.push_back(ab.mul(avOffset, avTileSize)); 170 inputSizes.push_back(ab.mul(avSize, avTileSize)); 171 } else { 172 inputIndices.push_back(origOffsets[dim]); 173 inputSizes.push_back(origSizes[dim]); 174 } 175 176 // Limit the size of the input operand for incomplete tiles. 177 if (packOp.getPaddingValue()) { 178 OpFoldResult dimSize = srcDimValues[dim]; 179 auto avDimSize = AV(dim0).bind(dimSize); 180 auto avInputIdx = AV(dim1).bind(inputIndices.back()); 181 inputSizes.back() = 182 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); 183 } 184 } 185 186 auto oneAttr = b.getI64IntegerAttr(1); 187 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 188 189 SmallVector<Value> tiledOperands; 190 auto sourceSlice = b.create<ExtractSliceOp>( 191 loc, packOp.getSource(), inputIndices, inputSizes, strides); 192 tiledOperands.push_back(sourceSlice); 193 194 SmallVector<OpFoldResult> outputOffsets, outputSizes; 195 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, 196 outputSizes))) 197 return {}; 198 199 strides.append(packOp.getDestRank() - inputRank, oneAttr); 200 auto outSlice = b.create<ExtractSliceOp>( 201 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 202 tiledOperands.push_back(outSlice); 203 204 if (auto val = packOp.getPaddingValue()) 205 tiledOperands.push_back(val); 206 for (auto tile : packOp.getInnerTiles()) 207 tiledOperands.push_back(tile); 208 209 Operation *tiledPackOp = b.create<PackOp>( 210 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); 211 212 return TilingResult{ 213 {tiledPackOp}, 214 SmallVector<Value>(tiledPackOp->getResults()), 215 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})}; 216 } 217 218 LogicalResult 219 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 220 ArrayRef<OpFoldResult> offsets, 221 ArrayRef<OpFoldResult> sizes, 222 SmallVector<OpFoldResult> &resultOffsets, 223 SmallVector<OpFoldResult> &resultSizes) const { 224 // The iteration domain is over outer dimensions of packed layout. In this 225 // context, the outer dimensions of `resultOffsets` are `offsets`. The 226 // inner dimensions of `resultOffsets` are zeros because tiling is not 227 // applied to them. 228 auto packOp = cast<PackOp>(op); 229 int64_t inputRank = packOp.getSourceRank(); 230 int64_t outputRank = packOp.getDestRank(); 231 auto zeroAttr = b.getI64IntegerAttr(0); 232 resultOffsets.assign(offsets.begin(), offsets.end()); 233 resultOffsets.append(outputRank - inputRank, zeroAttr); 234 235 ReifiedRankedShapedTypeDims outputShape; 236 (void)reifyResultShapes(b, packOp, outputShape); 237 resultSizes.assign(sizes.begin(), sizes.end()); 238 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) 239 resultSizes.push_back(outputShape[0][dataTileDim]); 240 241 return success(); 242 } 243 244 FailureOr<TilingResult> 245 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 246 ArrayRef<OpFoldResult> offsets, 247 ArrayRef<OpFoldResult> sizes) const { 248 auto packOp = cast<PackOp>(op); 249 int64_t numTiles = packOp.getInnerDimsPos().size(); 250 251 // tensor.pack op is fusible (as a producer) only if full inner tiles are 252 // iterated or inner dims are not tiled. Otherwise, it will generate a 253 // sequence of non-trivial ops (for partial tiles). 254 for (auto offset : offsets.take_back(numTiles)) 255 if (!isConstantIntValue(offset, 0)) 256 return failure(); 257 258 for (auto iter : 259 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) 260 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 261 return failure(); 262 263 FailureOr<TilingResult> tilingResult = getTiledImplementation( 264 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles)); 265 if (failed(tilingResult)) 266 return failure(); 267 return tilingResult.value(); 268 } 269 270 /// Method to return the position of iteration domain tile computed by the 271 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and 272 /// `resultSizes` only cover outer dimensions. 273 LogicalResult getIterationDomainTileFromOperandTile( 274 Operation *op, OpBuilder &b, unsigned operandNumber, 275 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 276 SmallVectorImpl<OpFoldResult> &resultOffsets, 277 SmallVectorImpl<OpFoldResult> &resultSizes) const { 278 if (operandNumber != 0) 279 return failure(); 280 281 auto packOp = cast<PackOp>(op); 282 // It is not trivial to infer dest tile from source tile if `packOp` has 283 // padding semantic. 284 if (packOp.getPaddingValue()) 285 return failure(); 286 287 Location loc = packOp.getLoc(); 288 289 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 290 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 291 packOp.getDimAndTileMapping(); 292 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) { 293 if (dimAndTileMapping.count(dim)) { 294 FailureOr<int64_t> cstSize = 295 ValueBoundsConstraintSet::computeConstantBound( 296 presburger::BoundType::UB, sizes[dim], 297 /*stopCondition=*/nullptr, /*closedUB=*/true); 298 std::optional<int64_t> cstInnerSize = 299 getConstantIntValue(dimAndTileMapping[dim]); 300 // Currently fusing `packOp` as consumer only expects perfect tiling 301 // scenario because even if without padding semantic, the `packOp` may 302 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, 303 // where the `tileSize` from operand of `packOp` is 5, which is not 304 // exactly divided by `innerTile`(=6) of `packOp`. As the result: 305 // 1. the first slice is extracted from (0) to (4) and inserted into 306 // (0,0)~(0,4) at first row. 307 // 2. the second slice is extracted from (5) to (9) and SHOULD BE 308 // respectively inserted into two rows with different length, including 309 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate 310 // them, thus adding below constraint to bypass them temporarily. In 311 // another word, we can only support tiling with consumer if the tile 312 // size for the producer is a multiple of the inner tile size for the 313 // packed dimensions at this moment. 314 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { 315 return failure(); 316 } 317 318 using AV = affine::AffineValueExpr; 319 affine::AffineBuilder ab(b, loc); 320 AffineExpr dim0, sym; 321 bindDims(b.getContext(), dim0); 322 bindSymbols(b.getContext(), sym); 323 auto avOffset = AV(dim0).bind(offsets[dim]); 324 auto avSize = AV(dim0).bind(sizes[dim]); 325 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 326 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize)); 327 outerDimSizes.push_back(ab.ceil(avSize, avTileSize)); 328 } else { 329 outerDimOffsets.push_back(offsets[dim]); 330 outerDimSizes.push_back(sizes[dim]); 331 } 332 } 333 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm()); 334 resultOffsets = outerDimOffsets; 335 resultSizes = outerDimSizes; 336 return success(); 337 } 338 339 /// Method to return the tiled implementation of tensor.pack as a consumer. 340 FailureOr<TilingResult> getTiledImplementationFromOperandTile( 341 Operation *op, OpBuilder &b, unsigned operandNumber, 342 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 343 if (operandNumber != 0) 344 return failure(); 345 346 auto packOp = cast<PackOp>(op); 347 Location loc = packOp.getLoc(); 348 349 int64_t inputRank = packOp.getSourceRank(); 350 auto oneAttr = b.getI64IntegerAttr(1); 351 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 352 353 SmallVector<Value> tiledOperands; 354 auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(), 355 offsets, sizes, strides); 356 tiledOperands.push_back(sourceSlice); 357 358 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 359 if (failed(getIterationDomainTileFromOperandTile( 360 op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets, 361 outerDimSizes))) 362 return failure(); 363 364 SmallVector<OpFoldResult> outputOffsets, outputSizes; 365 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes, 366 outputOffsets, outputSizes))) 367 return failure(); 368 369 strides.append(packOp.getDestRank() - inputRank, oneAttr); 370 auto outSlice = b.create<ExtractSliceOp>( 371 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 372 tiledOperands.push_back(outSlice); 373 374 assert(!packOp.getPaddingValue() && "Expect no padding semantic"); 375 for (auto tile : packOp.getInnerTiles()) 376 tiledOperands.push_back(tile); 377 378 Operation *tiledPackOp = b.create<PackOp>( 379 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); 380 381 return TilingResult{ 382 {tiledPackOp}, 383 SmallVector<Value>(tiledPackOp->getResults()), 384 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})}; 385 } 386 }; 387 388 struct UnpackTileDimInfo { 389 bool isAlignedToInnerTileSize; 390 OpFoldResult sourceOffset; 391 OpFoldResult sourceSize; 392 OpFoldResult resultOffset; 393 OpFoldResult destExpandedSize; 394 }; 395 396 /// Returns the needed information for tiling unpack op on `tileDim` with given 397 /// `tileOffset` and `tileSize`. For more details, see the comment of the 398 /// `getTiledImplementation`. 399 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, 400 int64_t tileDim, 401 OpFoldResult tileOffset, 402 OpFoldResult tileSize) { 403 UnpackTileDimInfo info; 404 Attribute zeroAttr = b.getIndexAttr(0); 405 Attribute oneAttr = b.getIndexAttr(1); 406 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 407 unpackOp.getDimAndTileMapping(); 408 // The dimension is not one of packed data dimension. 409 if (!dimAndTileMapping.count(tileDim)) { 410 info.isAlignedToInnerTileSize = true; 411 info.sourceOffset = tileOffset; 412 info.sourceSize = tileSize; 413 info.resultOffset = zeroAttr; 414 info.destExpandedSize = tileSize; 415 return info; 416 } 417 418 Location loc = unpackOp.getLoc(); 419 using AV = affine::AffineValueExpr; 420 affine::AffineBuilder ab(b, loc); 421 AffineExpr dim0, dim1, sym0; 422 bindDims(b.getContext(), dim0, dim1); 423 bindSymbols(b.getContext(), sym0); 424 425 OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; 426 427 info.isAlignedToInnerTileSize = false; 428 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( 429 presburger::BoundType::UB, tileSize, 430 /*stopCondition=*/nullptr, /*closedUB=*/true); 431 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); 432 if (!failed(cstSize) && cstInnerSize) { 433 if (*cstSize % *cstInnerSize == 0) 434 info.isAlignedToInnerTileSize = true; 435 436 // If the tiling size equals to the inner tiling size, the outer dims are 437 // always 1. 438 if (*cstInnerSize == *cstSize) { 439 auto lhs = AV(dim0).bind(tileOffset); 440 auto rhs = AV(dim1).bind(innerTileSize); 441 info.sourceOffset = ab.floor(lhs, rhs); 442 info.sourceSize = oneAttr; 443 info.resultOffset = zeroAttr; 444 info.destExpandedSize = tileSize; 445 return info; 446 } 447 } 448 449 if (info.isAlignedToInnerTileSize) { 450 info.sourceOffset = 451 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); 452 info.resultOffset = zeroAttr; 453 info.destExpandedSize = tileSize; 454 455 // The ceilDiv is needed here because there could be incomplete tile even 456 // it is perfect tiling cases. E.g., 457 // %0 = unpack tensor<33x2xf32> into tensor<64xf32> 458 // If the tiling size is 32, there will be 3 tiles. Two of them have 459 // size=32; one of them have size=2. The size is represented using 460 // affine_min op; we need ceilDiv. 461 info.sourceSize = 462 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); 463 return info; 464 } 465 466 affine::DivModValue firstCoord = affine::getDivMod( 467 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), 468 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 469 OpFoldResult tileExclusiveBound = 470 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); 471 affine::DivModValue lastCoord = affine::getDivMod( 472 b, loc, 473 getValueOrCreateConstantIndexOp( 474 b, loc, 475 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), 476 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 477 478 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), 479 AV(dim1).bind(firstCoord.quotient)); 480 info.sourceSize = 481 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); 482 info.sourceOffset = firstCoord.quotient; 483 info.resultOffset = firstCoord.remainder; 484 // Do not create an Affine ops for expanded size because the affine op is too 485 // complicated which would trigger an issue in affine ops simplification. 486 info.destExpandedSize = b.createOrFold<arith::MulIOp>( 487 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), 488 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 489 return info; 490 } 491 492 struct UnPackOpTiling 493 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { 494 495 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 496 auto unpackOp = cast<UnPackOp>(op); 497 SmallVector<utils::IteratorType> iteratorTypes( 498 unpackOp.getDestRank(), utils::IteratorType::parallel); 499 return iteratorTypes; 500 } 501 502 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 503 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); 504 } 505 506 /// There are two cases in tiling unpack ops. If the tiling size is aligned to 507 /// the inner tile size, the corresponding tiles of source are all complete. 508 /// Otherwise, there are in-complete tiles. We will need to expand the slice 509 /// of source for getting complete tiles. The tiled unpack op unpacks more 510 /// data from source, so We'll need an extract_slice op to shift and truncate 511 /// the output. 512 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The 513 /// coordinates of second tile (i.e., result[15..31]) are 514 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last 515 /// row are incomplete tiles. To represent the unpack op, we have to complete 516 /// the rows. I.e., the input coordinates would start with (1, 0); end with 517 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements 518 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we 519 /// can get the actual result. 520 FailureOr<TilingResult> 521 getTiledImplementation(Operation *op, OpBuilder &b, 522 ArrayRef<OpFoldResult> offsets, 523 ArrayRef<OpFoldResult> sizes) const { 524 auto unpackOp = cast<UnPackOp>(op); 525 int64_t srcRank = unpackOp.getSourceRank(); 526 int64_t destRank = unpackOp.getDestRank(); 527 int64_t numInnerTiles = srcRank - destRank; 528 Location loc = unpackOp.getLoc(); 529 530 // The perfect tiling case indicates that the tiling sizes are multiple of 531 // inner_tile_size. In this context, no extra data is needed when 532 // representing the tiled unpack op. 533 bool isPerfectTilingCase = true; 534 Attribute oneAttr = b.getIndexAttr(1); 535 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); 536 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; 537 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; 538 for (auto dim : llvm::seq<int64_t>(0, destRank)) { 539 UnpackTileDimInfo info = 540 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); 541 if (!info.isAlignedToInnerTileSize) 542 isPerfectTilingCase = false; 543 sliceSrcIndices.push_back(info.sourceOffset); 544 sliceSrcSizes.push_back(info.sourceSize); 545 destExpandedSizes.push_back(info.destExpandedSize); 546 resultOffsetsFromDest.push_back(info.resultOffset); 547 } 548 549 // The tiling is applied on destination dimensions. We have to apply the 550 // interchange on source dimensions if outer_dims_perm is set. 551 applyPermToRange(sliceSrcIndices, sliceSrcSizes, 552 unpackOp.getOuterDimsPerm()); 553 Attribute zeroAttr = b.getIndexAttr(0); 554 sliceSrcIndices.append(numInnerTiles, zeroAttr); 555 sliceSrcSizes.append(unpackOp.getMixedTiles()); 556 sliceSrcStrides.append(numInnerTiles, oneAttr); 557 Value sliceSource = 558 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, 559 sliceSrcSizes, sliceSrcStrides); 560 561 SmallVector<OpFoldResult> destStrides(destRank, oneAttr); 562 Value sliceDest; 563 SmallVector<Operation *> generatedSlices; 564 if (isPerfectTilingCase) { 565 auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), 566 offsets, sizes, destStrides); 567 sliceDest = destSliceOp; 568 generatedSlices.push_back(destSliceOp); 569 } else { 570 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, 571 unpackOp.getDestType().getElementType()); 572 } 573 574 SmallVector<Value> tiledOperands = {sliceSource, sliceDest}; 575 for (auto tile : unpackOp.getInnerTiles()) 576 tiledOperands.push_back(tile); 577 578 Operation *tiledUnpackOp = b.create<UnPackOp>( 579 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); 580 581 if (isPerfectTilingCase) 582 return TilingResult{{tiledUnpackOp}, 583 SmallVector<Value>(tiledUnpackOp->getResults()), 584 generatedSlices}; 585 586 auto extractSlice = 587 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), 588 resultOffsetsFromDest, sizes, destStrides); 589 generatedSlices.push_back(extractSlice); 590 return TilingResult{ 591 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; 592 } 593 594 LogicalResult 595 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 596 ArrayRef<OpFoldResult> offsets, 597 ArrayRef<OpFoldResult> sizes, 598 SmallVector<OpFoldResult> &resultOffsets, 599 SmallVector<OpFoldResult> &resultSizes) const { 600 resultOffsets = llvm::to_vector(offsets); 601 resultSizes = llvm::to_vector(sizes); 602 return success(); 603 } 604 605 FailureOr<TilingResult> 606 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 607 ArrayRef<OpFoldResult> offsets, 608 ArrayRef<OpFoldResult> sizes) const { 609 FailureOr<TilingResult> tilingResult = 610 getTiledImplementation(op, b, offsets, sizes); 611 if (failed(tilingResult)) 612 return failure(); 613 return tilingResult.value(); 614 } 615 616 /// Method to return the position of iteration domain tile computed by the 617 /// tiled operation. 618 LogicalResult getIterationDomainTileFromOperandTile( 619 Operation *op, OpBuilder &b, unsigned operandNumber, 620 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 621 SmallVectorImpl<OpFoldResult> &resultOffsets, 622 SmallVectorImpl<OpFoldResult> &resultSizes) const { 623 auto unPackOp = cast<UnPackOp>(op); 624 Location loc = unPackOp.getLoc(); 625 626 int64_t numTiles = unPackOp.getInnerDimsPos().size(); 627 auto destOffsets = offsets.drop_back(numTiles); 628 auto destSizes = sizes.drop_back(numTiles); 629 // The tiling is applied on interchanged dimensions. We have to undo the 630 // interchange to map sizes and offsets to the original input. 631 int64_t outputRank = unPackOp.getDestRank(); 632 SmallVector<OpFoldResult> origOffsets(destOffsets); 633 SmallVector<OpFoldResult> origSizes(destSizes); 634 applyPermToRange(origOffsets, origSizes, 635 invertPermutationVector(unPackOp.getOuterDimsPerm())); 636 637 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 638 unPackOp.getDimAndTileMapping(); 639 640 for (auto dim : llvm::seq<int64_t>(0, outputRank)) { 641 using AV = affine::AffineValueExpr; 642 affine::AffineBuilder ab(b, loc); 643 AffineExpr dim0, dim1, sym; 644 bindDims(b.getContext(), dim0, dim1); 645 bindSymbols(b.getContext(), sym); 646 if (dimAndTileMapping.count(dim)) { 647 // If the data dimension is tiled, the i-th index is the product of 648 // offset_i and tile_i, and the i-th size is the product of sizes_i and 649 // tile_i. 650 auto avOffset = AV(dim0).bind(origOffsets[dim]); 651 auto avSize = AV(dim0).bind(origSizes[dim]); 652 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 653 resultOffsets.push_back(ab.mul(avOffset, avTileSize)); 654 resultSizes.push_back(ab.mul(avSize, avTileSize)); 655 } else { 656 resultOffsets.push_back(origOffsets[dim]); 657 resultSizes.push_back(origSizes[dim]); 658 } 659 } 660 return success(); 661 } 662 663 /// Method to return the tiled implementation of tensor.unpack as a consumer. 664 FailureOr<TilingResult> getTiledImplementationFromOperandTile( 665 Operation *op, OpBuilder &b, unsigned operandNumber, 666 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 667 auto unPackOp = cast<UnPackOp>(op); 668 // tensor.unpack op is fusible (as a consumer) only if inner dims are not 669 // tiled. 670 int64_t numTiles = unPackOp.getInnerDimsPos().size(); 671 for (auto iter : 672 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) { 673 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 674 return failure(); 675 } 676 677 Location loc = unPackOp.getLoc(); 678 679 // Fetch offset/size for creating the slice of the dest operand of 680 // unpack op. 681 SmallVector<OpFoldResult> outputOffsets, outputSizes; 682 if (failed(getIterationDomainTileFromOperandTile( 683 op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets, 684 outputSizes))) 685 return failure(); 686 687 auto oneAttr = b.getI64IntegerAttr(1); 688 int64_t outputRank = unPackOp.getDestRank(); 689 SmallVector<OpFoldResult> strides(outputRank, oneAttr); 690 691 SmallVector<Value> tiledOperands; 692 // Create slice of the dest operand. 693 auto extractDestSlice = b.create<ExtractSliceOp>( 694 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); 695 tiledOperands.push_back(extractDestSlice); 696 697 SmallVector<OpFoldResult> inputOffsets, inputSizes; 698 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); 699 // Create slice of the source operand. 700 auto extractSourceSlice = b.create<ExtractSliceOp>( 701 loc, unPackOp.getSource(), offsets, sizes, strides); 702 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); 703 for (auto tile : unPackOp.getInnerTiles()) 704 tiledOperands.push_back(tile); 705 706 // Create tiled unpack op. 707 Operation *tiledUnPackOp = 708 b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()}, 709 tiledOperands, op->getAttrs()); 710 711 return TilingResult{{tiledUnPackOp}, 712 SmallVector<Value>(tiledUnPackOp->getResults()), 713 llvm::to_vector(ArrayRef<Operation *>{ 714 extractSourceSlice, extractDestSlice})}; 715 } 716 }; 717 718 } // namespace 719 720 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, 721 tensor::PadOp padOp, 722 ArrayRef<OpFoldResult> offsets, 723 ArrayRef<OpFoldResult> sizes, 724 bool generateZeroSliceGuard) { 725 // Only constant padding value supported. 726 Value padValue = padOp.getConstantPaddingValue(); 727 if (!padValue) 728 return failure(); 729 730 // Helper variables and functions for various arithmetic operations. These 731 // are used extensively for computing new offset/length and padding values. 732 Location loc = padOp->getLoc(); 733 AffineExpr dim0, dim1; 734 bindDims(b.getContext(), dim0, dim1); 735 // Add two integers. 736 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 737 auto add = [&](OpFoldResult v1, OpFoldResult v2) { 738 return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2}); 739 }; 740 // Subtract two integers. 741 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 742 auto sub = [&](OpFoldResult v1, OpFoldResult v2) { 743 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); 744 }; 745 // Take the minimum of two integers. 746 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 747 auto min = [&](OpFoldResult v1, OpFoldResult v2) { 748 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); 749 }; 750 // Take the maximum of two integers. 751 auto max = [&](OpFoldResult v1, OpFoldResult v2) { 752 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); 753 }; 754 // Zero index-typed integer. 755 OpFoldResult zero = b.getIndexAttr(0); 756 757 // Compute new offsets, lengths, low padding, high padding. 758 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 759 SmallVector<OpFoldResult> newLows, newHighs; 760 // Set to true if the original data source is not read at all. 761 bool hasZeroLen = false; 762 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 763 // is true if the original data source turns out to be unused at runtime. 764 Value dynHasZeroLenCond; 765 766 int64_t rank = padOp.getSourceType().getRank(); 767 for (unsigned dim = 0; dim < rank; ++dim) { 768 auto low = padOp.getMixedLowPad()[dim]; 769 bool hasLowPad = !isConstantIntValue(low, 0); 770 auto high = padOp.getMixedHighPad()[dim]; 771 bool hasHighPad = !isConstantIntValue(high, 0); 772 auto offset = offsets[dim]; 773 auto length = sizes[dim]; 774 auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim); 775 776 // The new amount of low padding is `low - offset`. Except for the case 777 // where none of the low padding is read. In that case, the new amount of 778 // low padding is zero. 779 // 780 // Optimization: If low = 0, then newLow = 0. 781 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 782 newLows.push_back(newLow); 783 784 // Start reading the data from position `offset - low`. Since the original 785 // read may have started in the low padding zone, this value could be 786 // negative. Therefore, start reading from: 787 // 788 // max(offset - low, 0) 789 // 790 // The original read could also have started in the high padding zone. 791 // In that case, set the offset to the end of source tensor. The new 792 // ExtractSliceOp length will be zero in that case. (Effectively reading 793 // no data from the source.) 794 // 795 // Optimization: If low = 0, then the formula can be simplified. 796 OpFoldResult newOffset = hasLowPad 797 ? min(max(sub(offset, low), zero), srcSize) 798 : min(offset, srcSize); 799 newOffsets.push_back(newOffset); 800 801 // The original ExtractSliceOp was reading until position `offset + 802 // length`. Therefore, the corresponding position within the source tensor 803 // is: 804 // 805 // offset + length - low 806 // 807 // In case the original ExtractSliceOp stopped reading within the low 808 // padding zone, this value can be negative. In that case, the end 809 // position of the read should be zero. (Similar to newOffset.) 810 // 811 // The original read could also have stopped in the high padding zone. 812 // In that case, set the end positition of the read should be the end of 813 // the source tensor. (Similar to newOffset.) 814 // 815 // endLoc = min(max(offset - low + length, 0), srcSize) 816 // 817 // The new ExtractSliceOp length is `endLoc - newOffset`. 818 // 819 // Optimization: If low = 0, then the formula can be simplified. 820 OpFoldResult endLoc = 821 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) 822 : min(add(offset, length), srcSize); 823 OpFoldResult newLength = sub(endLoc, newOffset); 824 newLengths.push_back(newLength); 825 826 // Check if newLength is zero. In that case, no SubTensorOp should be 827 // executed. 828 if (isConstantIntValue(newLength, 0)) { 829 hasZeroLen = true; 830 } else if (!hasZeroLen) { 831 Value check = b.create<arith::CmpIOp>( 832 loc, arith::CmpIPredicate::eq, 833 getValueOrCreateConstantIndexOp(b, loc, newLength), 834 getValueOrCreateConstantIndexOp(b, loc, zero)); 835 dynHasZeroLenCond = 836 dynHasZeroLenCond 837 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 838 : check; 839 } 840 841 // The amount of high padding is simply the number of elements remaining, 842 // so that the result has the same length as the original ExtractSliceOp. 843 // As an optimization, if the original high padding is zero, then the new 844 // high padding must also be zero. 845 OpFoldResult newHigh = 846 hasHighPad ? sub(sub(length, newLength), newLow) : zero; 847 newHighs.push_back(newHigh); 848 849 // Only unit stride supported. 850 newStrides.push_back(b.getIndexAttr(1)); 851 } 852 853 // The shape of the result can be obtained from the sizes passed in. 854 SmallVector<Value> dynDims; 855 SmallVector<int64_t> shape; 856 dispatchIndexOpFoldResults(sizes, dynDims, shape); 857 RankedTensorType resultType = 858 RankedTensorType::get(shape, padOp.getResultType().getElementType()); 859 860 // Insert cast to ensure that types match. (May be folded away.) 861 auto castResult = [&](Value val) -> Value { 862 if (resultType == val.getType()) 863 return val; 864 return b.create<tensor::CastOp>(loc, resultType, val); 865 }; 866 867 // In cases where the original data source is unused: Emit a GenerateOp and 868 // do not generate a SliceOp. (The result shape of the SliceOp would 869 // have a dimension of size 0, the semantics of which is unclear.) 870 auto createGenerateOp = [&]() { 871 // Create GenerateOp. 872 auto generateOp = b.create<tensor::GenerateOp>( 873 loc, resultType, dynDims, 874 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 875 builder.create<tensor::YieldOp>(gLoc, padValue); 876 }); 877 return generateOp; 878 }; 879 880 // Emit a SliceOp and a PadOp. Should not be used in cases where 881 // the result shape of the new SliceOp has a zero dimension. 882 auto createPadOfExtractSlice = [&]() { 883 // Create pad(extract_slice(x)). 884 auto newSliceOp = b.create<tensor::ExtractSliceOp>( 885 loc, padOp.getSource(), newOffsets, newLengths, newStrides); 886 auto newPadOp = b.create<PadOp>( 887 loc, Type(), newSliceOp, newLows, newHighs, 888 /*nofold=*/padOp.getNofold(), 889 getPrunedAttributeList(padOp, PadOp::getAttributeNames())); 890 891 // Copy region to new PadOp. 892 IRMapping bvm; 893 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); 894 895 // Cast result and return. 896 return std::make_tuple(newPadOp, newSliceOp); 897 }; 898 899 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that 900 // the original data source x is not used. 901 if (hasZeroLen) { 902 Operation *generateOp = createGenerateOp(); 903 return TilingResult{{generateOp}, 904 {castResult(generateOp->getResult(0))}, 905 /*generatedSlices=*/{}}; 906 } 907 908 // If there are dynamic dimensions: Generate an scf.if check to avoid 909 // creating SliceOps with result dimensions of size 0 at runtime. 910 if (generateZeroSliceGuard && dynHasZeroLenCond) { 911 Operation *thenOp; 912 Operation *elseOp; 913 Operation *sliceOp; 914 auto result = b.create<scf::IfOp>( 915 loc, dynHasZeroLenCond, 916 /*thenBuilder=*/ 917 [&](OpBuilder &b, Location loc) { 918 thenOp = createGenerateOp(); 919 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); 920 }, 921 /*elseBuilder=*/ 922 [&](OpBuilder &b, Location loc) { 923 std::tie(elseOp, sliceOp) = createPadOfExtractSlice(); 924 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); 925 }); 926 return TilingResult{ 927 {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}}; 928 } 929 930 auto [newPadOp, sliceOp] = createPadOfExtractSlice(); 931 return TilingResult{ 932 {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}}; 933 } 934 935 void mlir::tensor::registerTilingInterfaceExternalModels( 936 DialectRegistry ®istry) { 937 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 938 tensor::PadOp::attachInterface<PadOpTiling>(*ctx); 939 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 940 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 941 }); 942 } 943 944 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( 945 DialectRegistry ®istry) { 946 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 947 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 948 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 949 }); 950 } 951