xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (revision 33927744db2a910fe1cdeecf9e074d488de2e787)
1 //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/Utils.h"
12 #include "mlir/Dialect/Arith/Utils/Utils.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/InferTypeOpInterface.h"
20 #include "mlir/Interfaces/TilingInterface.h"
21 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
22 
23 using namespace mlir;
24 using namespace mlir::tensor;
25 
26 namespace {
27 
28 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
29 
30   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31     auto padOp = cast<PadOp>(op);
32     SmallVector<utils::IteratorType> iteratorTypes(
33         padOp.getResultType().getRank(), utils::IteratorType::parallel);
34     return iteratorTypes;
35   }
36 
37   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
38     ReifiedRankedShapedTypeDims reifiedShapes;
39     (void)reifyResultShapes(b, op, reifiedShapes);
40     OpFoldResult zero = b.getIndexAttr(0);
41     OpFoldResult one = b.getIndexAttr(1);
42     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
43     // overwritten.
44     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
45     for (const auto &ub : enumerate(reifiedShapes[0]))
46       loopRanges[ub.index()].size = ub.value();
47     return loopRanges;
48   }
49 
50   FailureOr<TilingResult>
51   getTiledImplementation(Operation *op, OpBuilder &b,
52                          ArrayRef<OpFoldResult> offsets,
53                          ArrayRef<OpFoldResult> sizes) const {
54     FailureOr<TilingResult> result =
55         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
56     if (failed(result))
57       return failure();
58     return result.value();
59   }
60 
61   LogicalResult
62   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
63                         ArrayRef<OpFoldResult> offsets,
64                         ArrayRef<OpFoldResult> sizes,
65                         SmallVector<OpFoldResult> &resultOffsets,
66                         SmallVector<OpFoldResult> &resultSizes) const {
67     resultOffsets.assign(offsets.begin(), offsets.end());
68     resultSizes.assign(sizes.begin(), sizes.end());
69     return success();
70   }
71 
72   LogicalResult getIterationDomainTileFromResultTile(
73       Operation *op, OpBuilder &b, unsigned resultNumber,
74       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
75       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
76       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
77     iterDomainOffsets.assign(offsets.begin(), offsets.end());
78     iterDomainSizes.assign(sizes.begin(), sizes.end());
79     return success();
80   }
81 
82   FailureOr<TilingResult>
83   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
84                           ArrayRef<OpFoldResult> offsets,
85                           ArrayRef<OpFoldResult> sizes) const {
86     return getTiledImplementation(op, b, offsets, sizes);
87   }
88 };
89 
90 template <typename OpTy>
91 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
92                                                        OpBuilder &builder) {
93   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
94                 "applies to only pack or unpack operations");
95   OpBuilder::InsertionGuard g(builder);
96   int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
97                                                      : op.getDestRank();
98   OpFoldResult zero = builder.getIndexAttr(0);
99   OpFoldResult one = builder.getIndexAttr(1);
100   ReifiedRankedShapedTypeDims resultShape;
101   (void)reifyResultShapes(builder, op, resultShape);
102   SmallVector<Range> loopBounds(rank);
103   for (auto dim : llvm::seq<int64_t>(0, rank)) {
104     loopBounds[dim].offset = zero;
105     loopBounds[dim].stride = one;
106     loopBounds[dim].size = resultShape[0][dim];
107   }
108   return loopBounds;
109 }
110 
111 static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
112                              SmallVector<OpFoldResult> &sizes,
113                              ArrayRef<int64_t> permutation) {
114   if (permutation.empty())
115     return;
116   applyPermutationToVector<OpFoldResult>(offsets, permutation);
117   applyPermutationToVector<OpFoldResult>(sizes, permutation);
118 }
119 
120 struct PackOpTiling
121     : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
122 
123   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
124     // Note that here we only consider untiled dimensions and outer tiled data
125     // dimensions, the inner tiled data dimensions are materialized when
126     // building the body of the operation.
127     auto packOp = cast<PackOp>(op);
128     SmallVector<utils::IteratorType> iteratorTypes(
129         packOp.getSourceRank(), utils::IteratorType::parallel);
130     return iteratorTypes;
131   }
132 
133   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
134     return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
135   }
136 
137   FailureOr<TilingResult>
138   getTiledImplementation(Operation *op, OpBuilder &b,
139                          ArrayRef<OpFoldResult> offsets,
140                          ArrayRef<OpFoldResult> sizes) const {
141     auto packOp = cast<PackOp>(op);
142     Location loc = packOp.getLoc();
143 
144     // The tiling is applied on interchanged dimensions. We have to undo the
145     // interchange to map sizes and offsets to the original input.
146     int64_t inputRank = packOp.getSourceRank();
147     SmallVector<OpFoldResult> origOffsets(offsets);
148     SmallVector<OpFoldResult> origSizes(sizes);
149     applyPermToRange(origOffsets, origSizes,
150                      invertPermutationVector(packOp.getOuterDimsPerm()));
151 
152     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
153         packOp.getDimAndTileMapping();
154     SmallVector<OpFoldResult> srcDimValues =
155         tensor::getMixedSizes(b, loc, packOp.getSource());
156     SmallVector<OpFoldResult> inputIndices, inputSizes;
157     for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
158       using AV = affine::AffineValueExpr;
159       affine::AffineBuilder ab(b, loc);
160       AffineExpr dim0, dim1, sym;
161       bindDims(b.getContext(), dim0, dim1);
162       bindSymbols(b.getContext(), sym);
163       if (dimAndTileMapping.count(dim)) {
164         // If the data dimension is tiled, the i-th index is the product of
165         // offset_i and tile_i, and the i-th size is the product of sizes_i and
166         // tile_i.
167         auto avOffset = AV(dim0).bind(origOffsets[dim]);
168         auto avSize = AV(dim0).bind(origSizes[dim]);
169         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
170         inputIndices.push_back(ab.mul(avOffset, avTileSize));
171         inputSizes.push_back(ab.mul(avSize, avTileSize));
172       } else {
173         inputIndices.push_back(origOffsets[dim]);
174         inputSizes.push_back(origSizes[dim]);
175       }
176 
177       // Limit the size of the input operand for incomplete tiles.
178       if (packOp.getPaddingValue()) {
179         OpFoldResult dimSize = srcDimValues[dim];
180         auto avDimSize = AV(dim0).bind(dimSize);
181         auto avInputIdx = AV(dim1).bind(inputIndices.back());
182         inputSizes.back() =
183             ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
184       }
185     }
186 
187     auto oneAttr = b.getI64IntegerAttr(1);
188     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
189 
190     SmallVector<Value> tiledOperands;
191     auto sourceSlice = b.create<ExtractSliceOp>(
192         loc, packOp.getSource(), inputIndices, inputSizes, strides);
193     tiledOperands.push_back(sourceSlice);
194 
195     SmallVector<OpFoldResult> outputOffsets, outputSizes;
196     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
197                                      outputSizes)))
198       return {};
199 
200     strides.append(packOp.getDestRank() - inputRank, oneAttr);
201     auto outSlice = b.create<ExtractSliceOp>(
202         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
203     tiledOperands.push_back(outSlice);
204 
205     if (auto val = packOp.getPaddingValue())
206       tiledOperands.push_back(val);
207     for (auto tile : packOp.getInnerTiles())
208       tiledOperands.push_back(tile);
209 
210     Operation *tiledPackOp = b.create<PackOp>(
211         loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
212 
213     return TilingResult{
214         {tiledPackOp},
215         SmallVector<Value>(tiledPackOp->getResults()),
216         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
217   }
218 
219   LogicalResult
220   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
221                         ArrayRef<OpFoldResult> offsets,
222                         ArrayRef<OpFoldResult> sizes,
223                         SmallVector<OpFoldResult> &resultOffsets,
224                         SmallVector<OpFoldResult> &resultSizes) const {
225     // The iteration domain is over outer dimensions of packed layout. In this
226     // context, the outer dimensions of `resultOffsets` are `offsets`. The
227     // inner dimensions of `resultOffsets` are zeros because tiling is not
228     // applied to them.
229     auto packOp = cast<PackOp>(op);
230     int64_t inputRank = packOp.getSourceRank();
231     int64_t outputRank = packOp.getDestRank();
232     auto zeroAttr = b.getI64IntegerAttr(0);
233     resultOffsets.assign(offsets.begin(), offsets.end());
234     resultOffsets.append(outputRank - inputRank, zeroAttr);
235 
236     ReifiedRankedShapedTypeDims outputShape;
237     (void)reifyResultShapes(b, packOp, outputShape);
238     resultSizes.assign(sizes.begin(), sizes.end());
239     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
240       resultSizes.push_back(outputShape[0][dataTileDim]);
241 
242     return success();
243   }
244 
245   FailureOr<TilingResult>
246   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
247                           ArrayRef<OpFoldResult> offsets,
248                           ArrayRef<OpFoldResult> sizes) const {
249     auto packOp = cast<PackOp>(op);
250     int64_t numTiles = packOp.getInnerDimsPos().size();
251 
252     // tensor.pack op is fusible (as a producer) only if full inner tiles are
253     // iterated or inner dims are not tiled. Otherwise, it will generate a
254     // sequence of non-trivial ops (for partial tiles).
255     for (auto offset : offsets.take_back(numTiles))
256       if (!isConstantIntValue(offset, 0))
257         return failure();
258 
259     for (auto iter :
260          llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
261       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
262         return failure();
263 
264     FailureOr<TilingResult> tilingResult = getTiledImplementation(
265         op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
266     if (failed(tilingResult))
267       return failure();
268     return tilingResult.value();
269   }
270 
271   /// Method to return the position of iteration domain tile computed by the
272   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
273   /// `resultSizes` only cover outer dimensions.
274   LogicalResult getIterationDomainTileFromOperandTile(
275       Operation *op, OpBuilder &b, unsigned operandNumber,
276       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
277       SmallVectorImpl<OpFoldResult> &resultOffsets,
278       SmallVectorImpl<OpFoldResult> &resultSizes) const {
279     if (operandNumber != 0)
280       return failure();
281 
282     auto packOp = cast<PackOp>(op);
283     // It is not trivial to infer dest tile from source tile if `packOp` has
284     // padding semantic.
285     if (packOp.getPaddingValue())
286       return failure();
287 
288     Location loc = packOp.getLoc();
289 
290     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
291     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
292         packOp.getDimAndTileMapping();
293     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
294       if (dimAndTileMapping.count(dim)) {
295         FailureOr<int64_t> cstSize =
296             ValueBoundsConstraintSet::computeConstantBound(
297                 presburger::BoundType::UB, sizes[dim],
298                 /*stopCondition=*/nullptr, /*closedUB=*/true);
299         std::optional<int64_t> cstInnerSize =
300             getConstantIntValue(dimAndTileMapping[dim]);
301         // Currently fusing `packOp` as consumer only expects perfect tiling
302         // scenario because even if without padding semantic, the `packOp` may
303         // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
304         // where the `tileSize` from operand of `packOp` is 5, which is not
305         // exactly divided by `innerTile`(=6) of `packOp`. As the result:
306         // 1. the first slice is extracted from (0) to (4) and inserted into
307         // (0,0)~(0,4) at first row.
308         // 2. the second slice is extracted from (5) to (9) and SHOULD BE
309         // respectively inserted into two rows with different length, including
310         // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
311         // them, thus adding below constraint to bypass them temporarily. In
312         // another word, we can only support tiling with consumer if the tile
313         // size for the producer is a multiple of the inner tile size for the
314         // packed dimensions at this moment.
315         if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
316           return failure();
317         }
318 
319         using AV = affine::AffineValueExpr;
320         affine::AffineBuilder ab(b, loc);
321         AffineExpr dim0, sym;
322         bindDims(b.getContext(), dim0);
323         bindSymbols(b.getContext(), sym);
324         auto avOffset = AV(dim0).bind(offsets[dim]);
325         auto avSize = AV(dim0).bind(sizes[dim]);
326         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
327         outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
328         outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
329       } else {
330         outerDimOffsets.push_back(offsets[dim]);
331         outerDimSizes.push_back(sizes[dim]);
332       }
333     }
334     applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
335     resultOffsets = outerDimOffsets;
336     resultSizes = outerDimSizes;
337     return success();
338   }
339 
340   /// Method to return the tiled implementation of tensor.pack as a consumer.
341   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
342       Operation *op, OpBuilder &b, unsigned operandNumber,
343       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
344     if (operandNumber != 0)
345       return failure();
346 
347     auto packOp = cast<PackOp>(op);
348     Location loc = packOp.getLoc();
349 
350     int64_t inputRank = packOp.getSourceRank();
351     auto oneAttr = b.getI64IntegerAttr(1);
352     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
353 
354     SmallVector<Value> tiledOperands;
355     auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
356                                                 offsets, sizes, strides);
357     tiledOperands.push_back(sourceSlice);
358 
359     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
360     if (failed(getIterationDomainTileFromOperandTile(
361             op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
362             outerDimSizes)))
363       return failure();
364 
365     SmallVector<OpFoldResult> outputOffsets, outputSizes;
366     if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
367                                      outputOffsets, outputSizes)))
368       return failure();
369 
370     strides.append(packOp.getDestRank() - inputRank, oneAttr);
371     auto outSlice = b.create<ExtractSliceOp>(
372         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
373     tiledOperands.push_back(outSlice);
374 
375     assert(!packOp.getPaddingValue() && "Expect no padding semantic");
376     for (auto tile : packOp.getInnerTiles())
377       tiledOperands.push_back(tile);
378 
379     Operation *tiledPackOp = b.create<PackOp>(
380         loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
381 
382     return TilingResult{
383         {tiledPackOp},
384         SmallVector<Value>(tiledPackOp->getResults()),
385         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
386   }
387 };
388 
389 struct UnpackTileDimInfo {
390   bool isAlignedToInnerTileSize;
391   OpFoldResult sourceOffset;
392   OpFoldResult sourceSize;
393   OpFoldResult resultOffset;
394   OpFoldResult destExpandedSize;
395 };
396 
397 /// Returns the needed information for tiling unpack op on `tileDim` with given
398 /// `tileOffset` and `tileSize`. For more details, see the comment of the
399 /// `getTiledImplementation`.
400 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
401                                               int64_t tileDim,
402                                               OpFoldResult tileOffset,
403                                               OpFoldResult tileSize) {
404   UnpackTileDimInfo info;
405   Attribute zeroAttr = b.getIndexAttr(0);
406   Attribute oneAttr = b.getIndexAttr(1);
407   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
408       unpackOp.getDimAndTileMapping();
409   // The dimension is not one of packed data dimension.
410   if (!dimAndTileMapping.count(tileDim)) {
411     info.isAlignedToInnerTileSize = true;
412     info.sourceOffset = tileOffset;
413     info.sourceSize = tileSize;
414     info.resultOffset = zeroAttr;
415     info.destExpandedSize = tileSize;
416     return info;
417   }
418 
419   Location loc = unpackOp.getLoc();
420   using AV = affine::AffineValueExpr;
421   affine::AffineBuilder ab(b, loc);
422   AffineExpr dim0, dim1, sym0;
423   bindDims(b.getContext(), dim0, dim1);
424   bindSymbols(b.getContext(), sym0);
425 
426   OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
427 
428   info.isAlignedToInnerTileSize = false;
429   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
430       presburger::BoundType::UB, tileSize,
431       /*stopCondition=*/nullptr, /*closedUB=*/true);
432   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
433   if (!failed(cstSize) && cstInnerSize) {
434     if (*cstSize % *cstInnerSize == 0)
435       info.isAlignedToInnerTileSize = true;
436 
437     // If the tiling size equals to the inner tiling size, the outer dims are
438     // always 1.
439     if (*cstInnerSize == *cstSize) {
440       auto lhs = AV(dim0).bind(tileOffset);
441       auto rhs = AV(dim1).bind(innerTileSize);
442       info.sourceOffset = ab.floor(lhs, rhs);
443       info.sourceSize = oneAttr;
444       info.resultOffset = zeroAttr;
445       info.destExpandedSize = tileSize;
446       return info;
447     }
448   }
449 
450   if (info.isAlignedToInnerTileSize) {
451     info.sourceOffset =
452         ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
453     info.resultOffset = zeroAttr;
454     info.destExpandedSize = tileSize;
455 
456     // The ceilDiv is needed here because there could be incomplete tile even
457     // it is perfect tiling cases. E.g.,
458     //   %0 = unpack tensor<33x2xf32> into tensor<64xf32>
459     // If the tiling size is 32, there will be 3 tiles. Two of them have
460     // size=32; one of them have size=2. The size is represented using
461     // affine_min op; we need ceilDiv.
462     info.sourceSize =
463         ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
464     return info;
465   }
466 
467   affine::DivModValue firstCoord = affine::getDivMod(
468       b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
469       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
470   OpFoldResult tileExclusiveBound =
471       ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
472   affine::DivModValue lastCoord = affine::getDivMod(
473       b, loc,
474       getValueOrCreateConstantIndexOp(
475           b, loc,
476           ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
477       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
478 
479   OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
480                                        AV(dim1).bind(firstCoord.quotient));
481   info.sourceSize =
482       ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
483   info.sourceOffset = firstCoord.quotient;
484   info.resultOffset = firstCoord.remainder;
485   // Do not create an Affine ops for expanded size because the affine op is too
486   // complicated which would trigger an issue in affine ops simplification.
487   info.destExpandedSize = b.createOrFold<arith::MulIOp>(
488       loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
489       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
490   return info;
491 }
492 
493 struct UnPackOpTiling
494     : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
495 
496   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
497     auto unpackOp = cast<UnPackOp>(op);
498     SmallVector<utils::IteratorType> iteratorTypes(
499         unpackOp.getDestRank(), utils::IteratorType::parallel);
500     return iteratorTypes;
501   }
502 
503   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
504     return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
505   }
506 
507   /// There are two cases in tiling unpack ops. If the tiling size is aligned to
508   /// the inner tile size, the corresponding tiles of source are all complete.
509   /// Otherwise, there are in-complete tiles. We will need to expand the slice
510   /// of source for getting complete tiles. The tiled unpack op unpacks more
511   /// data from source, so We'll need an extract_slice op to shift and truncate
512   /// the output.
513   /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
514   /// coordinates of second tile (i.e., result[15..31]) are
515   /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
516   /// row are incomplete tiles. To represent the unpack op, we have to complete
517   /// the rows. I.e., the input coordinates would start with (1, 0); end with
518   /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
519   /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
520   /// can get the actual result.
521   FailureOr<TilingResult>
522   getTiledImplementation(Operation *op, OpBuilder &b,
523                          ArrayRef<OpFoldResult> offsets,
524                          ArrayRef<OpFoldResult> sizes) const {
525     auto unpackOp = cast<UnPackOp>(op);
526     int64_t srcRank = unpackOp.getSourceRank();
527     int64_t destRank = unpackOp.getDestRank();
528     int64_t numInnerTiles = srcRank - destRank;
529     Location loc = unpackOp.getLoc();
530 
531     // The perfect tiling case indicates that the tiling sizes are multiple of
532     // inner_tile_size. In this context, no extra data is needed when
533     // representing the tiled unpack op.
534     bool isPerfectTilingCase = true;
535     Attribute oneAttr = b.getIndexAttr(1);
536     SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
537     SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
538     SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
539     for (auto dim : llvm::seq<int64_t>(0, destRank)) {
540       UnpackTileDimInfo info =
541           getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
542       if (!info.isAlignedToInnerTileSize)
543         isPerfectTilingCase = false;
544       sliceSrcIndices.push_back(info.sourceOffset);
545       sliceSrcSizes.push_back(info.sourceSize);
546       destExpandedSizes.push_back(info.destExpandedSize);
547       resultOffsetsFromDest.push_back(info.resultOffset);
548     }
549 
550     // The tiling is applied on destination dimensions. We have to apply the
551     // interchange on source dimensions if outer_dims_perm is set.
552     applyPermToRange(sliceSrcIndices, sliceSrcSizes,
553                      unpackOp.getOuterDimsPerm());
554     Attribute zeroAttr = b.getIndexAttr(0);
555     sliceSrcIndices.append(numInnerTiles, zeroAttr);
556     sliceSrcSizes.append(unpackOp.getMixedTiles());
557     sliceSrcStrides.append(numInnerTiles, oneAttr);
558     SmallVector<Operation *> generatedSlices;
559     ExtractSliceOp sliceSource =
560         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
561                                  sliceSrcSizes, sliceSrcStrides);
562     generatedSlices.push_back(sliceSource);
563 
564     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
565     Value sliceDest;
566     if (isPerfectTilingCase) {
567       auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
568                                                   offsets, sizes, destStrides);
569       sliceDest = destSliceOp;
570       generatedSlices.push_back(destSliceOp);
571     } else {
572       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
573                                     unpackOp.getDestType().getElementType());
574     }
575 
576     SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
577     for (auto tile : unpackOp.getInnerTiles())
578       tiledOperands.push_back(tile);
579 
580     Operation *tiledUnpackOp = b.create<UnPackOp>(
581         loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
582 
583     if (isPerfectTilingCase)
584       return TilingResult{{tiledUnpackOp},
585                           SmallVector<Value>(tiledUnpackOp->getResults()),
586                           generatedSlices};
587 
588     auto extractSlice =
589         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
590                                  resultOffsetsFromDest, sizes, destStrides);
591     return TilingResult{
592         {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
593   }
594 
595   LogicalResult
596   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
597                         ArrayRef<OpFoldResult> offsets,
598                         ArrayRef<OpFoldResult> sizes,
599                         SmallVector<OpFoldResult> &resultOffsets,
600                         SmallVector<OpFoldResult> &resultSizes) const {
601     resultOffsets = llvm::to_vector(offsets);
602     resultSizes = llvm::to_vector(sizes);
603     return success();
604   }
605 
606   FailureOr<TilingResult>
607   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
608                           ArrayRef<OpFoldResult> offsets,
609                           ArrayRef<OpFoldResult> sizes) const {
610     FailureOr<TilingResult> tilingResult =
611         getTiledImplementation(op, b, offsets, sizes);
612     if (failed(tilingResult))
613       return failure();
614     return tilingResult.value();
615   }
616 
617   /// Method to return the position of iteration domain tile computed by the
618   /// tiled operation.
619   LogicalResult getIterationDomainTileFromOperandTile(
620       Operation *op, OpBuilder &b, unsigned operandNumber,
621       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
622       SmallVectorImpl<OpFoldResult> &resultOffsets,
623       SmallVectorImpl<OpFoldResult> &resultSizes) const {
624     auto unPackOp = cast<UnPackOp>(op);
625     // If the operand tile is the dest, then no adjustment is needed.
626     if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
627       resultOffsets = llvm::to_vector(offsets);
628       resultSizes = llvm::to_vector(sizes);
629       return success();
630     }
631     Location loc = unPackOp.getLoc();
632 
633     int64_t numTiles = unPackOp.getInnerDimsPos().size();
634     auto destOffsets = offsets.drop_back(numTiles);
635     auto destSizes = sizes.drop_back(numTiles);
636     // The tiling is applied on interchanged dimensions. We have to undo the
637     // interchange to map sizes and offsets to the original input.
638     int64_t outputRank = unPackOp.getDestRank();
639     ReifiedRankedShapedTypeDims reifiedReturnShapes;
640     if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
641       return failure();
642     SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
643     SmallVector<OpFoldResult> origOffsets(destOffsets);
644     SmallVector<OpFoldResult> origSizes(destSizes);
645     applyPermToRange(origOffsets, origSizes,
646                      invertPermutationVector(unPackOp.getOuterDimsPerm()));
647 
648     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
649         unPackOp.getDimAndTileMapping();
650 
651     for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
652       using AV = affine::AffineValueExpr;
653       affine::AffineBuilder ab(b, loc);
654       AffineExpr dim0, dim1, sym0;
655       bindDims(b.getContext(), dim0, dim1);
656       bindSymbols(b.getContext(), sym0);
657       if (dimAndTileMapping.count(dim)) {
658         // If the data dimension is tiled, the i-th index is the product of
659         // offset_i and tile_i, and the i-th size is the product of sizes_i and
660         // tile_i. The sizes must be clamped to the sizes of the unpack result.
661         auto avOffset = AV(dim0).bind(origOffsets[dim]);
662         auto avSize = AV(dim0).bind(origSizes[dim]);
663         auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
664         auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
665         resultOffsets.push_back(ab.mul(avOffset, avTileSize));
666         auto avResultOffset = AV(dim1).bind(resultOffsets.back());
667         resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
668                                       ab.sub(avResultSize, avResultOffset)}));
669       } else {
670         resultOffsets.push_back(origOffsets[dim]);
671         resultSizes.push_back(origSizes[dim]);
672       }
673     }
674     return success();
675   }
676 
677   /// Method to return the tiled implementation of tensor.unpack as a consumer.
678   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
679       Operation *op, OpBuilder &b, unsigned operandNumber,
680       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
681     auto unPackOp = cast<UnPackOp>(op);
682     // tensor.unpack op is fusible (as a consumer) only if inner dims are not
683     // tiled.
684     int64_t numTiles = unPackOp.getInnerDimsPos().size();
685     for (auto iter :
686          llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
687       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
688         return failure();
689     }
690 
691     Location loc = unPackOp.getLoc();
692 
693     // Fetch offset/size for creating the slice of the dest operand of
694     // unpack op.
695     SmallVector<OpFoldResult> outputOffsets, outputSizes;
696     if (failed(getIterationDomainTileFromOperandTile(
697             op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
698             outputSizes)))
699       return failure();
700 
701     auto oneAttr = b.getI64IntegerAttr(1);
702     int64_t outputRank = unPackOp.getDestRank();
703     SmallVector<OpFoldResult> strides(outputRank, oneAttr);
704 
705     SmallVector<Value> tiledOperands;
706     // Create slice of the dest operand.
707     auto extractDestSlice = b.create<ExtractSliceOp>(
708         loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
709     tiledOperands.push_back(extractDestSlice);
710 
711     SmallVector<OpFoldResult> inputOffsets, inputSizes;
712     strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
713     // Create slice of the source operand.
714     auto extractSourceSlice = b.create<ExtractSliceOp>(
715         loc, unPackOp.getSource(), offsets, sizes, strides);
716     tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
717     for (auto tile : unPackOp.getInnerTiles())
718       tiledOperands.push_back(tile);
719 
720     // Create tiled unpack op.
721     Operation *tiledUnPackOp =
722         b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
723                            tiledOperands, op->getAttrs());
724 
725     return TilingResult{{tiledUnPackOp},
726                         SmallVector<Value>(tiledUnPackOp->getResults()),
727                         llvm::to_vector(ArrayRef<Operation *>{
728                             extractSourceSlice, extractDestSlice})};
729   }
730 };
731 
732 } // namespace
733 
734 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
735                                                  tensor::PadOp padOp,
736                                                  ArrayRef<OpFoldResult> offsets,
737                                                  ArrayRef<OpFoldResult> sizes,
738                                                  bool generateZeroSliceGuard) {
739   // Only constant padding value supported.
740   Value padValue = padOp.getConstantPaddingValue();
741   if (!padValue)
742     return failure();
743 
744   // Helper variables and functions for various arithmetic operations. These
745   // are used extensively for computing new offset/length and padding values.
746   Location loc = padOp->getLoc();
747   AffineExpr dim0, dim1;
748   bindDims(b.getContext(), dim0, dim1);
749   // Subtract two integers.
750   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
751   auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
752     return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
753   };
754   // Take the minimum of two integers.
755   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
756   auto min = [&](OpFoldResult v1, OpFoldResult v2) {
757     return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
758   };
759   // Take the maximum of two integers.
760   auto max = [&](OpFoldResult v1, OpFoldResult v2) {
761     return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
762   };
763   // Zero index-typed integer.
764   OpFoldResult zero = b.getIndexAttr(0);
765 
766   // Compute new offsets, lengths, low padding, high padding.
767   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
768   SmallVector<OpFoldResult> newLows, newHighs;
769   // Set to true if the original data source is not read at all.
770   bool hasZeroLen = false;
771   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
772   // is true if the original data source turns out to be unused at runtime.
773   Value dynHasZeroLenCond;
774 
775   int64_t rank = padOp.getSourceType().getRank();
776   for (unsigned dim = 0; dim < rank; ++dim) {
777     auto low = padOp.getMixedLowPad()[dim];
778     bool hasLowPad = !isConstantIntValue(low, 0);
779     auto high = padOp.getMixedHighPad()[dim];
780     bool hasHighPad = !isConstantIntValue(high, 0);
781     auto offset = offsets[dim];
782     auto length = sizes[dim];
783     auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim);
784 
785     // The new amount of low padding is `low - offset`. Except for the case
786     // where none of the low padding is read. In that case, the new amount of
787     // low padding is zero.
788     //
789     // Optimization: If low = 0, then newLow = 0.
790     OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
791     newLows.push_back(newLow);
792 
793     // Start reading the data from position `offset - low`. Since the original
794     // read may have started in the low padding zone, this value could be
795     // negative. Therefore, start reading from:
796     //
797     // max(offset - low, 0)
798     //
799     // The original read could also have started in the high padding zone.
800     // In that case, set the offset to the end of source tensor. The new
801     // ExtractSliceOp length will be zero in that case. (Effectively reading
802     // no data from the source.)
803     //
804     // Optimization: If low = 0, then the formula can be simplified.
805     OpFoldResult newOffset = hasLowPad
806                                  ? min(max(sub(offset, low), zero), srcSize)
807                                  : min(offset, srcSize);
808     newOffsets.push_back(newOffset);
809 
810     // The original ExtractSliceOp was reading until position `offset +
811     // length`. Therefore, the corresponding position within the source tensor
812     // is:
813     //
814     // offset + length - low
815     //
816     // In case the original ExtractSliceOp stopped reading within the low
817     // padding zone, this value can be negative. In that case, the end
818     // position of the read should be zero. (Similar to newOffset.)
819     //
820     // The original read could also have stopped in the high padding zone.
821     // In that case, set the end positition of the read should be the end of
822     // the source tensor. (Similar to newOffset.)
823     // srcSize - newOffset represents how much length we have available
824     // and length - newLow represents how much length we want at most.
825     // Note that there are many ways to order this indexing math to compute
826     // newLength, but we want to make sure that the final affine.min ops in the
827     // sequence are bounding the index to as small a value as possible. If
828     // ValueBoundsOpInterface is used, this calculation will get upper bounds
829     // from the affine.min ops, so we want to use the smallest known value to
830     // set the bound at the end of the computation sequence. In this case, the
831     // index will be upper bounded by length - newLow.
832     OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow));
833     // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming
834     // length >= 0.
835     if (hasLowPad)
836       newLength = max(newLength, zero);
837     newLengths.push_back(newLength);
838 
839     // Check if newLength is zero. In that case, no SubTensorOp should be
840     // executed.
841     if (isConstantIntValue(newLength, 0)) {
842       hasZeroLen = true;
843     } else if (!hasZeroLen) {
844       Value check = b.create<arith::CmpIOp>(
845           loc, arith::CmpIPredicate::eq,
846           getValueOrCreateConstantIndexOp(b, loc, newLength),
847           getValueOrCreateConstantIndexOp(b, loc, zero));
848       dynHasZeroLenCond =
849           dynHasZeroLenCond
850               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
851               : check;
852     }
853 
854     // The amount of high padding is simply the number of elements remaining,
855     // so that the result has the same length as the original ExtractSliceOp.
856     // As an optimization, if the original high padding is zero, then the new
857     // high padding must also be zero.
858     OpFoldResult newHigh =
859         hasHighPad ? sub(sub(length, newLength), newLow) : zero;
860     newHighs.push_back(newHigh);
861 
862     // Only unit stride supported.
863     newStrides.push_back(b.getIndexAttr(1));
864   }
865 
866   // The shape of the result can be obtained from the sizes passed in.
867   SmallVector<Value> dynDims;
868   SmallVector<int64_t> shape;
869   dispatchIndexOpFoldResults(sizes, dynDims, shape);
870   RankedTensorType resultType =
871       RankedTensorType::get(shape, padOp.getResultType().getElementType());
872 
873   // Insert cast to ensure that types match. (May be folded away.)
874   auto castResult = [&](Value val) -> Value {
875     if (resultType == val.getType())
876       return val;
877     return b.create<tensor::CastOp>(loc, resultType, val);
878   };
879 
880   // In cases where the original data source is unused: Emit a GenerateOp and
881   // do not generate a SliceOp. (The result shape of the SliceOp would
882   // have a dimension of size 0, the semantics of which is unclear.)
883   auto createGenerateOp = [&]() {
884     // Create GenerateOp.
885     auto generateOp = b.create<tensor::GenerateOp>(
886         loc, resultType, dynDims,
887         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
888           builder.create<tensor::YieldOp>(gLoc, padValue);
889         });
890     return generateOp;
891   };
892 
893   // Emit a SliceOp and a PadOp. Should not be used in cases where
894   // the result shape of the new SliceOp has a zero dimension.
895   auto createPadOfExtractSlice = [&]() {
896     // Create pad(extract_slice(x)).
897     auto newSliceOp = b.create<tensor::ExtractSliceOp>(
898         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
899     auto newPadOp = b.create<PadOp>(
900         loc, Type(), newSliceOp, newLows, newHighs,
901         /*nofold=*/padOp.getNofold(),
902         getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
903 
904     // Copy region to new PadOp.
905     IRMapping bvm;
906     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
907 
908     // Cast result and return.
909     return std::make_tuple(newPadOp, newSliceOp);
910   };
911 
912   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
913   // the original data source x is not used.
914   if (hasZeroLen) {
915     Operation *generateOp = createGenerateOp();
916     return TilingResult{{generateOp},
917                         {castResult(generateOp->getResult(0))},
918                         /*generatedSlices=*/{}};
919   }
920 
921   // If there are dynamic dimensions: Generate an scf.if check to avoid
922   // creating SliceOps with result dimensions of size 0 at runtime.
923   if (generateZeroSliceGuard && dynHasZeroLenCond) {
924     Operation *thenOp;
925     Operation *elseOp;
926     Operation *sliceOp;
927     auto result = b.create<scf::IfOp>(
928         loc, dynHasZeroLenCond,
929         /*thenBuilder=*/
930         [&](OpBuilder &b, Location loc) {
931           thenOp = createGenerateOp();
932           b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
933         },
934         /*elseBuilder=*/
935         [&](OpBuilder &b, Location loc) {
936           std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
937           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
938         });
939     return TilingResult{
940         {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
941   }
942 
943   auto [newPadOp, sliceOp] = createPadOfExtractSlice();
944   return TilingResult{
945       {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
946 }
947 
948 void mlir::tensor::registerTilingInterfaceExternalModels(
949     DialectRegistry &registry) {
950   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
951     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
952     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
953     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
954   });
955 }
956 
957 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
958     DialectRegistry &registry) {
959   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
960     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
961     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
962   });
963 }
964