1 //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Affine/Utils.h" 12 #include "mlir/Dialect/Arith/Utils/Utils.h" 13 #include "mlir/Dialect/Linalg/IR/Linalg.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Utils/Utils.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Interfaces/TilingInterface.h" 20 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 21 22 using namespace mlir; 23 using namespace mlir::tensor; 24 25 namespace { 26 27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 28 29 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 30 auto padOp = cast<PadOp>(op); 31 SmallVector<utils::IteratorType> iteratorTypes( 32 padOp.getResultType().getRank(), utils::IteratorType::parallel); 33 return iteratorTypes; 34 } 35 36 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 37 ReifiedRankedShapedTypeDims reifiedShapes; 38 (void)reifyResultShapes(b, op, reifiedShapes); 39 OpFoldResult zero = b.getIndexAttr(0); 40 OpFoldResult one = b.getIndexAttr(1); 41 // Initialize all the ranges to {zero, one, one}. All the `ub`s are 42 // overwritten. 43 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 44 for (const auto &ub : enumerate(reifiedShapes[0])) 45 loopRanges[ub.index()].size = ub.value(); 46 return loopRanges; 47 } 48 49 FailureOr<TilingResult> 50 getTiledImplementation(Operation *op, OpBuilder &b, 51 ArrayRef<OpFoldResult> offsets, 52 ArrayRef<OpFoldResult> sizes) const { 53 FailureOr<TilingResult> result = 54 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); 55 if (failed(result)) 56 return failure(); 57 return result.value(); 58 } 59 60 LogicalResult 61 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 62 ArrayRef<OpFoldResult> offsets, 63 ArrayRef<OpFoldResult> sizes, 64 SmallVector<OpFoldResult> &resultOffsets, 65 SmallVector<OpFoldResult> &resultSizes) const { 66 resultOffsets.assign(offsets.begin(), offsets.end()); 67 resultSizes.assign(sizes.begin(), sizes.end()); 68 return success(); 69 } 70 71 LogicalResult getIterationDomainTileFromResultTile( 72 Operation *op, OpBuilder &b, unsigned resultNumber, 73 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 74 SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 75 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 76 iterDomainOffsets.assign(offsets.begin(), offsets.end()); 77 iterDomainSizes.assign(sizes.begin(), sizes.end()); 78 return success(); 79 } 80 81 FailureOr<TilingResult> 82 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 83 ArrayRef<OpFoldResult> offsets, 84 ArrayRef<OpFoldResult> sizes) const { 85 return getTiledImplementation(op, b, offsets, sizes); 86 } 87 }; 88 89 template <typename OpTy> 90 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, 91 OpBuilder &builder) { 92 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 93 "applies to only pack or unpack operations"); 94 OpBuilder::InsertionGuard g(builder); 95 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() 96 : op.getDestRank(); 97 OpFoldResult zero = builder.getIndexAttr(0); 98 OpFoldResult one = builder.getIndexAttr(1); 99 ReifiedRankedShapedTypeDims resultShape; 100 (void)reifyResultShapes(builder, op, resultShape); 101 SmallVector<Range> loopBounds(rank); 102 for (auto dim : llvm::seq<int64_t>(0, rank)) { 103 loopBounds[dim].offset = zero; 104 loopBounds[dim].stride = one; 105 loopBounds[dim].size = resultShape[0][dim]; 106 } 107 return loopBounds; 108 } 109 110 static void applyPermToRange(SmallVector<OpFoldResult> &offsets, 111 SmallVector<OpFoldResult> &sizes, 112 ArrayRef<int64_t> permutation) { 113 if (permutation.empty()) 114 return; 115 applyPermutationToVector<OpFoldResult>(offsets, permutation); 116 applyPermutationToVector<OpFoldResult>(sizes, permutation); 117 } 118 119 struct PackOpTiling 120 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { 121 122 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 123 // Note that here we only consider untiled dimensions and outer tiled data 124 // dimensions, the inner tiled data dimensions are materialized when 125 // building the body of the operation. 126 auto packOp = cast<PackOp>(op); 127 SmallVector<utils::IteratorType> iteratorTypes( 128 packOp.getSourceRank(), utils::IteratorType::parallel); 129 return iteratorTypes; 130 } 131 132 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 133 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); 134 } 135 136 FailureOr<TilingResult> 137 getTiledImplementation(Operation *op, OpBuilder &b, 138 ArrayRef<OpFoldResult> offsets, 139 ArrayRef<OpFoldResult> sizes) const { 140 auto packOp = cast<PackOp>(op); 141 Location loc = packOp.getLoc(); 142 143 // The tiling is applied on interchanged dimensions. We have to undo the 144 // interchange to map sizes and offsets to the original input. 145 int64_t inputRank = packOp.getSourceRank(); 146 SmallVector<OpFoldResult> origOffsets(offsets); 147 SmallVector<OpFoldResult> origSizes(sizes); 148 applyPermToRange(origOffsets, origSizes, 149 invertPermutationVector(packOp.getOuterDimsPerm())); 150 151 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 152 packOp.getDimAndTileMapping(); 153 SmallVector<OpFoldResult> srcDimValues = 154 tensor::getMixedSizes(b, loc, packOp.getSource()); 155 SmallVector<OpFoldResult> inputIndices, inputSizes; 156 for (auto dim : llvm::seq<int64_t>(0, inputRank)) { 157 using AV = affine::AffineValueExpr; 158 affine::AffineBuilder ab(b, loc); 159 AffineExpr dim0, dim1, sym; 160 bindDims(b.getContext(), dim0, dim1); 161 bindSymbols(b.getContext(), sym); 162 if (dimAndTileMapping.count(dim)) { 163 // If the data dimension is tiled, the i-th index is the product of 164 // offset_i and tile_i, and the i-th size is the product of sizes_i and 165 // tile_i. 166 auto avOffset = AV(dim0).bind(origOffsets[dim]); 167 auto avSize = AV(dim0).bind(origSizes[dim]); 168 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 169 inputIndices.push_back(ab.mul(avOffset, avTileSize)); 170 inputSizes.push_back(ab.mul(avSize, avTileSize)); 171 } else { 172 inputIndices.push_back(origOffsets[dim]); 173 inputSizes.push_back(origSizes[dim]); 174 } 175 176 // Limit the size of the input operand for incomplete tiles. 177 if (packOp.getPaddingValue()) { 178 OpFoldResult dimSize = srcDimValues[dim]; 179 auto avDimSize = AV(dim0).bind(dimSize); 180 auto avInputIdx = AV(dim1).bind(inputIndices.back()); 181 inputSizes.back() = 182 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); 183 } 184 } 185 186 auto oneAttr = b.getI64IntegerAttr(1); 187 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 188 189 SmallVector<Value> tiledOperands; 190 tiledOperands.push_back(b.create<ExtractSliceOp>( 191 loc, packOp.getSource(), inputIndices, inputSizes, strides)); 192 193 SmallVector<OpFoldResult> outputOffsets, outputSizes; 194 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, 195 outputSizes))) 196 return {}; 197 198 strides.append(packOp.getDestRank() - inputRank, oneAttr); 199 auto extractSlice = b.create<ExtractSliceOp>( 200 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 201 tiledOperands.push_back(extractSlice); 202 203 if (auto val = packOp.getPaddingValue()) 204 tiledOperands.push_back(val); 205 for (auto tile : packOp.getInnerTiles()) 206 tiledOperands.push_back(tile); 207 208 Operation *tiledPackOp = b.create<PackOp>( 209 loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); 210 211 return TilingResult{{tiledPackOp}, 212 SmallVector<Value>(tiledPackOp->getResults())}; 213 } 214 215 LogicalResult 216 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 217 ArrayRef<OpFoldResult> offsets, 218 ArrayRef<OpFoldResult> sizes, 219 SmallVector<OpFoldResult> &resultOffsets, 220 SmallVector<OpFoldResult> &resultSizes) const { 221 // The iteration domain is over outer dimensions of packed layout. In this 222 // context, the outer dimensions of `resultOffsets` are `offsets`. The 223 // inner dimensions of `resultOffsets` are zeros because tiling is not 224 // applied to them. 225 auto packOp = cast<PackOp>(op); 226 int64_t inputRank = packOp.getSourceRank(); 227 int64_t outputRank = packOp.getDestRank(); 228 auto zeroAttr = b.getI64IntegerAttr(0); 229 resultOffsets.assign(offsets.begin(), offsets.end()); 230 resultOffsets.append(outputRank - inputRank, zeroAttr); 231 232 ReifiedRankedShapedTypeDims outputShape; 233 (void)reifyResultShapes(b, packOp, outputShape); 234 resultSizes.assign(sizes.begin(), sizes.end()); 235 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) 236 resultSizes.push_back(outputShape[0][dataTileDim]); 237 238 return success(); 239 } 240 241 FailureOr<TilingResult> 242 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 243 ArrayRef<OpFoldResult> offsets, 244 ArrayRef<OpFoldResult> sizes) const { 245 auto packOp = cast<PackOp>(op); 246 int64_t numTiles = packOp.getInnerDimsPos().size(); 247 248 // tensor.pack op is fusible (as a producer) only if full inner tiles are 249 // iterated or inner dims are not tiled. Otherwise, it will generate a 250 // sequence of non-trivial ops (for partial tiles). 251 for (auto offset : offsets.take_back(numTiles)) 252 if (!isConstantIntValue(offset, 0)) 253 return failure(); 254 255 for (auto iter : 256 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) 257 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 258 return failure(); 259 260 FailureOr<TilingResult> tilingResult = getTiledImplementation( 261 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles)); 262 if (failed(tilingResult)) 263 return failure(); 264 return tilingResult.value(); 265 } 266 267 /// Method to return the position of iteration domain tile computed by the 268 /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and 269 /// `resultSizes` only cover outer dimensions. 270 LogicalResult getIterationDomainTileFromOperandTile( 271 Operation *op, OpBuilder &b, unsigned operandNumber, 272 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 273 SmallVectorImpl<OpFoldResult> &resultOffsets, 274 SmallVectorImpl<OpFoldResult> &resultSizes) const { 275 if (operandNumber != 0) 276 return failure(); 277 278 auto packOp = cast<PackOp>(op); 279 // It is not trivial to infer dest tile from source tile if `packOp` has 280 // padding semantic. 281 if (packOp.getPaddingValue()) 282 return failure(); 283 284 Location loc = packOp.getLoc(); 285 286 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 287 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 288 packOp.getDimAndTileMapping(); 289 for (auto dim : packOp.getOuterDimsPerm()) { 290 if (dimAndTileMapping.count(dim)) { 291 FailureOr<int64_t> cstSize = 292 ValueBoundsConstraintSet::computeConstantBound( 293 presburger::BoundType::UB, sizes[dim], 294 /*stopCondition=*/nullptr, /*closedUB=*/true); 295 std::optional<int64_t> cstInnerSize = 296 getConstantIntValue(dimAndTileMapping[dim]); 297 // Currently fusing `packOp` as consumer only expects perfect tiling 298 // scenario because even if without padding semantic, the `packOp` may 299 // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, 300 // where the `tileSize` from operand of `packOp` is 5, which is not 301 // exactly divided by `innerTile`(=6) of `packOp`. As the result: 302 // 1. the first slice is extracted from (0) to (4) and inserted into 303 // (0,0)~(0,4) at first row. 304 // 2. the second slice is extracted from (5) to (9) and SHOULD BE 305 // respectively inserted into two rows with different length, including 306 // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate 307 // them, thus adding below constraint to bypass them temporarily. In 308 // another word, we can only support tiling with consumer if the tile 309 // size for the producer is a multiple of the inner tile size for the 310 // packed dimensions at this moment. 311 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { 312 return failure(); 313 } 314 315 using AV = affine::AffineValueExpr; 316 affine::AffineBuilder ab(b, loc); 317 AffineExpr dim0, sym; 318 bindDims(b.getContext(), dim0); 319 bindSymbols(b.getContext(), sym); 320 auto avOffset = AV(dim0).bind(offsets[dim]); 321 auto avSize = AV(dim0).bind(sizes[dim]); 322 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 323 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize)); 324 outerDimSizes.push_back(ab.ceil(avSize, avTileSize)); 325 } else { 326 outerDimOffsets.push_back(offsets[dim]); 327 outerDimSizes.push_back(sizes[dim]); 328 } 329 } 330 331 resultOffsets = outerDimOffsets; 332 resultSizes = outerDimSizes; 333 return success(); 334 } 335 336 /// Method to return the tiled implementation of tensor.pack as a consumer. 337 FailureOr<TilingResult> getTiledImplementationFromOperandTile( 338 Operation *op, OpBuilder &b, unsigned operandNumber, 339 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 340 if (operandNumber != 0) 341 return failure(); 342 343 auto packOp = cast<PackOp>(op); 344 Location loc = packOp.getLoc(); 345 346 int64_t inputRank = packOp.getSourceRank(); 347 auto oneAttr = b.getI64IntegerAttr(1); 348 SmallVector<OpFoldResult> strides(inputRank, oneAttr); 349 350 SmallVector<Value> tiledOperands; 351 tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(), 352 offsets, sizes, strides)); 353 354 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 355 if (failed(getIterationDomainTileFromOperandTile( 356 op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets, 357 outerDimSizes))) 358 return failure(); 359 360 SmallVector<OpFoldResult> outputOffsets, outputSizes; 361 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes, 362 outputOffsets, outputSizes))) 363 return failure(); 364 365 strides.append(packOp.getDestRank() - inputRank, oneAttr); 366 auto extractSlice = b.create<ExtractSliceOp>( 367 loc, packOp.getDest(), outputOffsets, outputSizes, strides); 368 tiledOperands.push_back(extractSlice); 369 370 assert(!packOp.getPaddingValue() && "Expect no padding semantic"); 371 for (auto tile : packOp.getInnerTiles()) 372 tiledOperands.push_back(tile); 373 374 Operation *tiledPackOp = b.create<PackOp>( 375 loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); 376 377 return TilingResult{{tiledPackOp}, 378 SmallVector<Value>(tiledPackOp->getResults())}; 379 } 380 }; 381 382 struct UnpackTileDimInfo { 383 bool isAlignedToInnerTileSize; 384 OpFoldResult sourceOffset; 385 OpFoldResult sourceSize; 386 OpFoldResult resultOffset; 387 OpFoldResult destExpandedSize; 388 }; 389 390 /// Returns the needed information for tiling unpack op on `tileDim` with given 391 /// `tileOffset` and `tileSize`. For more details, see the comment of the 392 /// `getTiledImplementation`. 393 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, 394 int64_t tileDim, 395 OpFoldResult tileOffset, 396 OpFoldResult tileSize) { 397 UnpackTileDimInfo info; 398 Attribute zeroAttr = b.getIndexAttr(0); 399 Attribute oneAttr = b.getIndexAttr(1); 400 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 401 unpackOp.getDimAndTileMapping(); 402 // The dimension is not one of packed data dimension. 403 if (!dimAndTileMapping.count(tileDim)) { 404 info.isAlignedToInnerTileSize = true; 405 info.sourceOffset = tileOffset; 406 info.sourceSize = tileSize; 407 info.resultOffset = zeroAttr; 408 info.destExpandedSize = tileSize; 409 return info; 410 } 411 412 Location loc = unpackOp.getLoc(); 413 using AV = affine::AffineValueExpr; 414 affine::AffineBuilder ab(b, loc); 415 AffineExpr dim0, dim1, sym0; 416 bindDims(b.getContext(), dim0, dim1); 417 bindSymbols(b.getContext(), sym0); 418 419 OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; 420 421 info.isAlignedToInnerTileSize = false; 422 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( 423 presburger::BoundType::UB, tileSize, 424 /*stopCondition=*/nullptr, /*closedUB=*/true); 425 std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); 426 if (!failed(cstSize) && cstInnerSize) { 427 if (*cstSize % *cstInnerSize == 0) 428 info.isAlignedToInnerTileSize = true; 429 430 // If the tiling size equals to the inner tiling size, the outer dims are 431 // always 1. 432 if (*cstInnerSize == *cstSize) { 433 auto lhs = AV(dim0).bind(tileOffset); 434 auto rhs = AV(dim1).bind(innerTileSize); 435 info.sourceOffset = ab.floor(lhs, rhs); 436 info.sourceSize = oneAttr; 437 info.resultOffset = zeroAttr; 438 info.destExpandedSize = tileSize; 439 return info; 440 } 441 } 442 443 if (info.isAlignedToInnerTileSize) { 444 info.sourceOffset = 445 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); 446 info.resultOffset = zeroAttr; 447 info.destExpandedSize = tileSize; 448 449 // The ceilDiv is needed here because there could be incomplete tile even 450 // it is perfect tiling cases. E.g., 451 // %0 = unpack tensor<33x2xf32> into tensor<64xf32> 452 // If the tiling size is 32, there will be 3 tiles. Two of them have 453 // size=32; one of them have size=2. The size is represented using 454 // affine_min op; we need ceilDiv. 455 info.sourceSize = 456 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); 457 return info; 458 } 459 460 affine::DivModValue firstCoord = affine::getDivMod( 461 b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), 462 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 463 OpFoldResult tileExclusiveBound = 464 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); 465 affine::DivModValue lastCoord = affine::getDivMod( 466 b, loc, 467 getValueOrCreateConstantIndexOp( 468 b, loc, 469 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), 470 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 471 472 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), 473 AV(dim1).bind(firstCoord.quotient)); 474 info.sourceSize = 475 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); 476 info.sourceOffset = firstCoord.quotient; 477 info.resultOffset = firstCoord.remainder; 478 // Do not create an Affine ops for expanded size because the affine op is too 479 // complicated which would trigger an issue in affine ops simplification. 480 info.destExpandedSize = b.createOrFold<arith::MulIOp>( 481 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), 482 getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 483 return info; 484 } 485 486 struct UnPackOpTiling 487 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { 488 489 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 490 auto unpackOp = cast<UnPackOp>(op); 491 SmallVector<utils::IteratorType> iteratorTypes( 492 unpackOp.getDestRank(), utils::IteratorType::parallel); 493 return iteratorTypes; 494 } 495 496 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 497 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); 498 } 499 500 /// There are two cases in tiling unpack ops. If the tiling size is aligned to 501 /// the inner tile size, the corresponding tiles of source are all complete. 502 /// Otherwise, there are in-complete tiles. We will need to expand the slice 503 /// of source for getting complete tiles. The tiled unpack op unpacks more 504 /// data from source, so We'll need an extract_slice op to shift and truncate 505 /// the output. 506 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The 507 /// coordinates of second tile (i.e., result[15..31]) are 508 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last 509 /// row are incomplete tiles. To represent the unpack op, we have to complete 510 /// the rows. I.e., the input coordinates would start with (1, 0); end with 511 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements 512 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we 513 /// can get the actual result. 514 FailureOr<TilingResult> 515 getTiledImplementation(Operation *op, OpBuilder &b, 516 ArrayRef<OpFoldResult> offsets, 517 ArrayRef<OpFoldResult> sizes) const { 518 auto unpackOp = cast<UnPackOp>(op); 519 int64_t srcRank = unpackOp.getSourceRank(); 520 int64_t destRank = unpackOp.getDestRank(); 521 int64_t numInnerTiles = srcRank - destRank; 522 Location loc = unpackOp.getLoc(); 523 524 // The perfect tiling case indicates that the tiling sizes are multiple of 525 // inner_tile_size. In this context, no extra data is needed when 526 // representing the tiled unpack op. 527 bool isPerfectTilingCase = true; 528 Attribute oneAttr = b.getIndexAttr(1); 529 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); 530 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; 531 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; 532 for (auto dim : llvm::seq<int64_t>(0, destRank)) { 533 UnpackTileDimInfo info = 534 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); 535 if (!info.isAlignedToInnerTileSize) 536 isPerfectTilingCase = false; 537 sliceSrcIndices.push_back(info.sourceOffset); 538 sliceSrcSizes.push_back(info.sourceSize); 539 destExpandedSizes.push_back(info.destExpandedSize); 540 resultOffsetsFromDest.push_back(info.resultOffset); 541 } 542 543 // The tiling is applied on destination dimensions. We have to apply the 544 // interchange on source dimensions if outer_dims_perm is set. 545 applyPermToRange(sliceSrcIndices, sliceSrcSizes, 546 unpackOp.getOuterDimsPerm()); 547 Attribute zeroAttr = b.getIndexAttr(0); 548 sliceSrcIndices.append(numInnerTiles, zeroAttr); 549 sliceSrcSizes.append(unpackOp.getMixedTiles()); 550 sliceSrcStrides.append(numInnerTiles, oneAttr); 551 Value sliceSource = 552 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, 553 sliceSrcSizes, sliceSrcStrides); 554 555 SmallVector<OpFoldResult> destStrides(destRank, oneAttr); 556 Value sliceDest; 557 if (isPerfectTilingCase) { 558 sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets, 559 sizes, destStrides); 560 } else { 561 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, 562 unpackOp.getDestType().getElementType()); 563 } 564 565 SmallVector<Value> tiledOperands = {sliceSource, sliceDest}; 566 for (auto tile : unpackOp.getInnerTiles()) 567 tiledOperands.push_back(tile); 568 569 Operation *tiledUnpackOp = b.create<UnPackOp>( 570 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); 571 572 if (isPerfectTilingCase) 573 return TilingResult{{tiledUnpackOp}, 574 SmallVector<Value>(tiledUnpackOp->getResults())}; 575 576 auto extractSlice = 577 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), 578 resultOffsetsFromDest, sizes, destStrides); 579 return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; 580 } 581 582 LogicalResult 583 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 584 ArrayRef<OpFoldResult> offsets, 585 ArrayRef<OpFoldResult> sizes, 586 SmallVector<OpFoldResult> &resultOffsets, 587 SmallVector<OpFoldResult> &resultSizes) const { 588 resultOffsets = llvm::to_vector(offsets); 589 resultSizes = llvm::to_vector(sizes); 590 return success(); 591 } 592 593 FailureOr<TilingResult> 594 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 595 ArrayRef<OpFoldResult> offsets, 596 ArrayRef<OpFoldResult> sizes) const { 597 FailureOr<TilingResult> tilingResult = 598 getTiledImplementation(op, b, offsets, sizes); 599 if (failed(tilingResult)) 600 return failure(); 601 return tilingResult.value(); 602 } 603 604 /// Method to return the position of iteration domain tile computed by the 605 /// tiled operation. 606 LogicalResult getIterationDomainTileFromOperandTile( 607 Operation *op, OpBuilder &b, unsigned operandNumber, 608 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 609 SmallVectorImpl<OpFoldResult> &resultOffsets, 610 SmallVectorImpl<OpFoldResult> &resultSizes) const { 611 auto unPackOp = cast<UnPackOp>(op); 612 Location loc = unPackOp.getLoc(); 613 614 int64_t numTiles = unPackOp.getInnerDimsPos().size(); 615 auto destOffsets = offsets.drop_back(numTiles); 616 auto destSizes = sizes.drop_back(numTiles); 617 // The tiling is applied on interchanged dimensions. We have to undo the 618 // interchange to map sizes and offsets to the original input. 619 int64_t outputRank = unPackOp.getDestRank(); 620 SmallVector<OpFoldResult> origOffsets(destOffsets); 621 SmallVector<OpFoldResult> origSizes(destSizes); 622 applyPermToRange(origOffsets, origSizes, 623 invertPermutationVector(unPackOp.getOuterDimsPerm())); 624 625 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 626 unPackOp.getDimAndTileMapping(); 627 628 for (auto dim : llvm::seq<int64_t>(0, outputRank)) { 629 using AV = affine::AffineValueExpr; 630 affine::AffineBuilder ab(b, loc); 631 AffineExpr dim0, dim1, sym; 632 bindDims(b.getContext(), dim0, dim1); 633 bindSymbols(b.getContext(), sym); 634 if (dimAndTileMapping.count(dim)) { 635 // If the data dimension is tiled, the i-th index is the product of 636 // offset_i and tile_i, and the i-th size is the product of sizes_i and 637 // tile_i. 638 auto avOffset = AV(dim0).bind(origOffsets[dim]); 639 auto avSize = AV(dim0).bind(origSizes[dim]); 640 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 641 resultOffsets.push_back(ab.mul(avOffset, avTileSize)); 642 resultSizes.push_back(ab.mul(avSize, avTileSize)); 643 } else { 644 resultOffsets.push_back(origOffsets[dim]); 645 resultSizes.push_back(origSizes[dim]); 646 } 647 } 648 return success(); 649 } 650 651 /// Method to return the tiled implementation of tensor.unpack as a consumer. 652 FailureOr<TilingResult> getTiledImplementationFromOperandTile( 653 Operation *op, OpBuilder &b, unsigned operandNumber, 654 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 655 auto unPackOp = cast<UnPackOp>(op); 656 // tensor.unpack op is fusible (as a consumer) only if inner dims are not 657 // tiled. 658 int64_t numTiles = unPackOp.getInnerDimsPos().size(); 659 for (auto iter : 660 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) { 661 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 662 return failure(); 663 } 664 665 Location loc = unPackOp.getLoc(); 666 667 // Fetch offset/size for creating the slice of the dest operand of 668 // unpack op. 669 SmallVector<OpFoldResult> outputOffsets, outputSizes; 670 if (failed(getIterationDomainTileFromOperandTile( 671 op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets, 672 outputSizes))) 673 return failure(); 674 675 auto oneAttr = b.getI64IntegerAttr(1); 676 int64_t outputRank = unPackOp.getDestRank(); 677 SmallVector<OpFoldResult> strides(outputRank, oneAttr); 678 679 SmallVector<Value> tiledOperands; 680 // Create slice of the dest operand. 681 auto extractDestSlice = b.create<ExtractSliceOp>( 682 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); 683 tiledOperands.push_back(extractDestSlice); 684 685 SmallVector<OpFoldResult> inputOffsets, inputSizes; 686 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); 687 // Create slice of the source operand. 688 auto extractSourceSlice = b.create<ExtractSliceOp>( 689 loc, unPackOp.getSource(), offsets, sizes, strides); 690 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); 691 for (auto tile : unPackOp.getInnerTiles()) 692 tiledOperands.push_back(tile); 693 694 // Create tiled unpack op. 695 Operation *tiledUnPackOp = 696 b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()}, 697 tiledOperands, op->getAttrs()); 698 699 return TilingResult{{tiledUnPackOp}, 700 SmallVector<Value>(tiledUnPackOp->getResults())}; 701 } 702 }; 703 704 } // namespace 705 706 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, 707 tensor::PadOp padOp, 708 ArrayRef<OpFoldResult> offsets, 709 ArrayRef<OpFoldResult> sizes, 710 bool generateZeroSliceGuard) { 711 // Only constant padding value supported. 712 Value padValue = padOp.getConstantPaddingValue(); 713 if (!padValue) 714 return failure(); 715 716 // Helper variables and functions for various arithmetic operations. These 717 // are used extensively for computing new offset/length and padding values. 718 Location loc = padOp->getLoc(); 719 AffineExpr dim0, dim1; 720 bindDims(b.getContext(), dim0, dim1); 721 // Add two integers. 722 auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); 723 auto add = [&](OpFoldResult v1, OpFoldResult v2) { 724 return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2}); 725 }; 726 // Subtract two integers. 727 auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 728 auto sub = [&](OpFoldResult v1, OpFoldResult v2) { 729 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); 730 }; 731 // Take the minimum of two integers. 732 auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 733 auto min = [&](OpFoldResult v1, OpFoldResult v2) { 734 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); 735 }; 736 // Take the maximum of two integers. 737 auto max = [&](OpFoldResult v1, OpFoldResult v2) { 738 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); 739 }; 740 // Zero index-typed integer. 741 OpFoldResult zero = b.getIndexAttr(0); 742 743 // Compute new offsets, lengths, low padding, high padding. 744 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 745 SmallVector<OpFoldResult> newLows, newHighs; 746 // Set to true if the original data source is not read at all. 747 bool hasZeroLen = false; 748 // Same as hasZeroLen, but for dynamic dimension sizes. This condition 749 // is true if the original data source turns out to be unused at runtime. 750 Value dynHasZeroLenCond; 751 752 int64_t rank = padOp.getSourceType().getRank(); 753 for (unsigned dim = 0; dim < rank; ++dim) { 754 auto low = padOp.getMixedLowPad()[dim]; 755 bool hasLowPad = !isConstantIntValue(low, 0); 756 auto high = padOp.getMixedHighPad()[dim]; 757 bool hasHighPad = !isConstantIntValue(high, 0); 758 auto offset = offsets[dim]; 759 auto length = sizes[dim]; 760 auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim); 761 762 // The new amount of low padding is `low - offset`. Except for the case 763 // where none of the low padding is read. In that case, the new amount of 764 // low padding is zero. 765 // 766 // Optimization: If low = 0, then newLow = 0. 767 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 768 newLows.push_back(newLow); 769 770 // Start reading the data from position `offset - low`. Since the original 771 // read may have started in the low padding zone, this value could be 772 // negative. Therefore, start reading from: 773 // 774 // max(offset - low, 0) 775 // 776 // The original read could also have started in the high padding zone. 777 // In that case, set the offset to the end of source tensor. The new 778 // ExtractSliceOp length will be zero in that case. (Effectively reading 779 // no data from the source.) 780 // 781 // Optimization: If low = 0, then the formula can be simplified. 782 OpFoldResult newOffset = hasLowPad 783 ? min(max(sub(offset, low), zero), srcSize) 784 : min(offset, srcSize); 785 newOffsets.push_back(newOffset); 786 787 // The original ExtractSliceOp was reading until position `offset + 788 // length`. Therefore, the corresponding position within the source tensor 789 // is: 790 // 791 // offset + length - low 792 // 793 // In case the original ExtractSliceOp stopped reading within the low 794 // padding zone, this value can be negative. In that case, the end 795 // position of the read should be zero. (Similar to newOffset.) 796 // 797 // The original read could also have stopped in the high padding zone. 798 // In that case, set the end positition of the read should be the end of 799 // the source tensor. (Similar to newOffset.) 800 // 801 // endLoc = min(max(offset - low + length, 0), srcSize) 802 // 803 // The new ExtractSliceOp length is `endLoc - newOffset`. 804 // 805 // Optimization: If low = 0, then the formula can be simplified. 806 OpFoldResult endLoc = 807 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) 808 : min(add(offset, length), srcSize); 809 OpFoldResult newLength = sub(endLoc, newOffset); 810 newLengths.push_back(newLength); 811 812 // Check if newLength is zero. In that case, no SubTensorOp should be 813 // executed. 814 if (isConstantIntValue(newLength, 0)) { 815 hasZeroLen = true; 816 } else if (!hasZeroLen) { 817 Value check = b.create<arith::CmpIOp>( 818 loc, arith::CmpIPredicate::eq, 819 getValueOrCreateConstantIndexOp(b, loc, newLength), 820 getValueOrCreateConstantIndexOp(b, loc, zero)); 821 dynHasZeroLenCond = 822 dynHasZeroLenCond 823 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 824 : check; 825 } 826 827 // The amount of high padding is simply the number of elements remaining, 828 // so that the result has the same length as the original ExtractSliceOp. 829 // As an optimization, if the original high padding is zero, then the new 830 // high padding must also be zero. 831 OpFoldResult newHigh = 832 hasHighPad ? sub(sub(length, newLength), newLow) : zero; 833 newHighs.push_back(newHigh); 834 835 // Only unit stride supported. 836 newStrides.push_back(b.getIndexAttr(1)); 837 } 838 839 // The shape of the result can be obtained from the sizes passed in. 840 SmallVector<Value> dynDims; 841 SmallVector<int64_t> shape; 842 dispatchIndexOpFoldResults(sizes, dynDims, shape); 843 RankedTensorType resultType = 844 RankedTensorType::get(shape, padOp.getResultType().getElementType()); 845 846 // Insert cast to ensure that types match. (May be folded away.) 847 auto castResult = [&](Value val) -> Value { 848 if (resultType == val.getType()) 849 return val; 850 return b.create<tensor::CastOp>(loc, resultType, val); 851 }; 852 853 // In cases where the original data source is unused: Emit a GenerateOp and 854 // do not generate a SliceOp. (The result shape of the SliceOp would 855 // have a dimension of size 0, the semantics of which is unclear.) 856 auto createGenerateOp = [&]() { 857 // Create GenerateOp. 858 auto generateOp = b.create<tensor::GenerateOp>( 859 loc, resultType, dynDims, 860 [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 861 builder.create<tensor::YieldOp>(gLoc, padValue); 862 }); 863 return generateOp; 864 }; 865 866 // Emit a SliceOp and a PadOp. Should not be used in cases where 867 // the result shape of the new SliceOp has a zero dimension. 868 auto createPadOfExtractSlice = [&]() { 869 // Create pad(extract_slice(x)). 870 Value newSliceOp = b.create<tensor::ExtractSliceOp>( 871 loc, padOp.getSource(), newOffsets, newLengths, newStrides); 872 auto newPadOp = b.create<PadOp>( 873 loc, Type(), newSliceOp, newLows, newHighs, 874 /*nofold=*/padOp.getNofold(), 875 getPrunedAttributeList(padOp, PadOp::getAttributeNames())); 876 877 // Copy region to new PadOp. 878 IRMapping bvm; 879 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); 880 881 // Cast result and return. 882 return newPadOp; 883 }; 884 885 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that 886 // the original data source x is not used. 887 if (hasZeroLen) { 888 Operation *generateOp = createGenerateOp(); 889 return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; 890 } 891 892 // If there are dynamic dimensions: Generate an scf.if check to avoid 893 // creating SliceOps with result dimensions of size 0 at runtime. 894 if (generateZeroSliceGuard && dynHasZeroLenCond) { 895 Operation *thenOp; 896 Operation *elseOp; 897 auto result = b.create<scf::IfOp>( 898 loc, dynHasZeroLenCond, 899 /*thenBuilder=*/ 900 [&](OpBuilder &b, Location loc) { 901 thenOp = createGenerateOp(); 902 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); 903 }, 904 /*elseBuilder=*/ 905 [&](OpBuilder &b, Location loc) { 906 elseOp = createPadOfExtractSlice(); 907 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); 908 }); 909 return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())}; 910 } 911 912 Operation *newPadOp = createPadOfExtractSlice(); 913 return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; 914 } 915 916 void mlir::tensor::registerTilingInterfaceExternalModels( 917 DialectRegistry ®istry) { 918 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 919 tensor::PadOp::attachInterface<PadOpTiling>(*ctx); 920 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 921 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 922 }); 923 } 924 925 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( 926 DialectRegistry ®istry) { 927 registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 928 tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 929 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 930 }); 931 } 932