xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (revision 4c48f016effde67d500fc95290096aec9f3bdb70)
1 //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Affine/Utils.h"
12 #include "mlir/Dialect/Arith/Utils/Utils.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/IndexingUtils.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
21 
22 using namespace mlir;
23 using namespace mlir::tensor;
24 
25 namespace {
26 
27 struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
28 
29   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30     auto padOp = cast<PadOp>(op);
31     SmallVector<utils::IteratorType> iteratorTypes(
32         padOp.getResultType().getRank(), utils::IteratorType::parallel);
33     return iteratorTypes;
34   }
35 
36   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
37     ReifiedRankedShapedTypeDims reifiedShapes;
38     (void)reifyResultShapes(b, op, reifiedShapes);
39     Location loc = op->getLoc();
40     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
41     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
42     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
43     // overwritten.
44     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
45     for (const auto &ub : enumerate(reifiedShapes[0]))
46       loopRanges[ub.index()].size = ub.value();
47     return loopRanges;
48   }
49 
50   FailureOr<TilingResult>
51   getTiledImplementation(Operation *op, OpBuilder &b,
52                          ArrayRef<OpFoldResult> offsets,
53                          ArrayRef<OpFoldResult> sizes) const {
54     FailureOr<TilingResult> result =
55         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
56     if (failed(result))
57       return failure();
58     return result.value();
59   }
60 
61   LogicalResult
62   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
63                         ArrayRef<OpFoldResult> offsets,
64                         ArrayRef<OpFoldResult> sizes,
65                         SmallVector<OpFoldResult> &resultOffsets,
66                         SmallVector<OpFoldResult> &resultSizes) const {
67     resultOffsets.assign(offsets.begin(), offsets.end());
68     resultSizes.assign(sizes.begin(), sizes.end());
69     return success();
70   }
71 };
72 
73 template <typename OpTy>
74 static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
75                                                        OpBuilder &builder) {
76   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
77                 "applies to only pack or unpack operations");
78   OpBuilder::InsertionGuard g(builder);
79   Location loc = op.getLoc();
80   int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
81                                                      : op.getDestRank();
82   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
83   Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
84   ReifiedRankedShapedTypeDims resultShape;
85   (void)reifyResultShapes(builder, op, resultShape);
86   SmallVector<Range> loopBounds(rank);
87   for (auto dim : llvm::seq<int64_t>(0, rank)) {
88     loopBounds[dim].offset = zero;
89     loopBounds[dim].stride = one;
90     loopBounds[dim].size = resultShape[0][dim];
91   }
92   return loopBounds;
93 }
94 
95 static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
96                              SmallVector<OpFoldResult> &sizes,
97                              ArrayRef<int64_t> permutation) {
98   if (permutation.empty())
99     return;
100   applyPermutationToVector<OpFoldResult>(offsets, permutation);
101   applyPermutationToVector<OpFoldResult>(sizes, permutation);
102 }
103 
104 struct PackOpTiling
105     : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
106 
107   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
108     // Note that here we only consider untiled dimensions and outer tiled data
109     // dimensions, the inner tiled data dimensions are materialized when
110     // building the body of the operation.
111     auto packOp = cast<PackOp>(op);
112     SmallVector<utils::IteratorType> iteratorTypes(
113         packOp.getSourceRank(), utils::IteratorType::parallel);
114     return iteratorTypes;
115   }
116 
117   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
118     return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
119   }
120 
121   FailureOr<TilingResult>
122   getTiledImplementation(Operation *op, OpBuilder &b,
123                          ArrayRef<OpFoldResult> offsets,
124                          ArrayRef<OpFoldResult> sizes) const {
125     auto packOp = cast<PackOp>(op);
126     Location loc = packOp.getLoc();
127 
128     // The tiling is applied on interchanged dimensions. We have to undo the
129     // interchange to map sizes and offsets to the original input.
130     int64_t inputRank = packOp.getSourceRank();
131     SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end());
132     SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end());
133     applyPermToRange(origOffsets, origSizes,
134                      invertPermutationVector(packOp.getOuterDimsPerm()));
135 
136     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
137         packOp.getDimAndTileMapping();
138     SmallVector<OpFoldResult> srcDimValues =
139         tensor::createDimValues(b, loc, packOp.getSource());
140     SmallVector<OpFoldResult> inputIndices, inputSizes;
141     for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
142       using AV = affine::AffineValueExpr;
143       affine::AffineBuilder ab(b, loc);
144       AffineExpr dim0, dim1, sym;
145       bindDims(b.getContext(), dim0, dim1);
146       bindSymbols(b.getContext(), sym);
147       if (dimAndTileMapping.count(dim)) {
148         // If the data dimension is tiled, the i-th index is the product of
149         // offset_i and tile_i, and the i-th size is the product of sizes_i and
150         // tile_i.
151         auto avOffset = AV(dim0).bind(origOffsets[dim]);
152         auto avSize = AV(dim0).bind(origSizes[dim]);
153         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
154         inputIndices.push_back(ab.mul(avOffset, avTileSize));
155         inputSizes.push_back(ab.mul(avSize, avTileSize));
156       } else {
157         inputIndices.push_back(origOffsets[dim]);
158         inputSizes.push_back(origSizes[dim]);
159       }
160 
161       // Limit the size of the input operand for incomplete tiles.
162       if (packOp.getPaddingValue()) {
163         OpFoldResult dimSize = srcDimValues[dim];
164         auto avDimSize = AV(dim0).bind(dimSize);
165         auto avInputIdx = AV(dim1).bind(inputIndices.back());
166         inputSizes.back() =
167             ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
168       }
169     }
170 
171     auto oneAttr = b.getI64IntegerAttr(1);
172     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
173 
174     SmallVector<Value> tiledOperands;
175     tiledOperands.push_back(b.create<ExtractSliceOp>(
176         loc, packOp.getSource(), inputIndices, inputSizes, strides));
177 
178     SmallVector<OpFoldResult> outputOffsets, outputSizes;
179     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
180                                      outputSizes)))
181       return {};
182 
183     strides.append(packOp.getDestRank() - inputRank, oneAttr);
184     auto extractSlice = b.create<ExtractSliceOp>(
185         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
186     tiledOperands.push_back(extractSlice);
187 
188     if (auto val = packOp.getPaddingValue())
189       tiledOperands.push_back(val);
190     for (auto tile : packOp.getInnerTiles())
191       tiledOperands.push_back(tile);
192 
193     Operation *tiledPackOp = b.create<PackOp>(
194         loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
195 
196     return TilingResult{{tiledPackOp},
197                         SmallVector<Value>(tiledPackOp->getResults())};
198   }
199 
200   LogicalResult
201   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
202                         ArrayRef<OpFoldResult> offsets,
203                         ArrayRef<OpFoldResult> sizes,
204                         SmallVector<OpFoldResult> &resultOffsets,
205                         SmallVector<OpFoldResult> &resultSizes) const {
206     // The iteration domain is over outer dimensions of packed layout. In this
207     // context, the outer dimensions of `resultOffsets` are `offsets`. The
208     // inner dimensions of `resultOffsets` are zeros because tiling is not
209     // applied to them.
210     auto packOp = cast<PackOp>(op);
211     int64_t inputRank = packOp.getSourceRank();
212     int64_t outputRank = packOp.getDestRank();
213     auto zeroAttr = b.getI64IntegerAttr(0);
214     resultOffsets.assign(offsets.begin(), offsets.end());
215     resultOffsets.append(outputRank - inputRank, zeroAttr);
216 
217     ReifiedRankedShapedTypeDims outputShape;
218     (void)reifyResultShapes(b, packOp, outputShape);
219     resultSizes.assign(sizes.begin(), sizes.end());
220     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
221       resultSizes.push_back(outputShape[0][dataTileDim]);
222 
223     return success();
224   }
225 };
226 
227 struct UnpackTileDimInfo {
228   bool isAlignedToInnerTileSize;
229   OpFoldResult sourceOffset;
230   OpFoldResult sourceSize;
231   OpFoldResult resultOffset;
232   OpFoldResult destExpandedSize;
233 };
234 
235 /// Returns the needed information for tiling unpack op on `tileDim` with given
236 /// `tileOffset` and `tileSize`. For more details, see the comment of the
237 /// `getTiledImplementation`.
238 static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
239                                               int64_t tileDim,
240                                               OpFoldResult tileOffset,
241                                               OpFoldResult tileSize) {
242   UnpackTileDimInfo info;
243   Attribute zeroAttr = b.getIndexAttr(0);
244   Attribute oneAttr = b.getIndexAttr(1);
245   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
246       unpackOp.getDimAndTileMapping();
247   // The dimension is not one of packed data dimension.
248   if (!dimAndTileMapping.count(tileDim)) {
249     info.isAlignedToInnerTileSize = true;
250     info.sourceOffset = tileOffset;
251     info.sourceSize = tileSize;
252     info.resultOffset = zeroAttr;
253     info.destExpandedSize = tileSize;
254     return info;
255   }
256 
257   Location loc = unpackOp.getLoc();
258   using AV = affine::AffineValueExpr;
259   affine::AffineBuilder ab(b, loc);
260   AffineExpr dim0, dim1, sym0;
261   bindDims(b.getContext(), dim0, dim1);
262   bindSymbols(b.getContext(), sym0);
263 
264   OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
265 
266   info.isAlignedToInnerTileSize = false;
267   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
268       presburger::BoundType::UB,
269       getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
270       /*stopCondition=*/nullptr, /*closedUB=*/true);
271   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
272   if (!failed(cstSize) && cstInnerSize) {
273     if (*cstSize % *cstInnerSize == 0)
274       info.isAlignedToInnerTileSize = true;
275 
276     // If the tiling size equals to the inner tiling size, the outer dims are
277     // always 1.
278     if (*cstInnerSize == *cstSize) {
279       auto lhs = AV(dim0).bind(tileOffset);
280       auto rhs = AV(dim1).bind(innerTileSize);
281       info.sourceOffset = ab.floor(lhs, rhs);
282       info.sourceSize = oneAttr;
283       info.resultOffset = zeroAttr;
284       info.destExpandedSize = tileSize;
285       return info;
286     }
287   }
288 
289   if (info.isAlignedToInnerTileSize) {
290     info.sourceOffset =
291         ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
292     info.resultOffset = zeroAttr;
293     info.destExpandedSize = tileSize;
294 
295     // The ceilDiv is needed here because there could be incomplete tile even
296     // it is perfect tiling cases. E.g.,
297     //   %0 = unpack tensor<33x2xf32> into tensor<64xf32>
298     // If the tiling size is 32, there will be 3 tiles. Two of them have
299     // size=32; one of them have size=2. The size is represented using
300     // affine_min op; we need ceilDiv.
301     info.sourceSize =
302         ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
303     return info;
304   }
305 
306   affine::DivModValue firstCoord = affine::getDivMod(
307       b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
308       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
309   OpFoldResult tileExclusiveBound =
310       ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
311   affine::DivModValue lastCoord = affine::getDivMod(
312       b, loc,
313       getValueOrCreateConstantIndexOp(
314           b, loc,
315           ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
316       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
317 
318   OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
319                                        AV(dim1).bind(firstCoord.quotient));
320   info.sourceSize =
321       ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
322   info.sourceOffset = firstCoord.quotient;
323   info.resultOffset = firstCoord.remainder;
324   // Do not create an Affine ops for expanded size because the affine op is too
325   // complicated which would trigger an issue in affine ops simplification.
326   info.destExpandedSize = b.createOrFold<arith::MulIOp>(
327       loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
328       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
329   return info;
330 }
331 
332 struct UnPackOpTiling
333     : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
334 
335   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
336     auto unpackOp = cast<UnPackOp>(op);
337     SmallVector<utils::IteratorType> iteratorTypes(
338         unpackOp.getDestRank(), utils::IteratorType::parallel);
339     return iteratorTypes;
340   }
341 
342   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
343     return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
344   }
345 
346   /// There are two cases in tiling unpack ops. If the tiling size is aligned to
347   /// the inner tile size, the corresponding tiles of source are all complete.
348   /// Otherwise, there are in-complete tiles. We will need to expand the slice
349   /// of source for getting complete tiles. The tiled unpack op unpacks more
350   /// data from source, so We'll need an extract_slice op to shift and truncate
351   /// the output.
352   /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
353   /// coordinates of second tile (i.e., result[15..31]) are
354   /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
355   /// row are incomplete tiles. To represent the unpack op, we have to complete
356   /// the rows. I.e., the input coordinates would start with (1, 0); end with
357   /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
358   /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
359   /// can get the actual result.
360   FailureOr<TilingResult>
361   getTiledImplementation(Operation *op, OpBuilder &b,
362                          ArrayRef<OpFoldResult> offsets,
363                          ArrayRef<OpFoldResult> sizes) const {
364     auto unpackOp = cast<UnPackOp>(op);
365     int64_t srcRank = unpackOp.getSourceRank();
366     int64_t destRank = unpackOp.getDestRank();
367     int64_t numInnerTiles = srcRank - destRank;
368     Location loc = unpackOp.getLoc();
369 
370     // The perfect tiling case indicates that the tiling sizes are multiple of
371     // inner_tile_size. In this context, no extra data is needed when
372     // representing the tiled unpack op.
373     bool isPerfectTilingCase = true;
374     Attribute oneAttr = b.getIndexAttr(1);
375     SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
376     SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
377     SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
378     for (auto dim : llvm::seq<int64_t>(0, destRank)) {
379       UnpackTileDimInfo info =
380           getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
381       if (!info.isAlignedToInnerTileSize)
382         isPerfectTilingCase = false;
383       sliceSrcIndices.push_back(info.sourceOffset);
384       sliceSrcSizes.push_back(info.sourceSize);
385       destExpandedSizes.push_back(info.destExpandedSize);
386       resultOffsetsFromDest.push_back(info.resultOffset);
387     }
388 
389     // The tiling is applied on destination dimensions. We have to apply the
390     // interchange on source dimensions if outer_dims_perm is set.
391     applyPermToRange(sliceSrcIndices, sliceSrcSizes,
392                      unpackOp.getOuterDimsPerm());
393     Attribute zeroAttr = b.getIndexAttr(0);
394     sliceSrcIndices.append(numInnerTiles, zeroAttr);
395     sliceSrcSizes.append(unpackOp.getMixedTiles());
396     sliceSrcStrides.append(numInnerTiles, oneAttr);
397     Value sliceSource =
398         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
399                                  sliceSrcSizes, sliceSrcStrides);
400 
401     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
402     Value sliceDest;
403     if (isPerfectTilingCase) {
404       sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
405                                            sizes, destStrides);
406     } else {
407       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
408                                     unpackOp.getDestType().getElementType());
409     }
410 
411     SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
412     for (auto tile : unpackOp.getInnerTiles())
413       tiledOperands.push_back(tile);
414 
415     Operation *tiledUnpackOp = b.create<UnPackOp>(
416         loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
417 
418     if (isPerfectTilingCase)
419       return TilingResult{{tiledUnpackOp},
420                           SmallVector<Value>(tiledUnpackOp->getResults())};
421 
422     auto extractSlice =
423         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
424                                  resultOffsetsFromDest, sizes, destStrides);
425     return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
426   }
427 
428   LogicalResult
429   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
430                         ArrayRef<OpFoldResult> offsets,
431                         ArrayRef<OpFoldResult> sizes,
432                         SmallVector<OpFoldResult> &resultOffsets,
433                         SmallVector<OpFoldResult> &resultSizes) const {
434     resultOffsets = llvm::to_vector(offsets);
435     resultSizes = llvm::to_vector(sizes);
436     return success();
437   }
438 
439   FailureOr<TilingResult>
440   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
441                           ArrayRef<OpFoldResult> offsets,
442                           ArrayRef<OpFoldResult> sizes) const {
443     FailureOr<TilingResult> tilingResult =
444         getTiledImplementation(op, b, offsets, sizes);
445     if (failed(tilingResult))
446       return failure();
447     return tilingResult.value();
448   }
449 };
450 
451 } // namespace
452 
453 FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
454                                                  tensor::PadOp padOp,
455                                                  ArrayRef<OpFoldResult> offsets,
456                                                  ArrayRef<OpFoldResult> sizes,
457                                                  bool generateZeroSliceGuard) {
458   // Only constant padding value supported.
459   Value padValue = padOp.getConstantPaddingValue();
460   if (!padValue)
461     return failure();
462 
463   // Helper variables and functions for various arithmetic operations. These
464   // are used extensively for computing new offset/length and padding values.
465   Location loc = padOp->getLoc();
466   AffineExpr dim0, dim1;
467   bindDims(b.getContext(), dim0, dim1);
468   // Add two integers.
469   auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
470   auto add = [&](OpFoldResult v1, OpFoldResult v2) {
471     return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2});
472   };
473   // Subtract two integers.
474   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
475   auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
476     return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
477   };
478   // Take the minimum of two integers.
479   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
480   auto min = [&](OpFoldResult v1, OpFoldResult v2) {
481     return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
482   };
483   // Take the maximum of two integers.
484   auto max = [&](OpFoldResult v1, OpFoldResult v2) {
485     return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
486   };
487   // Zero index-typed integer.
488   OpFoldResult zero = b.getIndexAttr(0);
489 
490   // Compute new offsets, lengths, low padding, high padding.
491   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
492   SmallVector<OpFoldResult> newLows, newHighs;
493   // Set to true if the original data source is not read at all.
494   bool hasZeroLen = false;
495   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
496   // is true if the original data source turns out to be unused at runtime.
497   Value dynHasZeroLenCond;
498 
499   int64_t rank = padOp.getSourceType().getRank();
500   for (unsigned dim = 0; dim < rank; ++dim) {
501     auto low = padOp.getMixedLowPad()[dim];
502     bool hasLowPad = !isConstantIntValue(low, 0);
503     auto high = padOp.getMixedHighPad()[dim];
504     bool hasHighPad = !isConstantIntValue(high, 0);
505     auto offset = offsets[dim];
506     auto length = sizes[dim];
507     auto srcSize =
508         tensor::createDimValue(b, loc, padOp.getSource(), dim).value();
509 
510     // The new amount of low padding is `low - offset`. Except for the case
511     // where none of the low padding is read. In that case, the new amount of
512     // low padding is zero.
513     //
514     // Optimization: If low = 0, then newLow = 0.
515     OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
516     newLows.push_back(newLow);
517 
518     // Start reading the data from position `offset - low`. Since the original
519     // read may have started in the low padding zone, this value could be
520     // negative. Therefore, start reading from:
521     //
522     // max(offset - low, 0)
523     //
524     // The original read could also have started in the high padding zone.
525     // In that case, set the offset to the end of source tensor. The new
526     // ExtractSliceOp length will be zero in that case. (Effectively reading
527     // no data from the source.)
528     //
529     // Optimization: If low = 0, then the formula can be simplified.
530     OpFoldResult newOffset = hasLowPad
531                                  ? min(max(sub(offset, low), zero), srcSize)
532                                  : min(offset, srcSize);
533     newOffsets.push_back(newOffset);
534 
535     // The original ExtractSliceOp was reading until position `offset +
536     // length`. Therefore, the corresponding position within the source tensor
537     // is:
538     //
539     // offset + length - low
540     //
541     // In case the original ExtractSliceOp stopped reading within the low
542     // padding zone, this value can be negative. In that case, the end
543     // position of the read should be zero. (Similar to newOffset.)
544     //
545     // The original read could also have stopped in the high padding zone.
546     // In that case, set the end positition of the read should be the end of
547     // the source tensor. (Similar to newOffset.)
548     //
549     // endLoc = min(max(offset - low + length, 0), srcSize)
550     //
551     // The new ExtractSliceOp length is `endLoc - newOffset`.
552     //
553     // Optimization: If low = 0, then the formula can be simplified.
554     OpFoldResult endLoc =
555         hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
556                   : min(add(offset, length), srcSize);
557     OpFoldResult newLength = sub(endLoc, newOffset);
558     newLengths.push_back(newLength);
559 
560     // Check if newLength is zero. In that case, no SubTensorOp should be
561     // executed.
562     if (isConstantIntValue(newLength, 0)) {
563       hasZeroLen = true;
564     } else if (!hasZeroLen) {
565       Value check = b.create<arith::CmpIOp>(
566           loc, arith::CmpIPredicate::eq,
567           getValueOrCreateConstantIndexOp(b, loc, newLength),
568           getValueOrCreateConstantIndexOp(b, loc, zero));
569       dynHasZeroLenCond =
570           dynHasZeroLenCond
571               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
572               : check;
573     }
574 
575     // The amount of high padding is simply the number of elements remaining,
576     // so that the result has the same length as the original ExtractSliceOp.
577     // As an optimization, if the original high padding is zero, then the new
578     // high padding must also be zero.
579     OpFoldResult newHigh =
580         hasHighPad ? sub(sub(length, newLength), newLow) : zero;
581     newHighs.push_back(newHigh);
582 
583     // Only unit stride supported.
584     newStrides.push_back(b.getIndexAttr(1));
585   }
586 
587   // The shape of the result can be obtained from the sizes passed in.
588   SmallVector<Value> dynDims;
589   SmallVector<int64_t> shape;
590   dispatchIndexOpFoldResults(sizes, dynDims, shape);
591   RankedTensorType resultType =
592       RankedTensorType::get(shape, padOp.getResultType().getElementType());
593 
594   // Insert cast to ensure that types match. (May be folded away.)
595   auto castResult = [&](Value val) -> Value {
596     if (resultType == val.getType())
597       return val;
598     return b.create<tensor::CastOp>(loc, resultType, val);
599   };
600 
601   // In cases where the original data source is unused: Emit a GenerateOp and
602   // do not generate a SliceOp. (The result shape of the SliceOp would
603   // have a dimension of size 0, the semantics of which is unclear.)
604   auto createGenerateOp = [&]() {
605     // Create GenerateOp.
606     auto generateOp = b.create<tensor::GenerateOp>(
607         loc, resultType, dynDims,
608         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
609           builder.create<tensor::YieldOp>(gLoc, padValue);
610         });
611     return generateOp;
612   };
613 
614   // Emit a SliceOp and a PadOp. Should not be used in cases where
615   // the result shape of the new SliceOp has a zero dimension.
616   auto createPadOfExtractSlice = [&]() {
617     // Create pad(extract_slice(x)).
618     Value newSliceOp = b.create<tensor::ExtractSliceOp>(
619         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
620     auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs,
621                                     /*nofold=*/padOp.getNofold());
622 
623     // Copy region to new PadOp.
624     IRMapping bvm;
625     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
626 
627     // Cast result and return.
628     return newPadOp;
629   };
630 
631   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
632   // the original data source x is not used.
633   if (hasZeroLen) {
634     Operation *generateOp = createGenerateOp();
635     return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}};
636   }
637 
638   // If there are dynamic dimensions: Generate an scf.if check to avoid
639   // creating SliceOps with result dimensions of size 0 at runtime.
640   if (generateZeroSliceGuard && dynHasZeroLenCond) {
641     Operation *thenOp;
642     Operation *elseOp;
643     auto result = b.create<scf::IfOp>(
644         loc, dynHasZeroLenCond,
645         /*thenBuilder=*/
646         [&](OpBuilder &b, Location loc) {
647           thenOp = createGenerateOp();
648           b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
649         },
650         /*elseBuilder=*/
651         [&](OpBuilder &b, Location loc) {
652           elseOp = createPadOfExtractSlice();
653           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
654         });
655     return TilingResult{{elseOp}, SmallVector<Value>(result->getResults())};
656   }
657 
658   Operation *newPadOp = createPadOfExtractSlice();
659   return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}};
660 }
661 
662 void mlir::tensor::registerTilingInterfaceExternalModels(
663     DialectRegistry &registry) {
664   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
665     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
666     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
667     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
668   });
669 }
670 
671 void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
672     DialectRegistry &registry) {
673   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
674     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
675     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
676   });
677 }
678