1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Affine/Utils.h" 12 #include "mlir/Dialect/Arith/Utils/Utils.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Utils/Utils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Interfaces/InferTypeOpInterface.h" 20 #include "mlir/Interfaces/TilingInterface.h" 21 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 22 23 using namespace mlir; 24 using namespace mlir::tensor; 25 26 namespace { 27 28 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 29 30 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 31 auto padOp = cast<PadOp>(op); 32 SmallVector<utils::IteratorType> iteratorTypes( 33 padOp.getResultType().getRank(), utils::IteratorType::parallel); 34 return iteratorTypes; 35 } 36 37 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 38 ReifiedRankedShapedTypeDims reifiedShapes; 39 (void)reifyResultShapes(b, op, reifiedShapes); 40 OpFoldResult zero = b.getIndexAttr(0); 41 OpFoldResult one = b.getIndexAttr(1); 42 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 43 // overwritten. 44 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 45 for (const auto &ub : enumerate(reifiedShapes[0])) 46 loopRanges[ub.index()].size = ub.value(); 47 return loopRanges; 48 } 49 50 FailureOr<TilingResult> 51 getTiledImplementation(Operation *op, OpBuilder &b, 52 ArrayRef<OpFoldResult> offsets, 53 ArrayRef<OpFoldResult> sizes) const { 54 FailureOr<TilingResult> result = 55 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); 56 if (failed(result)) 57 return failure(); 58 return result.value(); 59 } 60 61 LogicalResult 62 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 63 ArrayRef<OpFoldResult> offsets, 64 ArrayRef<OpFoldResult> sizes, 65 SmallVector<OpFoldResult> &resultOffsets, 66 SmallVector<OpFoldResult> &resultSizes) const { 67 resultOffsets.assign(offsets.begin(), offsets.end()); 68 resultSizes.assign(sizes.begin(), sizes.end()); 69 return success(); 70 } 71 72 LogicalResult getIterationDomainTileFromResultTile( 73 Operation *op, OpBuilder &b, unsigned resultNumber, 74 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 75 SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 76 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 77 iterDomainOffsets.assign(offsets.begin(), offsets.end()); 78 iterDomainSizes.assign(sizes.begin(), sizes.end()); 79 return success(); 80 } 81 82 FailureOr<TilingResult> 83 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 84 ArrayRef<OpFoldResult> offsets, 85 ArrayRef<OpFoldResult> sizes) const { 86 return getTiledImplementation(op, b, offsets, sizes); 87 } 88 }; 89 90 template <typename OpTy> 91 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, 92 OpBuilder &builder) { 93 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 94 "applies to only pack or unpack operations"); 95 OpBuilder::InsertionGuard g(builder); 96 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() 97 : op.getDestRank(); 98 OpFoldResult zero = builder.getIndexAttr(0); 99 OpFoldResult one = builder.getIndexAttr(1); 100 ReifiedRankedShapedTypeDims resultShape; 101 (void)reifyResultShapes(builder, op, resultShape); 102 SmallVector<Range> loopBounds(rank); 103 for (auto dim : llvm::seq<int64_t>(0, rank)) { 104 loopBounds[dim].offset = zero; 105 loopBounds[dim].stride = one; 106 loopBounds[dim].size = resultShape[0][dim]; 107 } 108 return loopBounds; 109 } 110 111 static void applyPermToRange(SmallVector<OpFoldResult> &offsets, 112 SmallVector<OpFoldResult> &sizes, 113 ArrayRef<int64_t> permutation) { 114 if (permutation.empty()) 115 return; 116 applyPermutationToVector<OpFoldResult>(offsets, permutation); 117 applyPermutationToVector<OpFoldResult>(sizes, permutation); 118 } 119 120 struct PackOpTiling 121 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { 122 123 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 124 // Note that here we only consider untiled dimensions and outer tiled data 125 // dimensions, the inner tiled data dimensions are materialized when 126 // building the body of the operation. 127 auto packOp = cast<PackOp>(op); 128 SmallVector<utils::IteratorType> iteratorTypes( 129 packOp.getSourceRank(), utils::IteratorType::parallel); 130 return iteratorTypes; 131 } 132 133 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 134 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); 135 } 136 137 FailureOr<TilingResult> 138 getTiledImplementation(Operation *op, OpBuilder &b, 139 ArrayRef<OpFoldResult> offsets, 140 ArrayRef<OpFoldResult> sizes) const { 141 auto packOp = cast<PackOp>(op); 142 Location loc = packOp.getLoc(); 143 144 // The tiling is applied on interchanged dimensions. We have to undo the 145 // interchange to map sizes and offsets to the original input. 146 int64_t inputRank = packOp.getSourceRank(); 147 SmallVector<OpFoldResult> origOffsets(offsets); 148 SmallVector<OpFoldResult> origSizes(sizes); 149 applyPermToRange(origOffsets, origSizes, 150 invertPermutationVector(packOp.getOuterDimsPerm())); 151 152 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 153 packOp.getDimAndTileMapping(); 154 SmallVector<OpFoldResult> srcDimValues = 155 tensor::getMixedSizes(b, loc, packOp.getSource()); 156 SmallVector<OpFoldResult> inputIndices, inputSizes; 157 for (auto dim : llvm::seq<int64_t>(0, inputRank)) { 158 using AV = affine::AffineValueExpr; 159 affine::AffineBuilder ab(b, loc); 160 AffineExpr dim0, dim1, sym; 161 bindDims(b.getContext(), dim0, dim1); 162 bindSymbols(b.getContext(), sym); 163 if (dimAndTileMapping.count(dim)) { 164 // If the data dimension is tiled, the i-th index is the product of 165 // offset_i and tile_i, and the i-th size is the product of sizes_i and 166 // tile_i. 167 auto avOffset = AV(dim0).bind(origOffsets[dim]); 168 auto avSize = AV(dim0).bind(origSizes[dim]); 169 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 170 inputIndices.push_back(ab.mul(avOffset, avTileSize)); 171 inputSizes.push_back(ab.mul(avSize, avTileSize)); 172 } else { 173 inputIndices.push_back(origOffsets[dim]); 174 inputSizes.push_back(origSizes[dim]); 175 } 176 177 // Limit the size of the input operand for incomplete tiles. 178 if (packOp.getPaddingValue()) { 179 OpFoldResult dimSize = srcDimValues[dim]; 180 auto avDimSize = AV(dim0).bind(dimSize); 181 auto avInputIdx = AV(dim1).bind(inputIndices.back()); 182 inputSizes.back() = 183 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); 184 } 185 } 186 187 auto oneAttr = b.getI64IntegerAttr(1); 188 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 189 190 SmallVector<Value> tiledOperands; 191 auto sourceSlice = b.create<ExtractSliceOp>( 192 loc, packOp.getSource(), inputIndices, inputSizes, strides); 193 tiledOperands.push_back(sourceSlice); 194 195 SmallVector<OpFoldResult> outputOffsets, outputSizes; 196 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, 197 outputSizes))) 198 return {}; 199 200 strides.append(packOp.getDestRank() - inputRank, oneAttr); 201 auto outSlice = b.create<ExtractSliceOp>( 202 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 203 tiledOperands.push_back(outSlice); 204 205 if (auto val = packOp.getPaddingValue()) 206 tiledOperands.push_back(val); 207 for (auto tile : packOp.getInnerTiles()) 208 tiledOperands.push_back(tile); 209 210 Operation *tiledPackOp = b.create<PackOp>( 211 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); 212 213 return TilingResult{ 214 {tiledPackOp}, 215 SmallVector<Value>(tiledPackOp->getResults()), 216 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})}; 217 } 218 219 LogicalResult 220 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 221 ArrayRef<OpFoldResult> offsets, 222 ArrayRef<OpFoldResult> sizes, 223 SmallVector<OpFoldResult> &resultOffsets, 224 SmallVector<OpFoldResult> &resultSizes) const { 225 // The iteration domain is over outer dimensions of packed layout. In this 226 // context, the outer dimensions of `resultOffsets` are `offsets`. The 227 // inner dimensions of `resultOffsets` are zeros because tiling is not 228 // applied to them. 229 auto packOp = cast<PackOp>(op); 230 int64_t inputRank = packOp.getSourceRank(); 231 int64_t outputRank = packOp.getDestRank(); 232 auto zeroAttr = b.getI64IntegerAttr(0); 233 resultOffsets.assign(offsets.begin(), offsets.end()); 234 resultOffsets.append(outputRank - inputRank, zeroAttr); 235 236 ReifiedRankedShapedTypeDims outputShape; 237 (void)reifyResultShapes(b, packOp, outputShape); 238 resultSizes.assign(sizes.begin(), sizes.end()); 239 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) 240 resultSizes.push_back(outputShape[0][dataTileDim]); 241 242 return success(); 243 } 244 245 FailureOr<TilingResult> 246 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 247 ArrayRef<OpFoldResult> offsets, 248 ArrayRef<OpFoldResult> sizes) const { 249 auto packOp = cast<PackOp>(op); 250 int64_t numTiles = packOp.getInnerDimsPos().size(); 251 252 // tensor.pack op is fusible (as a producer) only if full inner tiles are 253 // iterated or inner dims are not tiled. Otherwise, it will generate a 254 // sequence of non-trivial ops (for partial tiles). 255 for (auto offset : offsets.take_back(numTiles)) 256 if (!isConstantIntValue(offset, 0)) 257 return failure(); 258 259 for (auto iter : 260 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) 261 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 262 return failure(); 263 264 FailureOr<TilingResult> tilingResult = getTiledImplementation( 265 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles)); 266 if (failed(tilingResult)) 267 return failure(); 268 return tilingResult.value(); 269 } 270 271 /// Method to return the position of iteration domain tile computed by the 272 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and 273 /// `resultSizes` only cover outer dimensions. 274 LogicalResult getIterationDomainTileFromOperandTile( 275 Operation *op, OpBuilder &b, unsigned operandNumber, 276 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 277 SmallVectorImpl<OpFoldResult> &resultOffsets, 278 SmallVectorImpl<OpFoldResult> &resultSizes) const { 279 if (operandNumber != 0) 280 return failure(); 281 282 auto packOp = cast<PackOp>(op); 283 // It is not trivial to infer dest tile from source tile if `packOp` has 284 // padding semantic. 285 if (packOp.getPaddingValue()) 286 return failure(); 287 288 Location loc = packOp.getLoc(); 289 290 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 291 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 292 packOp.getDimAndTileMapping(); 293 for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) { 294 if (dimAndTileMapping.count(dim)) { 295 FailureOr<int64_t> cstSize = 296 ValueBoundsConstraintSet::computeConstantBound( 297 presburger::BoundType::UB, sizes[dim], 298 /*stopCondition=*/nullptr, /*closedUB=*/true); 299 std::optional<int64_t> cstInnerSize = 300 getConstantIntValue(dimAndTileMapping[dim]); 301 // Currently fusing `packOp` as consumer only expects perfect tiling 302 // scenario because even if without padding semantic, the `packOp` may 303 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, 304 // where the `tileSize` from operand of `packOp` is 5, which is not 305 // exactly divided by `innerTile`(=6) of `packOp`. As the result: 306 // 1. the first slice is extracted from (0) to (4) and inserted into 307 // (0,0)~(0,4) at first row. 308 // 2. the second slice is extracted from (5) to (9) and SHOULD BE 309 // respectively inserted into two rows with different length, including 310 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate 311 // them, thus adding below constraint to bypass them temporarily. In 312 // another word, we can only support tiling with consumer if the tile 313 // size for the producer is a multiple of the inner tile size for the 314 // packed dimensions at this moment. 315 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { 316 return failure(); 317 } 318 319 using AV = affine::AffineValueExpr; 320 affine::AffineBuilder ab(b, loc); 321 AffineExpr dim0, sym; 322 bindDims(b.getContext(), dim0); 323 bindSymbols(b.getContext(), sym); 324 auto avOffset = AV(dim0).bind(offsets[dim]); 325 auto avSize = AV(dim0).bind(sizes[dim]); 326 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 327 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize)); 328 outerDimSizes.push_back(ab.ceil(avSize, avTileSize)); 329 } else { 330 outerDimOffsets.push_back(offsets[dim]); 331 outerDimSizes.push_back(sizes[dim]); 332 } 333 } 334 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm()); 335 resultOffsets = outerDimOffsets; 336 resultSizes = outerDimSizes; 337 return success(); 338 } 339 340 /// Method to return the tiled implementation of tensor.pack as a consumer. 341 FailureOr<TilingResult> getTiledImplementationFromOperandTile( 342 Operation *op, OpBuilder &b, unsigned operandNumber, 343 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 344 if (operandNumber != 0) 345 return failure(); 346 347 auto packOp = cast<PackOp>(op); 348 Location loc = packOp.getLoc(); 349 350 int64_t inputRank = packOp.getSourceRank(); 351 auto oneAttr = b.getI64IntegerAttr(1); 352 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 353 354 SmallVector<Value> tiledOperands; 355 auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(), 356 offsets, sizes, strides); 357 tiledOperands.push_back(sourceSlice); 358 359 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 360 if (failed(getIterationDomainTileFromOperandTile( 361 op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets, 362 outerDimSizes))) 363 return failure(); 364 365 SmallVector<OpFoldResult> outputOffsets, outputSizes; 366 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes, 367 outputOffsets, outputSizes))) 368 return failure(); 369 370 strides.append(packOp.getDestRank() - inputRank, oneAttr); 371 auto outSlice = b.create<ExtractSliceOp>( 372 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 373 tiledOperands.push_back(outSlice); 374 375 assert(!packOp.getPaddingValue() && "Expect no padding semantic"); 376 for (auto tile : packOp.getInnerTiles()) 377 tiledOperands.push_back(tile); 378 379 Operation *tiledPackOp = b.create<PackOp>( 380 loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); 381 382 return TilingResult{ 383 {tiledPackOp}, 384 SmallVector<Value>(tiledPackOp->getResults()), 385 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})}; 386 } 387 }; 388 389 struct UnpackTileDimInfo { 390 bool isAlignedToInnerTileSize; 391 OpFoldResult sourceOffset; 392 OpFoldResult sourceSize; 393 OpFoldResult resultOffset; 394 OpFoldResult destExpandedSize; 395 }; 396 397 /// Returns the needed information for tiling unpack op on `tileDim` with given 398 /// `tileOffset` and `tileSize`. For more details, see the comment of the 399 /// `getTiledImplementation`. 400 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, 401 int64_t tileDim, 402 OpFoldResult tileOffset, 403 OpFoldResult tileSize) { 404 UnpackTileDimInfo info; 405 Attribute zeroAttr = b.getIndexAttr(0); 406 Attribute oneAttr = b.getIndexAttr(1); 407 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 408 unpackOp.getDimAndTileMapping(); 409 // The dimension is not one of packed data dimension. 410 if (!dimAndTileMapping.count(tileDim)) { 411 info.isAlignedToInnerTileSize = true; 412 info.sourceOffset = tileOffset; 413 info.sourceSize = tileSize; 414 info.resultOffset = zeroAttr; 415 info.destExpandedSize = tileSize; 416 return info; 417 } 418 419 Location loc = unpackOp.getLoc(); 420 using AV = affine::AffineValueExpr; 421 affine::AffineBuilder ab(b, loc); 422 AffineExpr dim0, dim1, sym0; 423 bindDims(b.getContext(), dim0, dim1); 424 bindSymbols(b.getContext(), sym0); 425 426 OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; 427 428 info.isAlignedToInnerTileSize = false; 429 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( 430 presburger::BoundType::UB, tileSize, 431 /*stopCondition=*/nullptr, /*closedUB=*/true); 432 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); 433 if (!failed(cstSize) && cstInnerSize) { 434 if (*cstSize % *cstInnerSize == 0) 435 info.isAlignedToInnerTileSize = true; 436 437 // If the tiling size equals to the inner tiling size, the outer dims are 438 // always 1. 439 if (*cstInnerSize == *cstSize) { 440 auto lhs = AV(dim0).bind(tileOffset); 441 auto rhs = AV(dim1).bind(innerTileSize); 442 info.sourceOffset = ab.floor(lhs, rhs); 443 info.sourceSize = oneAttr; 444 info.resultOffset = zeroAttr; 445 info.destExpandedSize = tileSize; 446 return info; 447 } 448 } 449 450 if (info.isAlignedToInnerTileSize) { 451 info.sourceOffset = 452 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); 453 info.resultOffset = zeroAttr; 454 info.destExpandedSize = tileSize; 455 456 // The ceilDiv is needed here because there could be incomplete tile even 457 // it is perfect tiling cases. E.g., 458 // %0 = unpack tensor<33x2xf32> into tensor<64xf32> 459 // If the tiling size is 32, there will be 3 tiles. Two of them have 460 // size=32; one of them have size=2. The size is represented using 461 // affine_min op; we need ceilDiv. 462 info.sourceSize = 463 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); 464 return info; 465 } 466 467 affine::DivModValue firstCoord = affine::getDivMod( 468 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), 469 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 470 OpFoldResult tileExclusiveBound = 471 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); 472 affine::DivModValue lastCoord = affine::getDivMod( 473 b, loc, 474 getValueOrCreateConstantIndexOp( 475 b, loc, 476 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), 477 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 478 479 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), 480 AV(dim1).bind(firstCoord.quotient)); 481 info.sourceSize = 482 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); 483 info.sourceOffset = firstCoord.quotient; 484 info.resultOffset = firstCoord.remainder; 485 // Do not create an Affine ops for expanded size because the affine op is too 486 // complicated which would trigger an issue in affine ops simplification. 487 info.destExpandedSize = b.createOrFold<arith::MulIOp>( 488 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), 489 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 490 return info; 491 } 492 493 struct UnPackOpTiling 494 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { 495 496 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 497 auto unpackOp = cast<UnPackOp>(op); 498 SmallVector<utils::IteratorType> iteratorTypes( 499 unpackOp.getDestRank(), utils::IteratorType::parallel); 500 return iteratorTypes; 501 } 502 503 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 504 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); 505 } 506 507 /// There are two cases in tiling unpack ops. If the tiling size is aligned to 508 /// the inner tile size, the corresponding tiles of source are all complete. 509 /// Otherwise, there are in-complete tiles. We will need to expand the slice 510 /// of source for getting complete tiles. The tiled unpack op unpacks more 511 /// data from source, so We'll need an extract_slice op to shift and truncate 512 /// the output. 513 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The 514 /// coordinates of second tile (i.e., result[15..31]) are 515 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last 516 /// row are incomplete tiles. To represent the unpack op, we have to complete 517 /// the rows. I.e., the input coordinates would start with (1, 0); end with 518 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements 519 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we 520 /// can get the actual result. 521 FailureOr<TilingResult> 522 getTiledImplementation(Operation *op, OpBuilder &b, 523 ArrayRef<OpFoldResult> offsets, 524 ArrayRef<OpFoldResult> sizes) const { 525 auto unpackOp = cast<UnPackOp>(op); 526 int64_t srcRank = unpackOp.getSourceRank(); 527 int64_t destRank = unpackOp.getDestRank(); 528 int64_t numInnerTiles = srcRank - destRank; 529 Location loc = unpackOp.getLoc(); 530 531 // The perfect tiling case indicates that the tiling sizes are multiple of 532 // inner_tile_size. In this context, no extra data is needed when 533 // representing the tiled unpack op. 534 bool isPerfectTilingCase = true; 535 Attribute oneAttr = b.getIndexAttr(1); 536 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); 537 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; 538 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; 539 for (auto dim : llvm::seq<int64_t>(0, destRank)) { 540 UnpackTileDimInfo info = 541 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); 542 if (!info.isAlignedToInnerTileSize) 543 isPerfectTilingCase = false; 544 sliceSrcIndices.push_back(info.sourceOffset); 545 sliceSrcSizes.push_back(info.sourceSize); 546 destExpandedSizes.push_back(info.destExpandedSize); 547 resultOffsetsFromDest.push_back(info.resultOffset); 548 } 549 550 // The tiling is applied on destination dimensions. We have to apply the 551 // interchange on source dimensions if outer_dims_perm is set. 552 applyPermToRange(sliceSrcIndices, sliceSrcSizes, 553 unpackOp.getOuterDimsPerm()); 554 Attribute zeroAttr = b.getIndexAttr(0); 555 sliceSrcIndices.append(numInnerTiles, zeroAttr); 556 sliceSrcSizes.append(unpackOp.getMixedTiles()); 557 sliceSrcStrides.append(numInnerTiles, oneAttr); 558 SmallVector<Operation *> generatedSlices; 559 ExtractSliceOp sliceSource = 560 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, 561 sliceSrcSizes, sliceSrcStrides); 562 generatedSlices.push_back(sliceSource); 563 564 SmallVector<OpFoldResult> destStrides(destRank, oneAttr); 565 Value sliceDest; 566 if (isPerfectTilingCase) { 567 auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), 568 offsets, sizes, destStrides); 569 sliceDest = destSliceOp; 570 generatedSlices.push_back(destSliceOp); 571 } else { 572 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, 573 unpackOp.getDestType().getElementType()); 574 } 575 576 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest}; 577 for (auto tile : unpackOp.getInnerTiles()) 578 tiledOperands.push_back(tile); 579 580 Operation *tiledUnpackOp = b.create<UnPackOp>( 581 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); 582 583 if (isPerfectTilingCase) 584 return TilingResult{{tiledUnpackOp}, 585 SmallVector<Value>(tiledUnpackOp->getResults()), 586 generatedSlices}; 587 588 auto extractSlice = 589 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), 590 resultOffsetsFromDest, sizes, destStrides); 591 return TilingResult{ 592 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; 593 } 594 595 LogicalResult 596 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 597 ArrayRef<OpFoldResult> offsets, 598 ArrayRef<OpFoldResult> sizes, 599 SmallVector<OpFoldResult> &resultOffsets, 600 SmallVector<OpFoldResult> &resultSizes) const { 601 resultOffsets = llvm::to_vector(offsets); 602 resultSizes = llvm::to_vector(sizes); 603 return success(); 604 } 605 606 FailureOr<TilingResult> 607 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 608 ArrayRef<OpFoldResult> offsets, 609 ArrayRef<OpFoldResult> sizes) const { 610 FailureOr<TilingResult> tilingResult = 611 getTiledImplementation(op, b, offsets, sizes); 612 if (failed(tilingResult)) 613 return failure(); 614 return tilingResult.value(); 615 } 616 617 /// Method to return the position of iteration domain tile computed by the 618 /// tiled operation. 619 LogicalResult getIterationDomainTileFromOperandTile( 620 Operation *op, OpBuilder &b, unsigned operandNumber, 621 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 622 SmallVectorImpl<OpFoldResult> &resultOffsets, 623 SmallVectorImpl<OpFoldResult> &resultSizes) const { 624 auto unPackOp = cast<UnPackOp>(op); 625 // If the operand tile is the dest, then no adjustment is needed. 626 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) { 627 resultOffsets = llvm::to_vector(offsets); 628 resultSizes = llvm::to_vector(sizes); 629 return success(); 630 } 631 Location loc = unPackOp.getLoc(); 632 633 int64_t numTiles = unPackOp.getInnerDimsPos().size(); 634 auto destOffsets = offsets.drop_back(numTiles); 635 auto destSizes = sizes.drop_back(numTiles); 636 // The tiling is applied on interchanged dimensions. We have to undo the 637 // interchange to map sizes and offsets to the original input. 638 int64_t outputRank = unPackOp.getDestRank(); 639 ReifiedRankedShapedTypeDims reifiedReturnShapes; 640 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes))) 641 return failure(); 642 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front(); 643 SmallVector<OpFoldResult> origOffsets(destOffsets); 644 SmallVector<OpFoldResult> origSizes(destSizes); 645 applyPermToRange(origOffsets, origSizes, 646 invertPermutationVector(unPackOp.getOuterDimsPerm())); 647 648 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 649 unPackOp.getDimAndTileMapping(); 650 651 for (auto dim : llvm::seq<int64_t>(0, outputRank)) { 652 using AV = affine::AffineValueExpr; 653 affine::AffineBuilder ab(b, loc); 654 AffineExpr dim0, dim1, sym0; 655 bindDims(b.getContext(), dim0, dim1); 656 bindSymbols(b.getContext(), sym0); 657 if (dimAndTileMapping.count(dim)) { 658 // If the data dimension is tiled, the i-th index is the product of 659 // offset_i and tile_i, and the i-th size is the product of sizes_i and 660 // tile_i. The sizes must be clamped to the sizes of the unpack result. 661 auto avOffset = AV(dim0).bind(origOffsets[dim]); 662 auto avSize = AV(dim0).bind(origSizes[dim]); 663 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]); 664 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]); 665 resultOffsets.push_back(ab.mul(avOffset, avTileSize)); 666 auto avResultOffset = AV(dim1).bind(resultOffsets.back()); 667 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize), 668 ab.sub(avResultSize, avResultOffset)})); 669 } else { 670 resultOffsets.push_back(origOffsets[dim]); 671 resultSizes.push_back(origSizes[dim]); 672 } 673 } 674 return success(); 675 } 676 677 /// Method to return the tiled implementation of tensor.unpack as a consumer. 678 FailureOr<TilingResult> getTiledImplementationFromOperandTile( 679 Operation *op, OpBuilder &b, unsigned operandNumber, 680 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 681 auto unPackOp = cast<UnPackOp>(op); 682 // tensor.unpack op is fusible (as a consumer) only if inner dims are not 683 // tiled. 684 int64_t numTiles = unPackOp.getInnerDimsPos().size(); 685 for (auto iter : 686 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) { 687 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 688 return failure(); 689 } 690 691 Location loc = unPackOp.getLoc(); 692 693 // Fetch offset/size for creating the slice of the dest operand of 694 // unpack op. 695 SmallVector<OpFoldResult> outputOffsets, outputSizes; 696 if (failed(getIterationDomainTileFromOperandTile( 697 op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets, 698 outputSizes))) 699 return failure(); 700 701 auto oneAttr = b.getI64IntegerAttr(1); 702 int64_t outputRank = unPackOp.getDestRank(); 703 SmallVector<OpFoldResult> strides(outputRank, oneAttr); 704 705 SmallVector<Value> tiledOperands; 706 // Create slice of the dest operand. 707 auto extractDestSlice = b.create<ExtractSliceOp>( 708 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); 709 tiledOperands.push_back(extractDestSlice); 710 711 SmallVector<OpFoldResult> inputOffsets, inputSizes; 712 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); 713 // Create slice of the source operand. 714 auto extractSourceSlice = b.create<ExtractSliceOp>( 715 loc, unPackOp.getSource(), offsets, sizes, strides); 716 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); 717 for (auto tile : unPackOp.getInnerTiles()) 718 tiledOperands.push_back(tile); 719 720 // Create tiled unpack op. 721 Operation *tiledUnPackOp = 722 b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()}, 723 tiledOperands, op->getAttrs()); 724 725 return TilingResult{{tiledUnPackOp}, 726 SmallVector<Value>(tiledUnPackOp->getResults()), 727 llvm::to_vector(ArrayRef<Operation *>{ 728 extractSourceSlice, extractDestSlice})}; 729 } 730 }; 731 732 } // namespace 733 734 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, 735 tensor::PadOp padOp, 736 ArrayRef<OpFoldResult> offsets, 737 ArrayRef<OpFoldResult> sizes, 738 bool generateZeroSliceGuard) { 739 // Only constant padding value supported. 740 Value padValue = padOp.getConstantPaddingValue(); 741 if (!padValue) 742 return failure(); 743 744 // Helper variables and functions for various arithmetic operations. These 745 // are used extensively for computing new offset/length and padding values. 746 Location loc = padOp->getLoc(); 747 AffineExpr dim0, dim1; 748 bindDims(b.getContext(), dim0, dim1); 749 // Subtract two integers. 750 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 751 auto sub = [&](OpFoldResult v1, OpFoldResult v2) { 752 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); 753 }; 754 // Take the minimum of two integers. 755 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 756 auto min = [&](OpFoldResult v1, OpFoldResult v2) { 757 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); 758 }; 759 // Take the maximum of two integers. 760 auto max = [&](OpFoldResult v1, OpFoldResult v2) { 761 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); 762 }; 763 // Zero index-typed integer. 764 OpFoldResult zero = b.getIndexAttr(0); 765 766 // Compute new offsets, lengths, low padding, high padding. 767 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 768 SmallVector<OpFoldResult> newLows, newHighs; 769 // Set to true if the original data source is not read at all. 770 bool hasZeroLen = false; 771 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 772 // is true if the original data source turns out to be unused at runtime. 773 Value dynHasZeroLenCond; 774 775 int64_t rank = padOp.getSourceType().getRank(); 776 for (unsigned dim = 0; dim < rank; ++dim) { 777 auto low = padOp.getMixedLowPad()[dim]; 778 bool hasLowPad = !isConstantIntValue(low, 0); 779 auto high = padOp.getMixedHighPad()[dim]; 780 bool hasHighPad = !isConstantIntValue(high, 0); 781 auto offset = offsets[dim]; 782 auto length = sizes[dim]; 783 auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim); 784 785 // The new amount of low padding is `low - offset`. Except for the case 786 // where none of the low padding is read. In that case, the new amount of 787 // low padding is zero. 788 // 789 // Optimization: If low = 0, then newLow = 0. 790 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 791 newLows.push_back(newLow); 792 793 // Start reading the data from position `offset - low`. Since the original 794 // read may have started in the low padding zone, this value could be 795 // negative. Therefore, start reading from: 796 // 797 // max(offset - low, 0) 798 // 799 // The original read could also have started in the high padding zone. 800 // In that case, set the offset to the end of source tensor. The new 801 // ExtractSliceOp length will be zero in that case. (Effectively reading 802 // no data from the source.) 803 // 804 // Optimization: If low = 0, then the formula can be simplified. 805 OpFoldResult newOffset = hasLowPad 806 ? min(max(sub(offset, low), zero), srcSize) 807 : min(offset, srcSize); 808 newOffsets.push_back(newOffset); 809 810 // The original ExtractSliceOp was reading until position `offset + 811 // length`. Therefore, the corresponding position within the source tensor 812 // is: 813 // 814 // offset + length - low 815 // 816 // In case the original ExtractSliceOp stopped reading within the low 817 // padding zone, this value can be negative. In that case, the end 818 // position of the read should be zero. (Similar to newOffset.) 819 // 820 // The original read could also have stopped in the high padding zone. 821 // In that case, set the end positition of the read should be the end of 822 // the source tensor. (Similar to newOffset.) 823 // srcSize - newOffset represents how much length we have available 824 // and length - newLow represents how much length we want at most. 825 // Note that there are many ways to order this indexing math to compute 826 // newLength, but we want to make sure that the final affine.min ops in the 827 // sequence are bounding the index to as small a value as possible. If 828 // ValueBoundsOpInterface is used, this calculation will get upper bounds 829 // from the affine.min ops, so we want to use the smallest known value to 830 // set the bound at the end of the computation sequence. In this case, the 831 // index will be upper bounded by length - newLow. 832 OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow)); 833 // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming 834 // length >= 0. 835 if (hasLowPad) 836 newLength = max(newLength, zero); 837 newLengths.push_back(newLength); 838 839 // Check if newLength is zero. In that case, no SubTensorOp should be 840 // executed. 841 if (isConstantIntValue(newLength, 0)) { 842 hasZeroLen = true; 843 } else if (!hasZeroLen) { 844 Value check = b.create<arith::CmpIOp>( 845 loc, arith::CmpIPredicate::eq, 846 getValueOrCreateConstantIndexOp(b, loc, newLength), 847 getValueOrCreateConstantIndexOp(b, loc, zero)); 848 dynHasZeroLenCond = 849 dynHasZeroLenCond 850 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 851 : check; 852 } 853 854 // The amount of high padding is simply the number of elements remaining, 855 // so that the result has the same length as the original ExtractSliceOp. 856 // As an optimization, if the original high padding is zero, then the new 857 // high padding must also be zero. 858 OpFoldResult newHigh = 859 hasHighPad ? sub(sub(length, newLength), newLow) : zero; 860 newHighs.push_back(newHigh); 861 862 // Only unit stride supported. 863 newStrides.push_back(b.getIndexAttr(1)); 864 } 865 866 // The shape of the result can be obtained from the sizes passed in. 867 SmallVector<Value> dynDims; 868 SmallVector<int64_t> shape; 869 dispatchIndexOpFoldResults(sizes, dynDims, shape); 870 RankedTensorType resultType = 871 RankedTensorType::get(shape, padOp.getResultType().getElementType()); 872 873 // Insert cast to ensure that types match. (May be folded away.) 874 auto castResult = [&](Value val) -> Value { 875 if (resultType == val.getType()) 876 return val; 877 return b.create<tensor::CastOp>(loc, resultType, val); 878 }; 879 880 // In cases where the original data source is unused: Emit a GenerateOp and 881 // do not generate a SliceOp. (The result shape of the SliceOp would 882 // have a dimension of size 0, the semantics of which is unclear.) 883 auto createGenerateOp = [&]() { 884 // Create GenerateOp. 885 auto generateOp = b.create<tensor::GenerateOp>( 886 loc, resultType, dynDims, 887 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 888 builder.create<tensor::YieldOp>(gLoc, padValue); 889 }); 890 return generateOp; 891 }; 892 893 // Emit a SliceOp and a PadOp. Should not be used in cases where 894 // the result shape of the new SliceOp has a zero dimension. 895 auto createPadOfExtractSlice = [&]() { 896 // Create pad(extract_slice(x)). 897 auto newSliceOp = b.create<tensor::ExtractSliceOp>( 898 loc, padOp.getSource(), newOffsets, newLengths, newStrides); 899 auto newPadOp = b.create<PadOp>( 900 loc, Type(), newSliceOp, newLows, newHighs, 901 /*nofold=*/padOp.getNofold(), 902 getPrunedAttributeList(padOp, PadOp::getAttributeNames())); 903 904 // Copy region to new PadOp. 905 IRMapping bvm; 906 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); 907 908 // Cast result and return. 909 return std::make_tuple(newPadOp, newSliceOp); 910 }; 911 912 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that 913 // the original data source x is not used. 914 if (hasZeroLen) { 915 Operation *generateOp = createGenerateOp(); 916 return TilingResult{{generateOp}, 917 {castResult(generateOp->getResult(0))}, 918 /*generatedSlices=*/{}}; 919 } 920 921 // If there are dynamic dimensions: Generate an scf.if check to avoid 922 // creating SliceOps with result dimensions of size 0 at runtime. 923 if (generateZeroSliceGuard && dynHasZeroLenCond) { 924 Operation *thenOp; 925 Operation *elseOp; 926 Operation *sliceOp; 927 auto result = b.create<scf::IfOp>( 928 loc, dynHasZeroLenCond, 929 /*thenBuilder=*/ 930 [&](OpBuilder &b, Location loc) { 931 thenOp = createGenerateOp(); 932 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); 933 }, 934 /*elseBuilder=*/ 935 [&](OpBuilder &b, Location loc) { 936 std::tie(elseOp, sliceOp) = createPadOfExtractSlice(); 937 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); 938 }); 939 return TilingResult{ 940 {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}}; 941 } 942 943 auto [newPadOp, sliceOp] = createPadOfExtractSlice(); 944 return TilingResult{ 945 {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}}; 946 } 947 948 void mlir::tensor::registerTilingInterfaceExternalModels( 949 DialectRegistry ®istry) { 950 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 951 tensor::PadOp::attachInterface<PadOpTiling>(*ctx); 952 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 953 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 954 }); 955 } 956 957 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( 958 DialectRegistry ®istry) { 959 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 960 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 961 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 962 }); 963 } 964