xref: /llvm-project/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (revision 33927744db2a910fe1cdeecf9e074d488de2e787)
1fd0c6f53SAlexander Belyaev //===- TensorTilingInterface.cpp - Tiling Interface  models *- C++ ------*-===//
2fd0c6f53SAlexander Belyaev //
3fd0c6f53SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fd0c6f53SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5fd0c6f53SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fd0c6f53SAlexander Belyaev //
7fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
8fd0c6f53SAlexander Belyaev 
9fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Affine/IR/AffineOps.h"
110d03ba62SHanhan Wang #include "mlir/Dialect/Affine/Utils.h"
12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
13fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Linalg/IR/Linalg.h"
1483396d85SHanhan Wang #include "mlir/Dialect/Linalg/Utils/Utils.h"
158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
16fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/IR/Tensor.h"
170d03ba62SHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h"
180d03ba62SHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h"
198cc616bcSMax191 #include "mlir/Interfaces/InferTypeOpInterface.h"
20fd0c6f53SAlexander Belyaev #include "mlir/Interfaces/TilingInterface.h"
21eabb6ccdSMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h"
22fd0c6f53SAlexander Belyaev 
23fd0c6f53SAlexander Belyaev using namespace mlir;
24fd0c6f53SAlexander Belyaev using namespace mlir::tensor;
25fd0c6f53SAlexander Belyaev 
26fd0c6f53SAlexander Belyaev namespace {
27fd0c6f53SAlexander Belyaev 
28fd0c6f53SAlexander Belyaev struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
29fd0c6f53SAlexander Belyaev 
304f1c1242SOleg Shyshkov   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
31fd0c6f53SAlexander Belyaev     auto padOp = cast<PadOp>(op);
324f1c1242SOleg Shyshkov     SmallVector<utils::IteratorType> iteratorTypes(
334f1c1242SOleg Shyshkov         padOp.getResultType().getRank(), utils::IteratorType::parallel);
34fd0c6f53SAlexander Belyaev     return iteratorTypes;
35fd0c6f53SAlexander Belyaev   }
36fd0c6f53SAlexander Belyaev 
37fd0c6f53SAlexander Belyaev   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
38fd0c6f53SAlexander Belyaev     ReifiedRankedShapedTypeDims reifiedShapes;
39758329dcSMatthias Springer     (void)reifyResultShapes(b, op, reifiedShapes);
403e12cf2aSMatthias Springer     OpFoldResult zero = b.getIndexAttr(0);
413e12cf2aSMatthias Springer     OpFoldResult one = b.getIndexAttr(1);
42fd0c6f53SAlexander Belyaev     // Initialize all the ranges to {zero, one, one}. All the `ub`s are
43fd0c6f53SAlexander Belyaev     // overwritten.
44fd0c6f53SAlexander Belyaev     SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
45fd0c6f53SAlexander Belyaev     for (const auto &ub : enumerate(reifiedShapes[0]))
46fd0c6f53SAlexander Belyaev       loopRanges[ub.index()].size = ub.value();
47fd0c6f53SAlexander Belyaev     return loopRanges;
48fd0c6f53SAlexander Belyaev   }
49fd0c6f53SAlexander Belyaev 
50809e3d8cSMahesh Ravishankar   FailureOr<TilingResult>
5154794284SMatthias Springer   getTiledImplementation(Operation *op, OpBuilder &b,
52fd0c6f53SAlexander Belyaev                          ArrayRef<OpFoldResult> offsets,
5354794284SMatthias Springer                          ArrayRef<OpFoldResult> sizes) const {
54809e3d8cSMahesh Ravishankar     FailureOr<TilingResult> result =
550edb4127SLei Zhang         tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
56809e3d8cSMahesh Ravishankar     if (failed(result))
57809e3d8cSMahesh Ravishankar       return failure();
58809e3d8cSMahesh Ravishankar     return result.value();
590edb4127SLei Zhang   }
60a235562cSMahesh Ravishankar 
61a235562cSMahesh Ravishankar   LogicalResult
62a235562cSMahesh Ravishankar   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
63a235562cSMahesh Ravishankar                         ArrayRef<OpFoldResult> offsets,
64a235562cSMahesh Ravishankar                         ArrayRef<OpFoldResult> sizes,
65f220c359SOleksandr "Alex" Zinenko                         SmallVector<OpFoldResult> &resultOffsets,
66f220c359SOleksandr "Alex" Zinenko                         SmallVector<OpFoldResult> &resultSizes) const {
67a235562cSMahesh Ravishankar     resultOffsets.assign(offsets.begin(), offsets.end());
68a235562cSMahesh Ravishankar     resultSizes.assign(sizes.begin(), sizes.end());
69a235562cSMahesh Ravishankar     return success();
70a235562cSMahesh Ravishankar   }
7191e57c6fSQuinn Dawkins 
7291e57c6fSQuinn Dawkins   LogicalResult getIterationDomainTileFromResultTile(
7391e57c6fSQuinn Dawkins       Operation *op, OpBuilder &b, unsigned resultNumber,
7491e57c6fSQuinn Dawkins       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
7591e57c6fSQuinn Dawkins       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
7691e57c6fSQuinn Dawkins       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
7791e57c6fSQuinn Dawkins     iterDomainOffsets.assign(offsets.begin(), offsets.end());
7891e57c6fSQuinn Dawkins     iterDomainSizes.assign(sizes.begin(), sizes.end());
7991e57c6fSQuinn Dawkins     return success();
8091e57c6fSQuinn Dawkins   }
8191e57c6fSQuinn Dawkins 
8291e57c6fSQuinn Dawkins   FailureOr<TilingResult>
8391e57c6fSQuinn Dawkins   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
8491e57c6fSQuinn Dawkins                           ArrayRef<OpFoldResult> offsets,
8591e57c6fSQuinn Dawkins                           ArrayRef<OpFoldResult> sizes) const {
8691e57c6fSQuinn Dawkins     return getTiledImplementation(op, b, offsets, sizes);
8791e57c6fSQuinn Dawkins   }
880edb4127SLei Zhang };
890edb4127SLei Zhang 
9083396d85SHanhan Wang template <typename OpTy>
9183396d85SHanhan Wang static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
9283396d85SHanhan Wang                                                        OpBuilder &builder) {
9383396d85SHanhan Wang   static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
9483396d85SHanhan Wang                 "applies to only pack or unpack operations");
9583396d85SHanhan Wang   OpBuilder::InsertionGuard g(builder);
9683396d85SHanhan Wang   int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
9783396d85SHanhan Wang                                                      : op.getDestRank();
983e12cf2aSMatthias Springer   OpFoldResult zero = builder.getIndexAttr(0);
993e12cf2aSMatthias Springer   OpFoldResult one = builder.getIndexAttr(1);
10083396d85SHanhan Wang   ReifiedRankedShapedTypeDims resultShape;
101758329dcSMatthias Springer   (void)reifyResultShapes(builder, op, resultShape);
10283396d85SHanhan Wang   SmallVector<Range> loopBounds(rank);
10383396d85SHanhan Wang   for (auto dim : llvm::seq<int64_t>(0, rank)) {
10483396d85SHanhan Wang     loopBounds[dim].offset = zero;
10583396d85SHanhan Wang     loopBounds[dim].stride = one;
10683396d85SHanhan Wang     loopBounds[dim].size = resultShape[0][dim];
10783396d85SHanhan Wang   }
10883396d85SHanhan Wang   return loopBounds;
10983396d85SHanhan Wang }
11083396d85SHanhan Wang 
111d5a9fc13SLorenzo Chelini static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
11283396d85SHanhan Wang                              SmallVector<OpFoldResult> &sizes,
11383396d85SHanhan Wang                              ArrayRef<int64_t> permutation) {
11483396d85SHanhan Wang   if (permutation.empty())
11583396d85SHanhan Wang     return;
116d5a9fc13SLorenzo Chelini   applyPermutationToVector<OpFoldResult>(offsets, permutation);
117d5a9fc13SLorenzo Chelini   applyPermutationToVector<OpFoldResult>(sizes, permutation);
11883396d85SHanhan Wang }
11983396d85SHanhan Wang 
1200d03ba62SHanhan Wang struct PackOpTiling
1210d03ba62SHanhan Wang     : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
1220d03ba62SHanhan Wang 
1230d03ba62SHanhan Wang   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
1240d03ba62SHanhan Wang     // Note that here we only consider untiled dimensions and outer tiled data
1250d03ba62SHanhan Wang     // dimensions, the inner tiled data dimensions are materialized when
1260d03ba62SHanhan Wang     // building the body of the operation.
1270d03ba62SHanhan Wang     auto packOp = cast<PackOp>(op);
1280d03ba62SHanhan Wang     SmallVector<utils::IteratorType> iteratorTypes(
1290d03ba62SHanhan Wang         packOp.getSourceRank(), utils::IteratorType::parallel);
1300d03ba62SHanhan Wang     return iteratorTypes;
1310d03ba62SHanhan Wang   }
1320d03ba62SHanhan Wang 
1330d03ba62SHanhan Wang   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
13483396d85SHanhan Wang     return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
1350d03ba62SHanhan Wang   }
1360d03ba62SHanhan Wang 
137809e3d8cSMahesh Ravishankar   FailureOr<TilingResult>
1380d03ba62SHanhan Wang   getTiledImplementation(Operation *op, OpBuilder &b,
1390d03ba62SHanhan Wang                          ArrayRef<OpFoldResult> offsets,
1400d03ba62SHanhan Wang                          ArrayRef<OpFoldResult> sizes) const {
1410d03ba62SHanhan Wang     auto packOp = cast<PackOp>(op);
1420d03ba62SHanhan Wang     Location loc = packOp.getLoc();
1430d03ba62SHanhan Wang 
1440d03ba62SHanhan Wang     // The tiling is applied on interchanged dimensions. We have to undo the
1450d03ba62SHanhan Wang     // interchange to map sizes and offsets to the original input.
1460d03ba62SHanhan Wang     int64_t inputRank = packOp.getSourceRank();
1475262865aSKazu Hirata     SmallVector<OpFoldResult> origOffsets(offsets);
1485262865aSKazu Hirata     SmallVector<OpFoldResult> origSizes(sizes);
149d5a9fc13SLorenzo Chelini     applyPermToRange(origOffsets, origSizes,
150d5a9fc13SLorenzo Chelini                      invertPermutationVector(packOp.getOuterDimsPerm()));
1510d03ba62SHanhan Wang 
1520d03ba62SHanhan Wang     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1530d03ba62SHanhan Wang         packOp.getDimAndTileMapping();
1540d03ba62SHanhan Wang     SmallVector<OpFoldResult> srcDimValues =
1556596b0ddSMatthias Springer         tensor::getMixedSizes(b, loc, packOp.getSource());
1560d03ba62SHanhan Wang     SmallVector<OpFoldResult> inputIndices, inputSizes;
1570d03ba62SHanhan Wang     for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
1584c48f016SMatthias Springer       using AV = affine::AffineValueExpr;
1594c48f016SMatthias Springer       affine::AffineBuilder ab(b, loc);
1600d03ba62SHanhan Wang       AffineExpr dim0, dim1, sym;
1610d03ba62SHanhan Wang       bindDims(b.getContext(), dim0, dim1);
1620d03ba62SHanhan Wang       bindSymbols(b.getContext(), sym);
1630d03ba62SHanhan Wang       if (dimAndTileMapping.count(dim)) {
1640d03ba62SHanhan Wang         // If the data dimension is tiled, the i-th index is the product of
1650d03ba62SHanhan Wang         // offset_i and tile_i, and the i-th size is the product of sizes_i and
1660d03ba62SHanhan Wang         // tile_i.
1670d03ba62SHanhan Wang         auto avOffset = AV(dim0).bind(origOffsets[dim]);
1680d03ba62SHanhan Wang         auto avSize = AV(dim0).bind(origSizes[dim]);
1690d03ba62SHanhan Wang         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
1700d03ba62SHanhan Wang         inputIndices.push_back(ab.mul(avOffset, avTileSize));
1710d03ba62SHanhan Wang         inputSizes.push_back(ab.mul(avSize, avTileSize));
1720d03ba62SHanhan Wang       } else {
1730d03ba62SHanhan Wang         inputIndices.push_back(origOffsets[dim]);
1740d03ba62SHanhan Wang         inputSizes.push_back(origSizes[dim]);
1750d03ba62SHanhan Wang       }
1760d03ba62SHanhan Wang 
1770d03ba62SHanhan Wang       // Limit the size of the input operand for incomplete tiles.
1784d036510SHanhan Wang       if (packOp.getPaddingValue()) {
1790d03ba62SHanhan Wang         OpFoldResult dimSize = srcDimValues[dim];
1800d03ba62SHanhan Wang         auto avDimSize = AV(dim0).bind(dimSize);
1810d03ba62SHanhan Wang         auto avInputIdx = AV(dim1).bind(inputIndices.back());
1820d03ba62SHanhan Wang         inputSizes.back() =
1830d03ba62SHanhan Wang             ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
1840d03ba62SHanhan Wang       }
1854d036510SHanhan Wang     }
1860d03ba62SHanhan Wang 
1870d03ba62SHanhan Wang     auto oneAttr = b.getI64IntegerAttr(1);
1880d03ba62SHanhan Wang     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
1890d03ba62SHanhan Wang 
1900d03ba62SHanhan Wang     SmallVector<Value> tiledOperands;
191d5f0969cSMaheshRavishankar     auto sourceSlice = b.create<ExtractSliceOp>(
192d5f0969cSMaheshRavishankar         loc, packOp.getSource(), inputIndices, inputSizes, strides);
193d5f0969cSMaheshRavishankar     tiledOperands.push_back(sourceSlice);
1940d03ba62SHanhan Wang 
1950d03ba62SHanhan Wang     SmallVector<OpFoldResult> outputOffsets, outputSizes;
1960d03ba62SHanhan Wang     if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
1970d03ba62SHanhan Wang                                      outputSizes)))
1980d03ba62SHanhan Wang       return {};
1990d03ba62SHanhan Wang 
2000d03ba62SHanhan Wang     strides.append(packOp.getDestRank() - inputRank, oneAttr);
201d5f0969cSMaheshRavishankar     auto outSlice = b.create<ExtractSliceOp>(
2020d03ba62SHanhan Wang         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
203d5f0969cSMaheshRavishankar     tiledOperands.push_back(outSlice);
2040d03ba62SHanhan Wang 
2050d03ba62SHanhan Wang     if (auto val = packOp.getPaddingValue())
2060d03ba62SHanhan Wang       tiledOperands.push_back(val);
2070d03ba62SHanhan Wang     for (auto tile : packOp.getInnerTiles())
2080d03ba62SHanhan Wang       tiledOperands.push_back(tile);
2090d03ba62SHanhan Wang 
2100d03ba62SHanhan Wang     Operation *tiledPackOp = b.create<PackOp>(
211d5f0969cSMaheshRavishankar         loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
2120d03ba62SHanhan Wang 
213d5f0969cSMaheshRavishankar     return TilingResult{
214d5f0969cSMaheshRavishankar         {tiledPackOp},
215d5f0969cSMaheshRavishankar         SmallVector<Value>(tiledPackOp->getResults()),
216d5f0969cSMaheshRavishankar         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
2170d03ba62SHanhan Wang   }
2180d03ba62SHanhan Wang 
2190d03ba62SHanhan Wang   LogicalResult
2200d03ba62SHanhan Wang   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
2210d03ba62SHanhan Wang                         ArrayRef<OpFoldResult> offsets,
2220d03ba62SHanhan Wang                         ArrayRef<OpFoldResult> sizes,
223f220c359SOleksandr "Alex" Zinenko                         SmallVector<OpFoldResult> &resultOffsets,
224f220c359SOleksandr "Alex" Zinenko                         SmallVector<OpFoldResult> &resultSizes) const {
2250d03ba62SHanhan Wang     // The iteration domain is over outer dimensions of packed layout. In this
2260d03ba62SHanhan Wang     // context, the outer dimensions of `resultOffsets` are `offsets`. The
2270d03ba62SHanhan Wang     // inner dimensions of `resultOffsets` are zeros because tiling is not
2280d03ba62SHanhan Wang     // applied to them.
2290d03ba62SHanhan Wang     auto packOp = cast<PackOp>(op);
2300d03ba62SHanhan Wang     int64_t inputRank = packOp.getSourceRank();
2310d03ba62SHanhan Wang     int64_t outputRank = packOp.getDestRank();
2320d03ba62SHanhan Wang     auto zeroAttr = b.getI64IntegerAttr(0);
2330d03ba62SHanhan Wang     resultOffsets.assign(offsets.begin(), offsets.end());
2340d03ba62SHanhan Wang     resultOffsets.append(outputRank - inputRank, zeroAttr);
2350d03ba62SHanhan Wang 
2360d03ba62SHanhan Wang     ReifiedRankedShapedTypeDims outputShape;
237758329dcSMatthias Springer     (void)reifyResultShapes(b, packOp, outputShape);
2380d03ba62SHanhan Wang     resultSizes.assign(sizes.begin(), sizes.end());
2390d03ba62SHanhan Wang     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
2402a5b13e7SMatthias Springer       resultSizes.push_back(outputShape[0][dataTileDim]);
2410d03ba62SHanhan Wang 
2420d03ba62SHanhan Wang     return success();
2430d03ba62SHanhan Wang   }
2448b68cec9SHanhan Wang 
2458b68cec9SHanhan Wang   FailureOr<TilingResult>
2468b68cec9SHanhan Wang   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
2478b68cec9SHanhan Wang                           ArrayRef<OpFoldResult> offsets,
2488b68cec9SHanhan Wang                           ArrayRef<OpFoldResult> sizes) const {
2498b68cec9SHanhan Wang     auto packOp = cast<PackOp>(op);
2508b68cec9SHanhan Wang     int64_t numTiles = packOp.getInnerDimsPos().size();
2518b68cec9SHanhan Wang 
2528b68cec9SHanhan Wang     // tensor.pack op is fusible (as a producer) only if full inner tiles are
2538b68cec9SHanhan Wang     // iterated or inner dims are not tiled. Otherwise, it will generate a
2548b68cec9SHanhan Wang     // sequence of non-trivial ops (for partial tiles).
2558b68cec9SHanhan Wang     for (auto offset : offsets.take_back(numTiles))
2568b68cec9SHanhan Wang       if (!isConstantIntValue(offset, 0))
2578b68cec9SHanhan Wang         return failure();
2588b68cec9SHanhan Wang 
2598b68cec9SHanhan Wang     for (auto iter :
2608b68cec9SHanhan Wang          llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
2618b68cec9SHanhan Wang       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
2628b68cec9SHanhan Wang         return failure();
2638b68cec9SHanhan Wang 
2648b68cec9SHanhan Wang     FailureOr<TilingResult> tilingResult = getTiledImplementation(
2658b68cec9SHanhan Wang         op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
2668b68cec9SHanhan Wang     if (failed(tilingResult))
2678b68cec9SHanhan Wang       return failure();
2688b68cec9SHanhan Wang     return tilingResult.value();
2698b68cec9SHanhan Wang   }
270f06563a5SYun-Fly 
271f06563a5SYun-Fly   /// Method to return the position of iteration domain tile computed by the
272f06563a5SYun-Fly   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
273f06563a5SYun-Fly   /// `resultSizes` only cover outer dimensions.
274f06563a5SYun-Fly   LogicalResult getIterationDomainTileFromOperandTile(
275f06563a5SYun-Fly       Operation *op, OpBuilder &b, unsigned operandNumber,
276f06563a5SYun-Fly       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
277f06563a5SYun-Fly       SmallVectorImpl<OpFoldResult> &resultOffsets,
278f06563a5SYun-Fly       SmallVectorImpl<OpFoldResult> &resultSizes) const {
279f06563a5SYun-Fly     if (operandNumber != 0)
280f06563a5SYun-Fly       return failure();
281f06563a5SYun-Fly 
282f06563a5SYun-Fly     auto packOp = cast<PackOp>(op);
283f06563a5SYun-Fly     // It is not trivial to infer dest tile from source tile if `packOp` has
284f06563a5SYun-Fly     // padding semantic.
285f06563a5SYun-Fly     if (packOp.getPaddingValue())
286f06563a5SYun-Fly       return failure();
287f06563a5SYun-Fly 
288f06563a5SYun-Fly     Location loc = packOp.getLoc();
289f06563a5SYun-Fly 
290f06563a5SYun-Fly     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
291f06563a5SYun-Fly     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
292f06563a5SYun-Fly         packOp.getDimAndTileMapping();
293c8763f04SYun-Fly     for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
294f06563a5SYun-Fly       if (dimAndTileMapping.count(dim)) {
295f06563a5SYun-Fly         FailureOr<int64_t> cstSize =
296f06563a5SYun-Fly             ValueBoundsConstraintSet::computeConstantBound(
297f06563a5SYun-Fly                 presburger::BoundType::UB, sizes[dim],
298f06563a5SYun-Fly                 /*stopCondition=*/nullptr, /*closedUB=*/true);
299f06563a5SYun-Fly         std::optional<int64_t> cstInnerSize =
300f06563a5SYun-Fly             getConstantIntValue(dimAndTileMapping[dim]);
301f06563a5SYun-Fly         // Currently fusing `packOp` as consumer only expects perfect tiling
302f06563a5SYun-Fly         // scenario because even if without padding semantic, the `packOp` may
303f06563a5SYun-Fly         // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
304f06563a5SYun-Fly         // where the `tileSize` from operand of `packOp` is 5, which is not
305f06563a5SYun-Fly         // exactly divided by `innerTile`(=6) of `packOp`. As the result:
306f06563a5SYun-Fly         // 1. the first slice is extracted from (0) to (4) and inserted into
307f06563a5SYun-Fly         // (0,0)~(0,4) at first row.
308f06563a5SYun-Fly         // 2. the second slice is extracted from (5) to (9) and SHOULD BE
309f06563a5SYun-Fly         // respectively inserted into two rows with different length, including
310f06563a5SYun-Fly         // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
311f06563a5SYun-Fly         // them, thus adding below constraint to bypass them temporarily. In
312f06563a5SYun-Fly         // another word, we can only support tiling with consumer if the tile
313f06563a5SYun-Fly         // size for the producer is a multiple of the inner tile size for the
314f06563a5SYun-Fly         // packed dimensions at this moment.
315f06563a5SYun-Fly         if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
316f06563a5SYun-Fly           return failure();
317f06563a5SYun-Fly         }
318f06563a5SYun-Fly 
319f06563a5SYun-Fly         using AV = affine::AffineValueExpr;
320f06563a5SYun-Fly         affine::AffineBuilder ab(b, loc);
321f06563a5SYun-Fly         AffineExpr dim0, sym;
322f06563a5SYun-Fly         bindDims(b.getContext(), dim0);
323f06563a5SYun-Fly         bindSymbols(b.getContext(), sym);
324f06563a5SYun-Fly         auto avOffset = AV(dim0).bind(offsets[dim]);
325f06563a5SYun-Fly         auto avSize = AV(dim0).bind(sizes[dim]);
326f06563a5SYun-Fly         auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
327f06563a5SYun-Fly         outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
328f06563a5SYun-Fly         outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
329f06563a5SYun-Fly       } else {
330f06563a5SYun-Fly         outerDimOffsets.push_back(offsets[dim]);
331f06563a5SYun-Fly         outerDimSizes.push_back(sizes[dim]);
332f06563a5SYun-Fly       }
333f06563a5SYun-Fly     }
334c8763f04SYun-Fly     applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
335f06563a5SYun-Fly     resultOffsets = outerDimOffsets;
336f06563a5SYun-Fly     resultSizes = outerDimSizes;
337f06563a5SYun-Fly     return success();
338f06563a5SYun-Fly   }
339f06563a5SYun-Fly 
340f06563a5SYun-Fly   /// Method to return the tiled implementation of tensor.pack as a consumer.
341f06563a5SYun-Fly   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
342f06563a5SYun-Fly       Operation *op, OpBuilder &b, unsigned operandNumber,
343f06563a5SYun-Fly       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
344f06563a5SYun-Fly     if (operandNumber != 0)
345f06563a5SYun-Fly       return failure();
346f06563a5SYun-Fly 
347f06563a5SYun-Fly     auto packOp = cast<PackOp>(op);
348f06563a5SYun-Fly     Location loc = packOp.getLoc();
349f06563a5SYun-Fly 
350f06563a5SYun-Fly     int64_t inputRank = packOp.getSourceRank();
351f06563a5SYun-Fly     auto oneAttr = b.getI64IntegerAttr(1);
352f06563a5SYun-Fly     SmallVector<OpFoldResult> strides(inputRank, oneAttr);
353f06563a5SYun-Fly 
354f06563a5SYun-Fly     SmallVector<Value> tiledOperands;
355d5f0969cSMaheshRavishankar     auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(),
356d5f0969cSMaheshRavishankar                                                 offsets, sizes, strides);
357d5f0969cSMaheshRavishankar     tiledOperands.push_back(sourceSlice);
358f06563a5SYun-Fly 
359f06563a5SYun-Fly     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
360f06563a5SYun-Fly     if (failed(getIterationDomainTileFromOperandTile(
361f06563a5SYun-Fly             op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
362f06563a5SYun-Fly             outerDimSizes)))
363f06563a5SYun-Fly       return failure();
364f06563a5SYun-Fly 
365f06563a5SYun-Fly     SmallVector<OpFoldResult> outputOffsets, outputSizes;
366f06563a5SYun-Fly     if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
367f06563a5SYun-Fly                                      outputOffsets, outputSizes)))
368f06563a5SYun-Fly       return failure();
369f06563a5SYun-Fly 
370f06563a5SYun-Fly     strides.append(packOp.getDestRank() - inputRank, oneAttr);
371d5f0969cSMaheshRavishankar     auto outSlice = b.create<ExtractSliceOp>(
372f06563a5SYun-Fly         loc, packOp.getDest(), outputOffsets, outputSizes, strides);
373d5f0969cSMaheshRavishankar     tiledOperands.push_back(outSlice);
374f06563a5SYun-Fly 
375f06563a5SYun-Fly     assert(!packOp.getPaddingValue() && "Expect no padding semantic");
376f06563a5SYun-Fly     for (auto tile : packOp.getInnerTiles())
377f06563a5SYun-Fly       tiledOperands.push_back(tile);
378f06563a5SYun-Fly 
379f06563a5SYun-Fly     Operation *tiledPackOp = b.create<PackOp>(
380d5f0969cSMaheshRavishankar         loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs());
381f06563a5SYun-Fly 
382d5f0969cSMaheshRavishankar     return TilingResult{
383d5f0969cSMaheshRavishankar         {tiledPackOp},
384d5f0969cSMaheshRavishankar         SmallVector<Value>(tiledPackOp->getResults()),
385d5f0969cSMaheshRavishankar         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
386f06563a5SYun-Fly   }
3870d03ba62SHanhan Wang };
3880d03ba62SHanhan Wang 
38983396d85SHanhan Wang struct UnpackTileDimInfo {
39083396d85SHanhan Wang   bool isAlignedToInnerTileSize;
39183396d85SHanhan Wang   OpFoldResult sourceOffset;
39283396d85SHanhan Wang   OpFoldResult sourceSize;
39383396d85SHanhan Wang   OpFoldResult resultOffset;
39483396d85SHanhan Wang   OpFoldResult destExpandedSize;
39583396d85SHanhan Wang };
39683396d85SHanhan Wang 
39783396d85SHanhan Wang /// Returns the needed information for tiling unpack op on `tileDim` with given
39883396d85SHanhan Wang /// `tileOffset` and `tileSize`. For more details, see the comment of the
39983396d85SHanhan Wang /// `getTiledImplementation`.
40083396d85SHanhan Wang static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
40183396d85SHanhan Wang                                               int64_t tileDim,
40283396d85SHanhan Wang                                               OpFoldResult tileOffset,
40383396d85SHanhan Wang                                               OpFoldResult tileSize) {
40483396d85SHanhan Wang   UnpackTileDimInfo info;
40583396d85SHanhan Wang   Attribute zeroAttr = b.getIndexAttr(0);
40683396d85SHanhan Wang   Attribute oneAttr = b.getIndexAttr(1);
40783396d85SHanhan Wang   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
40883396d85SHanhan Wang       unpackOp.getDimAndTileMapping();
40983396d85SHanhan Wang   // The dimension is not one of packed data dimension.
41083396d85SHanhan Wang   if (!dimAndTileMapping.count(tileDim)) {
41183396d85SHanhan Wang     info.isAlignedToInnerTileSize = true;
41283396d85SHanhan Wang     info.sourceOffset = tileOffset;
41383396d85SHanhan Wang     info.sourceSize = tileSize;
41483396d85SHanhan Wang     info.resultOffset = zeroAttr;
41583396d85SHanhan Wang     info.destExpandedSize = tileSize;
41683396d85SHanhan Wang     return info;
41783396d85SHanhan Wang   }
41883396d85SHanhan Wang 
41983396d85SHanhan Wang   Location loc = unpackOp.getLoc();
4204c48f016SMatthias Springer   using AV = affine::AffineValueExpr;
4214c48f016SMatthias Springer   affine::AffineBuilder ab(b, loc);
42283396d85SHanhan Wang   AffineExpr dim0, dim1, sym0;
42383396d85SHanhan Wang   bindDims(b.getContext(), dim0, dim1);
42483396d85SHanhan Wang   bindSymbols(b.getContext(), sym0);
42583396d85SHanhan Wang 
42683396d85SHanhan Wang   OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
42783396d85SHanhan Wang 
42883396d85SHanhan Wang   info.isAlignedToInnerTileSize = false;
429eabb6ccdSMatthias Springer   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
43040dd3aa9SMatthias Springer       presburger::BoundType::UB, tileSize,
431eabb6ccdSMatthias Springer       /*stopCondition=*/nullptr, /*closedUB=*/true);
43222426110SRamkumar Ramachandra   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
43383396d85SHanhan Wang   if (!failed(cstSize) && cstInnerSize) {
434cbb09813SFangrui Song     if (*cstSize % *cstInnerSize == 0)
43583396d85SHanhan Wang       info.isAlignedToInnerTileSize = true;
43683396d85SHanhan Wang 
43783396d85SHanhan Wang     // If the tiling size equals to the inner tiling size, the outer dims are
43883396d85SHanhan Wang     // always 1.
439cbb09813SFangrui Song     if (*cstInnerSize == *cstSize) {
44083396d85SHanhan Wang       auto lhs = AV(dim0).bind(tileOffset);
44183396d85SHanhan Wang       auto rhs = AV(dim1).bind(innerTileSize);
44283396d85SHanhan Wang       info.sourceOffset = ab.floor(lhs, rhs);
44383396d85SHanhan Wang       info.sourceSize = oneAttr;
44483396d85SHanhan Wang       info.resultOffset = zeroAttr;
44583396d85SHanhan Wang       info.destExpandedSize = tileSize;
44683396d85SHanhan Wang       return info;
44783396d85SHanhan Wang     }
44883396d85SHanhan Wang   }
44983396d85SHanhan Wang 
45083396d85SHanhan Wang   if (info.isAlignedToInnerTileSize) {
45183396d85SHanhan Wang     info.sourceOffset =
45283396d85SHanhan Wang         ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
45383396d85SHanhan Wang     info.resultOffset = zeroAttr;
45483396d85SHanhan Wang     info.destExpandedSize = tileSize;
45583396d85SHanhan Wang 
45683396d85SHanhan Wang     // The ceilDiv is needed here because there could be incomplete tile even
45783396d85SHanhan Wang     // it is perfect tiling cases. E.g.,
45883396d85SHanhan Wang     //   %0 = unpack tensor<33x2xf32> into tensor<64xf32>
45983396d85SHanhan Wang     // If the tiling size is 32, there will be 3 tiles. Two of them have
46083396d85SHanhan Wang     // size=32; one of them have size=2. The size is represented using
46183396d85SHanhan Wang     // affine_min op; we need ceilDiv.
46283396d85SHanhan Wang     info.sourceSize =
46383396d85SHanhan Wang         ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
46483396d85SHanhan Wang     return info;
46583396d85SHanhan Wang   }
46683396d85SHanhan Wang 
4674c48f016SMatthias Springer   affine::DivModValue firstCoord = affine::getDivMod(
4684c48f016SMatthias Springer       b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset),
46983396d85SHanhan Wang       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
47083396d85SHanhan Wang   OpFoldResult tileExclusiveBound =
47183396d85SHanhan Wang       ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
4724c48f016SMatthias Springer   affine::DivModValue lastCoord = affine::getDivMod(
47383396d85SHanhan Wang       b, loc,
47483396d85SHanhan Wang       getValueOrCreateConstantIndexOp(
47583396d85SHanhan Wang           b, loc,
47683396d85SHanhan Wang           ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
47783396d85SHanhan Wang       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
47883396d85SHanhan Wang 
47983396d85SHanhan Wang   OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
48083396d85SHanhan Wang                                        AV(dim1).bind(firstCoord.quotient));
48183396d85SHanhan Wang   info.sourceSize =
48283396d85SHanhan Wang       ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
48383396d85SHanhan Wang   info.sourceOffset = firstCoord.quotient;
48483396d85SHanhan Wang   info.resultOffset = firstCoord.remainder;
4855fa9933cSHanhan Wang   // Do not create an Affine ops for expanded size because the affine op is too
4865fa9933cSHanhan Wang   // complicated which would trigger an issue in affine ops simplification.
4875fa9933cSHanhan Wang   info.destExpandedSize = b.createOrFold<arith::MulIOp>(
4885fa9933cSHanhan Wang       loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
4895fa9933cSHanhan Wang       getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
49083396d85SHanhan Wang   return info;
49183396d85SHanhan Wang }
49283396d85SHanhan Wang 
49383396d85SHanhan Wang struct UnPackOpTiling
49483396d85SHanhan Wang     : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
49583396d85SHanhan Wang 
49683396d85SHanhan Wang   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
49783396d85SHanhan Wang     auto unpackOp = cast<UnPackOp>(op);
49883396d85SHanhan Wang     SmallVector<utils::IteratorType> iteratorTypes(
49983396d85SHanhan Wang         unpackOp.getDestRank(), utils::IteratorType::parallel);
50083396d85SHanhan Wang     return iteratorTypes;
50183396d85SHanhan Wang   }
50283396d85SHanhan Wang 
50383396d85SHanhan Wang   SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
50483396d85SHanhan Wang     return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
50583396d85SHanhan Wang   }
50683396d85SHanhan Wang 
50783396d85SHanhan Wang   /// There are two cases in tiling unpack ops. If the tiling size is aligned to
50883396d85SHanhan Wang   /// the inner tile size, the corresponding tiles of source are all complete.
50983396d85SHanhan Wang   /// Otherwise, there are in-complete tiles. We will need to expand the slice
51083396d85SHanhan Wang   /// of source for getting complete tiles. The tiled unpack op unpacks more
51183396d85SHanhan Wang   /// data from source, so We'll need an extract_slice op to shift and truncate
51283396d85SHanhan Wang   /// the output.
51383396d85SHanhan Wang   /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
51483396d85SHanhan Wang   /// coordinates of second tile (i.e., result[15..31]) are
51583396d85SHanhan Wang   /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
51683396d85SHanhan Wang   /// row are incomplete tiles. To represent the unpack op, we have to complete
51783396d85SHanhan Wang   /// the rows. I.e., the input coordinates would start with (1, 0); end with
51883396d85SHanhan Wang   /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
51983396d85SHanhan Wang   /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
52083396d85SHanhan Wang   /// can get the actual result.
521809e3d8cSMahesh Ravishankar   FailureOr<TilingResult>
52283396d85SHanhan Wang   getTiledImplementation(Operation *op, OpBuilder &b,
52383396d85SHanhan Wang                          ArrayRef<OpFoldResult> offsets,
52483396d85SHanhan Wang                          ArrayRef<OpFoldResult> sizes) const {
52583396d85SHanhan Wang     auto unpackOp = cast<UnPackOp>(op);
52683396d85SHanhan Wang     int64_t srcRank = unpackOp.getSourceRank();
52783396d85SHanhan Wang     int64_t destRank = unpackOp.getDestRank();
52883396d85SHanhan Wang     int64_t numInnerTiles = srcRank - destRank;
52983396d85SHanhan Wang     Location loc = unpackOp.getLoc();
53083396d85SHanhan Wang 
53183396d85SHanhan Wang     // The perfect tiling case indicates that the tiling sizes are multiple of
53283396d85SHanhan Wang     // inner_tile_size. In this context, no extra data is needed when
53383396d85SHanhan Wang     // representing the tiled unpack op.
53483396d85SHanhan Wang     bool isPerfectTilingCase = true;
53583396d85SHanhan Wang     Attribute oneAttr = b.getIndexAttr(1);
53683396d85SHanhan Wang     SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
53783396d85SHanhan Wang     SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
53883396d85SHanhan Wang     SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
53983396d85SHanhan Wang     for (auto dim : llvm::seq<int64_t>(0, destRank)) {
54083396d85SHanhan Wang       UnpackTileDimInfo info =
54183396d85SHanhan Wang           getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
54283396d85SHanhan Wang       if (!info.isAlignedToInnerTileSize)
54383396d85SHanhan Wang         isPerfectTilingCase = false;
54483396d85SHanhan Wang       sliceSrcIndices.push_back(info.sourceOffset);
54583396d85SHanhan Wang       sliceSrcSizes.push_back(info.sourceSize);
54683396d85SHanhan Wang       destExpandedSizes.push_back(info.destExpandedSize);
54783396d85SHanhan Wang       resultOffsetsFromDest.push_back(info.resultOffset);
54883396d85SHanhan Wang     }
54983396d85SHanhan Wang 
55083396d85SHanhan Wang     // The tiling is applied on destination dimensions. We have to apply the
55183396d85SHanhan Wang     // interchange on source dimensions if outer_dims_perm is set.
552d5a9fc13SLorenzo Chelini     applyPermToRange(sliceSrcIndices, sliceSrcSizes,
55383396d85SHanhan Wang                      unpackOp.getOuterDimsPerm());
55483396d85SHanhan Wang     Attribute zeroAttr = b.getIndexAttr(0);
55583396d85SHanhan Wang     sliceSrcIndices.append(numInnerTiles, zeroAttr);
55683396d85SHanhan Wang     sliceSrcSizes.append(unpackOp.getMixedTiles());
55783396d85SHanhan Wang     sliceSrcStrides.append(numInnerTiles, oneAttr);
558f1595ecfSMax191     SmallVector<Operation *> generatedSlices;
559f1595ecfSMax191     ExtractSliceOp sliceSource =
56083396d85SHanhan Wang         b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
56183396d85SHanhan Wang                                  sliceSrcSizes, sliceSrcStrides);
562f1595ecfSMax191     generatedSlices.push_back(sliceSource);
56383396d85SHanhan Wang 
56483396d85SHanhan Wang     SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
56583396d85SHanhan Wang     Value sliceDest;
56683396d85SHanhan Wang     if (isPerfectTilingCase) {
567d5f0969cSMaheshRavishankar       auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(),
568d5f0969cSMaheshRavishankar                                                   offsets, sizes, destStrides);
569d5f0969cSMaheshRavishankar       sliceDest = destSliceOp;
570d5f0969cSMaheshRavishankar       generatedSlices.push_back(destSliceOp);
57183396d85SHanhan Wang     } else {
57283396d85SHanhan Wang       sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
57383396d85SHanhan Wang                                     unpackOp.getDestType().getElementType());
57483396d85SHanhan Wang     }
57583396d85SHanhan Wang 
576f1595ecfSMax191     SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
577be75cf93SHanhan Wang     for (auto tile : unpackOp.getInnerTiles())
578be75cf93SHanhan Wang       tiledOperands.push_back(tile);
579be75cf93SHanhan Wang 
580be75cf93SHanhan Wang     Operation *tiledUnpackOp = b.create<UnPackOp>(
581be75cf93SHanhan Wang         loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
58283396d85SHanhan Wang 
58383396d85SHanhan Wang     if (isPerfectTilingCase)
584809e3d8cSMahesh Ravishankar       return TilingResult{{tiledUnpackOp},
585d5f0969cSMaheshRavishankar                           SmallVector<Value>(tiledUnpackOp->getResults()),
586d5f0969cSMaheshRavishankar                           generatedSlices};
58783396d85SHanhan Wang 
588809e3d8cSMahesh Ravishankar     auto extractSlice =
58983396d85SHanhan Wang         b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
59083396d85SHanhan Wang                                  resultOffsetsFromDest, sizes, destStrides);
591d5f0969cSMaheshRavishankar     return TilingResult{
592d5f0969cSMaheshRavishankar         {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
59383396d85SHanhan Wang   }
59483396d85SHanhan Wang 
59583396d85SHanhan Wang   LogicalResult
59683396d85SHanhan Wang   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
59783396d85SHanhan Wang                         ArrayRef<OpFoldResult> offsets,
59883396d85SHanhan Wang                         ArrayRef<OpFoldResult> sizes,
599f220c359SOleksandr "Alex" Zinenko                         SmallVector<OpFoldResult> &resultOffsets,
600f220c359SOleksandr "Alex" Zinenko                         SmallVector<OpFoldResult> &resultSizes) const {
60183396d85SHanhan Wang     resultOffsets = llvm::to_vector(offsets);
60283396d85SHanhan Wang     resultSizes = llvm::to_vector(sizes);
60383396d85SHanhan Wang     return success();
60483396d85SHanhan Wang   }
605ead535b2SHanhan Wang 
606809e3d8cSMahesh Ravishankar   FailureOr<TilingResult>
607809e3d8cSMahesh Ravishankar   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
608ead535b2SHanhan Wang                           ArrayRef<OpFoldResult> offsets,
609ead535b2SHanhan Wang                           ArrayRef<OpFoldResult> sizes) const {
610809e3d8cSMahesh Ravishankar     FailureOr<TilingResult> tilingResult =
611809e3d8cSMahesh Ravishankar         getTiledImplementation(op, b, offsets, sizes);
612809e3d8cSMahesh Ravishankar     if (failed(tilingResult))
613809e3d8cSMahesh Ravishankar       return failure();
614809e3d8cSMahesh Ravishankar     return tilingResult.value();
615ead535b2SHanhan Wang   }
6162b2ce50fSAbhishek Varma 
6172b2ce50fSAbhishek Varma   /// Method to return the position of iteration domain tile computed by the
6182b2ce50fSAbhishek Varma   /// tiled operation.
6192b2ce50fSAbhishek Varma   LogicalResult getIterationDomainTileFromOperandTile(
6202b2ce50fSAbhishek Varma       Operation *op, OpBuilder &b, unsigned operandNumber,
6212b2ce50fSAbhishek Varma       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
6222b2ce50fSAbhishek Varma       SmallVectorImpl<OpFoldResult> &resultOffsets,
6232b2ce50fSAbhishek Varma       SmallVectorImpl<OpFoldResult> &resultSizes) const {
6242b2ce50fSAbhishek Varma     auto unPackOp = cast<UnPackOp>(op);
6258cc616bcSMax191     // If the operand tile is the dest, then no adjustment is needed.
6268cc616bcSMax191     if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
6278cc616bcSMax191       resultOffsets = llvm::to_vector(offsets);
6288cc616bcSMax191       resultSizes = llvm::to_vector(sizes);
6298cc616bcSMax191       return success();
6308cc616bcSMax191     }
6312b2ce50fSAbhishek Varma     Location loc = unPackOp.getLoc();
6322b2ce50fSAbhishek Varma 
6332b2ce50fSAbhishek Varma     int64_t numTiles = unPackOp.getInnerDimsPos().size();
6342b2ce50fSAbhishek Varma     auto destOffsets = offsets.drop_back(numTiles);
6352b2ce50fSAbhishek Varma     auto destSizes = sizes.drop_back(numTiles);
6362b2ce50fSAbhishek Varma     // The tiling is applied on interchanged dimensions. We have to undo the
6372b2ce50fSAbhishek Varma     // interchange to map sizes and offsets to the original input.
6382b2ce50fSAbhishek Varma     int64_t outputRank = unPackOp.getDestRank();
6398cc616bcSMax191     ReifiedRankedShapedTypeDims reifiedReturnShapes;
6408cc616bcSMax191     if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
6418cc616bcSMax191       return failure();
6428cc616bcSMax191     SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
6435262865aSKazu Hirata     SmallVector<OpFoldResult> origOffsets(destOffsets);
6445262865aSKazu Hirata     SmallVector<OpFoldResult> origSizes(destSizes);
6452b2ce50fSAbhishek Varma     applyPermToRange(origOffsets, origSizes,
6462b2ce50fSAbhishek Varma                      invertPermutationVector(unPackOp.getOuterDimsPerm()));
6472b2ce50fSAbhishek Varma 
6482b2ce50fSAbhishek Varma     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
6492b2ce50fSAbhishek Varma         unPackOp.getDimAndTileMapping();
6502b2ce50fSAbhishek Varma 
6512b2ce50fSAbhishek Varma     for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
6522b2ce50fSAbhishek Varma       using AV = affine::AffineValueExpr;
6532b2ce50fSAbhishek Varma       affine::AffineBuilder ab(b, loc);
6548cc616bcSMax191       AffineExpr dim0, dim1, sym0;
6552b2ce50fSAbhishek Varma       bindDims(b.getContext(), dim0, dim1);
6568cc616bcSMax191       bindSymbols(b.getContext(), sym0);
6572b2ce50fSAbhishek Varma       if (dimAndTileMapping.count(dim)) {
6582b2ce50fSAbhishek Varma         // If the data dimension is tiled, the i-th index is the product of
6592b2ce50fSAbhishek Varma         // offset_i and tile_i, and the i-th size is the product of sizes_i and
6608cc616bcSMax191         // tile_i. The sizes must be clamped to the sizes of the unpack result.
6612b2ce50fSAbhishek Varma         auto avOffset = AV(dim0).bind(origOffsets[dim]);
6622b2ce50fSAbhishek Varma         auto avSize = AV(dim0).bind(origSizes[dim]);
6638cc616bcSMax191         auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
6648cc616bcSMax191         auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
6652b2ce50fSAbhishek Varma         resultOffsets.push_back(ab.mul(avOffset, avTileSize));
6668cc616bcSMax191         auto avResultOffset = AV(dim1).bind(resultOffsets.back());
6678cc616bcSMax191         resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
6688cc616bcSMax191                                       ab.sub(avResultSize, avResultOffset)}));
6692b2ce50fSAbhishek Varma       } else {
6702b2ce50fSAbhishek Varma         resultOffsets.push_back(origOffsets[dim]);
6712b2ce50fSAbhishek Varma         resultSizes.push_back(origSizes[dim]);
6722b2ce50fSAbhishek Varma       }
6732b2ce50fSAbhishek Varma     }
6742b2ce50fSAbhishek Varma     return success();
6752b2ce50fSAbhishek Varma   }
6762b2ce50fSAbhishek Varma 
6772b2ce50fSAbhishek Varma   /// Method to return the tiled implementation of tensor.unpack as a consumer.
6782b2ce50fSAbhishek Varma   FailureOr<TilingResult> getTiledImplementationFromOperandTile(
6792b2ce50fSAbhishek Varma       Operation *op, OpBuilder &b, unsigned operandNumber,
6802b2ce50fSAbhishek Varma       ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
6812b2ce50fSAbhishek Varma     auto unPackOp = cast<UnPackOp>(op);
6822b2ce50fSAbhishek Varma     // tensor.unpack op is fusible (as a consumer) only if inner dims are not
6832b2ce50fSAbhishek Varma     // tiled.
6842b2ce50fSAbhishek Varma     int64_t numTiles = unPackOp.getInnerDimsPos().size();
6852b2ce50fSAbhishek Varma     for (auto iter :
6862b2ce50fSAbhishek Varma          llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
6872b2ce50fSAbhishek Varma       if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
6882b2ce50fSAbhishek Varma         return failure();
6892b2ce50fSAbhishek Varma     }
6902b2ce50fSAbhishek Varma 
6912b2ce50fSAbhishek Varma     Location loc = unPackOp.getLoc();
6922b2ce50fSAbhishek Varma 
6932b2ce50fSAbhishek Varma     // Fetch offset/size for creating the slice of the dest operand of
6942b2ce50fSAbhishek Varma     // unpack op.
6952b2ce50fSAbhishek Varma     SmallVector<OpFoldResult> outputOffsets, outputSizes;
6962b2ce50fSAbhishek Varma     if (failed(getIterationDomainTileFromOperandTile(
6972b2ce50fSAbhishek Varma             op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets,
6982b2ce50fSAbhishek Varma             outputSizes)))
6992b2ce50fSAbhishek Varma       return failure();
7002b2ce50fSAbhishek Varma 
7012b2ce50fSAbhishek Varma     auto oneAttr = b.getI64IntegerAttr(1);
7022b2ce50fSAbhishek Varma     int64_t outputRank = unPackOp.getDestRank();
7032b2ce50fSAbhishek Varma     SmallVector<OpFoldResult> strides(outputRank, oneAttr);
7042b2ce50fSAbhishek Varma 
7052b2ce50fSAbhishek Varma     SmallVector<Value> tiledOperands;
7062b2ce50fSAbhishek Varma     // Create slice of the dest operand.
7072b2ce50fSAbhishek Varma     auto extractDestSlice = b.create<ExtractSliceOp>(
7082b2ce50fSAbhishek Varma         loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
7092b2ce50fSAbhishek Varma     tiledOperands.push_back(extractDestSlice);
7102b2ce50fSAbhishek Varma 
7112b2ce50fSAbhishek Varma     SmallVector<OpFoldResult> inputOffsets, inputSizes;
7122b2ce50fSAbhishek Varma     strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
7132b2ce50fSAbhishek Varma     // Create slice of the source operand.
7142b2ce50fSAbhishek Varma     auto extractSourceSlice = b.create<ExtractSliceOp>(
7152b2ce50fSAbhishek Varma         loc, unPackOp.getSource(), offsets, sizes, strides);
7162b2ce50fSAbhishek Varma     tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
7172b2ce50fSAbhishek Varma     for (auto tile : unPackOp.getInnerTiles())
7182b2ce50fSAbhishek Varma       tiledOperands.push_back(tile);
7192b2ce50fSAbhishek Varma 
7202b2ce50fSAbhishek Varma     // Create tiled unpack op.
7212b2ce50fSAbhishek Varma     Operation *tiledUnPackOp =
7222b2ce50fSAbhishek Varma         b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()},
7232b2ce50fSAbhishek Varma                            tiledOperands, op->getAttrs());
7242b2ce50fSAbhishek Varma 
7252b2ce50fSAbhishek Varma     return TilingResult{{tiledUnPackOp},
726d5f0969cSMaheshRavishankar                         SmallVector<Value>(tiledUnPackOp->getResults()),
727d5f0969cSMaheshRavishankar                         llvm::to_vector(ArrayRef<Operation *>{
728d5f0969cSMaheshRavishankar                             extractSourceSlice, extractDestSlice})};
7292b2ce50fSAbhishek Varma   }
73083396d85SHanhan Wang };
73183396d85SHanhan Wang 
7320edb4127SLei Zhang } // namespace
7330edb4127SLei Zhang 
734809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
735809e3d8cSMahesh Ravishankar                                                  tensor::PadOp padOp,
7360edb4127SLei Zhang                                                  ArrayRef<OpFoldResult> offsets,
7370edb4127SLei Zhang                                                  ArrayRef<OpFoldResult> sizes,
7380edb4127SLei Zhang                                                  bool generateZeroSliceGuard) {
739fd0c6f53SAlexander Belyaev   // Only constant padding value supported.
740fd0c6f53SAlexander Belyaev   Value padValue = padOp.getConstantPaddingValue();
741fd0c6f53SAlexander Belyaev   if (!padValue)
742809e3d8cSMahesh Ravishankar     return failure();
743fd0c6f53SAlexander Belyaev 
744fd0c6f53SAlexander Belyaev   // Helper variables and functions for various arithmetic operations. These
745fd0c6f53SAlexander Belyaev   // are used extensively for computing new offset/length and padding values.
7460edb4127SLei Zhang   Location loc = padOp->getLoc();
747fd0c6f53SAlexander Belyaev   AffineExpr dim0, dim1;
748fd0c6f53SAlexander Belyaev   bindDims(b.getContext(), dim0, dim1);
749fd0c6f53SAlexander Belyaev   // Subtract two integers.
750fd0c6f53SAlexander Belyaev   auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
751d3fa067eSMahesh Ravishankar   auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
7524c48f016SMatthias Springer     return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
753fd0c6f53SAlexander Belyaev   };
754fd0c6f53SAlexander Belyaev   // Take the minimum of two integers.
755fd0c6f53SAlexander Belyaev   auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
756d3fa067eSMahesh Ravishankar   auto min = [&](OpFoldResult v1, OpFoldResult v2) {
7574c48f016SMatthias Springer     return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
758fd0c6f53SAlexander Belyaev   };
759fd0c6f53SAlexander Belyaev   // Take the maximum of two integers.
760d3fa067eSMahesh Ravishankar   auto max = [&](OpFoldResult v1, OpFoldResult v2) {
7614c48f016SMatthias Springer     return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
762fd0c6f53SAlexander Belyaev   };
763fd0c6f53SAlexander Belyaev   // Zero index-typed integer.
764d3fa067eSMahesh Ravishankar   OpFoldResult zero = b.getIndexAttr(0);
765fd0c6f53SAlexander Belyaev 
766fd0c6f53SAlexander Belyaev   // Compute new offsets, lengths, low padding, high padding.
767fd0c6f53SAlexander Belyaev   SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
768d3fa067eSMahesh Ravishankar   SmallVector<OpFoldResult> newLows, newHighs;
769fd0c6f53SAlexander Belyaev   // Set to true if the original data source is not read at all.
770fd0c6f53SAlexander Belyaev   bool hasZeroLen = false;
771fd0c6f53SAlexander Belyaev   // Same as hasZeroLen, but for dynamic dimension sizes. This condition
772fd0c6f53SAlexander Belyaev   // is true if the original data source turns out to be unused at runtime.
773fd0c6f53SAlexander Belyaev   Value dynHasZeroLenCond;
774fd0c6f53SAlexander Belyaev 
775fd0c6f53SAlexander Belyaev   int64_t rank = padOp.getSourceType().getRank();
776fd0c6f53SAlexander Belyaev   for (unsigned dim = 0; dim < rank; ++dim) {
777d3fa067eSMahesh Ravishankar     auto low = padOp.getMixedLowPad()[dim];
778d3fa067eSMahesh Ravishankar     bool hasLowPad = !isConstantIntValue(low, 0);
779d3fa067eSMahesh Ravishankar     auto high = padOp.getMixedHighPad()[dim];
780d3fa067eSMahesh Ravishankar     bool hasHighPad = !isConstantIntValue(high, 0);
781d3fa067eSMahesh Ravishankar     auto offset = offsets[dim];
782d3fa067eSMahesh Ravishankar     auto length = sizes[dim];
7836596b0ddSMatthias Springer     auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim);
784fd0c6f53SAlexander Belyaev 
785fd0c6f53SAlexander Belyaev     // The new amount of low padding is `low - offset`. Except for the case
786fd0c6f53SAlexander Belyaev     // where none of the low padding is read. In that case, the new amount of
787fd0c6f53SAlexander Belyaev     // low padding is zero.
788fd0c6f53SAlexander Belyaev     //
789fd0c6f53SAlexander Belyaev     // Optimization: If low = 0, then newLow = 0.
790d3fa067eSMahesh Ravishankar     OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
791d3fa067eSMahesh Ravishankar     newLows.push_back(newLow);
792fd0c6f53SAlexander Belyaev 
793fd0c6f53SAlexander Belyaev     // Start reading the data from position `offset - low`. Since the original
794fd0c6f53SAlexander Belyaev     // read may have started in the low padding zone, this value could be
795fd0c6f53SAlexander Belyaev     // negative. Therefore, start reading from:
796fd0c6f53SAlexander Belyaev     //
797fd0c6f53SAlexander Belyaev     // max(offset - low, 0)
798fd0c6f53SAlexander Belyaev     //
799fd0c6f53SAlexander Belyaev     // The original read could also have started in the high padding zone.
800fd0c6f53SAlexander Belyaev     // In that case, set the offset to the end of source tensor. The new
801fd0c6f53SAlexander Belyaev     // ExtractSliceOp length will be zero in that case. (Effectively reading
802fd0c6f53SAlexander Belyaev     // no data from the source.)
803fd0c6f53SAlexander Belyaev     //
804fd0c6f53SAlexander Belyaev     // Optimization: If low = 0, then the formula can be simplified.
805d3fa067eSMahesh Ravishankar     OpFoldResult newOffset = hasLowPad
806d3fa067eSMahesh Ravishankar                                  ? min(max(sub(offset, low), zero), srcSize)
807fd0c6f53SAlexander Belyaev                                  : min(offset, srcSize);
808d3fa067eSMahesh Ravishankar     newOffsets.push_back(newOffset);
809fd0c6f53SAlexander Belyaev 
810fd0c6f53SAlexander Belyaev     // The original ExtractSliceOp was reading until position `offset +
811fd0c6f53SAlexander Belyaev     // length`. Therefore, the corresponding position within the source tensor
812fd0c6f53SAlexander Belyaev     // is:
813fd0c6f53SAlexander Belyaev     //
814fd0c6f53SAlexander Belyaev     // offset + length - low
815fd0c6f53SAlexander Belyaev     //
816fd0c6f53SAlexander Belyaev     // In case the original ExtractSliceOp stopped reading within the low
817fd0c6f53SAlexander Belyaev     // padding zone, this value can be negative. In that case, the end
818fd0c6f53SAlexander Belyaev     // position of the read should be zero. (Similar to newOffset.)
819fd0c6f53SAlexander Belyaev     //
820fd0c6f53SAlexander Belyaev     // The original read could also have stopped in the high padding zone.
821fd0c6f53SAlexander Belyaev     // In that case, set the end positition of the read should be the end of
822fd0c6f53SAlexander Belyaev     // the source tensor. (Similar to newOffset.)
8233f136f7dSNirvedh Meshram     // srcSize - newOffset represents how much length we have available
8243f136f7dSNirvedh Meshram     // and length - newLow represents how much length we want at most.
82577400103SNirvedh Meshram     // Note that there are many ways to order this indexing math to compute
82677400103SNirvedh Meshram     // newLength, but we want to make sure that the final affine.min ops in the
82777400103SNirvedh Meshram     // sequence are bounding the index to as small a value as possible. If
828*33927744SNirvedh     // ValueBoundsOpInterface is used, this calculation will get upper bounds
82977400103SNirvedh Meshram     // from the affine.min ops, so we want to use the smallest known value to
83077400103SNirvedh Meshram     // set the bound at the end of the computation sequence. In this case, the
83177400103SNirvedh Meshram     // index will be upper bounded by length - newLow.
8323f136f7dSNirvedh Meshram     OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow));
8333f136f7dSNirvedh Meshram     // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming
8343f136f7dSNirvedh Meshram     // length >= 0.
8353f136f7dSNirvedh Meshram     if (hasLowPad)
8363f136f7dSNirvedh Meshram       newLength = max(newLength, zero);
837d3fa067eSMahesh Ravishankar     newLengths.push_back(newLength);
838fd0c6f53SAlexander Belyaev 
839fd0c6f53SAlexander Belyaev     // Check if newLength is zero. In that case, no SubTensorOp should be
840fd0c6f53SAlexander Belyaev     // executed.
841d3fa067eSMahesh Ravishankar     if (isConstantIntValue(newLength, 0)) {
842d3fa067eSMahesh Ravishankar       hasZeroLen = true;
843d3fa067eSMahesh Ravishankar     } else if (!hasZeroLen) {
844d3fa067eSMahesh Ravishankar       Value check = b.create<arith::CmpIOp>(
845d3fa067eSMahesh Ravishankar           loc, arith::CmpIPredicate::eq,
846d3fa067eSMahesh Ravishankar           getValueOrCreateConstantIndexOp(b, loc, newLength),
847d3fa067eSMahesh Ravishankar           getValueOrCreateConstantIndexOp(b, loc, zero));
848fd0c6f53SAlexander Belyaev       dynHasZeroLenCond =
849fd0c6f53SAlexander Belyaev           dynHasZeroLenCond
850fd0c6f53SAlexander Belyaev               ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
851fd0c6f53SAlexander Belyaev               : check;
852fd0c6f53SAlexander Belyaev     }
853fd0c6f53SAlexander Belyaev 
854fd0c6f53SAlexander Belyaev     // The amount of high padding is simply the number of elements remaining,
855fd0c6f53SAlexander Belyaev     // so that the result has the same length as the original ExtractSliceOp.
856fd0c6f53SAlexander Belyaev     // As an optimization, if the original high padding is zero, then the new
857fd0c6f53SAlexander Belyaev     // high padding must also be zero.
858d3fa067eSMahesh Ravishankar     OpFoldResult newHigh =
859d3fa067eSMahesh Ravishankar         hasHighPad ? sub(sub(length, newLength), newLow) : zero;
860d3fa067eSMahesh Ravishankar     newHighs.push_back(newHigh);
861fd0c6f53SAlexander Belyaev 
862fd0c6f53SAlexander Belyaev     // Only unit stride supported.
863fd0c6f53SAlexander Belyaev     newStrides.push_back(b.getIndexAttr(1));
864fd0c6f53SAlexander Belyaev   }
865fd0c6f53SAlexander Belyaev 
866fd0c6f53SAlexander Belyaev   // The shape of the result can be obtained from the sizes passed in.
867fd0c6f53SAlexander Belyaev   SmallVector<Value> dynDims;
868fd0c6f53SAlexander Belyaev   SmallVector<int64_t> shape;
869ded75a28SAliia Khasanova   dispatchIndexOpFoldResults(sizes, dynDims, shape);
870fd0c6f53SAlexander Belyaev   RankedTensorType resultType =
871fd0c6f53SAlexander Belyaev       RankedTensorType::get(shape, padOp.getResultType().getElementType());
872fd0c6f53SAlexander Belyaev 
873fd0c6f53SAlexander Belyaev   // Insert cast to ensure that types match. (May be folded away.)
874809e3d8cSMahesh Ravishankar   auto castResult = [&](Value val) -> Value {
875d3fa067eSMahesh Ravishankar     if (resultType == val.getType())
876809e3d8cSMahesh Ravishankar       return val;
8770edb4127SLei Zhang     return b.create<tensor::CastOp>(loc, resultType, val);
878fd0c6f53SAlexander Belyaev   };
879fd0c6f53SAlexander Belyaev 
880fd0c6f53SAlexander Belyaev   // In cases where the original data source is unused: Emit a GenerateOp and
881fd0c6f53SAlexander Belyaev   // do not generate a SliceOp. (The result shape of the SliceOp would
882fd0c6f53SAlexander Belyaev   // have a dimension of size 0, the semantics of which is unclear.)
883fd0c6f53SAlexander Belyaev   auto createGenerateOp = [&]() {
884fd0c6f53SAlexander Belyaev     // Create GenerateOp.
885fd0c6f53SAlexander Belyaev     auto generateOp = b.create<tensor::GenerateOp>(
886fd0c6f53SAlexander Belyaev         loc, resultType, dynDims,
887fd0c6f53SAlexander Belyaev         [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
888fd0c6f53SAlexander Belyaev           builder.create<tensor::YieldOp>(gLoc, padValue);
889fd0c6f53SAlexander Belyaev         });
890809e3d8cSMahesh Ravishankar     return generateOp;
891fd0c6f53SAlexander Belyaev   };
892fd0c6f53SAlexander Belyaev 
893fd0c6f53SAlexander Belyaev   // Emit a SliceOp and a PadOp. Should not be used in cases where
894fd0c6f53SAlexander Belyaev   // the result shape of the new SliceOp has a zero dimension.
8950edb4127SLei Zhang   auto createPadOfExtractSlice = [&]() {
8960edb4127SLei Zhang     // Create pad(extract_slice(x)).
897d5f0969cSMaheshRavishankar     auto newSliceOp = b.create<tensor::ExtractSliceOp>(
89804235d07SJacques Pienaar         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
899b0674405SMahesh Ravishankar     auto newPadOp = b.create<PadOp>(
900b0674405SMahesh Ravishankar         loc, Type(), newSliceOp, newLows, newHighs,
901b0674405SMahesh Ravishankar         /*nofold=*/padOp.getNofold(),
902b0674405SMahesh Ravishankar         getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
903fd0c6f53SAlexander Belyaev 
904fd0c6f53SAlexander Belyaev     // Copy region to new PadOp.
9054d67b278SJeff Niu     IRMapping bvm;
90604235d07SJacques Pienaar     padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
907fd0c6f53SAlexander Belyaev 
908fd0c6f53SAlexander Belyaev     // Cast result and return.
909d5f0969cSMaheshRavishankar     return std::make_tuple(newPadOp, newSliceOp);
910fd0c6f53SAlexander Belyaev   };
911fd0c6f53SAlexander Belyaev 
9120edb4127SLei Zhang   // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
9130edb4127SLei Zhang   // the original data source x is not used.
914809e3d8cSMahesh Ravishankar   if (hasZeroLen) {
915809e3d8cSMahesh Ravishankar     Operation *generateOp = createGenerateOp();
916d5f0969cSMaheshRavishankar     return TilingResult{{generateOp},
917d5f0969cSMaheshRavishankar                         {castResult(generateOp->getResult(0))},
918d5f0969cSMaheshRavishankar                         /*generatedSlices=*/{}};
919809e3d8cSMahesh Ravishankar   }
920fd0c6f53SAlexander Belyaev 
921fd0c6f53SAlexander Belyaev   // If there are dynamic dimensions: Generate an scf.if check to avoid
922fd0c6f53SAlexander Belyaev   // creating SliceOps with result dimensions of size 0 at runtime.
9230edb4127SLei Zhang   if (generateZeroSliceGuard && dynHasZeroLenCond) {
924809e3d8cSMahesh Ravishankar     Operation *thenOp;
925809e3d8cSMahesh Ravishankar     Operation *elseOp;
926d5f0969cSMaheshRavishankar     Operation *sliceOp;
927fd0c6f53SAlexander Belyaev     auto result = b.create<scf::IfOp>(
9281125c5c0SFrederik Gossen         loc, dynHasZeroLenCond,
929fd0c6f53SAlexander Belyaev         /*thenBuilder=*/
930fd0c6f53SAlexander Belyaev         [&](OpBuilder &b, Location loc) {
931809e3d8cSMahesh Ravishankar           thenOp = createGenerateOp();
932809e3d8cSMahesh Ravishankar           b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
933fd0c6f53SAlexander Belyaev         },
934fd0c6f53SAlexander Belyaev         /*elseBuilder=*/
935fd0c6f53SAlexander Belyaev         [&](OpBuilder &b, Location loc) {
936d5f0969cSMaheshRavishankar           std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
937809e3d8cSMahesh Ravishankar           b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
938fd0c6f53SAlexander Belyaev         });
939d5f0969cSMaheshRavishankar     return TilingResult{
940d5f0969cSMaheshRavishankar         {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}};
941fd0c6f53SAlexander Belyaev   }
942809e3d8cSMahesh Ravishankar 
943d5f0969cSMaheshRavishankar   auto [newPadOp, sliceOp] = createPadOfExtractSlice();
944d5f0969cSMaheshRavishankar   return TilingResult{
945d5f0969cSMaheshRavishankar       {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
946fd0c6f53SAlexander Belyaev }
947fd0c6f53SAlexander Belyaev 
948a235562cSMahesh Ravishankar void mlir::tensor::registerTilingInterfaceExternalModels(
949fd0c6f53SAlexander Belyaev     DialectRegistry &registry) {
95077eee579SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
95177eee579SRiver Riddle     tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
9520d03ba62SHanhan Wang     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
95383396d85SHanhan Wang     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
95477eee579SRiver Riddle   });
955fd0c6f53SAlexander Belyaev }
956adcb9888SHanhan Wang 
957adcb9888SHanhan Wang void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
958adcb9888SHanhan Wang     DialectRegistry &registry) {
959adcb9888SHanhan Wang   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
960adcb9888SHanhan Wang     tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
961adcb9888SHanhan Wang     tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
962adcb9888SHanhan Wang   });
963adcb9888SHanhan Wang }
964