1fd0c6f53SAlexander Belyaev //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// 2fd0c6f53SAlexander Belyaev // 3fd0c6f53SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fd0c6f53SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information. 5fd0c6f53SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fd0c6f53SAlexander Belyaev // 7fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===// 8fd0c6f53SAlexander Belyaev 9fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 10fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Affine/IR/AffineOps.h" 110d03ba62SHanhan Wang #include "mlir/Dialect/Affine/Utils.h" 12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h" 13fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Linalg/IR/Linalg.h" 1483396d85SHanhan Wang #include "mlir/Dialect/Linalg/Utils/Utils.h" 158b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 16fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/IR/Tensor.h" 170d03ba62SHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h" 180d03ba62SHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h" 198cc616bcSMax191 #include "mlir/Interfaces/InferTypeOpInterface.h" 20fd0c6f53SAlexander Belyaev #include "mlir/Interfaces/TilingInterface.h" 21eabb6ccdSMatthias Springer #include "mlir/Interfaces/ValueBoundsOpInterface.h" 22fd0c6f53SAlexander Belyaev 23fd0c6f53SAlexander Belyaev using namespace mlir; 24fd0c6f53SAlexander Belyaev using namespace mlir::tensor; 25fd0c6f53SAlexander Belyaev 26fd0c6f53SAlexander Belyaev namespace { 27fd0c6f53SAlexander Belyaev 28fd0c6f53SAlexander Belyaev struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { 29fd0c6f53SAlexander Belyaev 304f1c1242SOleg Shyshkov SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 31fd0c6f53SAlexander Belyaev auto padOp = cast<PadOp>(op); 324f1c1242SOleg Shyshkov SmallVector<utils::IteratorType> iteratorTypes( 334f1c1242SOleg Shyshkov padOp.getResultType().getRank(), utils::IteratorType::parallel); 34fd0c6f53SAlexander Belyaev return iteratorTypes; 35fd0c6f53SAlexander Belyaev } 36fd0c6f53SAlexander Belyaev 37fd0c6f53SAlexander Belyaev SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 38fd0c6f53SAlexander Belyaev ReifiedRankedShapedTypeDims reifiedShapes; 39758329dcSMatthias Springer (void)reifyResultShapes(b, op, reifiedShapes); 403e12cf2aSMatthias Springer OpFoldResult zero = b.getIndexAttr(0); 413e12cf2aSMatthias Springer OpFoldResult one = b.getIndexAttr(1); 42fd0c6f53SAlexander Belyaev // Initialize all the ranges to {zero, one, one}. All the `ub`s are 43fd0c6f53SAlexander Belyaev // overwritten. 44fd0c6f53SAlexander Belyaev SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); 45fd0c6f53SAlexander Belyaev for (const auto &ub : enumerate(reifiedShapes[0])) 46fd0c6f53SAlexander Belyaev loopRanges[ub.index()].size = ub.value(); 47fd0c6f53SAlexander Belyaev return loopRanges; 48fd0c6f53SAlexander Belyaev } 49fd0c6f53SAlexander Belyaev 50809e3d8cSMahesh Ravishankar FailureOr<TilingResult> 5154794284SMatthias Springer getTiledImplementation(Operation *op, OpBuilder &b, 52fd0c6f53SAlexander Belyaev ArrayRef<OpFoldResult> offsets, 5354794284SMatthias Springer ArrayRef<OpFoldResult> sizes) const { 54809e3d8cSMahesh Ravishankar FailureOr<TilingResult> result = 550edb4127SLei Zhang tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); 56809e3d8cSMahesh Ravishankar if (failed(result)) 57809e3d8cSMahesh Ravishankar return failure(); 58809e3d8cSMahesh Ravishankar return result.value(); 590edb4127SLei Zhang } 60a235562cSMahesh Ravishankar 61a235562cSMahesh Ravishankar LogicalResult 62a235562cSMahesh Ravishankar getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 63a235562cSMahesh Ravishankar ArrayRef<OpFoldResult> offsets, 64a235562cSMahesh Ravishankar ArrayRef<OpFoldResult> sizes, 65f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultOffsets, 66f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultSizes) const { 67a235562cSMahesh Ravishankar resultOffsets.assign(offsets.begin(), offsets.end()); 68a235562cSMahesh Ravishankar resultSizes.assign(sizes.begin(), sizes.end()); 69a235562cSMahesh Ravishankar return success(); 70a235562cSMahesh Ravishankar } 7191e57c6fSQuinn Dawkins 7291e57c6fSQuinn Dawkins LogicalResult getIterationDomainTileFromResultTile( 7391e57c6fSQuinn Dawkins Operation *op, OpBuilder &b, unsigned resultNumber, 7491e57c6fSQuinn Dawkins ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 7591e57c6fSQuinn Dawkins SmallVectorImpl<OpFoldResult> &iterDomainOffsets, 7691e57c6fSQuinn Dawkins SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { 7791e57c6fSQuinn Dawkins iterDomainOffsets.assign(offsets.begin(), offsets.end()); 7891e57c6fSQuinn Dawkins iterDomainSizes.assign(sizes.begin(), sizes.end()); 7991e57c6fSQuinn Dawkins return success(); 8091e57c6fSQuinn Dawkins } 8191e57c6fSQuinn Dawkins 8291e57c6fSQuinn Dawkins FailureOr<TilingResult> 8391e57c6fSQuinn Dawkins generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 8491e57c6fSQuinn Dawkins ArrayRef<OpFoldResult> offsets, 8591e57c6fSQuinn Dawkins ArrayRef<OpFoldResult> sizes) const { 8691e57c6fSQuinn Dawkins return getTiledImplementation(op, b, offsets, sizes); 8791e57c6fSQuinn Dawkins } 880edb4127SLei Zhang }; 890edb4127SLei Zhang 9083396d85SHanhan Wang template <typename OpTy> 9183396d85SHanhan Wang static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, 9283396d85SHanhan Wang OpBuilder &builder) { 9383396d85SHanhan Wang static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, 9483396d85SHanhan Wang "applies to only pack or unpack operations"); 9583396d85SHanhan Wang OpBuilder::InsertionGuard g(builder); 9683396d85SHanhan Wang int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() 9783396d85SHanhan Wang : op.getDestRank(); 983e12cf2aSMatthias Springer OpFoldResult zero = builder.getIndexAttr(0); 993e12cf2aSMatthias Springer OpFoldResult one = builder.getIndexAttr(1); 10083396d85SHanhan Wang ReifiedRankedShapedTypeDims resultShape; 101758329dcSMatthias Springer (void)reifyResultShapes(builder, op, resultShape); 10283396d85SHanhan Wang SmallVector<Range> loopBounds(rank); 10383396d85SHanhan Wang for (auto dim : llvm::seq<int64_t>(0, rank)) { 10483396d85SHanhan Wang loopBounds[dim].offset = zero; 10583396d85SHanhan Wang loopBounds[dim].stride = one; 10683396d85SHanhan Wang loopBounds[dim].size = resultShape[0][dim]; 10783396d85SHanhan Wang } 10883396d85SHanhan Wang return loopBounds; 10983396d85SHanhan Wang } 11083396d85SHanhan Wang 111d5a9fc13SLorenzo Chelini static void applyPermToRange(SmallVector<OpFoldResult> &offsets, 11283396d85SHanhan Wang SmallVector<OpFoldResult> &sizes, 11383396d85SHanhan Wang ArrayRef<int64_t> permutation) { 11483396d85SHanhan Wang if (permutation.empty()) 11583396d85SHanhan Wang return; 116d5a9fc13SLorenzo Chelini applyPermutationToVector<OpFoldResult>(offsets, permutation); 117d5a9fc13SLorenzo Chelini applyPermutationToVector<OpFoldResult>(sizes, permutation); 11883396d85SHanhan Wang } 11983396d85SHanhan Wang 1200d03ba62SHanhan Wang struct PackOpTiling 1210d03ba62SHanhan Wang : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { 1220d03ba62SHanhan Wang 1230d03ba62SHanhan Wang SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 1240d03ba62SHanhan Wang // Note that here we only consider untiled dimensions and outer tiled data 1250d03ba62SHanhan Wang // dimensions, the inner tiled data dimensions are materialized when 1260d03ba62SHanhan Wang // building the body of the operation. 1270d03ba62SHanhan Wang auto packOp = cast<PackOp>(op); 1280d03ba62SHanhan Wang SmallVector<utils::IteratorType> iteratorTypes( 1290d03ba62SHanhan Wang packOp.getSourceRank(), utils::IteratorType::parallel); 1300d03ba62SHanhan Wang return iteratorTypes; 1310d03ba62SHanhan Wang } 1320d03ba62SHanhan Wang 1330d03ba62SHanhan Wang SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 13483396d85SHanhan Wang return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); 1350d03ba62SHanhan Wang } 1360d03ba62SHanhan Wang 137809e3d8cSMahesh Ravishankar FailureOr<TilingResult> 1380d03ba62SHanhan Wang getTiledImplementation(Operation *op, OpBuilder &b, 1390d03ba62SHanhan Wang ArrayRef<OpFoldResult> offsets, 1400d03ba62SHanhan Wang ArrayRef<OpFoldResult> sizes) const { 1410d03ba62SHanhan Wang auto packOp = cast<PackOp>(op); 1420d03ba62SHanhan Wang Location loc = packOp.getLoc(); 1430d03ba62SHanhan Wang 1440d03ba62SHanhan Wang // The tiling is applied on interchanged dimensions. We have to undo the 1450d03ba62SHanhan Wang // interchange to map sizes and offsets to the original input. 1460d03ba62SHanhan Wang int64_t inputRank = packOp.getSourceRank(); 1475262865aSKazu Hirata SmallVector<OpFoldResult> origOffsets(offsets); 1485262865aSKazu Hirata SmallVector<OpFoldResult> origSizes(sizes); 149d5a9fc13SLorenzo Chelini applyPermToRange(origOffsets, origSizes, 150d5a9fc13SLorenzo Chelini invertPermutationVector(packOp.getOuterDimsPerm())); 1510d03ba62SHanhan Wang 1520d03ba62SHanhan Wang DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 1530d03ba62SHanhan Wang packOp.getDimAndTileMapping(); 1540d03ba62SHanhan Wang SmallVector<OpFoldResult> srcDimValues = 1556596b0ddSMatthias Springer tensor::getMixedSizes(b, loc, packOp.getSource()); 1560d03ba62SHanhan Wang SmallVector<OpFoldResult> inputIndices, inputSizes; 1570d03ba62SHanhan Wang for (auto dim : llvm::seq<int64_t>(0, inputRank)) { 1584c48f016SMatthias Springer using AV = affine::AffineValueExpr; 1594c48f016SMatthias Springer affine::AffineBuilder ab(b, loc); 1600d03ba62SHanhan Wang AffineExpr dim0, dim1, sym; 1610d03ba62SHanhan Wang bindDims(b.getContext(), dim0, dim1); 1620d03ba62SHanhan Wang bindSymbols(b.getContext(), sym); 1630d03ba62SHanhan Wang if (dimAndTileMapping.count(dim)) { 1640d03ba62SHanhan Wang // If the data dimension is tiled, the i-th index is the product of 1650d03ba62SHanhan Wang // offset_i and tile_i, and the i-th size is the product of sizes_i and 1660d03ba62SHanhan Wang // tile_i. 1670d03ba62SHanhan Wang auto avOffset = AV(dim0).bind(origOffsets[dim]); 1680d03ba62SHanhan Wang auto avSize = AV(dim0).bind(origSizes[dim]); 1690d03ba62SHanhan Wang auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 1700d03ba62SHanhan Wang inputIndices.push_back(ab.mul(avOffset, avTileSize)); 1710d03ba62SHanhan Wang inputSizes.push_back(ab.mul(avSize, avTileSize)); 1720d03ba62SHanhan Wang } else { 1730d03ba62SHanhan Wang inputIndices.push_back(origOffsets[dim]); 1740d03ba62SHanhan Wang inputSizes.push_back(origSizes[dim]); 1750d03ba62SHanhan Wang } 1760d03ba62SHanhan Wang 1770d03ba62SHanhan Wang // Limit the size of the input operand for incomplete tiles. 1784d036510SHanhan Wang if (packOp.getPaddingValue()) { 1790d03ba62SHanhan Wang OpFoldResult dimSize = srcDimValues[dim]; 1800d03ba62SHanhan Wang auto avDimSize = AV(dim0).bind(dimSize); 1810d03ba62SHanhan Wang auto avInputIdx = AV(dim1).bind(inputIndices.back()); 1820d03ba62SHanhan Wang inputSizes.back() = 1830d03ba62SHanhan Wang ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); 1840d03ba62SHanhan Wang } 1854d036510SHanhan Wang } 1860d03ba62SHanhan Wang 1870d03ba62SHanhan Wang auto oneAttr = b.getI64IntegerAttr(1); 1880d03ba62SHanhan Wang SmallVector<OpFoldResult> strides(inputRank, oneAttr); 1890d03ba62SHanhan Wang 1900d03ba62SHanhan Wang SmallVector<Value> tiledOperands; 191d5f0969cSMaheshRavishankar auto sourceSlice = b.create<ExtractSliceOp>( 192d5f0969cSMaheshRavishankar loc, packOp.getSource(), inputIndices, inputSizes, strides); 193d5f0969cSMaheshRavishankar tiledOperands.push_back(sourceSlice); 1940d03ba62SHanhan Wang 1950d03ba62SHanhan Wang SmallVector<OpFoldResult> outputOffsets, outputSizes; 1960d03ba62SHanhan Wang if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, 1970d03ba62SHanhan Wang outputSizes))) 1980d03ba62SHanhan Wang return {}; 1990d03ba62SHanhan Wang 2000d03ba62SHanhan Wang strides.append(packOp.getDestRank() - inputRank, oneAttr); 201d5f0969cSMaheshRavishankar auto outSlice = b.create<ExtractSliceOp>( 2020d03ba62SHanhan Wang loc, packOp.getDest(), outputOffsets, outputSizes, strides); 203d5f0969cSMaheshRavishankar tiledOperands.push_back(outSlice); 2040d03ba62SHanhan Wang 2050d03ba62SHanhan Wang if (auto val = packOp.getPaddingValue()) 2060d03ba62SHanhan Wang tiledOperands.push_back(val); 2070d03ba62SHanhan Wang for (auto tile : packOp.getInnerTiles()) 2080d03ba62SHanhan Wang tiledOperands.push_back(tile); 2090d03ba62SHanhan Wang 2100d03ba62SHanhan Wang Operation *tiledPackOp = b.create<PackOp>( 211d5f0969cSMaheshRavishankar loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); 2120d03ba62SHanhan Wang 213d5f0969cSMaheshRavishankar return TilingResult{ 214d5f0969cSMaheshRavishankar {tiledPackOp}, 215d5f0969cSMaheshRavishankar SmallVector<Value>(tiledPackOp->getResults()), 216d5f0969cSMaheshRavishankar llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})}; 2170d03ba62SHanhan Wang } 2180d03ba62SHanhan Wang 2190d03ba62SHanhan Wang LogicalResult 2200d03ba62SHanhan Wang getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 2210d03ba62SHanhan Wang ArrayRef<OpFoldResult> offsets, 2220d03ba62SHanhan Wang ArrayRef<OpFoldResult> sizes, 223f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultOffsets, 224f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultSizes) const { 2250d03ba62SHanhan Wang // The iteration domain is over outer dimensions of packed layout. In this 2260d03ba62SHanhan Wang // context, the outer dimensions of `resultOffsets` are `offsets`. The 2270d03ba62SHanhan Wang // inner dimensions of `resultOffsets` are zeros because tiling is not 2280d03ba62SHanhan Wang // applied to them. 2290d03ba62SHanhan Wang auto packOp = cast<PackOp>(op); 2300d03ba62SHanhan Wang int64_t inputRank = packOp.getSourceRank(); 2310d03ba62SHanhan Wang int64_t outputRank = packOp.getDestRank(); 2320d03ba62SHanhan Wang auto zeroAttr = b.getI64IntegerAttr(0); 2330d03ba62SHanhan Wang resultOffsets.assign(offsets.begin(), offsets.end()); 2340d03ba62SHanhan Wang resultOffsets.append(outputRank - inputRank, zeroAttr); 2350d03ba62SHanhan Wang 2360d03ba62SHanhan Wang ReifiedRankedShapedTypeDims outputShape; 237758329dcSMatthias Springer (void)reifyResultShapes(b, packOp, outputShape); 2380d03ba62SHanhan Wang resultSizes.assign(sizes.begin(), sizes.end()); 2390d03ba62SHanhan Wang for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) 2402a5b13e7SMatthias Springer resultSizes.push_back(outputShape[0][dataTileDim]); 2410d03ba62SHanhan Wang 2420d03ba62SHanhan Wang return success(); 2430d03ba62SHanhan Wang } 2448b68cec9SHanhan Wang 2458b68cec9SHanhan Wang FailureOr<TilingResult> 2468b68cec9SHanhan Wang generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 2478b68cec9SHanhan Wang ArrayRef<OpFoldResult> offsets, 2488b68cec9SHanhan Wang ArrayRef<OpFoldResult> sizes) const { 2498b68cec9SHanhan Wang auto packOp = cast<PackOp>(op); 2508b68cec9SHanhan Wang int64_t numTiles = packOp.getInnerDimsPos().size(); 2518b68cec9SHanhan Wang 2528b68cec9SHanhan Wang // tensor.pack op is fusible (as a producer) only if full inner tiles are 2538b68cec9SHanhan Wang // iterated or inner dims are not tiled. Otherwise, it will generate a 2548b68cec9SHanhan Wang // sequence of non-trivial ops (for partial tiles). 2558b68cec9SHanhan Wang for (auto offset : offsets.take_back(numTiles)) 2568b68cec9SHanhan Wang if (!isConstantIntValue(offset, 0)) 2578b68cec9SHanhan Wang return failure(); 2588b68cec9SHanhan Wang 2598b68cec9SHanhan Wang for (auto iter : 2608b68cec9SHanhan Wang llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) 2618b68cec9SHanhan Wang if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 2628b68cec9SHanhan Wang return failure(); 2638b68cec9SHanhan Wang 2648b68cec9SHanhan Wang FailureOr<TilingResult> tilingResult = getTiledImplementation( 2658b68cec9SHanhan Wang op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles)); 2668b68cec9SHanhan Wang if (failed(tilingResult)) 2678b68cec9SHanhan Wang return failure(); 2688b68cec9SHanhan Wang return tilingResult.value(); 2698b68cec9SHanhan Wang } 270f06563a5SYun-Fly 271f06563a5SYun-Fly /// Method to return the position of iteration domain tile computed by the 272f06563a5SYun-Fly /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and 273f06563a5SYun-Fly /// `resultSizes` only cover outer dimensions. 274f06563a5SYun-Fly LogicalResult getIterationDomainTileFromOperandTile( 275f06563a5SYun-Fly Operation *op, OpBuilder &b, unsigned operandNumber, 276f06563a5SYun-Fly ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 277f06563a5SYun-Fly SmallVectorImpl<OpFoldResult> &resultOffsets, 278f06563a5SYun-Fly SmallVectorImpl<OpFoldResult> &resultSizes) const { 279f06563a5SYun-Fly if (operandNumber != 0) 280f06563a5SYun-Fly return failure(); 281f06563a5SYun-Fly 282f06563a5SYun-Fly auto packOp = cast<PackOp>(op); 283f06563a5SYun-Fly // It is not trivial to infer dest tile from source tile if `packOp` has 284f06563a5SYun-Fly // padding semantic. 285f06563a5SYun-Fly if (packOp.getPaddingValue()) 286f06563a5SYun-Fly return failure(); 287f06563a5SYun-Fly 288f06563a5SYun-Fly Location loc = packOp.getLoc(); 289f06563a5SYun-Fly 290f06563a5SYun-Fly SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 291f06563a5SYun-Fly DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 292f06563a5SYun-Fly packOp.getDimAndTileMapping(); 293c8763f04SYun-Fly for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) { 294f06563a5SYun-Fly if (dimAndTileMapping.count(dim)) { 295f06563a5SYun-Fly FailureOr<int64_t> cstSize = 296f06563a5SYun-Fly ValueBoundsConstraintSet::computeConstantBound( 297f06563a5SYun-Fly presburger::BoundType::UB, sizes[dim], 298f06563a5SYun-Fly /*stopCondition=*/nullptr, /*closedUB=*/true); 299f06563a5SYun-Fly std::optional<int64_t> cstInnerSize = 300f06563a5SYun-Fly getConstantIntValue(dimAndTileMapping[dim]); 301f06563a5SYun-Fly // Currently fusing `packOp` as consumer only expects perfect tiling 302f06563a5SYun-Fly // scenario because even if without padding semantic, the `packOp` may 303f06563a5SYun-Fly // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, 304f06563a5SYun-Fly // where the `tileSize` from operand of `packOp` is 5, which is not 305f06563a5SYun-Fly // exactly divided by `innerTile`(=6) of `packOp`. As the result: 306f06563a5SYun-Fly // 1. the first slice is extracted from (0) to (4) and inserted into 307f06563a5SYun-Fly // (0,0)~(0,4) at first row. 308f06563a5SYun-Fly // 2. the second slice is extracted from (5) to (9) and SHOULD BE 309f06563a5SYun-Fly // respectively inserted into two rows with different length, including 310f06563a5SYun-Fly // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate 311f06563a5SYun-Fly // them, thus adding below constraint to bypass them temporarily. In 312f06563a5SYun-Fly // another word, we can only support tiling with consumer if the tile 313f06563a5SYun-Fly // size for the producer is a multiple of the inner tile size for the 314f06563a5SYun-Fly // packed dimensions at this moment. 315f06563a5SYun-Fly if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { 316f06563a5SYun-Fly return failure(); 317f06563a5SYun-Fly } 318f06563a5SYun-Fly 319f06563a5SYun-Fly using AV = affine::AffineValueExpr; 320f06563a5SYun-Fly affine::AffineBuilder ab(b, loc); 321f06563a5SYun-Fly AffineExpr dim0, sym; 322f06563a5SYun-Fly bindDims(b.getContext(), dim0); 323f06563a5SYun-Fly bindSymbols(b.getContext(), sym); 324f06563a5SYun-Fly auto avOffset = AV(dim0).bind(offsets[dim]); 325f06563a5SYun-Fly auto avSize = AV(dim0).bind(sizes[dim]); 326f06563a5SYun-Fly auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); 327f06563a5SYun-Fly outerDimOffsets.push_back(ab.floor(avOffset, avTileSize)); 328f06563a5SYun-Fly outerDimSizes.push_back(ab.ceil(avSize, avTileSize)); 329f06563a5SYun-Fly } else { 330f06563a5SYun-Fly outerDimOffsets.push_back(offsets[dim]); 331f06563a5SYun-Fly outerDimSizes.push_back(sizes[dim]); 332f06563a5SYun-Fly } 333f06563a5SYun-Fly } 334c8763f04SYun-Fly applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm()); 335f06563a5SYun-Fly resultOffsets = outerDimOffsets; 336f06563a5SYun-Fly resultSizes = outerDimSizes; 337f06563a5SYun-Fly return success(); 338f06563a5SYun-Fly } 339f06563a5SYun-Fly 340f06563a5SYun-Fly /// Method to return the tiled implementation of tensor.pack as a consumer. 341f06563a5SYun-Fly FailureOr<TilingResult> getTiledImplementationFromOperandTile( 342f06563a5SYun-Fly Operation *op, OpBuilder &b, unsigned operandNumber, 343f06563a5SYun-Fly ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 344f06563a5SYun-Fly if (operandNumber != 0) 345f06563a5SYun-Fly return failure(); 346f06563a5SYun-Fly 347f06563a5SYun-Fly auto packOp = cast<PackOp>(op); 348f06563a5SYun-Fly Location loc = packOp.getLoc(); 349f06563a5SYun-Fly 350f06563a5SYun-Fly int64_t inputRank = packOp.getSourceRank(); 351f06563a5SYun-Fly auto oneAttr = b.getI64IntegerAttr(1); 352f06563a5SYun-Fly SmallVector<OpFoldResult> strides(inputRank, oneAttr); 353f06563a5SYun-Fly 354f06563a5SYun-Fly SmallVector<Value> tiledOperands; 355d5f0969cSMaheshRavishankar auto sourceSlice = b.create<ExtractSliceOp>(loc, packOp.getSource(), 356d5f0969cSMaheshRavishankar offsets, sizes, strides); 357d5f0969cSMaheshRavishankar tiledOperands.push_back(sourceSlice); 358f06563a5SYun-Fly 359f06563a5SYun-Fly SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; 360f06563a5SYun-Fly if (failed(getIterationDomainTileFromOperandTile( 361f06563a5SYun-Fly op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets, 362f06563a5SYun-Fly outerDimSizes))) 363f06563a5SYun-Fly return failure(); 364f06563a5SYun-Fly 365f06563a5SYun-Fly SmallVector<OpFoldResult> outputOffsets, outputSizes; 366f06563a5SYun-Fly if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes, 367f06563a5SYun-Fly outputOffsets, outputSizes))) 368f06563a5SYun-Fly return failure(); 369f06563a5SYun-Fly 370f06563a5SYun-Fly strides.append(packOp.getDestRank() - inputRank, oneAttr); 371d5f0969cSMaheshRavishankar auto outSlice = b.create<ExtractSliceOp>( 372f06563a5SYun-Fly loc, packOp.getDest(), outputOffsets, outputSizes, strides); 373d5f0969cSMaheshRavishankar tiledOperands.push_back(outSlice); 374f06563a5SYun-Fly 375f06563a5SYun-Fly assert(!packOp.getPaddingValue() && "Expect no padding semantic"); 376f06563a5SYun-Fly for (auto tile : packOp.getInnerTiles()) 377f06563a5SYun-Fly tiledOperands.push_back(tile); 378f06563a5SYun-Fly 379f06563a5SYun-Fly Operation *tiledPackOp = b.create<PackOp>( 380d5f0969cSMaheshRavishankar loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); 381f06563a5SYun-Fly 382d5f0969cSMaheshRavishankar return TilingResult{ 383d5f0969cSMaheshRavishankar {tiledPackOp}, 384d5f0969cSMaheshRavishankar SmallVector<Value>(tiledPackOp->getResults()), 385d5f0969cSMaheshRavishankar llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})}; 386f06563a5SYun-Fly } 3870d03ba62SHanhan Wang }; 3880d03ba62SHanhan Wang 38983396d85SHanhan Wang struct UnpackTileDimInfo { 39083396d85SHanhan Wang bool isAlignedToInnerTileSize; 39183396d85SHanhan Wang OpFoldResult sourceOffset; 39283396d85SHanhan Wang OpFoldResult sourceSize; 39383396d85SHanhan Wang OpFoldResult resultOffset; 39483396d85SHanhan Wang OpFoldResult destExpandedSize; 39583396d85SHanhan Wang }; 39683396d85SHanhan Wang 39783396d85SHanhan Wang /// Returns the needed information for tiling unpack op on `tileDim` with given 39883396d85SHanhan Wang /// `tileOffset` and `tileSize`. For more details, see the comment of the 39983396d85SHanhan Wang /// `getTiledImplementation`. 40083396d85SHanhan Wang static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, 40183396d85SHanhan Wang int64_t tileDim, 40283396d85SHanhan Wang OpFoldResult tileOffset, 40383396d85SHanhan Wang OpFoldResult tileSize) { 40483396d85SHanhan Wang UnpackTileDimInfo info; 40583396d85SHanhan Wang Attribute zeroAttr = b.getIndexAttr(0); 40683396d85SHanhan Wang Attribute oneAttr = b.getIndexAttr(1); 40783396d85SHanhan Wang DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 40883396d85SHanhan Wang unpackOp.getDimAndTileMapping(); 40983396d85SHanhan Wang // The dimension is not one of packed data dimension. 41083396d85SHanhan Wang if (!dimAndTileMapping.count(tileDim)) { 41183396d85SHanhan Wang info.isAlignedToInnerTileSize = true; 41283396d85SHanhan Wang info.sourceOffset = tileOffset; 41383396d85SHanhan Wang info.sourceSize = tileSize; 41483396d85SHanhan Wang info.resultOffset = zeroAttr; 41583396d85SHanhan Wang info.destExpandedSize = tileSize; 41683396d85SHanhan Wang return info; 41783396d85SHanhan Wang } 41883396d85SHanhan Wang 41983396d85SHanhan Wang Location loc = unpackOp.getLoc(); 4204c48f016SMatthias Springer using AV = affine::AffineValueExpr; 4214c48f016SMatthias Springer affine::AffineBuilder ab(b, loc); 42283396d85SHanhan Wang AffineExpr dim0, dim1, sym0; 42383396d85SHanhan Wang bindDims(b.getContext(), dim0, dim1); 42483396d85SHanhan Wang bindSymbols(b.getContext(), sym0); 42583396d85SHanhan Wang 42683396d85SHanhan Wang OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; 42783396d85SHanhan Wang 42883396d85SHanhan Wang info.isAlignedToInnerTileSize = false; 429eabb6ccdSMatthias Springer FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( 43040dd3aa9SMatthias Springer presburger::BoundType::UB, tileSize, 431eabb6ccdSMatthias Springer /*stopCondition=*/nullptr, /*closedUB=*/true); 43222426110SRamkumar Ramachandra std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); 43383396d85SHanhan Wang if (!failed(cstSize) && cstInnerSize) { 434cbb09813SFangrui Song if (*cstSize % *cstInnerSize == 0) 43583396d85SHanhan Wang info.isAlignedToInnerTileSize = true; 43683396d85SHanhan Wang 43783396d85SHanhan Wang // If the tiling size equals to the inner tiling size, the outer dims are 43883396d85SHanhan Wang // always 1. 439cbb09813SFangrui Song if (*cstInnerSize == *cstSize) { 44083396d85SHanhan Wang auto lhs = AV(dim0).bind(tileOffset); 44183396d85SHanhan Wang auto rhs = AV(dim1).bind(innerTileSize); 44283396d85SHanhan Wang info.sourceOffset = ab.floor(lhs, rhs); 44383396d85SHanhan Wang info.sourceSize = oneAttr; 44483396d85SHanhan Wang info.resultOffset = zeroAttr; 44583396d85SHanhan Wang info.destExpandedSize = tileSize; 44683396d85SHanhan Wang return info; 44783396d85SHanhan Wang } 44883396d85SHanhan Wang } 44983396d85SHanhan Wang 45083396d85SHanhan Wang if (info.isAlignedToInnerTileSize) { 45183396d85SHanhan Wang info.sourceOffset = 45283396d85SHanhan Wang ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize)); 45383396d85SHanhan Wang info.resultOffset = zeroAttr; 45483396d85SHanhan Wang info.destExpandedSize = tileSize; 45583396d85SHanhan Wang 45683396d85SHanhan Wang // The ceilDiv is needed here because there could be incomplete tile even 45783396d85SHanhan Wang // it is perfect tiling cases. E.g., 45883396d85SHanhan Wang // %0 = unpack tensor<33x2xf32> into tensor<64xf32> 45983396d85SHanhan Wang // If the tiling size is 32, there will be 3 tiles. Two of them have 46083396d85SHanhan Wang // size=32; one of them have size=2. The size is represented using 46183396d85SHanhan Wang // affine_min op; we need ceilDiv. 46283396d85SHanhan Wang info.sourceSize = 46383396d85SHanhan Wang ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize)); 46483396d85SHanhan Wang return info; 46583396d85SHanhan Wang } 46683396d85SHanhan Wang 4674c48f016SMatthias Springer affine::DivModValue firstCoord = affine::getDivMod( 4684c48f016SMatthias Springer b, loc, getValueOrCreateConstantIndexOp(b, loc, tileOffset), 46983396d85SHanhan Wang getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 47083396d85SHanhan Wang OpFoldResult tileExclusiveBound = 47183396d85SHanhan Wang ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize)); 4724c48f016SMatthias Springer affine::DivModValue lastCoord = affine::getDivMod( 47383396d85SHanhan Wang b, loc, 47483396d85SHanhan Wang getValueOrCreateConstantIndexOp( 47583396d85SHanhan Wang b, loc, 47683396d85SHanhan Wang ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))), 47783396d85SHanhan Wang getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 47883396d85SHanhan Wang 47983396d85SHanhan Wang OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient), 48083396d85SHanhan Wang AV(dim1).bind(firstCoord.quotient)); 48183396d85SHanhan Wang info.sourceSize = 48283396d85SHanhan Wang ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr)); 48383396d85SHanhan Wang info.sourceOffset = firstCoord.quotient; 48483396d85SHanhan Wang info.resultOffset = firstCoord.remainder; 4855fa9933cSHanhan Wang // Do not create an Affine ops for expanded size because the affine op is too 4865fa9933cSHanhan Wang // complicated which would trigger an issue in affine ops simplification. 4875fa9933cSHanhan Wang info.destExpandedSize = b.createOrFold<arith::MulIOp>( 4885fa9933cSHanhan Wang loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), 4895fa9933cSHanhan Wang getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); 49083396d85SHanhan Wang return info; 49183396d85SHanhan Wang } 49283396d85SHanhan Wang 49383396d85SHanhan Wang struct UnPackOpTiling 49483396d85SHanhan Wang : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { 49583396d85SHanhan Wang 49683396d85SHanhan Wang SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { 49783396d85SHanhan Wang auto unpackOp = cast<UnPackOp>(op); 49883396d85SHanhan Wang SmallVector<utils::IteratorType> iteratorTypes( 49983396d85SHanhan Wang unpackOp.getDestRank(), utils::IteratorType::parallel); 50083396d85SHanhan Wang return iteratorTypes; 50183396d85SHanhan Wang } 50283396d85SHanhan Wang 50383396d85SHanhan Wang SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { 50483396d85SHanhan Wang return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); 50583396d85SHanhan Wang } 50683396d85SHanhan Wang 50783396d85SHanhan Wang /// There are two cases in tiling unpack ops. If the tiling size is aligned to 50883396d85SHanhan Wang /// the inner tile size, the corresponding tiles of source are all complete. 50983396d85SHanhan Wang /// Otherwise, there are in-complete tiles. We will need to expand the slice 51083396d85SHanhan Wang /// of source for getting complete tiles. The tiled unpack op unpacks more 51183396d85SHanhan Wang /// data from source, so We'll need an extract_slice op to shift and truncate 51283396d85SHanhan Wang /// the output. 51383396d85SHanhan Wang /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The 51483396d85SHanhan Wang /// coordinates of second tile (i.e., result[15..31]) are 51583396d85SHanhan Wang /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last 51683396d85SHanhan Wang /// row are incomplete tiles. To represent the unpack op, we have to complete 51783396d85SHanhan Wang /// the rows. I.e., the input coordinates would start with (1, 0); end with 51883396d85SHanhan Wang /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements 51983396d85SHanhan Wang /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we 52083396d85SHanhan Wang /// can get the actual result. 521809e3d8cSMahesh Ravishankar FailureOr<TilingResult> 52283396d85SHanhan Wang getTiledImplementation(Operation *op, OpBuilder &b, 52383396d85SHanhan Wang ArrayRef<OpFoldResult> offsets, 52483396d85SHanhan Wang ArrayRef<OpFoldResult> sizes) const { 52583396d85SHanhan Wang auto unpackOp = cast<UnPackOp>(op); 52683396d85SHanhan Wang int64_t srcRank = unpackOp.getSourceRank(); 52783396d85SHanhan Wang int64_t destRank = unpackOp.getDestRank(); 52883396d85SHanhan Wang int64_t numInnerTiles = srcRank - destRank; 52983396d85SHanhan Wang Location loc = unpackOp.getLoc(); 53083396d85SHanhan Wang 53183396d85SHanhan Wang // The perfect tiling case indicates that the tiling sizes are multiple of 53283396d85SHanhan Wang // inner_tile_size. In this context, no extra data is needed when 53383396d85SHanhan Wang // representing the tiled unpack op. 53483396d85SHanhan Wang bool isPerfectTilingCase = true; 53583396d85SHanhan Wang Attribute oneAttr = b.getIndexAttr(1); 53683396d85SHanhan Wang SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); 53783396d85SHanhan Wang SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; 53883396d85SHanhan Wang SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; 53983396d85SHanhan Wang for (auto dim : llvm::seq<int64_t>(0, destRank)) { 54083396d85SHanhan Wang UnpackTileDimInfo info = 54183396d85SHanhan Wang getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); 54283396d85SHanhan Wang if (!info.isAlignedToInnerTileSize) 54383396d85SHanhan Wang isPerfectTilingCase = false; 54483396d85SHanhan Wang sliceSrcIndices.push_back(info.sourceOffset); 54583396d85SHanhan Wang sliceSrcSizes.push_back(info.sourceSize); 54683396d85SHanhan Wang destExpandedSizes.push_back(info.destExpandedSize); 54783396d85SHanhan Wang resultOffsetsFromDest.push_back(info.resultOffset); 54883396d85SHanhan Wang } 54983396d85SHanhan Wang 55083396d85SHanhan Wang // The tiling is applied on destination dimensions. We have to apply the 55183396d85SHanhan Wang // interchange on source dimensions if outer_dims_perm is set. 552d5a9fc13SLorenzo Chelini applyPermToRange(sliceSrcIndices, sliceSrcSizes, 55383396d85SHanhan Wang unpackOp.getOuterDimsPerm()); 55483396d85SHanhan Wang Attribute zeroAttr = b.getIndexAttr(0); 55583396d85SHanhan Wang sliceSrcIndices.append(numInnerTiles, zeroAttr); 55683396d85SHanhan Wang sliceSrcSizes.append(unpackOp.getMixedTiles()); 55783396d85SHanhan Wang sliceSrcStrides.append(numInnerTiles, oneAttr); 558f1595ecfSMax191 SmallVector<Operation *> generatedSlices; 559f1595ecfSMax191 ExtractSliceOp sliceSource = 56083396d85SHanhan Wang b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, 56183396d85SHanhan Wang sliceSrcSizes, sliceSrcStrides); 562f1595ecfSMax191 generatedSlices.push_back(sliceSource); 56383396d85SHanhan Wang 56483396d85SHanhan Wang SmallVector<OpFoldResult> destStrides(destRank, oneAttr); 56583396d85SHanhan Wang Value sliceDest; 56683396d85SHanhan Wang if (isPerfectTilingCase) { 567d5f0969cSMaheshRavishankar auto destSliceOp = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), 568d5f0969cSMaheshRavishankar offsets, sizes, destStrides); 569d5f0969cSMaheshRavishankar sliceDest = destSliceOp; 570d5f0969cSMaheshRavishankar generatedSlices.push_back(destSliceOp); 57183396d85SHanhan Wang } else { 57283396d85SHanhan Wang sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, 57383396d85SHanhan Wang unpackOp.getDestType().getElementType()); 57483396d85SHanhan Wang } 57583396d85SHanhan Wang 576f1595ecfSMax191 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest}; 577be75cf93SHanhan Wang for (auto tile : unpackOp.getInnerTiles()) 578be75cf93SHanhan Wang tiledOperands.push_back(tile); 579be75cf93SHanhan Wang 580be75cf93SHanhan Wang Operation *tiledUnpackOp = b.create<UnPackOp>( 581be75cf93SHanhan Wang loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); 58283396d85SHanhan Wang 58383396d85SHanhan Wang if (isPerfectTilingCase) 584809e3d8cSMahesh Ravishankar return TilingResult{{tiledUnpackOp}, 585d5f0969cSMaheshRavishankar SmallVector<Value>(tiledUnpackOp->getResults()), 586d5f0969cSMaheshRavishankar generatedSlices}; 58783396d85SHanhan Wang 588809e3d8cSMahesh Ravishankar auto extractSlice = 58983396d85SHanhan Wang b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), 59083396d85SHanhan Wang resultOffsetsFromDest, sizes, destStrides); 591d5f0969cSMaheshRavishankar return TilingResult{ 592d5f0969cSMaheshRavishankar {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; 59383396d85SHanhan Wang } 59483396d85SHanhan Wang 59583396d85SHanhan Wang LogicalResult 59683396d85SHanhan Wang getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, 59783396d85SHanhan Wang ArrayRef<OpFoldResult> offsets, 59883396d85SHanhan Wang ArrayRef<OpFoldResult> sizes, 599f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultOffsets, 600f220c359SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> &resultSizes) const { 60183396d85SHanhan Wang resultOffsets = llvm::to_vector(offsets); 60283396d85SHanhan Wang resultSizes = llvm::to_vector(sizes); 60383396d85SHanhan Wang return success(); 60483396d85SHanhan Wang } 605ead535b2SHanhan Wang 606809e3d8cSMahesh Ravishankar FailureOr<TilingResult> 607809e3d8cSMahesh Ravishankar generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, 608ead535b2SHanhan Wang ArrayRef<OpFoldResult> offsets, 609ead535b2SHanhan Wang ArrayRef<OpFoldResult> sizes) const { 610809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tilingResult = 611809e3d8cSMahesh Ravishankar getTiledImplementation(op, b, offsets, sizes); 612809e3d8cSMahesh Ravishankar if (failed(tilingResult)) 613809e3d8cSMahesh Ravishankar return failure(); 614809e3d8cSMahesh Ravishankar return tilingResult.value(); 615ead535b2SHanhan Wang } 6162b2ce50fSAbhishek Varma 6172b2ce50fSAbhishek Varma /// Method to return the position of iteration domain tile computed by the 6182b2ce50fSAbhishek Varma /// tiled operation. 6192b2ce50fSAbhishek Varma LogicalResult getIterationDomainTileFromOperandTile( 6202b2ce50fSAbhishek Varma Operation *op, OpBuilder &b, unsigned operandNumber, 6212b2ce50fSAbhishek Varma ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, 6222b2ce50fSAbhishek Varma SmallVectorImpl<OpFoldResult> &resultOffsets, 6232b2ce50fSAbhishek Varma SmallVectorImpl<OpFoldResult> &resultSizes) const { 6242b2ce50fSAbhishek Varma auto unPackOp = cast<UnPackOp>(op); 6258cc616bcSMax191 // If the operand tile is the dest, then no adjustment is needed. 6268cc616bcSMax191 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) { 6278cc616bcSMax191 resultOffsets = llvm::to_vector(offsets); 6288cc616bcSMax191 resultSizes = llvm::to_vector(sizes); 6298cc616bcSMax191 return success(); 6308cc616bcSMax191 } 6312b2ce50fSAbhishek Varma Location loc = unPackOp.getLoc(); 6322b2ce50fSAbhishek Varma 6332b2ce50fSAbhishek Varma int64_t numTiles = unPackOp.getInnerDimsPos().size(); 6342b2ce50fSAbhishek Varma auto destOffsets = offsets.drop_back(numTiles); 6352b2ce50fSAbhishek Varma auto destSizes = sizes.drop_back(numTiles); 6362b2ce50fSAbhishek Varma // The tiling is applied on interchanged dimensions. We have to undo the 6372b2ce50fSAbhishek Varma // interchange to map sizes and offsets to the original input. 6382b2ce50fSAbhishek Varma int64_t outputRank = unPackOp.getDestRank(); 6398cc616bcSMax191 ReifiedRankedShapedTypeDims reifiedReturnShapes; 6408cc616bcSMax191 if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes))) 6418cc616bcSMax191 return failure(); 6428cc616bcSMax191 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front(); 6435262865aSKazu Hirata SmallVector<OpFoldResult> origOffsets(destOffsets); 6445262865aSKazu Hirata SmallVector<OpFoldResult> origSizes(destSizes); 6452b2ce50fSAbhishek Varma applyPermToRange(origOffsets, origSizes, 6462b2ce50fSAbhishek Varma invertPermutationVector(unPackOp.getOuterDimsPerm())); 6472b2ce50fSAbhishek Varma 6482b2ce50fSAbhishek Varma DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 6492b2ce50fSAbhishek Varma unPackOp.getDimAndTileMapping(); 6502b2ce50fSAbhishek Varma 6512b2ce50fSAbhishek Varma for (auto dim : llvm::seq<int64_t>(0, outputRank)) { 6522b2ce50fSAbhishek Varma using AV = affine::AffineValueExpr; 6532b2ce50fSAbhishek Varma affine::AffineBuilder ab(b, loc); 6548cc616bcSMax191 AffineExpr dim0, dim1, sym0; 6552b2ce50fSAbhishek Varma bindDims(b.getContext(), dim0, dim1); 6568cc616bcSMax191 bindSymbols(b.getContext(), sym0); 6572b2ce50fSAbhishek Varma if (dimAndTileMapping.count(dim)) { 6582b2ce50fSAbhishek Varma // If the data dimension is tiled, the i-th index is the product of 6592b2ce50fSAbhishek Varma // offset_i and tile_i, and the i-th size is the product of sizes_i and 6608cc616bcSMax191 // tile_i. The sizes must be clamped to the sizes of the unpack result. 6612b2ce50fSAbhishek Varma auto avOffset = AV(dim0).bind(origOffsets[dim]); 6622b2ce50fSAbhishek Varma auto avSize = AV(dim0).bind(origSizes[dim]); 6638cc616bcSMax191 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]); 6648cc616bcSMax191 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]); 6652b2ce50fSAbhishek Varma resultOffsets.push_back(ab.mul(avOffset, avTileSize)); 6668cc616bcSMax191 auto avResultOffset = AV(dim1).bind(resultOffsets.back()); 6678cc616bcSMax191 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize), 6688cc616bcSMax191 ab.sub(avResultSize, avResultOffset)})); 6692b2ce50fSAbhishek Varma } else { 6702b2ce50fSAbhishek Varma resultOffsets.push_back(origOffsets[dim]); 6712b2ce50fSAbhishek Varma resultSizes.push_back(origSizes[dim]); 6722b2ce50fSAbhishek Varma } 6732b2ce50fSAbhishek Varma } 6742b2ce50fSAbhishek Varma return success(); 6752b2ce50fSAbhishek Varma } 6762b2ce50fSAbhishek Varma 6772b2ce50fSAbhishek Varma /// Method to return the tiled implementation of tensor.unpack as a consumer. 6782b2ce50fSAbhishek Varma FailureOr<TilingResult> getTiledImplementationFromOperandTile( 6792b2ce50fSAbhishek Varma Operation *op, OpBuilder &b, unsigned operandNumber, 6802b2ce50fSAbhishek Varma ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { 6812b2ce50fSAbhishek Varma auto unPackOp = cast<UnPackOp>(op); 6822b2ce50fSAbhishek Varma // tensor.unpack op is fusible (as a consumer) only if inner dims are not 6832b2ce50fSAbhishek Varma // tiled. 6842b2ce50fSAbhishek Varma int64_t numTiles = unPackOp.getInnerDimsPos().size(); 6852b2ce50fSAbhishek Varma for (auto iter : 6862b2ce50fSAbhishek Varma llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) { 6872b2ce50fSAbhishek Varma if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) 6882b2ce50fSAbhishek Varma return failure(); 6892b2ce50fSAbhishek Varma } 6902b2ce50fSAbhishek Varma 6912b2ce50fSAbhishek Varma Location loc = unPackOp.getLoc(); 6922b2ce50fSAbhishek Varma 6932b2ce50fSAbhishek Varma // Fetch offset/size for creating the slice of the dest operand of 6942b2ce50fSAbhishek Varma // unpack op. 6952b2ce50fSAbhishek Varma SmallVector<OpFoldResult> outputOffsets, outputSizes; 6962b2ce50fSAbhishek Varma if (failed(getIterationDomainTileFromOperandTile( 6972b2ce50fSAbhishek Varma op, b, /*operandNumber=*/0, offsets, sizes, outputOffsets, 6982b2ce50fSAbhishek Varma outputSizes))) 6992b2ce50fSAbhishek Varma return failure(); 7002b2ce50fSAbhishek Varma 7012b2ce50fSAbhishek Varma auto oneAttr = b.getI64IntegerAttr(1); 7022b2ce50fSAbhishek Varma int64_t outputRank = unPackOp.getDestRank(); 7032b2ce50fSAbhishek Varma SmallVector<OpFoldResult> strides(outputRank, oneAttr); 7042b2ce50fSAbhishek Varma 7052b2ce50fSAbhishek Varma SmallVector<Value> tiledOperands; 7062b2ce50fSAbhishek Varma // Create slice of the dest operand. 7072b2ce50fSAbhishek Varma auto extractDestSlice = b.create<ExtractSliceOp>( 7082b2ce50fSAbhishek Varma loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); 7092b2ce50fSAbhishek Varma tiledOperands.push_back(extractDestSlice); 7102b2ce50fSAbhishek Varma 7112b2ce50fSAbhishek Varma SmallVector<OpFoldResult> inputOffsets, inputSizes; 7122b2ce50fSAbhishek Varma strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); 7132b2ce50fSAbhishek Varma // Create slice of the source operand. 7142b2ce50fSAbhishek Varma auto extractSourceSlice = b.create<ExtractSliceOp>( 7152b2ce50fSAbhishek Varma loc, unPackOp.getSource(), offsets, sizes, strides); 7162b2ce50fSAbhishek Varma tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); 7172b2ce50fSAbhishek Varma for (auto tile : unPackOp.getInnerTiles()) 7182b2ce50fSAbhishek Varma tiledOperands.push_back(tile); 7192b2ce50fSAbhishek Varma 7202b2ce50fSAbhishek Varma // Create tiled unpack op. 7212b2ce50fSAbhishek Varma Operation *tiledUnPackOp = 7222b2ce50fSAbhishek Varma b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()}, 7232b2ce50fSAbhishek Varma tiledOperands, op->getAttrs()); 7242b2ce50fSAbhishek Varma 7252b2ce50fSAbhishek Varma return TilingResult{{tiledUnPackOp}, 726d5f0969cSMaheshRavishankar SmallVector<Value>(tiledUnPackOp->getResults()), 727d5f0969cSMaheshRavishankar llvm::to_vector(ArrayRef<Operation *>{ 728d5f0969cSMaheshRavishankar extractSourceSlice, extractDestSlice})}; 7292b2ce50fSAbhishek Varma } 73083396d85SHanhan Wang }; 73183396d85SHanhan Wang 7320edb4127SLei Zhang } // namespace 7330edb4127SLei Zhang 734809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, 735809e3d8cSMahesh Ravishankar tensor::PadOp padOp, 7360edb4127SLei Zhang ArrayRef<OpFoldResult> offsets, 7370edb4127SLei Zhang ArrayRef<OpFoldResult> sizes, 7380edb4127SLei Zhang bool generateZeroSliceGuard) { 739fd0c6f53SAlexander Belyaev // Only constant padding value supported. 740fd0c6f53SAlexander Belyaev Value padValue = padOp.getConstantPaddingValue(); 741fd0c6f53SAlexander Belyaev if (!padValue) 742809e3d8cSMahesh Ravishankar return failure(); 743fd0c6f53SAlexander Belyaev 744fd0c6f53SAlexander Belyaev // Helper variables and functions for various arithmetic operations. These 745fd0c6f53SAlexander Belyaev // are used extensively for computing new offset/length and padding values. 7460edb4127SLei Zhang Location loc = padOp->getLoc(); 747fd0c6f53SAlexander Belyaev AffineExpr dim0, dim1; 748fd0c6f53SAlexander Belyaev bindDims(b.getContext(), dim0, dim1); 749fd0c6f53SAlexander Belyaev // Subtract two integers. 750fd0c6f53SAlexander Belyaev auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); 751d3fa067eSMahesh Ravishankar auto sub = [&](OpFoldResult v1, OpFoldResult v2) { 7524c48f016SMatthias Springer return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); 753fd0c6f53SAlexander Belyaev }; 754fd0c6f53SAlexander Belyaev // Take the minimum of two integers. 755fd0c6f53SAlexander Belyaev auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); 756d3fa067eSMahesh Ravishankar auto min = [&](OpFoldResult v1, OpFoldResult v2) { 7574c48f016SMatthias Springer return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); 758fd0c6f53SAlexander Belyaev }; 759fd0c6f53SAlexander Belyaev // Take the maximum of two integers. 760d3fa067eSMahesh Ravishankar auto max = [&](OpFoldResult v1, OpFoldResult v2) { 7614c48f016SMatthias Springer return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); 762fd0c6f53SAlexander Belyaev }; 763fd0c6f53SAlexander Belyaev // Zero index-typed integer. 764d3fa067eSMahesh Ravishankar OpFoldResult zero = b.getIndexAttr(0); 765fd0c6f53SAlexander Belyaev 766fd0c6f53SAlexander Belyaev // Compute new offsets, lengths, low padding, high padding. 767fd0c6f53SAlexander Belyaev SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; 768d3fa067eSMahesh Ravishankar SmallVector<OpFoldResult> newLows, newHighs; 769fd0c6f53SAlexander Belyaev // Set to true if the original data source is not read at all. 770fd0c6f53SAlexander Belyaev bool hasZeroLen = false; 771fd0c6f53SAlexander Belyaev // Same as hasZeroLen, but for dynamic dimension sizes. This condition 772fd0c6f53SAlexander Belyaev // is true if the original data source turns out to be unused at runtime. 773fd0c6f53SAlexander Belyaev Value dynHasZeroLenCond; 774fd0c6f53SAlexander Belyaev 775fd0c6f53SAlexander Belyaev int64_t rank = padOp.getSourceType().getRank(); 776fd0c6f53SAlexander Belyaev for (unsigned dim = 0; dim < rank; ++dim) { 777d3fa067eSMahesh Ravishankar auto low = padOp.getMixedLowPad()[dim]; 778d3fa067eSMahesh Ravishankar bool hasLowPad = !isConstantIntValue(low, 0); 779d3fa067eSMahesh Ravishankar auto high = padOp.getMixedHighPad()[dim]; 780d3fa067eSMahesh Ravishankar bool hasHighPad = !isConstantIntValue(high, 0); 781d3fa067eSMahesh Ravishankar auto offset = offsets[dim]; 782d3fa067eSMahesh Ravishankar auto length = sizes[dim]; 7836596b0ddSMatthias Springer auto srcSize = tensor::getMixedSize(b, loc, padOp.getSource(), dim); 784fd0c6f53SAlexander Belyaev 785fd0c6f53SAlexander Belyaev // The new amount of low padding is `low - offset`. Except for the case 786fd0c6f53SAlexander Belyaev // where none of the low padding is read. In that case, the new amount of 787fd0c6f53SAlexander Belyaev // low padding is zero. 788fd0c6f53SAlexander Belyaev // 789fd0c6f53SAlexander Belyaev // Optimization: If low = 0, then newLow = 0. 790d3fa067eSMahesh Ravishankar OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; 791d3fa067eSMahesh Ravishankar newLows.push_back(newLow); 792fd0c6f53SAlexander Belyaev 793fd0c6f53SAlexander Belyaev // Start reading the data from position `offset - low`. Since the original 794fd0c6f53SAlexander Belyaev // read may have started in the low padding zone, this value could be 795fd0c6f53SAlexander Belyaev // negative. Therefore, start reading from: 796fd0c6f53SAlexander Belyaev // 797fd0c6f53SAlexander Belyaev // max(offset - low, 0) 798fd0c6f53SAlexander Belyaev // 799fd0c6f53SAlexander Belyaev // The original read could also have started in the high padding zone. 800fd0c6f53SAlexander Belyaev // In that case, set the offset to the end of source tensor. The new 801fd0c6f53SAlexander Belyaev // ExtractSliceOp length will be zero in that case. (Effectively reading 802fd0c6f53SAlexander Belyaev // no data from the source.) 803fd0c6f53SAlexander Belyaev // 804fd0c6f53SAlexander Belyaev // Optimization: If low = 0, then the formula can be simplified. 805d3fa067eSMahesh Ravishankar OpFoldResult newOffset = hasLowPad 806d3fa067eSMahesh Ravishankar ? min(max(sub(offset, low), zero), srcSize) 807fd0c6f53SAlexander Belyaev : min(offset, srcSize); 808d3fa067eSMahesh Ravishankar newOffsets.push_back(newOffset); 809fd0c6f53SAlexander Belyaev 810fd0c6f53SAlexander Belyaev // The original ExtractSliceOp was reading until position `offset + 811fd0c6f53SAlexander Belyaev // length`. Therefore, the corresponding position within the source tensor 812fd0c6f53SAlexander Belyaev // is: 813fd0c6f53SAlexander Belyaev // 814fd0c6f53SAlexander Belyaev // offset + length - low 815fd0c6f53SAlexander Belyaev // 816fd0c6f53SAlexander Belyaev // In case the original ExtractSliceOp stopped reading within the low 817fd0c6f53SAlexander Belyaev // padding zone, this value can be negative. In that case, the end 818fd0c6f53SAlexander Belyaev // position of the read should be zero. (Similar to newOffset.) 819fd0c6f53SAlexander Belyaev // 820fd0c6f53SAlexander Belyaev // The original read could also have stopped in the high padding zone. 821fd0c6f53SAlexander Belyaev // In that case, set the end positition of the read should be the end of 822fd0c6f53SAlexander Belyaev // the source tensor. (Similar to newOffset.) 8233f136f7dSNirvedh Meshram // srcSize - newOffset represents how much length we have available 8243f136f7dSNirvedh Meshram // and length - newLow represents how much length we want at most. 82577400103SNirvedh Meshram // Note that there are many ways to order this indexing math to compute 82677400103SNirvedh Meshram // newLength, but we want to make sure that the final affine.min ops in the 82777400103SNirvedh Meshram // sequence are bounding the index to as small a value as possible. If 828*33927744SNirvedh // ValueBoundsOpInterface is used, this calculation will get upper bounds 82977400103SNirvedh Meshram // from the affine.min ops, so we want to use the smallest known value to 83077400103SNirvedh Meshram // set the bound at the end of the computation sequence. In this case, the 83177400103SNirvedh Meshram // index will be upper bounded by length - newLow. 8323f136f7dSNirvedh Meshram OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow)); 8333f136f7dSNirvedh Meshram // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming 8343f136f7dSNirvedh Meshram // length >= 0. 8353f136f7dSNirvedh Meshram if (hasLowPad) 8363f136f7dSNirvedh Meshram newLength = max(newLength, zero); 837d3fa067eSMahesh Ravishankar newLengths.push_back(newLength); 838fd0c6f53SAlexander Belyaev 839fd0c6f53SAlexander Belyaev // Check if newLength is zero. In that case, no SubTensorOp should be 840fd0c6f53SAlexander Belyaev // executed. 841d3fa067eSMahesh Ravishankar if (isConstantIntValue(newLength, 0)) { 842d3fa067eSMahesh Ravishankar hasZeroLen = true; 843d3fa067eSMahesh Ravishankar } else if (!hasZeroLen) { 844d3fa067eSMahesh Ravishankar Value check = b.create<arith::CmpIOp>( 845d3fa067eSMahesh Ravishankar loc, arith::CmpIPredicate::eq, 846d3fa067eSMahesh Ravishankar getValueOrCreateConstantIndexOp(b, loc, newLength), 847d3fa067eSMahesh Ravishankar getValueOrCreateConstantIndexOp(b, loc, zero)); 848fd0c6f53SAlexander Belyaev dynHasZeroLenCond = 849fd0c6f53SAlexander Belyaev dynHasZeroLenCond 850fd0c6f53SAlexander Belyaev ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) 851fd0c6f53SAlexander Belyaev : check; 852fd0c6f53SAlexander Belyaev } 853fd0c6f53SAlexander Belyaev 854fd0c6f53SAlexander Belyaev // The amount of high padding is simply the number of elements remaining, 855fd0c6f53SAlexander Belyaev // so that the result has the same length as the original ExtractSliceOp. 856fd0c6f53SAlexander Belyaev // As an optimization, if the original high padding is zero, then the new 857fd0c6f53SAlexander Belyaev // high padding must also be zero. 858d3fa067eSMahesh Ravishankar OpFoldResult newHigh = 859d3fa067eSMahesh Ravishankar hasHighPad ? sub(sub(length, newLength), newLow) : zero; 860d3fa067eSMahesh Ravishankar newHighs.push_back(newHigh); 861fd0c6f53SAlexander Belyaev 862fd0c6f53SAlexander Belyaev // Only unit stride supported. 863fd0c6f53SAlexander Belyaev newStrides.push_back(b.getIndexAttr(1)); 864fd0c6f53SAlexander Belyaev } 865fd0c6f53SAlexander Belyaev 866fd0c6f53SAlexander Belyaev // The shape of the result can be obtained from the sizes passed in. 867fd0c6f53SAlexander Belyaev SmallVector<Value> dynDims; 868fd0c6f53SAlexander Belyaev SmallVector<int64_t> shape; 869ded75a28SAliia Khasanova dispatchIndexOpFoldResults(sizes, dynDims, shape); 870fd0c6f53SAlexander Belyaev RankedTensorType resultType = 871fd0c6f53SAlexander Belyaev RankedTensorType::get(shape, padOp.getResultType().getElementType()); 872fd0c6f53SAlexander Belyaev 873fd0c6f53SAlexander Belyaev // Insert cast to ensure that types match. (May be folded away.) 874809e3d8cSMahesh Ravishankar auto castResult = [&](Value val) -> Value { 875d3fa067eSMahesh Ravishankar if (resultType == val.getType()) 876809e3d8cSMahesh Ravishankar return val; 8770edb4127SLei Zhang return b.create<tensor::CastOp>(loc, resultType, val); 878fd0c6f53SAlexander Belyaev }; 879fd0c6f53SAlexander Belyaev 880fd0c6f53SAlexander Belyaev // In cases where the original data source is unused: Emit a GenerateOp and 881fd0c6f53SAlexander Belyaev // do not generate a SliceOp. (The result shape of the SliceOp would 882fd0c6f53SAlexander Belyaev // have a dimension of size 0, the semantics of which is unclear.) 883fd0c6f53SAlexander Belyaev auto createGenerateOp = [&]() { 884fd0c6f53SAlexander Belyaev // Create GenerateOp. 885fd0c6f53SAlexander Belyaev auto generateOp = b.create<tensor::GenerateOp>( 886fd0c6f53SAlexander Belyaev loc, resultType, dynDims, 887fd0c6f53SAlexander Belyaev [&](OpBuilder &builder, Location gLoc, ValueRange indices) { 888fd0c6f53SAlexander Belyaev builder.create<tensor::YieldOp>(gLoc, padValue); 889fd0c6f53SAlexander Belyaev }); 890809e3d8cSMahesh Ravishankar return generateOp; 891fd0c6f53SAlexander Belyaev }; 892fd0c6f53SAlexander Belyaev 893fd0c6f53SAlexander Belyaev // Emit a SliceOp and a PadOp. Should not be used in cases where 894fd0c6f53SAlexander Belyaev // the result shape of the new SliceOp has a zero dimension. 8950edb4127SLei Zhang auto createPadOfExtractSlice = [&]() { 8960edb4127SLei Zhang // Create pad(extract_slice(x)). 897d5f0969cSMaheshRavishankar auto newSliceOp = b.create<tensor::ExtractSliceOp>( 89804235d07SJacques Pienaar loc, padOp.getSource(), newOffsets, newLengths, newStrides); 899b0674405SMahesh Ravishankar auto newPadOp = b.create<PadOp>( 900b0674405SMahesh Ravishankar loc, Type(), newSliceOp, newLows, newHighs, 901b0674405SMahesh Ravishankar /*nofold=*/padOp.getNofold(), 902b0674405SMahesh Ravishankar getPrunedAttributeList(padOp, PadOp::getAttributeNames())); 903fd0c6f53SAlexander Belyaev 904fd0c6f53SAlexander Belyaev // Copy region to new PadOp. 9054d67b278SJeff Niu IRMapping bvm; 90604235d07SJacques Pienaar padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); 907fd0c6f53SAlexander Belyaev 908fd0c6f53SAlexander Belyaev // Cast result and return. 909d5f0969cSMaheshRavishankar return std::make_tuple(newPadOp, newSliceOp); 910fd0c6f53SAlexander Belyaev }; 911fd0c6f53SAlexander Belyaev 9120edb4127SLei Zhang // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that 9130edb4127SLei Zhang // the original data source x is not used. 914809e3d8cSMahesh Ravishankar if (hasZeroLen) { 915809e3d8cSMahesh Ravishankar Operation *generateOp = createGenerateOp(); 916d5f0969cSMaheshRavishankar return TilingResult{{generateOp}, 917d5f0969cSMaheshRavishankar {castResult(generateOp->getResult(0))}, 918d5f0969cSMaheshRavishankar /*generatedSlices=*/{}}; 919809e3d8cSMahesh Ravishankar } 920fd0c6f53SAlexander Belyaev 921fd0c6f53SAlexander Belyaev // If there are dynamic dimensions: Generate an scf.if check to avoid 922fd0c6f53SAlexander Belyaev // creating SliceOps with result dimensions of size 0 at runtime. 9230edb4127SLei Zhang if (generateZeroSliceGuard && dynHasZeroLenCond) { 924809e3d8cSMahesh Ravishankar Operation *thenOp; 925809e3d8cSMahesh Ravishankar Operation *elseOp; 926d5f0969cSMaheshRavishankar Operation *sliceOp; 927fd0c6f53SAlexander Belyaev auto result = b.create<scf::IfOp>( 9281125c5c0SFrederik Gossen loc, dynHasZeroLenCond, 929fd0c6f53SAlexander Belyaev /*thenBuilder=*/ 930fd0c6f53SAlexander Belyaev [&](OpBuilder &b, Location loc) { 931809e3d8cSMahesh Ravishankar thenOp = createGenerateOp(); 932809e3d8cSMahesh Ravishankar b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); 933fd0c6f53SAlexander Belyaev }, 934fd0c6f53SAlexander Belyaev /*elseBuilder=*/ 935fd0c6f53SAlexander Belyaev [&](OpBuilder &b, Location loc) { 936d5f0969cSMaheshRavishankar std::tie(elseOp, sliceOp) = createPadOfExtractSlice(); 937809e3d8cSMahesh Ravishankar b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); 938fd0c6f53SAlexander Belyaev }); 939d5f0969cSMaheshRavishankar return TilingResult{ 940d5f0969cSMaheshRavishankar {elseOp}, SmallVector<Value>(result->getResults()), {sliceOp}}; 941fd0c6f53SAlexander Belyaev } 942809e3d8cSMahesh Ravishankar 943d5f0969cSMaheshRavishankar auto [newPadOp, sliceOp] = createPadOfExtractSlice(); 944d5f0969cSMaheshRavishankar return TilingResult{ 945d5f0969cSMaheshRavishankar {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}}; 946fd0c6f53SAlexander Belyaev } 947fd0c6f53SAlexander Belyaev 948a235562cSMahesh Ravishankar void mlir::tensor::registerTilingInterfaceExternalModels( 949fd0c6f53SAlexander Belyaev DialectRegistry ®istry) { 95077eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 95177eee579SRiver Riddle tensor::PadOp::attachInterface<PadOpTiling>(*ctx); 9520d03ba62SHanhan Wang tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 95383396d85SHanhan Wang tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 95477eee579SRiver Riddle }); 955fd0c6f53SAlexander Belyaev } 956adcb9888SHanhan Wang 957adcb9888SHanhan Wang void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( 958adcb9888SHanhan Wang DialectRegistry ®istry) { 959adcb9888SHanhan Wang registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { 960adcb9888SHanhan Wang tensor::PackOp::attachInterface<PackOpTiling>(*ctx); 961adcb9888SHanhan Wang tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); 962adcb9888SHanhan Wang }); 963adcb9888SHanhan Wang } 964