xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp (revision 6596b0dde85888117bd230f64906a8c4de968b87)
12d3b54feSLei Zhang //===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===//
22d3b54feSLei Zhang //
32d3b54feSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42d3b54feSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
52d3b54feSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62d3b54feSLei Zhang //
72d3b54feSLei Zhang //===----------------------------------------------------------------------===//
82d3b54feSLei Zhang //
92d3b54feSLei Zhang // This file implements rewrites that replace slices of reshape results with
102d3b54feSLei Zhang // aggregated slices of the reshape source.
112d3b54feSLei Zhang //
122d3b54feSLei Zhang //===----------------------------------------------------------------------===//
132d3b54feSLei Zhang #include "mlir/Dialect/Affine/IR/AffineOps.h"
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
162d3b54feSLei Zhang #include "mlir/Dialect/Tensor/IR/Tensor.h"
172d3b54feSLei Zhang #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
182d3b54feSLei Zhang #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
192d3b54feSLei Zhang #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
202d3b54feSLei Zhang #include "mlir/Dialect/Utils/StaticValueUtils.h"
212d3b54feSLei Zhang #include "mlir/IR/BuiltinTypes.h"
222d3b54feSLei Zhang #include "mlir/IR/OpDefinition.h"
232d3b54feSLei Zhang #include "llvm/ADT/STLExtras.h"
242d3b54feSLei Zhang 
252d3b54feSLei Zhang using namespace mlir;
264c48f016SMatthias Springer using namespace mlir::affine;
272d3b54feSLei Zhang using namespace mlir::tensor;
282d3b54feSLei Zhang 
292d3b54feSLei Zhang /// A tuple that represents (dimension number, dimension value).
302d3b54feSLei Zhang using DimAndIndex = std::tuple<unsigned, Value>;
312d3b54feSLei Zhang 
322d3b54feSLei Zhang /// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
332d3b54feSLei Zhang /// slice described by `sliceParams` into the input index space.
invertSliceIndexing(OpBuilder & b,Location loc,ArrayRef<Range> sliceParams,const DimAndIndex & dimAndIndex)342d3b54feSLei Zhang static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
352d3b54feSLei Zhang                                        ArrayRef<Range> sliceParams,
362d3b54feSLei Zhang                                        const DimAndIndex &dimAndIndex) {
372d3b54feSLei Zhang   AffineExpr d0, s0, s1;
382d3b54feSLei Zhang   bindDims(b.getContext(), d0);
392d3b54feSLei Zhang   bindSymbols(b.getContext(), s0, s1);
402d3b54feSLei Zhang   auto [dim, indexValue] = dimAndIndex;
412d3b54feSLei Zhang   assert(dim < sliceParams.size() && "slice should be non rank-reducing");
422d3b54feSLei Zhang   return std::make_pair(
43efc290ceSMatthias Springer       dim, affine::makeComposedAffineApply(
442d3b54feSLei Zhang                b, loc, s0 + d0 * s1,
45efc290ceSMatthias Springer                {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
462d3b54feSLei Zhang }
472d3b54feSLei Zhang 
482d3b54feSLei Zhang /// Transform `dimAndIndex` from the result tensor index space of a
492d3b54feSLei Zhang /// CollapseShapeOp to the source tensor index space.
invertCollapseShapeIndexing(OpBuilder & b,Location loc,ArrayRef<ReassociationIndices> reassociation,ArrayRef<OpFoldResult> reshapeSourceShape,const DimAndIndex & dimAndIndex)502d3b54feSLei Zhang static ValueRange invertCollapseShapeIndexing(
512d3b54feSLei Zhang     OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
522d3b54feSLei Zhang     ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
532d3b54feSLei Zhang   const auto &[dim, indexValue] = dimAndIndex;
542d3b54feSLei Zhang   SmallVector<OpFoldResult> basis;
552d3b54feSLei Zhang   for (int64_t i : reassociation[dim])
562d3b54feSLei Zhang     basis.push_back(reshapeSourceShape[i]);
572d3b54feSLei Zhang   auto delinearized =
582d3b54feSLei Zhang       b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
592d3b54feSLei Zhang   return delinearized->getResults();
602d3b54feSLei Zhang }
612d3b54feSLei Zhang 
622d3b54feSLei Zhang FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder & b,tensor::CollapseShapeOp collapseOp,tensor::ExtractSliceOp extractOp)632d3b54feSLei Zhang tensor::ExtractSliceFromCollapseHelper::create(
642d3b54feSLei Zhang     OpBuilder &b, tensor::CollapseShapeOp collapseOp,
652d3b54feSLei Zhang     tensor::ExtractSliceOp extractOp) {
662d3b54feSLei Zhang   if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
672d3b54feSLei Zhang       collapseOp)
682d3b54feSLei Zhang     return failure();
692d3b54feSLei Zhang   SmallVector<Range> ranges;
702d3b54feSLei Zhang   ranges.reserve(extractOp.getSourceType().getRank());
712d3b54feSLei Zhang   for (const auto &[o, s, st] :
722d3b54feSLei Zhang        llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
732d3b54feSLei Zhang                  extractOp.getMixedStrides())) {
742d3b54feSLei Zhang     ranges.push_back({o, s, st});
752d3b54feSLei Zhang   }
762d3b54feSLei Zhang   return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
772d3b54feSLei Zhang }
782d3b54feSLei Zhang 
792d3b54feSLei Zhang FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder & b,tensor::CollapseShapeOp op,ArrayRef<Range> sliceParams)802d3b54feSLei Zhang tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
812d3b54feSLei Zhang                                                tensor::CollapseShapeOp op,
822d3b54feSLei Zhang                                                ArrayRef<Range> sliceParams) {
83446981bdSChristopher Bate   // Don't perform this pattern if the collapse op can be simplified by
84446981bdSChristopher Bate   // a rank-reducing extract slice.
85446981bdSChristopher Bate   if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
86446981bdSChristopher Bate           op.getSrcType(), op.getReassociationIndices())))
87446981bdSChristopher Bate     return failure();
882d3b54feSLei Zhang 
892d3b54feSLei Zhang   // Materialize the output shape of the collapse_shape operation. This will
902d3b54feSLei Zhang   // create IR describing the output shape in terms of the input shape.
912d3b54feSLei Zhang   ReifiedRankedShapedTypeDims reifiedShapes;
92758329dcSMatthias Springer   if (failed(reifyResultShapes(b, op, reifiedShapes)))
932d3b54feSLei Zhang     return failure();
942a5b13e7SMatthias Springer   SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
952d3b54feSLei Zhang   SmallVector<ReassociationIndices> reassociationIndices =
962d3b54feSLei Zhang       op.getReassociationIndices();
972d3b54feSLei Zhang 
982d3b54feSLei Zhang   // Determine which of the CollapseShapeOp's result dimensions are sliced
992d3b54feSLei Zhang   // and/or linearized.
1002d3b54feSLei Zhang   llvm::SmallBitVector linearizedDimensions =
1012d3b54feSLei Zhang       getLinearizedDimensions(reassociationIndices);
1022d3b54feSLei Zhang   llvm::SmallBitVector slicedDimensions =
1032d3b54feSLei Zhang       getSlicedDimensions(collapseShapeOutputShape, sliceParams);
1042d3b54feSLei Zhang 
105*6596b0ddSMatthias Springer   auto collapseShapeInputShape =
106*6596b0ddSMatthias Springer       tensor::getMixedSizes(b, op.getLoc(), op.getSrc());
1072d3b54feSLei Zhang 
1082d3b54feSLei Zhang   SmallVector<Value> tileSizes;
1092d3b54feSLei Zhang   for (unsigned i = 0; i < sliceParams.size(); i++) {
1102d3b54feSLei Zhang     if (slicedDimensions[i] && linearizedDimensions[i])
1112d3b54feSLei Zhang       tileSizes.push_back(
1122d3b54feSLei Zhang           getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
1132d3b54feSLei Zhang   }
1142d3b54feSLei Zhang 
1152d3b54feSLei Zhang   return ExtractSliceFromCollapseHelper(
1162d3b54feSLei Zhang       op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
1172d3b54feSLei Zhang       linearizedDimensions, slicedDimensions, tileSizes);
1182d3b54feSLei Zhang }
1192d3b54feSLei Zhang 
1202d3b54feSLei Zhang std::pair<Value, SmallVector<Range>>
emitLoopNestBody(OpBuilder & builder,Location loc,ValueRange tileInductionVars)1212d3b54feSLei Zhang tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
1222d3b54feSLei Zhang     OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
1232d3b54feSLei Zhang   // Create the helper class for forming the slice parameters.
1242d3b54feSLei Zhang   const SmallVector<ReassociationIndices> reassociationIndices =
1252d3b54feSLei Zhang       collapseShapeOp.getReassociationIndices();
1262d3b54feSLei Zhang   SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
1272d3b54feSLei Zhang                                  collapseShapeOutputShape, sliceParams);
1282d3b54feSLei Zhang 
1292d3b54feSLei Zhang   // Get the indices of the tiled dims (linearized by the collapse_shape
1302d3b54feSLei Zhang   // and sliced by the extract_slice) invert the index spaces
1312d3b54feSLei Zhang   // transformations.
1322d3b54feSLei Zhang   SmallVector<ValueRange> multiIndices;
1332d3b54feSLei Zhang   unsigned loopIdx = 0;
1342d3b54feSLei Zhang   for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
1352d3b54feSLei Zhang     if (linearizedDimensions[i] && slicedDimensions[i]) {
1362d3b54feSLei Zhang       DimAndIndex tb =
1372d3b54feSLei Zhang           invertSliceIndexing(builder, loc, sliceParams,
1382d3b54feSLei Zhang                               std::make_tuple(i, tileInductionVars[loopIdx++]));
1392d3b54feSLei Zhang       multiIndices.push_back(invertCollapseShapeIndexing(
1402d3b54feSLei Zhang           builder, loc, reassociationIndices, collapseShapeInputShape, tb));
1412d3b54feSLei Zhang     }
1422d3b54feSLei Zhang   }
1432d3b54feSLei Zhang 
1442d3b54feSLei Zhang   SmallVector<Range> extractParams =
1452d3b54feSLei Zhang       helper.getExtractSliceParams(builder.getContext(), multiIndices);
1462d3b54feSLei Zhang 
1472d3b54feSLei Zhang   Value subTileResult = builder.create<tensor::ExtractSliceOp>(
1482d3b54feSLei Zhang       loc, collapseShapeOp.getSrc(), extractParams);
1492d3b54feSLei Zhang 
1502d3b54feSLei Zhang   SmallVector<Range> insertParams =
1512d3b54feSLei Zhang       helper.getInsertSliceParams(builder.getContext(), tileInductionVars);
1522d3b54feSLei Zhang 
1532d3b54feSLei Zhang   // Collapse the dimensions of the source slice back down.
1542d3b54feSLei Zhang   Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
1552d3b54feSLei Zhang       loc, subTileResult, reassociationIndices);
1562d3b54feSLei Zhang   return std::make_pair(collapsedResult, insertParams);
1572d3b54feSLei Zhang }
158446981bdSChristopher Bate 
159446981bdSChristopher Bate FailureOr<Operation *>
simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op,RewriterBase & rewriter)160446981bdSChristopher Bate tensor::simplifyCollapseShapeWithRankReducingExtractSlice(
161446981bdSChristopher Bate     tensor::CollapseShapeOp op, RewriterBase &rewriter) {
162446981bdSChristopher Bate   SmallVector<ReassociationIndices> reassociationIndices =
163446981bdSChristopher Bate       op.getReassociationIndices();
164446981bdSChristopher Bate   RankedTensorType sourceType = op.getSrcType();
165446981bdSChristopher Bate   FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
166446981bdSChristopher Bate       getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
167446981bdSChristopher Bate                                                         reassociationIndices);
168446981bdSChristopher Bate   if (failed(info))
169446981bdSChristopher Bate     return failure();
170446981bdSChristopher Bate 
171446981bdSChristopher Bate   // Create the rank-reducing extract slice op.
172446981bdSChristopher Bate   auto zero = rewriter.getIndexAttr(0);
173446981bdSChristopher Bate   auto one = rewriter.getIndexAttr(1);
174446981bdSChristopher Bate   SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
175446981bdSChristopher Bate   SmallVector<OpFoldResult> sizes =
176*6596b0ddSMatthias Springer       tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc());
177446981bdSChristopher Bate   SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
178446981bdSChristopher Bate   auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
179446981bdSChristopher Bate       op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
180446981bdSChristopher Bate 
181446981bdSChristopher Bate   if (!info->newReassociationIndices.has_value()) {
182446981bdSChristopher Bate     rewriter.replaceOp(op, sliceOp.getResult());
183446981bdSChristopher Bate     return sliceOp.getOperation();
184446981bdSChristopher Bate   }
185446981bdSChristopher Bate 
186446981bdSChristopher Bate   return rewriter
187446981bdSChristopher Bate       .replaceOpWithNewOp<tensor::CollapseShapeOp>(
188cbb09813SFangrui Song           op, sliceOp.getResult(), *info->newReassociationIndices)
189446981bdSChristopher Bate       .getOperation();
190446981bdSChristopher Bate }
191