xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (revision 91e57c6fa80dee935a9080f27c4d9b7971b347d5)
1 //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/Utils.h"
12 #include "mlir/Dialect/Arith/Utils/Utils.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21 
22 using namespace mlir;
23 using namespace mlir::tensor;
24 
25 namespace {
26 
27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
28 
29   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30     auto padOp = cast<PadOp>(op);
31     SmallVector<utils::IteratorType> iteratorTypes(
32         padOp.getResultType().getRank(), utils::IteratorType::parallel);
33     return iteratorTypes;
34   }
35 
36   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
37     ReifiedRankedShapedTypeDims reifiedShapes;
38     (void)reifyResultShapes(b, op, reifiedShapes);
39     OpFoldResult zero = b.getIndexAttr(0);
40     OpFoldResult one = b.getIndexAttr(1);
41     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
42     // overwritten.
43     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
44     for (const auto &ub : enumerate(reifiedShapes[0]))
45       loopRanges[ub.index()].size = ub.value();
46     return loopRanges;
47   }
48 
49   FailureOr<TilingResult>
50   getTiledImplementation(Operation *op, OpBuilder &b,
51                          ArrayRef<OpFoldResult> offsets,
52                          ArrayRef<OpFoldResult> sizes) const {
53     FailureOr<TilingResult> result =
54         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
55     if (failed(result))
56       return failure();
57     return result.value();
58   }
59 
60   LogicalResult
61   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
62                         ArrayRef<OpFoldResult> offsets,
63                         ArrayRef<OpFoldResult> sizes,
64                         SmallVector<OpFoldResult> &resultOffsets,
65                         SmallVector<OpFoldResult> &resultSizes) const {
66     resultOffsets.assign(offsets.begin(), offsets.end());
67     resultSizes.assign(sizes.begin(), sizes.end());
68     return success();
69   }
70 
71   LogicalResult getIterationDomainTileFromResultTile(
72       Operation *op, OpBuilder &b, unsigned resultNumber,
73       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
74       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
75       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
76     iterDomainOffsets.assign(offsets.begin(), offsets.end());
77     iterDomainSizes.assign(sizes.begin(), sizes.end());
78     return success();
79   }
80 
81   FailureOr<TilingResult>
82   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
83                           ArrayRef<OpFoldResult> offsets,
84                           ArrayRef<OpFoldResult> sizes) const {
85     return getTiledImplementation(op, b, offsets, sizes);
86   }
87 };
88 
89 template <typename OpTy>
90 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
91                                                        OpBuilder &builder) {
92   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
93                 "applies to only pack or unpack operations");
94   OpBuilder::InsertionGuard g(builder);
95   int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
96                                                      : op.getDestRank();
97   OpFoldResult zero = builder.getIndexAttr(0);
98   OpFoldResult one = builder.getIndexAttr(1);
99   ReifiedRankedShapedTypeDims resultShape;
100   (void)reifyResultShapes(builder, op, resultShape);
101   SmallVector<Range> loopBounds(rank);
102   for (auto dim : llvm::seq<int64_t>(0, rank)) {
103     loopBounds[dim].offset = zero;
104     loopBounds[dim].stride = one;
105     loopBounds[dim].size = resultShape[0][dim];
106   }
107   return loopBounds;
108 }
109 
110 static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
111                              SmallVector<OpFoldResult> &sizes,
112                              ArrayRef<int64_t> permutation) {
113   if (permutation.empty())
114     return;
115   applyPermutationToVector<OpFoldResult>(offsets, permutation);
116   applyPermutationToVector<OpFoldResult>(sizes, permutation);
117 }
118 
119 struct PackOpTiling
120     : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
121 
122   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
123     // Note that here we only consider untiled dimensions and outer tiled data
124     // dimensions, the inner tiled data dimensions are materialized when
125     // building the body of the operation.
126     auto packOp = cast<PackOp>(op);
127     SmallVector<utils::IteratorType> iteratorTypes(
128         packOp.getSourceRank(), utils::IteratorType::parallel);
129     return iteratorTypes;
130   }
131 
132   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
133     return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
134   }
135 
136   FailureOr<TilingResult>
137   getTiledImplementation(Operation *op, OpBuilder &b,
138                          ArrayRef<OpFoldResult> offsets,
139                          ArrayRef<OpFoldResult> sizes) const {
140     auto packOp = cast<PackOp>(op);
141     Location loc = packOp.getLoc();
142 
143     // The tiling is applied on interchanged dimensions. We have to undo the
144     // interchange to map sizes and offsets to the original input.
145     int64_t inputRank = packOp.getSourceRank();
146     SmallVector<OpFoldResult> origOffsets(offsets);
147     SmallVector<OpFoldResult> origSizes(sizes);
148     applyPermToRange(origOffsets, origSizes,
149                      invertPermutationVector(packOp.getOuterDimsPerm()));
150 
151     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
152         packOp.getDimAndTileMapping();
153     SmallVector<OpFoldResult> srcDimValues =
154         tensor::getMixedSizes(b, loc, packOp.getSource());
155     SmallVector<OpFoldResult> inputIndices, inputSizes;
156     for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
157       using AV = affine::AffineValueExpr;
158       affine::AffineBuilder ab(b, loc);
159       AffineExpr dim0, dim1, sym;
160       bindDims(b.getContext(), dim0, dim1);
161       bindSymbols(b.getContext(), sym);
162       if (dimAndTileMapping.count(dim)) {
163         // If the data dimension is tiled, the i-th index is the product of
164         // offset_i and tile_i, and the i-th size is the product of sizes_i and
165         // tile_i.
166         auto avOffset = AV(dim0).bind(origOffsets[dim]);
167         auto avSize = AV(dim0).bind(origSizes[dim]);
168         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
169         inputIndices.push_back(ab.mul(avOffset, avTileSize));
170         inputSizes.push_back(ab.mul(avSize, avTileSize));
171       } else {
172         inputIndices.push_back(origOffsets[dim]);
173         inputSizes.push_back(origSizes[dim]);
174       }
175 
176       // Limit the size of the input operand for incomplete tiles.
177       if (packOp.getPaddingValue()) {
178         OpFoldResult dimSize = srcDimValues[dim];
179         auto avDimSize = AV(dim0).bind(dimSize);
180         auto avInputIdx = AV(dim1).bind(inputIndices.back());
181         inputSizes.back() =
182             ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
183       }
184     }
185 
186     auto oneAttr = b.getI64IntegerAttr(1);
187     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
188 
189     SmallVector<Value> tiledOperands;
190     tiledOperands.push_back(b.create<ExtractSliceOp>(
191         loc, packOp.getSource(), inputIndices, inputSizes, strides));
192 
193     SmallVector<OpFoldResult> outputOffsets, outputSizes;
194     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
195                                      outputSizes)))
196       return {};
197 
198     strides.append(packOp.getDestRank() - inputRank, oneAttr);
199     auto extractSlice = b.create<ExtractSliceOp>(
200         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
201     tiledOperands.push_back(extractSlice);
202 
203     if (auto val = packOp.getPaddingValue())
204       tiledOperands.push_back(val);
205     for (auto tile : packOp.getInnerTiles())
206       tiledOperands.push_back(tile);
207 
208     Operation *tiledPackOp = b.create<PackOp>(
209         loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
210 
211     return TilingResult{{tiledPackOp},
212                         SmallVector<Value>(tiledPackOp->getResults())};
213   }
214 
215   LogicalResult
216   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
217                         ArrayRef<OpFoldResult> offsets,
218                         ArrayRef<OpFoldResult> sizes,
219                         SmallVector<OpFoldResult> &resultOffsets,
220                         SmallVector<OpFoldResult> &resultSizes) const {
221     // The iteration domain is over outer dimensions of packed layout. In this
222     // context, the outer dimensions of `resultOffsets` are `offsets`. The
223     // inner dimensions of `resultOffsets` are zeros because tiling is not
224     // applied to them.
225     auto packOp = cast<PackOp>(op);
226     int64_t inputRank = packOp.getSourceRank();
227     int64_t outputRank = packOp.getDestRank();
228     auto zeroAttr = b.getI64IntegerAttr(0);
229     resultOffsets.assign(offsets.begin(), offsets.end());
230     resultOffsets.append(outputRank - inputRank, zeroAttr);
231 
232     ReifiedRankedShapedTypeDims outputShape;
233     (void)reifyResultShapes(b, packOp, outputShape);
234     resultSizes.assign(sizes.begin(), sizes.end());
235     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
236       resultSizes.push_back(outputShape[0][dataTileDim]);
237 
238     return success();
239   }
240 
241   FailureOr<TilingResult>
242   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
243                           ArrayRef<OpFoldResult> offsets,
244                           ArrayRef<OpFoldResult> sizes) const {
245     auto packOp = cast<PackOp>(op);
246     int64_t numTiles = packOp.getInnerDimsPos().size();
247 
248     // tensor.pack op is fusible (as a producer) only if full inner tiles are
249     // iterated or inner dims are not tiled. Otherwise, it will generate a
250     // sequence of non-trivial ops (for partial tiles).
251     for (auto offset : offsets.take_back(numTiles))
252       if (!isConstantIntValue(offset, 0))
253         return failure();
254 
255     for (auto iter :
256          llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
257       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
258         return failure();
259 
260     FailureOr<TilingResult> tilingResult = getTiledImplementation(
261         op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
262     if (failed(tilingResult))
263       return failure();
264     return tilingResult.value();
265   }
266 
267   /// Method to return the position of iteration domain tile computed by the
268   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
269   /// `resultSizes` only cover outer dimensions.
270   LogicalResult getIterationDomainTileFromOperandTile(
271       Operation *op, OpBuilder &b, unsigned operandNumber,
272       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
273       SmallVectorImpl<OpFoldResult> &resultOffsets,
274       SmallVectorImpl<OpFoldResult> &resultSizes) const {
275     if (operandNumber != 0)
276       return failure();
277 
278     auto packOp = cast<PackOp>(op);
279     // It is not trivial to infer dest tile from source tile if `packOp` has
280     // padding semantic.
281     if (packOp.getPaddingValue())
282       return failure();
283 
284     Location loc = packOp.getLoc();
285 
286     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
287     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
288         packOp.getDimAndTileMapping();
289     for (auto dim : packOp.getOuterDimsPerm()) {
290       if (dimAndTileMapping.count(dim)) {
291         FailureOr<int64_t> cstSize =
292             ValueBoundsConstraintSet::computeConstantBound(
293                 presburger::BoundType::UB, sizes[dim],
294                 /*stopCondition=*/nullptr, /*closedUB=*/true);
295         std::optional<int64_t> cstInnerSize =
296             getConstantIntValue(dimAndTileMapping[dim]);
297         // Currently fusing `packOp` as consumer only expects perfect tiling
298         // scenario because even if without padding semantic, the `packOp` may
299         // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
300         // where the `tileSize` from operand of `packOp` is 5, which is not
301         // exactly divided by `innerTile`(=6) of `packOp`. As the result:
302         // 1. the first slice is extracted from (0) to (4) and inserted into
303         // (0,0)~(0,4) at first row.
304         // 2. the second slice is extracted from (5) to (9) and SHOULD BE
305         // respectively inserted into two rows with different length, including
306         // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
307         // them, thus adding below constraint to bypass them temporarily. In
308         // another word, we can only support tiling with consumer if the tile
309         // size for the producer is a multiple of the inner tile size for the
310         // packed dimensions at this moment.
311         if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
312           return failure();
313         }
314 
315         using AV = affine::AffineValueExpr;
316         affine::AffineBuilder ab(b, loc);
317         AffineExpr dim0, sym;
318         bindDims(b.getContext(), dim0);
319         bindSymbols(b.getContext(), sym);
320         auto avOffset = AV(dim0).bind(offsets[dim]);
321         auto avSize = AV(dim0).bind(sizes[dim]);
322         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
323         outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
324         outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
325       } else {
326         outerDimOffsets.push_back(offsets[dim]);
327         outerDimSizes.push_back(sizes[dim]);
328       }
329     }
330 
331     resultOffsets = outerDimOffsets;
332     resultSizes = outerDimSizes;
333     return success();
334   }
335 
336   /// Method to return the tiled implementation of tensor.pack as a consumer.
337   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
338       Operation *op, OpBuilder &b, unsigned operandNumber,
339       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
340     if (operandNumber != 0)
341       return failure();
342 
343     auto packOp = cast<PackOp>(op);
344     Location loc = packOp.getLoc();
345 
346     int64_t inputRank = packOp.getSourceRank();
347     auto oneAttr = b.getI64IntegerAttr(1);
348     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
349 
350     SmallVector<Value> tiledOperands;
351     tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
352                                                      offsets, sizes, strides));
353 
354     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
355     if (failed(getIterationDomainTileFromOperandTile(
356             op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
357             outerDimSizes)))
358       return failure();
359 
360     SmallVector<OpFoldResult> outputOffsets, outputSizes;
361     if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
362                                      outputOffsets, outputSizes)))
363       return failure();
364 
365     strides.append(packOp.getDestRank() - inputRank, oneAttr);
366     auto extractSlice = b.create<ExtractSliceOp>(
367         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
368     tiledOperands.push_back(extractSlice);
369 
370     assert(!packOp.getPaddingValue() && "Expect no padding semantic");
371     for (auto tile : packOp.getInnerTiles())
372       tiledOperands.push_back(tile);
373 
374     Operation *tiledPackOp = b.create<PackOp>(
375         loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
376 
377     return TilingResult{{tiledPackOp},
378                         SmallVector<Value>(tiledPackOp->getResults())};
379   }
380 };
381 
382 struct UnpackTileDimInfo {
383   bool isAlignedToInnerTileSize;
384   OpFoldResult sourceOffset;
385   OpFoldResult sourceSize;
386   OpFoldResult resultOffset;
387   OpFoldResult destExpandedSize;
388 };
389 
390 /// Returns the needed information for tiling unpack op on `tileDim` with given
391 /// `tileOffset` and `tileSize`. For more details, see the comment of the
392 /// `getTiledImplementation`.
393 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
394                                               int64_t tileDim,
395                                               OpFoldResult tileOffset,
396                                               OpFoldResult tileSize) {
397   UnpackTileDimInfo info;
398   Attribute zeroAttr = b.getIndexAttr(0);
399   Attribute oneAttr = b.getIndexAttr(1);
400   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
401       unpackOp.getDimAndTileMapping();
402   // The dimension is not one of packed data dimension.
403   if (!dimAndTileMapping.count(tileDim)) {
404     info.isAlignedToInnerTileSize = true;
405     info.sourceOffset = tileOffset;
406     info.sourceSize = tileSize;
407     info.resultOffset = zeroAttr;
408     info.destExpandedSize = tileSize;
409     return info;
410   }
411 
412   Location loc = unpackOp.getLoc();
413   using AV = affine::AffineValueExpr;
414   affine::AffineBuilder ab(b, loc);
415   AffineExpr dim0, dim1, sym0;
416   bindDims(b.getContext(), dim0, dim1);
417   bindSymbols(b.getContext(), sym0);
418 
419   OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
420 
421   info.isAlignedToInnerTileSize = false;
422   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
423       presburger::BoundType::UB, tileSize,
424       /*stopCondition=*/nullptr, /*closedUB=*/true);
425   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
426   if (!failed(cstSize) && cstInnerSize) {
427     if (*cstSize % *cstInnerSize == 0)
428       info.isAlignedToInnerTileSize = true;
429 
430     // If the tiling size equals to the inner tiling size, the outer dims are
431     // always 1.
432     if (*cstInnerSize == *cstSize) {
433       auto lhs = AV(dim0).bind(tileOffset);
434       auto rhs = AV(dim1).bind(innerTileSize);
435       info.sourceOffset = ab.floor(lhs, rhs);
436       info.sourceSize = oneAttr;
437       info.resultOffset = zeroAttr;
438       info.destExpandedSize = tileSize;
439       return info;
440     }
441   }
442 
443   if (info.isAlignedToInnerTileSize) {
444     info.sourceOffset =
445         ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
446     info.resultOffset = zeroAttr;
447     info.destExpandedSize = tileSize;
448 
449     // The ceilDiv is needed here because there could be incomplete tile even
450     // it is perfect tiling cases. E.g.,
451     //   %0 = unpack tensor<33x2xf32> into tensor<64xf32>
452     // If the tiling size is 32, there will be 3 tiles. Two of them have
453     // size=32; one of them have size=2. The size is represented using
454     // affine_min op; we need ceilDiv.
455     info.sourceSize =
456         ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
457     return info;
458   }
459 
460   affine::DivModValue firstCoord = affine::getDivMod(
461       b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
462       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
463   OpFoldResult tileExclusiveBound =
464       ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
465   affine::DivModValue lastCoord = affine::getDivMod(
466       b, loc,
467       getValueOrCreateConstantIndexOp(
468           b, loc,
469           ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
470       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
471 
472   OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
473                                        AV(dim1).bind(firstCoord.quotient));
474   info.sourceSize =
475       ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
476   info.sourceOffset = firstCoord.quotient;
477   info.resultOffset = firstCoord.remainder;
478   // Do not create an Affine ops for expanded size because the affine op is too
479   // complicated which would trigger an issue in affine ops simplification.
480   info.destExpandedSize = b.createOrFold<arith::MulIOp>(
481       loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
482       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
483   return info;
484 }
485 
486 struct UnPackOpTiling
487     : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
488 
489   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
490     auto unpackOp = cast<UnPackOp>(op);
491     SmallVector<utils::IteratorType> iteratorTypes(
492         unpackOp.getDestRank(), utils::IteratorType::parallel);
493     return iteratorTypes;
494   }
495 
496   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
497     return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
498   }
499 
500   /// There are two cases in tiling unpack ops. If the tiling size is aligned to
501   /// the inner tile size, the corresponding tiles of source are all complete.
502   /// Otherwise, there are in-complete tiles. We will need to expand the slice
503   /// of source for getting complete tiles. The tiled unpack op unpacks more
504   /// data from source, so We'll need an extract_slice op to shift and truncate
505   /// the output.
506   /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
507   /// coordinates of second tile (i.e., result[15..31]) are
508   /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
509   /// row are incomplete tiles. To represent the unpack op, we have to complete
510   /// the rows. I.e., the input coordinates would start with (1, 0); end with
511   /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
512   /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
513   /// can get the actual result.
514   FailureOr<TilingResult>
515   getTiledImplementation(Operation *op, OpBuilder &b,
516                          ArrayRef<OpFoldResult> offsets,
517                          ArrayRef<OpFoldResult> sizes) const {
518     auto unpackOp = cast<UnPackOp>(op);
519     int64_t srcRank = unpackOp.getSourceRank();
520     int64_t destRank = unpackOp.getDestRank();
521     int64_t numInnerTiles = srcRank - destRank;
522     Location loc = unpackOp.getLoc();
523 
524     // The perfect tiling case indicates that the tiling sizes are multiple of
525     // inner_tile_size. In this context, no extra data is needed when
526     // representing the tiled unpack op.
527     bool isPerfectTilingCase = true;
528     Attribute oneAttr = b.getIndexAttr(1);
529     SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
530     SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
531     SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
532     for (auto dim : llvm::seq<int64_t>(0, destRank)) {
533       UnpackTileDimInfo info =
534           getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
535       if (!info.isAlignedToInnerTileSize)
536         isPerfectTilingCase = false;
537       sliceSrcIndices.push_back(info.sourceOffset);
538       sliceSrcSizes.push_back(info.sourceSize);
539       destExpandedSizes.push_back(info.destExpandedSize);
540       resultOffsetsFromDest.push_back(info.resultOffset);
541     }
542 
543     // The tiling is applied on destination dimensions. We have to apply the
544     // interchange on source dimensions if outer_dims_perm is set.
545     applyPermToRange(sliceSrcIndices, sliceSrcSizes,
546                      unpackOp.getOuterDimsPerm());
547     Attribute zeroAttr = b.getIndexAttr(0);
548     sliceSrcIndices.append(numInnerTiles, zeroAttr);
549     sliceSrcSizes.append(unpackOp.getMixedTiles());
550     sliceSrcStrides.append(numInnerTiles, oneAttr);
551     Value sliceSource =
552         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
553                                  sliceSrcSizes, sliceSrcStrides);
554 
555     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
556     Value sliceDest;
557     if (isPerfectTilingCase) {
558       sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
559                                            sizes, destStrides);
560     } else {
561       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
562                                     unpackOp.getDestType().getElementType());
563     }
564 
565     SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
566     for (auto tile : unpackOp.getInnerTiles())
567       tiledOperands.push_back(tile);
568 
569     Operation *tiledUnpackOp = b.create<UnPackOp>(
570         loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
571 
572     if (isPerfectTilingCase)
573       return TilingResult{{tiledUnpackOp},
574                           SmallVector<Value>(tiledUnpackOp->getResults())};
575 
576     auto extractSlice =
577         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
578                                  resultOffsetsFromDest, sizes, destStrides);
579     return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
580   }
581 
582   LogicalResult
583   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
584                         ArrayRef<OpFoldResult> offsets,
585                         ArrayRef<OpFoldResult> sizes,
586                         SmallVector<OpFoldResult> &resultOffsets,
587                         SmallVector<OpFoldResult> &resultSizes) const {
588     resultOffsets = llvm::to_vector(offsets);
589     resultSizes = llvm::to_vector(sizes);
590     return success();
591   }
592 
593   FailureOr<TilingResult>
594   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
595                           ArrayRef<OpFoldResult> offsets,
596                           ArrayRef<OpFoldResult> sizes) const {
597     FailureOr<TilingResult> tilingResult =
598         getTiledImplementation(op, b, offsets, sizes);
599     if (failed(tilingResult))
600       return failure();
601     return tilingResult.value();
602   }
603 
604   /// Method to return the position of iteration domain tile computed by the
605   /// tiled operation.
606   LogicalResult getIterationDomainTileFromOperandTile(
607       Operation *op, OpBuilder &b, unsigned operandNumber,
608       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
609       SmallVectorImpl<OpFoldResult> &resultOffsets,
610       SmallVectorImpl<OpFoldResult> &resultSizes) const {
611     auto unPackOp = cast<UnPackOp>(op);
612     Location loc = unPackOp.getLoc();
613 
614     int64_t numTiles = unPackOp.getInnerDimsPos().size();
615     auto destOffsets = offsets.drop_back(numTiles);
616     auto destSizes = sizes.drop_back(numTiles);
617     // The tiling is applied on interchanged dimensions. We have to undo the
618     // interchange to map sizes and offsets to the original input.
619     int64_t outputRank = unPackOp.getDestRank();
620     SmallVector<OpFoldResult> origOffsets(destOffsets);
621     SmallVector<OpFoldResult> origSizes(destSizes);
622     applyPermToRange(origOffsets, origSizes,
623                      invertPermutationVector(unPackOp.getOuterDimsPerm()));
624 
625     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
626         unPackOp.getDimAndTileMapping();
627 
628     for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
629       using AV = affine::AffineValueExpr;
630       affine::AffineBuilder ab(b, loc);
631       AffineExpr dim0, dim1, sym;
632       bindDims(b.getContext(), dim0, dim1);
633       bindSymbols(b.getContext(), sym);
634       if (dimAndTileMapping.count(dim)) {
635         // If the data dimension is tiled, the i-th index is the product of
636         // offset_i and tile_i, and the i-th size is the product of sizes_i and
637         // tile_i.
638         auto avOffset = AV(dim0).bind(origOffsets[dim]);
639         auto avSize = AV(dim0).bind(origSizes[dim]);
640         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
641         resultOffsets.push_back(ab.mul(avOffset, avTileSize));
642         resultSizes.push_back(ab.mul(avSize, avTileSize));
643       } else {
644         resultOffsets.push_back(origOffsets[dim]);
645         resultSizes.push_back(origSizes[dim]);
646       }
647     }
648     return success();
649   }
650 
651   /// Method to return the tiled implementation of tensor.unpack as a consumer.
652   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
653       Operation *op, OpBuilder &b, unsigned operandNumber,
654       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
655     auto unPackOp = cast<UnPackOp>(op);
656     // tensor.unpack op is fusible (as a consumer) only if inner dims are not
657     // tiled.
658     int64_t numTiles = unPackOp.getInnerDimsPos().size();
659     for (auto iter :
660          llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
661       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
662         return failure();
663     }
664 
665     Location loc = unPackOp.getLoc();
666 
667     // Fetch offset/size for creating the slice of the dest operand of
668     // unpack op.
669     SmallVector<OpFoldResult> outputOffsets, outputSizes;
670     if (failed(getIterationDomainTileFromOperandTile(
671             op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
672             outputSizes)))
673       return failure();
674 
675     auto oneAttr = b.getI64IntegerAttr(1);
676     int64_t outputRank = unPackOp.getDestRank();
677     SmallVector<OpFoldResult> strides(outputRank, oneAttr);
678 
679     SmallVector<Value> tiledOperands;
680     // Create slice of the dest operand.
681     auto extractDestSlice = b.create<ExtractSliceOp>(
682         loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
683     tiledOperands.push_back(extractDestSlice);
684 
685     SmallVector<OpFoldResult> inputOffsets, inputSizes;
686     strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
687     // Create slice of the source operand.
688     auto extractSourceSlice = b.create<ExtractSliceOp>(
689         loc, unPackOp.getSource(), offsets, sizes, strides);
690     tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
691     for (auto tile : unPackOp.getInnerTiles())
692       tiledOperands.push_back(tile);
693 
694     // Create tiled unpack op.
695     Operation *tiledUnPackOp =
696         b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
697                            tiledOperands, op->getAttrs());
698 
699     return TilingResult{{tiledUnPackOp},
700                         SmallVector<Value>(tiledUnPackOp->getResults())};
701   }
702 };
703 
704 } // namespace
705 
706 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
707                                                  tensor::PadOp padOp,
708                                                  ArrayRef<OpFoldResult> offsets,
709                                                  ArrayRef<OpFoldResult> sizes,
710                                                  bool generateZeroSliceGuard) {
711   // Only constant padding value supported.
712   Value padValue = padOp.getConstantPaddingValue();
713   if (!padValue)
714     return failure();
715 
716   // Helper variables and functions for various arithmetic operations. These
717   // are used extensively for computing new offset/length and padding values.
718   Location loc = padOp->getLoc();
719   AffineExpr dim0, dim1;
720   bindDims(b.getContext(), dim0, dim1);
721   // Add two integers.
722   auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
723   auto add = [&](OpFoldResult v1, OpFoldResult v2) {
724     return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2});
725   };
726   // Subtract two integers.
727   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
728   auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
729     return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
730   };
731   // Take the minimum of two integers.
732   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
733   auto min = [&](OpFoldResult v1, OpFoldResult v2) {
734     return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
735   };
736   // Take the maximum of two integers.
737   auto max = [&](OpFoldResult v1, OpFoldResult v2) {
738     return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
739   };
740   // Zero index-typed integer.
741   OpFoldResult zero = b.getIndexAttr(0);
742 
743   // Compute new offsets, lengths, low padding, high padding.
744   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
745   SmallVector<OpFoldResult> newLows, newHighs;
746   // Set to true if the original data source is not read at all.
747   bool hasZeroLen = false;
748   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
749   // is true if the original data source turns out to be unused at runtime.
750   Value dynHasZeroLenCond;
751 
752   int64_t rank = padOp.getSourceType().getRank();
753   for (unsigned dim = 0; dim < rank; ++dim) {
754     auto low = padOp.getMixedLowPad()[dim];
755     bool hasLowPad = !isConstantIntValue(low, 0);
756     auto high = padOp.getMixedHighPad()[dim];
757     bool hasHighPad = !isConstantIntValue(high, 0);
758     auto offset = offsets[dim];
759     auto length = sizes[dim];
760     auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim);
761 
762     // The new amount of low padding is `low - offset`. Except for the case
763     // where none of the low padding is read. In that case, the new amount of
764     // low padding is zero.
765     //
766     // Optimization: If low = 0, then newLow = 0.
767     OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
768     newLows.push_back(newLow);
769 
770     // Start reading the data from position `offset - low`. Since the original
771     // read may have started in the low padding zone, this value could be
772     // negative. Therefore, start reading from:
773     //
774     // max(offset - low, 0)
775     //
776     // The original read could also have started in the high padding zone.
777     // In that case, set the offset to the end of source tensor. The new
778     // ExtractSliceOp length will be zero in that case. (Effectively reading
779     // no data from the source.)
780     //
781     // Optimization: If low = 0, then the formula can be simplified.
782     OpFoldResult newOffset = hasLowPad
783                                  ? min(max(sub(offset, low), zero), srcSize)
784                                  : min(offset, srcSize);
785     newOffsets.push_back(newOffset);
786 
787     // The original ExtractSliceOp was reading until position `offset +
788     // length`. Therefore, the corresponding position within the source tensor
789     // is:
790     //
791     // offset + length - low
792     //
793     // In case the original ExtractSliceOp stopped reading within the low
794     // padding zone, this value can be negative. In that case, the end
795     // position of the read should be zero. (Similar to newOffset.)
796     //
797     // The original read could also have stopped in the high padding zone.
798     // In that case, set the end positition of the read should be the end of
799     // the source tensor. (Similar to newOffset.)
800     //
801     // endLoc = min(max(offset - low + length, 0), srcSize)
802     //
803     // The new ExtractSliceOp length is `endLoc - newOffset`.
804     //
805     // Optimization: If low = 0, then the formula can be simplified.
806     OpFoldResult endLoc =
807         hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
808                   : min(add(offset, length), srcSize);
809     OpFoldResult newLength = sub(endLoc, newOffset);
810     newLengths.push_back(newLength);
811 
812     // Check if newLength is zero. In that case, no SubTensorOp should be
813     // executed.
814     if (isConstantIntValue(newLength, 0)) {
815       hasZeroLen = true;
816     } else if (!hasZeroLen) {
817       Value check = b.create<arith::CmpIOp>(
818           loc, arith::CmpIPredicate::eq,
819           getValueOrCreateConstantIndexOp(b, loc, newLength),
820           getValueOrCreateConstantIndexOp(b, loc, zero));
821       dynHasZeroLenCond =
822           dynHasZeroLenCond
823               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
824               : check;
825     }
826 
827     // The amount of high padding is simply the number of elements remaining,
828     // so that the result has the same length as the original ExtractSliceOp.
829     // As an optimization, if the original high padding is zero, then the new
830     // high padding must also be zero.
831     OpFoldResult newHigh =
832         hasHighPad ? sub(sub(length, newLength), newLow) : zero;
833     newHighs.push_back(newHigh);
834 
835     // Only unit stride supported.
836     newStrides.push_back(b.getIndexAttr(1));
837   }
838 
839   // The shape of the result can be obtained from the sizes passed in.
840   SmallVector<Value> dynDims;
841   SmallVector<int64_t> shape;
842   dispatchIndexOpFoldResults(sizes, dynDims, shape);
843   RankedTensorType resultType =
844       RankedTensorType::get(shape, padOp.getResultType().getElementType());
845 
846   // Insert cast to ensure that types match. (May be folded away.)
847   auto castResult = [&](Value val) -> Value {
848     if (resultType == val.getType())
849       return val;
850     return b.create<tensor::CastOp>(loc, resultType, val);
851   };
852 
853   // In cases where the original data source is unused: Emit a GenerateOp and
854   // do not generate a SliceOp. (The result shape of the SliceOp would
855   // have a dimension of size 0, the semantics of which is unclear.)
856   auto createGenerateOp = [&]() {
857     // Create GenerateOp.
858     auto generateOp = b.create<tensor::GenerateOp>(
859         loc, resultType, dynDims,
860         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
861           builder.create<tensor::YieldOp>(gLoc, padValue);
862         });
863     return generateOp;
864   };
865 
866   // Emit a SliceOp and a PadOp. Should not be used in cases where
867   // the result shape of the new SliceOp has a zero dimension.
868   auto createPadOfExtractSlice = [&]() {
869     // Create pad(extract_slice(x)).
870     Value newSliceOp = b.create<tensor::ExtractSliceOp>(
871         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
872     auto newPadOp = b.create<PadOp>(
873         loc, Type(), newSliceOp, newLows, newHighs,
874         /*nofold=*/padOp.getNofold(),
875         getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
876 
877     // Copy region to new PadOp.
878     IRMapping bvm;
879     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
880 
881     // Cast result and return.
882     return newPadOp;
883   };
884 
885   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
886   // the original data source x is not used.
887   if (hasZeroLen) {
888     Operation *generateOp = createGenerateOp();
889     return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
890   }
891 
892   // If there are dynamic dimensions: Generate an scf.if check to avoid
893   // creating SliceOps with result dimensions of size 0 at runtime.
894   if (generateZeroSliceGuard && dynHasZeroLenCond) {
895     Operation *thenOp;
896     Operation *elseOp;
897     auto result = b.create<scf::IfOp>(
898         loc, dynHasZeroLenCond,
899         /*thenBuilder=*/
900         [&](OpBuilder &b, Location loc) {
901           thenOp = createGenerateOp();
902           b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
903         },
904         /*elseBuilder=*/
905         [&](OpBuilder &b, Location loc) {
906           elseOp = createPadOfExtractSlice();
907           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
908         });
909     return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
910   }
911 
912   Operation *newPadOp = createPadOfExtractSlice();
913   return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
914 }
915 
916 void mlir::tensor::registerTilingInterfaceExternalModels(
917     DialectRegistry &registry) {
918   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
919     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
920     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
921     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
922   });
923 }
924 
925 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
926     DialectRegistry &registry) {
927   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
928     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
929     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
930   });
931 }
932