12d3b54feSLei Zhang //===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===//
22d3b54feSLei Zhang //
32d3b54feSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42d3b54feSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
52d3b54feSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62d3b54feSLei Zhang //
72d3b54feSLei Zhang //===----------------------------------------------------------------------===//
82d3b54feSLei Zhang //
92d3b54feSLei Zhang // This file implements rewrites that replace slices of reshape results with
102d3b54feSLei Zhang // aggregated slices of the reshape source.
112d3b54feSLei Zhang //
122d3b54feSLei Zhang //===----------------------------------------------------------------------===//
132d3b54feSLei Zhang #include "mlir/Dialect/Affine/IR/AffineOps.h"
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
162d3b54feSLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h"
172d3b54feSLei Zhang #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
182d3b54feSLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
192d3b54feSLei Zhang #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
202d3b54feSLei Zhang #include "mlir/Dialect/Utils/StaticValueUtils.h"
212d3b54feSLei Zhang #include "mlir/IR/BuiltinTypes.h"
222d3b54feSLei Zhang #include "mlir/IR/OpDefinition.h"
232d3b54feSLei Zhang #include "llvm/ADT/STLExtras.h"
242d3b54feSLei Zhang
252d3b54feSLei Zhang using namespace mlir;
264c48f016SMatthias Springer using namespace mlir::affine;
272d3b54feSLei Zhang using namespace mlir::tensor;
282d3b54feSLei Zhang
292d3b54feSLei Zhang /// A tuple that represents (dimension number, dimension value).
302d3b54feSLei Zhang using DimAndIndex = std::tuple<unsigned, Value>;
312d3b54feSLei Zhang
322d3b54feSLei Zhang /// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
332d3b54feSLei Zhang /// slice described by `sliceParams` into the input index space.
invertSliceIndexing(OpBuilder & b,Location loc,ArrayRef<Range> sliceParams,const DimAndIndex & dimAndIndex)342d3b54feSLei Zhang static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
352d3b54feSLei Zhang ArrayRef<Range> sliceParams,
362d3b54feSLei Zhang const DimAndIndex &dimAndIndex) {
372d3b54feSLei Zhang AffineExpr d0, s0, s1;
382d3b54feSLei Zhang bindDims(b.getContext(), d0);
392d3b54feSLei Zhang bindSymbols(b.getContext(), s0, s1);
402d3b54feSLei Zhang auto [dim, indexValue] = dimAndIndex;
412d3b54feSLei Zhang assert(dim < sliceParams.size() && "slice should be non rank-reducing");
422d3b54feSLei Zhang return std::make_pair(
43efc290ceSMatthias Springer dim, affine::makeComposedAffineApply(
442d3b54feSLei Zhang b, loc, s0 + d0 * s1,
45efc290ceSMatthias Springer {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
462d3b54feSLei Zhang }
472d3b54feSLei Zhang
482d3b54feSLei Zhang /// Transform `dimAndIndex` from the result tensor index space of a
492d3b54feSLei Zhang /// CollapseShapeOp to the source tensor index space.
invertCollapseShapeIndexing(OpBuilder & b,Location loc,ArrayRef<ReassociationIndices> reassociation,ArrayRef<OpFoldResult> reshapeSourceShape,const DimAndIndex & dimAndIndex)502d3b54feSLei Zhang static ValueRange invertCollapseShapeIndexing(
512d3b54feSLei Zhang OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
522d3b54feSLei Zhang ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
532d3b54feSLei Zhang const auto &[dim, indexValue] = dimAndIndex;
542d3b54feSLei Zhang SmallVector<OpFoldResult> basis;
552d3b54feSLei Zhang for (int64_t i : reassociation[dim])
562d3b54feSLei Zhang basis.push_back(reshapeSourceShape[i]);
572d3b54feSLei Zhang auto delinearized =
582d3b54feSLei Zhang b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
592d3b54feSLei Zhang return delinearized->getResults();
602d3b54feSLei Zhang }
612d3b54feSLei Zhang
622d3b54feSLei Zhang FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder & b,tensor::CollapseShapeOp collapseOp,tensor::ExtractSliceOp extractOp)632d3b54feSLei Zhang tensor::ExtractSliceFromCollapseHelper::create(
642d3b54feSLei Zhang OpBuilder &b, tensor::CollapseShapeOp collapseOp,
652d3b54feSLei Zhang tensor::ExtractSliceOp extractOp) {
662d3b54feSLei Zhang if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
672d3b54feSLei Zhang collapseOp)
682d3b54feSLei Zhang return failure();
692d3b54feSLei Zhang SmallVector<Range> ranges;
702d3b54feSLei Zhang ranges.reserve(extractOp.getSourceType().getRank());
712d3b54feSLei Zhang for (const auto &[o, s, st] :
722d3b54feSLei Zhang llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
732d3b54feSLei Zhang extractOp.getMixedStrides())) {
742d3b54feSLei Zhang ranges.push_back({o, s, st});
752d3b54feSLei Zhang }
762d3b54feSLei Zhang return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
772d3b54feSLei Zhang }
782d3b54feSLei Zhang
792d3b54feSLei Zhang FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder & b,tensor::CollapseShapeOp op,ArrayRef<Range> sliceParams)802d3b54feSLei Zhang tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
812d3b54feSLei Zhang tensor::CollapseShapeOp op,
822d3b54feSLei Zhang ArrayRef<Range> sliceParams) {
83446981bdSChristopher Bate // Don't perform this pattern if the collapse op can be simplified by
84446981bdSChristopher Bate // a rank-reducing extract slice.
85446981bdSChristopher Bate if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
86446981bdSChristopher Bate op.getSrcType(), op.getReassociationIndices())))
87446981bdSChristopher Bate return failure();
882d3b54feSLei Zhang
892d3b54feSLei Zhang // Materialize the output shape of the collapse_shape operation. This will
902d3b54feSLei Zhang // create IR describing the output shape in terms of the input shape.
912d3b54feSLei Zhang ReifiedRankedShapedTypeDims reifiedShapes;
92758329dcSMatthias Springer if (failed(reifyResultShapes(b, op, reifiedShapes)))
932d3b54feSLei Zhang return failure();
942a5b13e7SMatthias Springer SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
952d3b54feSLei Zhang SmallVector<ReassociationIndices> reassociationIndices =
962d3b54feSLei Zhang op.getReassociationIndices();
972d3b54feSLei Zhang
982d3b54feSLei Zhang // Determine which of the CollapseShapeOp's result dimensions are sliced
992d3b54feSLei Zhang // and/or linearized.
1002d3b54feSLei Zhang llvm::SmallBitVector linearizedDimensions =
1012d3b54feSLei Zhang getLinearizedDimensions(reassociationIndices);
1022d3b54feSLei Zhang llvm::SmallBitVector slicedDimensions =
1032d3b54feSLei Zhang getSlicedDimensions(collapseShapeOutputShape, sliceParams);
1042d3b54feSLei Zhang
105*6596b0ddSMatthias Springer auto collapseShapeInputShape =
106*6596b0ddSMatthias Springer tensor::getMixedSizes(b, op.getLoc(), op.getSrc());
1072d3b54feSLei Zhang
1082d3b54feSLei Zhang SmallVector<Value> tileSizes;
1092d3b54feSLei Zhang for (unsigned i = 0; i < sliceParams.size(); i++) {
1102d3b54feSLei Zhang if (slicedDimensions[i] && linearizedDimensions[i])
1112d3b54feSLei Zhang tileSizes.push_back(
1122d3b54feSLei Zhang getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
1132d3b54feSLei Zhang }
1142d3b54feSLei Zhang
1152d3b54feSLei Zhang return ExtractSliceFromCollapseHelper(
1162d3b54feSLei Zhang op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
1172d3b54feSLei Zhang linearizedDimensions, slicedDimensions, tileSizes);
1182d3b54feSLei Zhang }
1192d3b54feSLei Zhang
1202d3b54feSLei Zhang std::pair<Value, SmallVector<Range>>
emitLoopNestBody(OpBuilder & builder,Location loc,ValueRange tileInductionVars)1212d3b54feSLei Zhang tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
1222d3b54feSLei Zhang OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
1232d3b54feSLei Zhang // Create the helper class for forming the slice parameters.
1242d3b54feSLei Zhang const SmallVector<ReassociationIndices> reassociationIndices =
1252d3b54feSLei Zhang collapseShapeOp.getReassociationIndices();
1262d3b54feSLei Zhang SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
1272d3b54feSLei Zhang collapseShapeOutputShape, sliceParams);
1282d3b54feSLei Zhang
1292d3b54feSLei Zhang // Get the indices of the tiled dims (linearized by the collapse_shape
1302d3b54feSLei Zhang // and sliced by the extract_slice) invert the index spaces
1312d3b54feSLei Zhang // transformations.
1322d3b54feSLei Zhang SmallVector<ValueRange> multiIndices;
1332d3b54feSLei Zhang unsigned loopIdx = 0;
1342d3b54feSLei Zhang for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
1352d3b54feSLei Zhang if (linearizedDimensions[i] && slicedDimensions[i]) {
1362d3b54feSLei Zhang DimAndIndex tb =
1372d3b54feSLei Zhang invertSliceIndexing(builder, loc, sliceParams,
1382d3b54feSLei Zhang std::make_tuple(i, tileInductionVars[loopIdx++]));
1392d3b54feSLei Zhang multiIndices.push_back(invertCollapseShapeIndexing(
1402d3b54feSLei Zhang builder, loc, reassociationIndices, collapseShapeInputShape, tb));
1412d3b54feSLei Zhang }
1422d3b54feSLei Zhang }
1432d3b54feSLei Zhang
1442d3b54feSLei Zhang SmallVector<Range> extractParams =
1452d3b54feSLei Zhang helper.getExtractSliceParams(builder.getContext(), multiIndices);
1462d3b54feSLei Zhang
1472d3b54feSLei Zhang Value subTileResult = builder.create<tensor::ExtractSliceOp>(
1482d3b54feSLei Zhang loc, collapseShapeOp.getSrc(), extractParams);
1492d3b54feSLei Zhang
1502d3b54feSLei Zhang SmallVector<Range> insertParams =
1512d3b54feSLei Zhang helper.getInsertSliceParams(builder.getContext(), tileInductionVars);
1522d3b54feSLei Zhang
1532d3b54feSLei Zhang // Collapse the dimensions of the source slice back down.
1542d3b54feSLei Zhang Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
1552d3b54feSLei Zhang loc, subTileResult, reassociationIndices);
1562d3b54feSLei Zhang return std::make_pair(collapsedResult, insertParams);
1572d3b54feSLei Zhang }
158446981bdSChristopher Bate
159446981bdSChristopher Bate FailureOr<Operation *>
simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op,RewriterBase & rewriter)160446981bdSChristopher Bate tensor::simplifyCollapseShapeWithRankReducingExtractSlice(
161446981bdSChristopher Bate tensor::CollapseShapeOp op, RewriterBase &rewriter) {
162446981bdSChristopher Bate SmallVector<ReassociationIndices> reassociationIndices =
163446981bdSChristopher Bate op.getReassociationIndices();
164446981bdSChristopher Bate RankedTensorType sourceType = op.getSrcType();
165446981bdSChristopher Bate FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
166446981bdSChristopher Bate getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
167446981bdSChristopher Bate reassociationIndices);
168446981bdSChristopher Bate if (failed(info))
169446981bdSChristopher Bate return failure();
170446981bdSChristopher Bate
171446981bdSChristopher Bate // Create the rank-reducing extract slice op.
172446981bdSChristopher Bate auto zero = rewriter.getIndexAttr(0);
173446981bdSChristopher Bate auto one = rewriter.getIndexAttr(1);
174446981bdSChristopher Bate SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
175446981bdSChristopher Bate SmallVector<OpFoldResult> sizes =
176*6596b0ddSMatthias Springer tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc());
177446981bdSChristopher Bate SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
178446981bdSChristopher Bate auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
179446981bdSChristopher Bate op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
180446981bdSChristopher Bate
181446981bdSChristopher Bate if (!info->newReassociationIndices.has_value()) {
182446981bdSChristopher Bate rewriter.replaceOp(op, sliceOp.getResult());
183446981bdSChristopher Bate return sliceOp.getOperation();
184446981bdSChristopher Bate }
185446981bdSChristopher Bate
186446981bdSChristopher Bate return rewriter
187446981bdSChristopher Bate .replaceOpWithNewOp<tensor::CollapseShapeOp>(
188cbb09813SFangrui Song op, sliceOp.getResult(), *info->newReassociationIndices)
189446981bdSChristopher Bate .getOperation();
190446981bdSChristopher Bate }
191