xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (revision 8b68cec9c0a42bbba04e7ed2e01136c16236cd50)
1 //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/Utils.h"
12 #include "mlir/Dialect/Arith/Utils/Utils.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21 
22 using namespace mlir;
23 using namespace mlir::tensor;
24 
25 namespace {
26 
27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
28 
29   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30     auto padOp = cast<PadOp>(op);
31     SmallVector<utils::IteratorType> iteratorTypes(
32         padOp.getResultType().getRank(), utils::IteratorType::parallel);
33     return iteratorTypes;
34   }
35 
36   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
37     ReifiedRankedShapedTypeDims reifiedShapes;
38     (void)reifyResultShapes(b, op, reifiedShapes);
39     OpFoldResult zero = b.getIndexAttr(0);
40     OpFoldResult one = b.getIndexAttr(1);
41     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
42     // overwritten.
43     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
44     for (const auto &ub : enumerate(reifiedShapes[0]))
45       loopRanges[ub.index()].size = ub.value();
46     return loopRanges;
47   }
48 
49   FailureOr<TilingResult>
50   getTiledImplementation(Operation *op, OpBuilder &b,
51                          ArrayRef<OpFoldResult> offsets,
52                          ArrayRef<OpFoldResult> sizes) const {
53     FailureOr<TilingResult> result =
54         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
55     if (failed(result))
56       return failure();
57     return result.value();
58   }
59 
60   LogicalResult
61   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
62                         ArrayRef<OpFoldResult> offsets,
63                         ArrayRef<OpFoldResult> sizes,
64                         SmallVector<OpFoldResult> &resultOffsets,
65                         SmallVector<OpFoldResult> &resultSizes) const {
66     resultOffsets.assign(offsets.begin(), offsets.end());
67     resultSizes.assign(sizes.begin(), sizes.end());
68     return success();
69   }
70 };
71 
72 template <typename OpTy>
73 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
74                                                        OpBuilder &builder) {
75   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
76                 "applies to only pack or unpack operations");
77   OpBuilder::InsertionGuard g(builder);
78   int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
79                                                      : op.getDestRank();
80   OpFoldResult zero = builder.getIndexAttr(0);
81   OpFoldResult one = builder.getIndexAttr(1);
82   ReifiedRankedShapedTypeDims resultShape;
83   (void)reifyResultShapes(builder, op, resultShape);
84   SmallVector<Range> loopBounds(rank);
85   for (auto dim : llvm::seq<int64_t>(0, rank)) {
86     loopBounds[dim].offset = zero;
87     loopBounds[dim].stride = one;
88     loopBounds[dim].size = resultShape[0][dim];
89   }
90   return loopBounds;
91 }
92 
93 static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
94                              SmallVector<OpFoldResult> &sizes,
95                              ArrayRef<int64_t> permutation) {
96   if (permutation.empty())
97     return;
98   applyPermutationToVector<OpFoldResult>(offsets, permutation);
99   applyPermutationToVector<OpFoldResult>(sizes, permutation);
100 }
101 
102 struct PackOpTiling
103     : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
104 
105   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
106     // Note that here we only consider untiled dimensions and outer tiled data
107     // dimensions, the inner tiled data dimensions are materialized when
108     // building the body of the operation.
109     auto packOp = cast<PackOp>(op);
110     SmallVector<utils::IteratorType> iteratorTypes(
111         packOp.getSourceRank(), utils::IteratorType::parallel);
112     return iteratorTypes;
113   }
114 
115   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
116     return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
117   }
118 
119   FailureOr<TilingResult>
120   getTiledImplementation(Operation *op, OpBuilder &b,
121                          ArrayRef<OpFoldResult> offsets,
122                          ArrayRef<OpFoldResult> sizes) const {
123     auto packOp = cast<PackOp>(op);
124     Location loc = packOp.getLoc();
125 
126     // The tiling is applied on interchanged dimensions. We have to undo the
127     // interchange to map sizes and offsets to the original input.
128     int64_t inputRank = packOp.getSourceRank();
129     SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end());
130     SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end());
131     applyPermToRange(origOffsets, origSizes,
132                      invertPermutationVector(packOp.getOuterDimsPerm()));
133 
134     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
135         packOp.getDimAndTileMapping();
136     SmallVector<OpFoldResult> srcDimValues =
137         tensor::getMixedSizes(b, loc, packOp.getSource());
138     SmallVector<OpFoldResult> inputIndices, inputSizes;
139     for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
140       using AV = affine::AffineValueExpr;
141       affine::AffineBuilder ab(b, loc);
142       AffineExpr dim0, dim1, sym;
143       bindDims(b.getContext(), dim0, dim1);
144       bindSymbols(b.getContext(), sym);
145       if (dimAndTileMapping.count(dim)) {
146         // If the data dimension is tiled, the i-th index is the product of
147         // offset_i and tile_i, and the i-th size is the product of sizes_i and
148         // tile_i.
149         auto avOffset = AV(dim0).bind(origOffsets[dim]);
150         auto avSize = AV(dim0).bind(origSizes[dim]);
151         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
152         inputIndices.push_back(ab.mul(avOffset, avTileSize));
153         inputSizes.push_back(ab.mul(avSize, avTileSize));
154       } else {
155         inputIndices.push_back(origOffsets[dim]);
156         inputSizes.push_back(origSizes[dim]);
157       }
158 
159       // Limit the size of the input operand for incomplete tiles.
160       if (packOp.getPaddingValue()) {
161         OpFoldResult dimSize = srcDimValues[dim];
162         auto avDimSize = AV(dim0).bind(dimSize);
163         auto avInputIdx = AV(dim1).bind(inputIndices.back());
164         inputSizes.back() =
165             ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
166       }
167     }
168 
169     auto oneAttr = b.getI64IntegerAttr(1);
170     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
171 
172     SmallVector<Value> tiledOperands;
173     tiledOperands.push_back(b.create<ExtractSliceOp>(
174         loc, packOp.getSource(), inputIndices, inputSizes, strides));
175 
176     SmallVector<OpFoldResult> outputOffsets, outputSizes;
177     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
178                                      outputSizes)))
179       return {};
180 
181     strides.append(packOp.getDestRank() - inputRank, oneAttr);
182     auto extractSlice = b.create<ExtractSliceOp>(
183         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
184     tiledOperands.push_back(extractSlice);
185 
186     if (auto val = packOp.getPaddingValue())
187       tiledOperands.push_back(val);
188     for (auto tile : packOp.getInnerTiles())
189       tiledOperands.push_back(tile);
190 
191     Operation *tiledPackOp = b.create<PackOp>(
192         loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
193 
194     return TilingResult{{tiledPackOp},
195                         SmallVector<Value>(tiledPackOp->getResults())};
196   }
197 
198   LogicalResult
199   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
200                         ArrayRef<OpFoldResult> offsets,
201                         ArrayRef<OpFoldResult> sizes,
202                         SmallVector<OpFoldResult> &resultOffsets,
203                         SmallVector<OpFoldResult> &resultSizes) const {
204     // The iteration domain is over outer dimensions of packed layout. In this
205     // context, the outer dimensions of `resultOffsets` are `offsets`. The
206     // inner dimensions of `resultOffsets` are zeros because tiling is not
207     // applied to them.
208     auto packOp = cast<PackOp>(op);
209     int64_t inputRank = packOp.getSourceRank();
210     int64_t outputRank = packOp.getDestRank();
211     auto zeroAttr = b.getI64IntegerAttr(0);
212     resultOffsets.assign(offsets.begin(), offsets.end());
213     resultOffsets.append(outputRank - inputRank, zeroAttr);
214 
215     ReifiedRankedShapedTypeDims outputShape;
216     (void)reifyResultShapes(b, packOp, outputShape);
217     resultSizes.assign(sizes.begin(), sizes.end());
218     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
219       resultSizes.push_back(outputShape[0][dataTileDim]);
220 
221     return success();
222   }
223 
224   FailureOr<TilingResult>
225   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
226                           ArrayRef<OpFoldResult> offsets,
227                           ArrayRef<OpFoldResult> sizes) const {
228     auto packOp = cast<PackOp>(op);
229     int64_t numTiles = packOp.getInnerDimsPos().size();
230 
231     // tensor.pack op is fusible (as a producer) only if full inner tiles are
232     // iterated or inner dims are not tiled. Otherwise, it will generate a
233     // sequence of non-trivial ops (for partial tiles).
234     for (auto offset : offsets.take_back(numTiles))
235       if (!isConstantIntValue(offset, 0))
236         return failure();
237 
238     for (auto iter :
239          llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
240       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
241         return failure();
242 
243     FailureOr<TilingResult> tilingResult = getTiledImplementation(
244         op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
245     if (failed(tilingResult))
246       return failure();
247     return tilingResult.value();
248   }
249 };
250 
251 struct UnpackTileDimInfo {
252   bool isAlignedToInnerTileSize;
253   OpFoldResult sourceOffset;
254   OpFoldResult sourceSize;
255   OpFoldResult resultOffset;
256   OpFoldResult destExpandedSize;
257 };
258 
259 /// Returns the needed information for tiling unpack op on `tileDim` with given
260 /// `tileOffset` and `tileSize`. For more details, see the comment of the
261 /// `getTiledImplementation`.
262 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
263                                               int64_t tileDim,
264                                               OpFoldResult tileOffset,
265                                               OpFoldResult tileSize) {
266   UnpackTileDimInfo info;
267   Attribute zeroAttr = b.getIndexAttr(0);
268   Attribute oneAttr = b.getIndexAttr(1);
269   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
270       unpackOp.getDimAndTileMapping();
271   // The dimension is not one of packed data dimension.
272   if (!dimAndTileMapping.count(tileDim)) {
273     info.isAlignedToInnerTileSize = true;
274     info.sourceOffset = tileOffset;
275     info.sourceSize = tileSize;
276     info.resultOffset = zeroAttr;
277     info.destExpandedSize = tileSize;
278     return info;
279   }
280 
281   Location loc = unpackOp.getLoc();
282   using AV = affine::AffineValueExpr;
283   affine::AffineBuilder ab(b, loc);
284   AffineExpr dim0, dim1, sym0;
285   bindDims(b.getContext(), dim0, dim1);
286   bindSymbols(b.getContext(), sym0);
287 
288   OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
289 
290   info.isAlignedToInnerTileSize = false;
291   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
292       presburger::BoundType::UB,
293       getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
294       /*stopCondition=*/nullptr, /*closedUB=*/true);
295   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
296   if (!failed(cstSize) && cstInnerSize) {
297     if (*cstSize % *cstInnerSize == 0)
298       info.isAlignedToInnerTileSize = true;
299 
300     // If the tiling size equals to the inner tiling size, the outer dims are
301     // always 1.
302     if (*cstInnerSize == *cstSize) {
303       auto lhs = AV(dim0).bind(tileOffset);
304       auto rhs = AV(dim1).bind(innerTileSize);
305       info.sourceOffset = ab.floor(lhs, rhs);
306       info.sourceSize = oneAttr;
307       info.resultOffset = zeroAttr;
308       info.destExpandedSize = tileSize;
309       return info;
310     }
311   }
312 
313   if (info.isAlignedToInnerTileSize) {
314     info.sourceOffset =
315         ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
316     info.resultOffset = zeroAttr;
317     info.destExpandedSize = tileSize;
318 
319     // The ceilDiv is needed here because there could be incomplete tile even
320     // it is perfect tiling cases. E.g.,
321     //   %0 = unpack tensor<33x2xf32> into tensor<64xf32>
322     // If the tiling size is 32, there will be 3 tiles. Two of them have
323     // size=32; one of them have size=2. The size is represented using
324     // affine_min op; we need ceilDiv.
325     info.sourceSize =
326         ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
327     return info;
328   }
329 
330   affine::DivModValue firstCoord = affine::getDivMod(
331       b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
332       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
333   OpFoldResult tileExclusiveBound =
334       ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
335   affine::DivModValue lastCoord = affine::getDivMod(
336       b, loc,
337       getValueOrCreateConstantIndexOp(
338           b, loc,
339           ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
340       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
341 
342   OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
343                                        AV(dim1).bind(firstCoord.quotient));
344   info.sourceSize =
345       ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
346   info.sourceOffset = firstCoord.quotient;
347   info.resultOffset = firstCoord.remainder;
348   // Do not create an Affine ops for expanded size because the affine op is too
349   // complicated which would trigger an issue in affine ops simplification.
350   info.destExpandedSize = b.createOrFold<arith::MulIOp>(
351       loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
352       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
353   return info;
354 }
355 
356 struct UnPackOpTiling
357     : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
358 
359   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
360     auto unpackOp = cast<UnPackOp>(op);
361     SmallVector<utils::IteratorType> iteratorTypes(
362         unpackOp.getDestRank(), utils::IteratorType::parallel);
363     return iteratorTypes;
364   }
365 
366   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
367     return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
368   }
369 
370   /// There are two cases in tiling unpack ops. If the tiling size is aligned to
371   /// the inner tile size, the corresponding tiles of source are all complete.
372   /// Otherwise, there are in-complete tiles. We will need to expand the slice
373   /// of source for getting complete tiles. The tiled unpack op unpacks more
374   /// data from source, so We'll need an extract_slice op to shift and truncate
375   /// the output.
376   /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
377   /// coordinates of second tile (i.e., result[15..31]) are
378   /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
379   /// row are incomplete tiles. To represent the unpack op, we have to complete
380   /// the rows. I.e., the input coordinates would start with (1, 0); end with
381   /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
382   /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
383   /// can get the actual result.
384   FailureOr<TilingResult>
385   getTiledImplementation(Operation *op, OpBuilder &b,
386                          ArrayRef<OpFoldResult> offsets,
387                          ArrayRef<OpFoldResult> sizes) const {
388     auto unpackOp = cast<UnPackOp>(op);
389     int64_t srcRank = unpackOp.getSourceRank();
390     int64_t destRank = unpackOp.getDestRank();
391     int64_t numInnerTiles = srcRank - destRank;
392     Location loc = unpackOp.getLoc();
393 
394     // The perfect tiling case indicates that the tiling sizes are multiple of
395     // inner_tile_size. In this context, no extra data is needed when
396     // representing the tiled unpack op.
397     bool isPerfectTilingCase = true;
398     Attribute oneAttr = b.getIndexAttr(1);
399     SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
400     SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
401     SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
402     for (auto dim : llvm::seq<int64_t>(0, destRank)) {
403       UnpackTileDimInfo info =
404           getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
405       if (!info.isAlignedToInnerTileSize)
406         isPerfectTilingCase = false;
407       sliceSrcIndices.push_back(info.sourceOffset);
408       sliceSrcSizes.push_back(info.sourceSize);
409       destExpandedSizes.push_back(info.destExpandedSize);
410       resultOffsetsFromDest.push_back(info.resultOffset);
411     }
412 
413     // The tiling is applied on destination dimensions. We have to apply the
414     // interchange on source dimensions if outer_dims_perm is set.
415     applyPermToRange(sliceSrcIndices, sliceSrcSizes,
416                      unpackOp.getOuterDimsPerm());
417     Attribute zeroAttr = b.getIndexAttr(0);
418     sliceSrcIndices.append(numInnerTiles, zeroAttr);
419     sliceSrcSizes.append(unpackOp.getMixedTiles());
420     sliceSrcStrides.append(numInnerTiles, oneAttr);
421     Value sliceSource =
422         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
423                                  sliceSrcSizes, sliceSrcStrides);
424 
425     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
426     Value sliceDest;
427     if (isPerfectTilingCase) {
428       sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
429                                            sizes, destStrides);
430     } else {
431       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
432                                     unpackOp.getDestType().getElementType());
433     }
434 
435     SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
436     for (auto tile : unpackOp.getInnerTiles())
437       tiledOperands.push_back(tile);
438 
439     Operation *tiledUnpackOp = b.create<UnPackOp>(
440         loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
441 
442     if (isPerfectTilingCase)
443       return TilingResult{{tiledUnpackOp},
444                           SmallVector<Value>(tiledUnpackOp->getResults())};
445 
446     auto extractSlice =
447         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
448                                  resultOffsetsFromDest, sizes, destStrides);
449     return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
450   }
451 
452   LogicalResult
453   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
454                         ArrayRef<OpFoldResult> offsets,
455                         ArrayRef<OpFoldResult> sizes,
456                         SmallVector<OpFoldResult> &resultOffsets,
457                         SmallVector<OpFoldResult> &resultSizes) const {
458     resultOffsets = llvm::to_vector(offsets);
459     resultSizes = llvm::to_vector(sizes);
460     return success();
461   }
462 
463   FailureOr<TilingResult>
464   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
465                           ArrayRef<OpFoldResult> offsets,
466                           ArrayRef<OpFoldResult> sizes) const {
467     FailureOr<TilingResult> tilingResult =
468         getTiledImplementation(op, b, offsets, sizes);
469     if (failed(tilingResult))
470       return failure();
471     return tilingResult.value();
472   }
473 };
474 
475 } // namespace
476 
477 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
478                                                  tensor::PadOp padOp,
479                                                  ArrayRef<OpFoldResult> offsets,
480                                                  ArrayRef<OpFoldResult> sizes,
481                                                  bool generateZeroSliceGuard) {
482   // Only constant padding value supported.
483   Value padValue = padOp.getConstantPaddingValue();
484   if (!padValue)
485     return failure();
486 
487   // Helper variables and functions for various arithmetic operations. These
488   // are used extensively for computing new offset/length and padding values.
489   Location loc = padOp->getLoc();
490   AffineExpr dim0, dim1;
491   bindDims(b.getContext(), dim0, dim1);
492   // Add two integers.
493   auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
494   auto add = [&](OpFoldResult v1, OpFoldResult v2) {
495     return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2});
496   };
497   // Subtract two integers.
498   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
499   auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
500     return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
501   };
502   // Take the minimum of two integers.
503   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
504   auto min = [&](OpFoldResult v1, OpFoldResult v2) {
505     return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
506   };
507   // Take the maximum of two integers.
508   auto max = [&](OpFoldResult v1, OpFoldResult v2) {
509     return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
510   };
511   // Zero index-typed integer.
512   OpFoldResult zero = b.getIndexAttr(0);
513 
514   // Compute new offsets, lengths, low padding, high padding.
515   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
516   SmallVector<OpFoldResult> newLows, newHighs;
517   // Set to true if the original data source is not read at all.
518   bool hasZeroLen = false;
519   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
520   // is true if the original data source turns out to be unused at runtime.
521   Value dynHasZeroLenCond;
522 
523   int64_t rank = padOp.getSourceType().getRank();
524   for (unsigned dim = 0; dim < rank; ++dim) {
525     auto low = padOp.getMixedLowPad()[dim];
526     bool hasLowPad = !isConstantIntValue(low, 0);
527     auto high = padOp.getMixedHighPad()[dim];
528     bool hasHighPad = !isConstantIntValue(high, 0);
529     auto offset = offsets[dim];
530     auto length = sizes[dim];
531     auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim);
532 
533     // The new amount of low padding is `low - offset`. Except for the case
534     // where none of the low padding is read. In that case, the new amount of
535     // low padding is zero.
536     //
537     // Optimization: If low = 0, then newLow = 0.
538     OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
539     newLows.push_back(newLow);
540 
541     // Start reading the data from position `offset - low`. Since the original
542     // read may have started in the low padding zone, this value could be
543     // negative. Therefore, start reading from:
544     //
545     // max(offset - low, 0)
546     //
547     // The original read could also have started in the high padding zone.
548     // In that case, set the offset to the end of source tensor. The new
549     // ExtractSliceOp length will be zero in that case. (Effectively reading
550     // no data from the source.)
551     //
552     // Optimization: If low = 0, then the formula can be simplified.
553     OpFoldResult newOffset = hasLowPad
554                                  ? min(max(sub(offset, low), zero), srcSize)
555                                  : min(offset, srcSize);
556     newOffsets.push_back(newOffset);
557 
558     // The original ExtractSliceOp was reading until position `offset +
559     // length`. Therefore, the corresponding position within the source tensor
560     // is:
561     //
562     // offset + length - low
563     //
564     // In case the original ExtractSliceOp stopped reading within the low
565     // padding zone, this value can be negative. In that case, the end
566     // position of the read should be zero. (Similar to newOffset.)
567     //
568     // The original read could also have stopped in the high padding zone.
569     // In that case, set the end positition of the read should be the end of
570     // the source tensor. (Similar to newOffset.)
571     //
572     // endLoc = min(max(offset - low + length, 0), srcSize)
573     //
574     // The new ExtractSliceOp length is `endLoc - newOffset`.
575     //
576     // Optimization: If low = 0, then the formula can be simplified.
577     OpFoldResult endLoc =
578         hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
579                   : min(add(offset, length), srcSize);
580     OpFoldResult newLength = sub(endLoc, newOffset);
581     newLengths.push_back(newLength);
582 
583     // Check if newLength is zero. In that case, no SubTensorOp should be
584     // executed.
585     if (isConstantIntValue(newLength, 0)) {
586       hasZeroLen = true;
587     } else if (!hasZeroLen) {
588       Value check = b.create<arith::CmpIOp>(
589           loc, arith::CmpIPredicate::eq,
590           getValueOrCreateConstantIndexOp(b, loc, newLength),
591           getValueOrCreateConstantIndexOp(b, loc, zero));
592       dynHasZeroLenCond =
593           dynHasZeroLenCond
594               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
595               : check;
596     }
597 
598     // The amount of high padding is simply the number of elements remaining,
599     // so that the result has the same length as the original ExtractSliceOp.
600     // As an optimization, if the original high padding is zero, then the new
601     // high padding must also be zero.
602     OpFoldResult newHigh =
603         hasHighPad ? sub(sub(length, newLength), newLow) : zero;
604     newHighs.push_back(newHigh);
605 
606     // Only unit stride supported.
607     newStrides.push_back(b.getIndexAttr(1));
608   }
609 
610   // The shape of the result can be obtained from the sizes passed in.
611   SmallVector<Value> dynDims;
612   SmallVector<int64_t> shape;
613   dispatchIndexOpFoldResults(sizes, dynDims, shape);
614   RankedTensorType resultType =
615       RankedTensorType::get(shape, padOp.getResultType().getElementType());
616 
617   // Insert cast to ensure that types match. (May be folded away.)
618   auto castResult = [&](Value val) -> Value {
619     if (resultType == val.getType())
620       return val;
621     return b.create<tensor::CastOp>(loc, resultType, val);
622   };
623 
624   // In cases where the original data source is unused: Emit a GenerateOp and
625   // do not generate a SliceOp. (The result shape of the SliceOp would
626   // have a dimension of size 0, the semantics of which is unclear.)
627   auto createGenerateOp = [&]() {
628     // Create GenerateOp.
629     auto generateOp = b.create<tensor::GenerateOp>(
630         loc, resultType, dynDims,
631         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
632           builder.create<tensor::YieldOp>(gLoc, padValue);
633         });
634     return generateOp;
635   };
636 
637   // Emit a SliceOp and a PadOp. Should not be used in cases where
638   // the result shape of the new SliceOp has a zero dimension.
639   auto createPadOfExtractSlice = [&]() {
640     // Create pad(extract_slice(x)).
641     Value newSliceOp = b.create<tensor::ExtractSliceOp>(
642         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
643     auto newPadOp = b.create<PadOp>(
644         loc, Type(), newSliceOp, newLows, newHighs,
645         /*nofold=*/padOp.getNofold(),
646         getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
647 
648     // Copy region to new PadOp.
649     IRMapping bvm;
650     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
651 
652     // Cast result and return.
653     return newPadOp;
654   };
655 
656   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
657   // the original data source x is not used.
658   if (hasZeroLen) {
659     Operation *generateOp = createGenerateOp();
660     return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
661   }
662 
663   // If there are dynamic dimensions: Generate an scf.if check to avoid
664   // creating SliceOps with result dimensions of size 0 at runtime.
665   if (generateZeroSliceGuard && dynHasZeroLenCond) {
666     Operation *thenOp;
667     Operation *elseOp;
668     auto result = b.create<scf::IfOp>(
669         loc, dynHasZeroLenCond,
670         /*thenBuilder=*/
671         [&](OpBuilder &b, Location loc) {
672           thenOp = createGenerateOp();
673           b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
674         },
675         /*elseBuilder=*/
676         [&](OpBuilder &b, Location loc) {
677           elseOp = createPadOfExtractSlice();
678           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
679         });
680     return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
681   }
682 
683   Operation *newPadOp = createPadOfExtractSlice();
684   return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
685 }
686 
687 void mlir::tensor::registerTilingInterfaceExternalModels(
688     DialectRegistry &registry) {
689   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
690     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
691     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
692     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
693   });
694 }
695 
696 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
697     DialectRegistry &registry) {
698   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
699     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
700     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
701   });
702 }
703