xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (revision f1595ecfdce5387e41826fd72ff930a1a39ae398)
1 //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/Utils.h"
12 #include "mlir/Dialect/Arith/Utils/Utils.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21 
22 using namespace mlir;
23 using namespace mlir::tensor;
24 
25 namespace {
26 
27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
28 
29   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30     auto padOp = cast<PadOp>(op);
31     SmallVector<utils::IteratorType> iteratorTypes(
32         padOp.getResultType().getRank(), utils::IteratorType::parallel);
33     return iteratorTypes;
34   }
35 
36   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
37     ReifiedRankedShapedTypeDims reifiedShapes;
38     (void)reifyResultShapes(b, op, reifiedShapes);
39     OpFoldResult zero = b.getIndexAttr(0);
40     OpFoldResult one = b.getIndexAttr(1);
41     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
42     // overwritten.
43     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
44     for (const auto &ub : enumerate(reifiedShapes[0]))
45       loopRanges[ub.index()].size = ub.value();
46     return loopRanges;
47   }
48 
49   FailureOr<TilingResult>
50   getTiledImplementation(Operation *op, OpBuilder &b,
51                          ArrayRef<OpFoldResult> offsets,
52                          ArrayRef<OpFoldResult> sizes) const {
53     FailureOr<TilingResult> result =
54         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
55     if (failed(result))
56       return failure();
57     return result.value();
58   }
59 
60   LogicalResult
61   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
62                         ArrayRef<OpFoldResult> offsets,
63                         ArrayRef<OpFoldResult> sizes,
64                         SmallVector<OpFoldResult> &resultOffsets,
65                         SmallVector<OpFoldResult> &resultSizes) const {
66     resultOffsets.assign(offsets.begin(), offsets.end());
67     resultSizes.assign(sizes.begin(), sizes.end());
68     return success();
69   }
70 
71   LogicalResult getIterationDomainTileFromResultTile(
72       Operation *op, OpBuilder &b, unsigned resultNumber,
73       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
74       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
75       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
76     iterDomainOffsets.assign(offsets.begin(), offsets.end());
77     iterDomainSizes.assign(sizes.begin(), sizes.end());
78     return success();
79   }
80 
81   FailureOr<TilingResult>
82   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
83                           ArrayRef<OpFoldResult> offsets,
84                           ArrayRef<OpFoldResult> sizes) const {
85     return getTiledImplementation(op, b, offsets, sizes);
86   }
87 };
88 
89 template <typename OpTy>
90 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
91                                                        OpBuilder &builder) {
92   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
93                 "applies to only pack or unpack operations");
94   OpBuilder::InsertionGuard g(builder);
95   int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
96                                                      : op.getDestRank();
97   OpFoldResult zero = builder.getIndexAttr(0);
98   OpFoldResult one = builder.getIndexAttr(1);
99   ReifiedRankedShapedTypeDims resultShape;
100   (void)reifyResultShapes(builder, op, resultShape);
101   SmallVector<Range> loopBounds(rank);
102   for (auto dim : llvm::seq<int64_t>(0, rank)) {
103     loopBounds[dim].offset = zero;
104     loopBounds[dim].stride = one;
105     loopBounds[dim].size = resultShape[0][dim];
106   }
107   return loopBounds;
108 }
109 
110 static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
111                              SmallVector<OpFoldResult> &sizes,
112                              ArrayRef<int64_t> permutation) {
113   if (permutation.empty())
114     return;
115   applyPermutationToVector<OpFoldResult>(offsets, permutation);
116   applyPermutationToVector<OpFoldResult>(sizes, permutation);
117 }
118 
119 struct PackOpTiling
120     : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
121 
122   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
123     // Note that here we only consider untiled dimensions and outer tiled data
124     // dimensions, the inner tiled data dimensions are materialized when
125     // building the body of the operation.
126     auto packOp = cast<PackOp>(op);
127     SmallVector<utils::IteratorType> iteratorTypes(
128         packOp.getSourceRank(), utils::IteratorType::parallel);
129     return iteratorTypes;
130   }
131 
132   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
133     return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
134   }
135 
136   FailureOr<TilingResult>
137   getTiledImplementation(Operation *op, OpBuilder &b,
138                          ArrayRef<OpFoldResult> offsets,
139                          ArrayRef<OpFoldResult> sizes) const {
140     auto packOp = cast<PackOp>(op);
141     Location loc = packOp.getLoc();
142 
143     // The tiling is applied on interchanged dimensions. We have to undo the
144     // interchange to map sizes and offsets to the original input.
145     int64_t inputRank = packOp.getSourceRank();
146     SmallVector<OpFoldResult> origOffsets(offsets);
147     SmallVector<OpFoldResult> origSizes(sizes);
148     applyPermToRange(origOffsets, origSizes,
149                      invertPermutationVector(packOp.getOuterDimsPerm()));
150 
151     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
152         packOp.getDimAndTileMapping();
153     SmallVector<OpFoldResult> srcDimValues =
154         tensor::getMixedSizes(b, loc, packOp.getSource());
155     SmallVector<OpFoldResult> inputIndices, inputSizes;
156     for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
157       using AV = affine::AffineValueExpr;
158       affine::AffineBuilder ab(b, loc);
159       AffineExpr dim0, dim1, sym;
160       bindDims(b.getContext(), dim0, dim1);
161       bindSymbols(b.getContext(), sym);
162       if (dimAndTileMapping.count(dim)) {
163         // If the data dimension is tiled, the i-th index is the product of
164         // offset_i and tile_i, and the i-th size is the product of sizes_i and
165         // tile_i.
166         auto avOffset = AV(dim0).bind(origOffsets[dim]);
167         auto avSize = AV(dim0).bind(origSizes[dim]);
168         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
169         inputIndices.push_back(ab.mul(avOffset, avTileSize));
170         inputSizes.push_back(ab.mul(avSize, avTileSize));
171       } else {
172         inputIndices.push_back(origOffsets[dim]);
173         inputSizes.push_back(origSizes[dim]);
174       }
175 
176       // Limit the size of the input operand for incomplete tiles.
177       if (packOp.getPaddingValue()) {
178         OpFoldResult dimSize = srcDimValues[dim];
179         auto avDimSize = AV(dim0).bind(dimSize);
180         auto avInputIdx = AV(dim1).bind(inputIndices.back());
181         inputSizes.back() =
182             ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
183       }
184     }
185 
186     auto oneAttr = b.getI64IntegerAttr(1);
187     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
188 
189     SmallVector<Value> tiledOperands;
190     auto sourceSlice = b.create<ExtractSliceOp>(
191         loc, packOp.getSource(), inputIndices, inputSizes, strides);
192     tiledOperands.push_back(sourceSlice);
193 
194     SmallVector<OpFoldResult> outputOffsets, outputSizes;
195     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
196                                      outputSizes)))
197       return {};
198 
199     strides.append(packOp.getDestRank() - inputRank, oneAttr);
200     auto outSlice = b.create<ExtractSliceOp>(
201         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
202     tiledOperands.push_back(outSlice);
203 
204     if (auto val = packOp.getPaddingValue())
205       tiledOperands.push_back(val);
206     for (auto tile : packOp.getInnerTiles())
207       tiledOperands.push_back(tile);
208 
209     Operation *tiledPackOp = b.create<PackOp>(
210         loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
211 
212     return TilingResult{
213         {tiledPackOp},
214         SmallVector<Value>(tiledPackOp->getResults()),
215         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
216   }
217 
218   LogicalResult
219   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
220                         ArrayRef<OpFoldResult> offsets,
221                         ArrayRef<OpFoldResult> sizes,
222                         SmallVector<OpFoldResult> &resultOffsets,
223                         SmallVector<OpFoldResult> &resultSizes) const {
224     // The iteration domain is over outer dimensions of packed layout. In this
225     // context, the outer dimensions of `resultOffsets` are `offsets`. The
226     // inner dimensions of `resultOffsets` are zeros because tiling is not
227     // applied to them.
228     auto packOp = cast<PackOp>(op);
229     int64_t inputRank = packOp.getSourceRank();
230     int64_t outputRank = packOp.getDestRank();
231     auto zeroAttr = b.getI64IntegerAttr(0);
232     resultOffsets.assign(offsets.begin(), offsets.end());
233     resultOffsets.append(outputRank - inputRank, zeroAttr);
234 
235     ReifiedRankedShapedTypeDims outputShape;
236     (void)reifyResultShapes(b, packOp, outputShape);
237     resultSizes.assign(sizes.begin(), sizes.end());
238     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
239       resultSizes.push_back(outputShape[0][dataTileDim]);
240 
241     return success();
242   }
243 
244   FailureOr<TilingResult>
245   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
246                           ArrayRef<OpFoldResult> offsets,
247                           ArrayRef<OpFoldResult> sizes) const {
248     auto packOp = cast<PackOp>(op);
249     int64_t numTiles = packOp.getInnerDimsPos().size();
250 
251     // tensor.pack op is fusible (as a producer) only if full inner tiles are
252     // iterated or inner dims are not tiled. Otherwise, it will generate a
253     // sequence of non-trivial ops (for partial tiles).
254     for (auto offset : offsets.take_back(numTiles))
255       if (!isConstantIntValue(offset, 0))
256         return failure();
257 
258     for (auto iter :
259          llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
260       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
261         return failure();
262 
263     FailureOr<TilingResult> tilingResult = getTiledImplementation(
264         op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
265     if (failed(tilingResult))
266       return failure();
267     return tilingResult.value();
268   }
269 
270   /// Method to return the position of iteration domain tile computed by the
271   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
272   /// `resultSizes` only cover outer dimensions.
273   LogicalResult getIterationDomainTileFromOperandTile(
274       Operation *op, OpBuilder &b, unsigned operandNumber,
275       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
276       SmallVectorImpl<OpFoldResult> &resultOffsets,
277       SmallVectorImpl<OpFoldResult> &resultSizes) const {
278     if (operandNumber != 0)
279       return failure();
280 
281     auto packOp = cast<PackOp>(op);
282     // It is not trivial to infer dest tile from source tile if `packOp` has
283     // padding semantic.
284     if (packOp.getPaddingValue())
285       return failure();
286 
287     Location loc = packOp.getLoc();
288 
289     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
290     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
291         packOp.getDimAndTileMapping();
292     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
293       if (dimAndTileMapping.count(dim)) {
294         FailureOr<int64_t> cstSize =
295             ValueBoundsConstraintSet::computeConstantBound(
296                 presburger::BoundType::UB, sizes[dim],
297                 /*stopCondition=*/nullptr, /*closedUB=*/true);
298         std::optional<int64_t> cstInnerSize =
299             getConstantIntValue(dimAndTileMapping[dim]);
300         // Currently fusing `packOp` as consumer only expects perfect tiling
301         // scenario because even if without padding semantic, the `packOp` may
302         // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
303         // where the `tileSize` from operand of `packOp` is 5, which is not
304         // exactly divided by `innerTile`(=6) of `packOp`. As the result:
305         // 1. the first slice is extracted from (0) to (4) and inserted into
306         // (0,0)~(0,4) at first row.
307         // 2. the second slice is extracted from (5) to (9) and SHOULD BE
308         // respectively inserted into two rows with different length, including
309         // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
310         // them, thus adding below constraint to bypass them temporarily. In
311         // another word, we can only support tiling with consumer if the tile
312         // size for the producer is a multiple of the inner tile size for the
313         // packed dimensions at this moment.
314         if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
315           return failure();
316         }
317 
318         using AV = affine::AffineValueExpr;
319         affine::AffineBuilder ab(b, loc);
320         AffineExpr dim0, sym;
321         bindDims(b.getContext(), dim0);
322         bindSymbols(b.getContext(), sym);
323         auto avOffset = AV(dim0).bind(offsets[dim]);
324         auto avSize = AV(dim0).bind(sizes[dim]);
325         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
326         outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
327         outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
328       } else {
329         outerDimOffsets.push_back(offsets[dim]);
330         outerDimSizes.push_back(sizes[dim]);
331       }
332     }
333     applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
334     resultOffsets = outerDimOffsets;
335     resultSizes = outerDimSizes;
336     return success();
337   }
338 
339   /// Method to return the tiled implementation of tensor.pack as a consumer.
340   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
341       Operation *op, OpBuilder &b, unsigned operandNumber,
342       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
343     if (operandNumber != 0)
344       return failure();
345 
346     auto packOp = cast<PackOp>(op);
347     Location loc = packOp.getLoc();
348 
349     int64_t inputRank = packOp.getSourceRank();
350     auto oneAttr = b.getI64IntegerAttr(1);
351     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
352 
353     SmallVector<Value> tiledOperands;
354     auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
355                                                 offsets, sizes, strides);
356     tiledOperands.push_back(sourceSlice);
357 
358     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
359     if (failed(getIterationDomainTileFromOperandTile(
360             op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
361             outerDimSizes)))
362       return failure();
363 
364     SmallVector<OpFoldResult> outputOffsets, outputSizes;
365     if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
366                                      outputOffsets, outputSizes)))
367       return failure();
368 
369     strides.append(packOp.getDestRank() - inputRank, oneAttr);
370     auto outSlice = b.create<ExtractSliceOp>(
371         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
372     tiledOperands.push_back(outSlice);
373 
374     assert(!packOp.getPaddingValue() && "Expect no padding semantic");
375     for (auto tile : packOp.getInnerTiles())
376       tiledOperands.push_back(tile);
377 
378     Operation *tiledPackOp = b.create<PackOp>(
379         loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
380 
381     return TilingResult{
382         {tiledPackOp},
383         SmallVector<Value>(tiledPackOp->getResults()),
384         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
385   }
386 };
387 
388 struct UnpackTileDimInfo {
389   bool isAlignedToInnerTileSize;
390   OpFoldResult sourceOffset;
391   OpFoldResult sourceSize;
392   OpFoldResult resultOffset;
393   OpFoldResult destExpandedSize;
394 };
395 
396 /// Returns the needed information for tiling unpack op on `tileDim` with given
397 /// `tileOffset` and `tileSize`. For more details, see the comment of the
398 /// `getTiledImplementation`.
399 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
400                                               int64_t tileDim,
401                                               OpFoldResult tileOffset,
402                                               OpFoldResult tileSize) {
403   UnpackTileDimInfo info;
404   Attribute zeroAttr = b.getIndexAttr(0);
405   Attribute oneAttr = b.getIndexAttr(1);
406   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
407       unpackOp.getDimAndTileMapping();
408   // The dimension is not one of packed data dimension.
409   if (!dimAndTileMapping.count(tileDim)) {
410     info.isAlignedToInnerTileSize = true;
411     info.sourceOffset = tileOffset;
412     info.sourceSize = tileSize;
413     info.resultOffset = zeroAttr;
414     info.destExpandedSize = tileSize;
415     return info;
416   }
417 
418   Location loc = unpackOp.getLoc();
419   using AV = affine::AffineValueExpr;
420   affine::AffineBuilder ab(b, loc);
421   AffineExpr dim0, dim1, sym0;
422   bindDims(b.getContext(), dim0, dim1);
423   bindSymbols(b.getContext(), sym0);
424 
425   OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
426 
427   info.isAlignedToInnerTileSize = false;
428   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
429       presburger::BoundType::UB, tileSize,
430       /*stopCondition=*/nullptr, /*closedUB=*/true);
431   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
432   if (!failed(cstSize) && cstInnerSize) {
433     if (*cstSize % *cstInnerSize == 0)
434       info.isAlignedToInnerTileSize = true;
435 
436     // If the tiling size equals to the inner tiling size, the outer dims are
437     // always 1.
438     if (*cstInnerSize == *cstSize) {
439       auto lhs = AV(dim0).bind(tileOffset);
440       auto rhs = AV(dim1).bind(innerTileSize);
441       info.sourceOffset = ab.floor(lhs, rhs);
442       info.sourceSize = oneAttr;
443       info.resultOffset = zeroAttr;
444       info.destExpandedSize = tileSize;
445       return info;
446     }
447   }
448 
449   if (info.isAlignedToInnerTileSize) {
450     info.sourceOffset =
451         ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
452     info.resultOffset = zeroAttr;
453     info.destExpandedSize = tileSize;
454 
455     // The ceilDiv is needed here because there could be incomplete tile even
456     // it is perfect tiling cases. E.g.,
457     //   %0 = unpack tensor<33x2xf32> into tensor<64xf32>
458     // If the tiling size is 32, there will be 3 tiles. Two of them have
459     // size=32; one of them have size=2. The size is represented using
460     // affine_min op; we need ceilDiv.
461     info.sourceSize =
462         ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
463     return info;
464   }
465 
466   affine::DivModValue firstCoord = affine::getDivMod(
467       b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
468       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
469   OpFoldResult tileExclusiveBound =
470       ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
471   affine::DivModValue lastCoord = affine::getDivMod(
472       b, loc,
473       getValueOrCreateConstantIndexOp(
474           b, loc,
475           ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
476       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
477 
478   OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
479                                        AV(dim1).bind(firstCoord.quotient));
480   info.sourceSize =
481       ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
482   info.sourceOffset = firstCoord.quotient;
483   info.resultOffset = firstCoord.remainder;
484   // Do not create an Affine ops for expanded size because the affine op is too
485   // complicated which would trigger an issue in affine ops simplification.
486   info.destExpandedSize = b.createOrFold<arith::MulIOp>(
487       loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
488       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
489   return info;
490 }
491 
492 struct UnPackOpTiling
493     : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
494 
495   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
496     auto unpackOp = cast<UnPackOp>(op);
497     SmallVector<utils::IteratorType> iteratorTypes(
498         unpackOp.getDestRank(), utils::IteratorType::parallel);
499     return iteratorTypes;
500   }
501 
502   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
503     return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
504   }
505 
506   /// There are two cases in tiling unpack ops. If the tiling size is aligned to
507   /// the inner tile size, the corresponding tiles of source are all complete.
508   /// Otherwise, there are in-complete tiles. We will need to expand the slice
509   /// of source for getting complete tiles. The tiled unpack op unpacks more
510   /// data from source, so We'll need an extract_slice op to shift and truncate
511   /// the output.
512   /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
513   /// coordinates of second tile (i.e., result[15..31]) are
514   /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
515   /// row are incomplete tiles. To represent the unpack op, we have to complete
516   /// the rows. I.e., the input coordinates would start with (1, 0); end with
517   /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
518   /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
519   /// can get the actual result.
520   FailureOr<TilingResult>
521   getTiledImplementation(Operation *op, OpBuilder &b,
522                          ArrayRef<OpFoldResult> offsets,
523                          ArrayRef<OpFoldResult> sizes) const {
524     auto unpackOp = cast<UnPackOp>(op);
525     int64_t srcRank = unpackOp.getSourceRank();
526     int64_t destRank = unpackOp.getDestRank();
527     int64_t numInnerTiles = srcRank - destRank;
528     Location loc = unpackOp.getLoc();
529 
530     // The perfect tiling case indicates that the tiling sizes are multiple of
531     // inner_tile_size. In this context, no extra data is needed when
532     // representing the tiled unpack op.
533     bool isPerfectTilingCase = true;
534     Attribute oneAttr = b.getIndexAttr(1);
535     SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
536     SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
537     SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
538     for (auto dim : llvm::seq<int64_t>(0, destRank)) {
539       UnpackTileDimInfo info =
540           getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
541       if (!info.isAlignedToInnerTileSize)
542         isPerfectTilingCase = false;
543       sliceSrcIndices.push_back(info.sourceOffset);
544       sliceSrcSizes.push_back(info.sourceSize);
545       destExpandedSizes.push_back(info.destExpandedSize);
546       resultOffsetsFromDest.push_back(info.resultOffset);
547     }
548 
549     // The tiling is applied on destination dimensions. We have to apply the
550     // interchange on source dimensions if outer_dims_perm is set.
551     applyPermToRange(sliceSrcIndices, sliceSrcSizes,
552                      unpackOp.getOuterDimsPerm());
553     Attribute zeroAttr = b.getIndexAttr(0);
554     sliceSrcIndices.append(numInnerTiles, zeroAttr);
555     sliceSrcSizes.append(unpackOp.getMixedTiles());
556     sliceSrcStrides.append(numInnerTiles, oneAttr);
557     SmallVector<Operation *> generatedSlices;
558     ExtractSliceOp sliceSource =
559         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
560                                  sliceSrcSizes, sliceSrcStrides);
561     generatedSlices.push_back(sliceSource);
562 
563     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
564     Value sliceDest;
565     if (isPerfectTilingCase) {
566       auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
567                                                   offsets, sizes, destStrides);
568       sliceDest = destSliceOp;
569       generatedSlices.push_back(destSliceOp);
570     } else {
571       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
572                                     unpackOp.getDestType().getElementType());
573     }
574 
575     SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
576     for (auto tile : unpackOp.getInnerTiles())
577       tiledOperands.push_back(tile);
578 
579     Operation *tiledUnpackOp = b.create<UnPackOp>(
580         loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
581 
582     if (isPerfectTilingCase)
583       return TilingResult{{tiledUnpackOp},
584                           SmallVector<Value>(tiledUnpackOp->getResults()),
585                           generatedSlices};
586 
587     auto extractSlice =
588         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
589                                  resultOffsetsFromDest, sizes, destStrides);
590     return TilingResult{
591         {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
592   }
593 
594   LogicalResult
595   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
596                         ArrayRef<OpFoldResult> offsets,
597                         ArrayRef<OpFoldResult> sizes,
598                         SmallVector<OpFoldResult> &resultOffsets,
599                         SmallVector<OpFoldResult> &resultSizes) const {
600     resultOffsets = llvm::to_vector(offsets);
601     resultSizes = llvm::to_vector(sizes);
602     return success();
603   }
604 
605   FailureOr<TilingResult>
606   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
607                           ArrayRef<OpFoldResult> offsets,
608                           ArrayRef<OpFoldResult> sizes) const {
609     FailureOr<TilingResult> tilingResult =
610         getTiledImplementation(op, b, offsets, sizes);
611     if (failed(tilingResult))
612       return failure();
613     return tilingResult.value();
614   }
615 
616   /// Method to return the position of iteration domain tile computed by the
617   /// tiled operation.
618   LogicalResult getIterationDomainTileFromOperandTile(
619       Operation *op, OpBuilder &b, unsigned operandNumber,
620       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
621       SmallVectorImpl<OpFoldResult> &resultOffsets,
622       SmallVectorImpl<OpFoldResult> &resultSizes) const {
623     auto unPackOp = cast<UnPackOp>(op);
624     Location loc = unPackOp.getLoc();
625 
626     int64_t numTiles = unPackOp.getInnerDimsPos().size();
627     auto destOffsets = offsets.drop_back(numTiles);
628     auto destSizes = sizes.drop_back(numTiles);
629     // The tiling is applied on interchanged dimensions. We have to undo the
630     // interchange to map sizes and offsets to the original input.
631     int64_t outputRank = unPackOp.getDestRank();
632     SmallVector<OpFoldResult> origOffsets(destOffsets);
633     SmallVector<OpFoldResult> origSizes(destSizes);
634     applyPermToRange(origOffsets, origSizes,
635                      invertPermutationVector(unPackOp.getOuterDimsPerm()));
636 
637     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
638         unPackOp.getDimAndTileMapping();
639 
640     for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
641       using AV = affine::AffineValueExpr;
642       affine::AffineBuilder ab(b, loc);
643       AffineExpr dim0, dim1, sym;
644       bindDims(b.getContext(), dim0, dim1);
645       bindSymbols(b.getContext(), sym);
646       if (dimAndTileMapping.count(dim)) {
647         // If the data dimension is tiled, the i-th index is the product of
648         // offset_i and tile_i, and the i-th size is the product of sizes_i and
649         // tile_i.
650         auto avOffset = AV(dim0).bind(origOffsets[dim]);
651         auto avSize = AV(dim0).bind(origSizes[dim]);
652         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
653         resultOffsets.push_back(ab.mul(avOffset, avTileSize));
654         resultSizes.push_back(ab.mul(avSize, avTileSize));
655       } else {
656         resultOffsets.push_back(origOffsets[dim]);
657         resultSizes.push_back(origSizes[dim]);
658       }
659     }
660     return success();
661   }
662 
663   /// Method to return the tiled implementation of tensor.unpack as a consumer.
664   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
665       Operation *op, OpBuilder &b, unsigned operandNumber,
666       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
667     auto unPackOp = cast<UnPackOp>(op);
668     // tensor.unpack op is fusible (as a consumer) only if inner dims are not
669     // tiled.
670     int64_t numTiles = unPackOp.getInnerDimsPos().size();
671     for (auto iter :
672          llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
673       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
674         return failure();
675     }
676 
677     Location loc = unPackOp.getLoc();
678 
679     // Fetch offset/size for creating the slice of the dest operand of
680     // unpack op.
681     SmallVector<OpFoldResult> outputOffsets, outputSizes;
682     if (failed(getIterationDomainTileFromOperandTile(
683             op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
684             outputSizes)))
685       return failure();
686 
687     auto oneAttr = b.getI64IntegerAttr(1);
688     int64_t outputRank = unPackOp.getDestRank();
689     SmallVector<OpFoldResult> strides(outputRank, oneAttr);
690 
691     SmallVector<Value> tiledOperands;
692     // Create slice of the dest operand.
693     auto extractDestSlice = b.create<ExtractSliceOp>(
694         loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
695     tiledOperands.push_back(extractDestSlice);
696 
697     SmallVector<OpFoldResult> inputOffsets, inputSizes;
698     strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
699     // Create slice of the source operand.
700     auto extractSourceSlice = b.create<ExtractSliceOp>(
701         loc, unPackOp.getSource(), offsets, sizes, strides);
702     tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
703     for (auto tile : unPackOp.getInnerTiles())
704       tiledOperands.push_back(tile);
705 
706     // Create tiled unpack op.
707     Operation *tiledUnPackOp =
708         b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
709                            tiledOperands, op->getAttrs());
710 
711     return TilingResult{{tiledUnPackOp},
712                         SmallVector<Value>(tiledUnPackOp->getResults()),
713                         llvm::to_vector(ArrayRef<Operation *>{
714                             extractSourceSlice, extractDestSlice})};
715   }
716 };
717 
718 } // namespace
719 
720 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
721                                                  tensor::PadOp padOp,
722                                                  ArrayRef<OpFoldResult> offsets,
723                                                  ArrayRef<OpFoldResult> sizes,
724                                                  bool generateZeroSliceGuard) {
725   // Only constant padding value supported.
726   Value padValue = padOp.getConstantPaddingValue();
727   if (!padValue)
728     return failure();
729 
730   // Helper variables and functions for various arithmetic operations. These
731   // are used extensively for computing new offset/length and padding values.
732   Location loc = padOp->getLoc();
733   AffineExpr dim0, dim1;
734   bindDims(b.getContext(), dim0, dim1);
735   // Add two integers.
736   auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
737   auto add = [&](OpFoldResult v1, OpFoldResult v2) {
738     return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2});
739   };
740   // Subtract two integers.
741   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
742   auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
743     return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
744   };
745   // Take the minimum of two integers.
746   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
747   auto min = [&](OpFoldResult v1, OpFoldResult v2) {
748     return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
749   };
750   // Take the maximum of two integers.
751   auto max = [&](OpFoldResult v1, OpFoldResult v2) {
752     return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
753   };
754   // Zero index-typed integer.
755   OpFoldResult zero = b.getIndexAttr(0);
756 
757   // Compute new offsets, lengths, low padding, high padding.
758   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
759   SmallVector<OpFoldResult> newLows, newHighs;
760   // Set to true if the original data source is not read at all.
761   bool hasZeroLen = false;
762   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
763   // is true if the original data source turns out to be unused at runtime.
764   Value dynHasZeroLenCond;
765 
766   int64_t rank = padOp.getSourceType().getRank();
767   for (unsigned dim = 0; dim < rank; ++dim) {
768     auto low = padOp.getMixedLowPad()[dim];
769     bool hasLowPad = !isConstantIntValue(low, 0);
770     auto high = padOp.getMixedHighPad()[dim];
771     bool hasHighPad = !isConstantIntValue(high, 0);
772     auto offset = offsets[dim];
773     auto length = sizes[dim];
774     auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim);
775 
776     // The new amount of low padding is `low - offset`. Except for the case
777     // where none of the low padding is read. In that case, the new amount of
778     // low padding is zero.
779     //
780     // Optimization: If low = 0, then newLow = 0.
781     OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
782     newLows.push_back(newLow);
783 
784     // Start reading the data from position `offset - low`. Since the original
785     // read may have started in the low padding zone, this value could be
786     // negative. Therefore, start reading from:
787     //
788     // max(offset - low, 0)
789     //
790     // The original read could also have started in the high padding zone.
791     // In that case, set the offset to the end of source tensor. The new
792     // ExtractSliceOp length will be zero in that case. (Effectively reading
793     // no data from the source.)
794     //
795     // Optimization: If low = 0, then the formula can be simplified.
796     OpFoldResult newOffset = hasLowPad
797                                  ? min(max(sub(offset, low), zero), srcSize)
798                                  : min(offset, srcSize);
799     newOffsets.push_back(newOffset);
800 
801     // The original ExtractSliceOp was reading until position `offset +
802     // length`. Therefore, the corresponding position within the source tensor
803     // is:
804     //
805     // offset + length - low
806     //
807     // In case the original ExtractSliceOp stopped reading within the low
808     // padding zone, this value can be negative. In that case, the end
809     // position of the read should be zero. (Similar to newOffset.)
810     //
811     // The original read could also have stopped in the high padding zone.
812     // In that case, set the end positition of the read should be the end of
813     // the source tensor. (Similar to newOffset.)
814     //
815     // endLoc = min(max(offset - low + length, 0), srcSize)
816     //
817     // The new ExtractSliceOp length is `endLoc - newOffset`.
818     //
819     // Optimization: If low = 0, then the formula can be simplified.
820     OpFoldResult endLoc =
821         hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
822                   : min(add(offset, length), srcSize);
823     OpFoldResult newLength = sub(endLoc, newOffset);
824     newLengths.push_back(newLength);
825 
826     // Check if newLength is zero. In that case, no SubTensorOp should be
827     // executed.
828     if (isConstantIntValue(newLength, 0)) {
829       hasZeroLen = true;
830     } else if (!hasZeroLen) {
831       Value check = b.create<arith::CmpIOp>(
832           loc, arith::CmpIPredicate::eq,
833           getValueOrCreateConstantIndexOp(b, loc, newLength),
834           getValueOrCreateConstantIndexOp(b, loc, zero));
835       dynHasZeroLenCond =
836           dynHasZeroLenCond
837               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
838               : check;
839     }
840 
841     // The amount of high padding is simply the number of elements remaining,
842     // so that the result has the same length as the original ExtractSliceOp.
843     // As an optimization, if the original high padding is zero, then the new
844     // high padding must also be zero.
845     OpFoldResult newHigh =
846         hasHighPad ? sub(sub(length, newLength), newLow) : zero;
847     newHighs.push_back(newHigh);
848 
849     // Only unit stride supported.
850     newStrides.push_back(b.getIndexAttr(1));
851   }
852 
853   // The shape of the result can be obtained from the sizes passed in.
854   SmallVector<Value> dynDims;
855   SmallVector<int64_t> shape;
856   dispatchIndexOpFoldResults(sizes, dynDims, shape);
857   RankedTensorType resultType =
858       RankedTensorType::get(shape, padOp.getResultType().getElementType());
859 
860   // Insert cast to ensure that types match. (May be folded away.)
861   auto castResult = [&](Value val) -> Value {
862     if (resultType == val.getType())
863       return val;
864     return b.create<tensor::CastOp>(loc, resultType, val);
865   };
866 
867   // In cases where the original data source is unused: Emit a GenerateOp and
868   // do not generate a SliceOp. (The result shape of the SliceOp would
869   // have a dimension of size 0, the semantics of which is unclear.)
870   auto createGenerateOp = [&]() {
871     // Create GenerateOp.
872     auto generateOp = b.create<tensor::GenerateOp>(
873         loc, resultType, dynDims,
874         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
875           builder.create<tensor::YieldOp>(gLoc, padValue);
876         });
877     return generateOp;
878   };
879 
880   // Emit a SliceOp and a PadOp. Should not be used in cases where
881   // the result shape of the new SliceOp has a zero dimension.
882   auto createPadOfExtractSlice = [&]() {
883     // Create pad(extract_slice(x)).
884     auto newSliceOp = b.create<tensor::ExtractSliceOp>(
885         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
886     auto newPadOp = b.create<PadOp>(
887         loc, Type(), newSliceOp, newLows, newHighs,
888         /*nofold=*/padOp.getNofold(),
889         getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
890 
891     // Copy region to new PadOp.
892     IRMapping bvm;
893     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
894 
895     // Cast result and return.
896     return std::make_tuple(newPadOp, newSliceOp);
897   };
898 
899   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
900   // the original data source x is not used.
901   if (hasZeroLen) {
902     Operation *generateOp = createGenerateOp();
903     return TilingResult{{generateOp},
904                         {castResult(generateOp->getResult(0))},
905                         /*generatedSlices=*/{}};
906   }
907 
908   // If there are dynamic dimensions: Generate an scf.if check to avoid
909   // creating SliceOps with result dimensions of size 0 at runtime.
910   if (generateZeroSliceGuard && dynHasZeroLenCond) {
911     Operation *thenOp;
912     Operation *elseOp;
913     Operation *sliceOp;
914     auto result = b.create<scf::IfOp>(
915         loc, dynHasZeroLenCond,
916         /*thenBuilder=*/
917         [&](OpBuilder &b, Location loc) {
918           thenOp = createGenerateOp();
919           b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
920         },
921         /*elseBuilder=*/
922         [&](OpBuilder &b, Location loc) {
923           std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
924           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
925         });
926     return TilingResult{
927         {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
928   }
929 
930   auto [newPadOp, sliceOp] = createPadOfExtractSlice();
931   return TilingResult{
932       {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
933 }
934 
935 void mlir::tensor::registerTilingInterfaceExternalModels(
936     DialectRegistry &registry) {
937   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
938     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
939     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
940     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
941   });
942 }
943 
944 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
945     DialectRegistry &registry) {
946   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
947     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
948     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
949   });
950 }
951