xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp (revision 6596b0dde85888117bd230f64906a8c4de968b87)
1 //===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites that replace slices of reshape results with
10 // aggregated slices of the reshape source.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "llvm/ADT/STLExtras.h"
24 
25 using namespace mlir;
26 using namespace mlir::affine;
27 using namespace mlir::tensor;
28 
29 /// A tuple that represents (dimension number, dimension value).
30 using DimAndIndex = std::tuple<unsigned, Value>;
31 
32 /// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
33 /// slice described by `sliceParams` into the input index space.
invertSliceIndexing(OpBuilder & b,Location loc,ArrayRef<Range> sliceParams,const DimAndIndex & dimAndIndex)34 static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
35                                        ArrayRef<Range> sliceParams,
36                                        const DimAndIndex &dimAndIndex) {
37   AffineExpr d0, s0, s1;
38   bindDims(b.getContext(), d0);
39   bindSymbols(b.getContext(), s0, s1);
40   auto [dim, indexValue] = dimAndIndex;
41   assert(dim < sliceParams.size() && "slice should be non rank-reducing");
42   return std::make_pair(
43       dim, affine::makeComposedAffineApply(
44                b, loc, s0 + d0 * s1,
45                {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
46 }
47 
48 /// Transform `dimAndIndex` from the result tensor index space of a
49 /// CollapseShapeOp to the source tensor index space.
invertCollapseShapeIndexing(OpBuilder & b,Location loc,ArrayRef<ReassociationIndices> reassociation,ArrayRef<OpFoldResult> reshapeSourceShape,const DimAndIndex & dimAndIndex)50 static ValueRange invertCollapseShapeIndexing(
51     OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
52     ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
53   const auto &[dim, indexValue] = dimAndIndex;
54   SmallVector<OpFoldResult> basis;
55   for (int64_t i : reassociation[dim])
56     basis.push_back(reshapeSourceShape[i]);
57   auto delinearized =
58       b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
59   return delinearized->getResults();
60 }
61 
62 FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder & b,tensor::CollapseShapeOp collapseOp,tensor::ExtractSliceOp extractOp)63 tensor::ExtractSliceFromCollapseHelper::create(
64     OpBuilder &b, tensor::CollapseShapeOp collapseOp,
65     tensor::ExtractSliceOp extractOp) {
66   if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
67       collapseOp)
68     return failure();
69   SmallVector<Range> ranges;
70   ranges.reserve(extractOp.getSourceType().getRank());
71   for (const auto &[o, s, st] :
72        llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
73                  extractOp.getMixedStrides())) {
74     ranges.push_back({o, s, st});
75   }
76   return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
77 }
78 
79 FailureOr<ExtractSliceFromCollapseHelper>
create(OpBuilder & b,tensor::CollapseShapeOp op,ArrayRef<Range> sliceParams)80 tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
81                                                tensor::CollapseShapeOp op,
82                                                ArrayRef<Range> sliceParams) {
83   // Don't perform this pattern if the collapse op can be simplified by
84   // a rank-reducing extract slice.
85   if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
86           op.getSrcType(), op.getReassociationIndices())))
87     return failure();
88 
89   // Materialize the output shape of the collapse_shape operation. This will
90   // create IR describing the output shape in terms of the input shape.
91   ReifiedRankedShapedTypeDims reifiedShapes;
92   if (failed(reifyResultShapes(b, op, reifiedShapes)))
93     return failure();
94   SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
95   SmallVector<ReassociationIndices> reassociationIndices =
96       op.getReassociationIndices();
97 
98   // Determine which of the CollapseShapeOp's result dimensions are sliced
99   // and/or linearized.
100   llvm::SmallBitVector linearizedDimensions =
101       getLinearizedDimensions(reassociationIndices);
102   llvm::SmallBitVector slicedDimensions =
103       getSlicedDimensions(collapseShapeOutputShape, sliceParams);
104 
105   auto collapseShapeInputShape =
106       tensor::getMixedSizes(b, op.getLoc(), op.getSrc());
107 
108   SmallVector<Value> tileSizes;
109   for (unsigned i = 0; i < sliceParams.size(); i++) {
110     if (slicedDimensions[i] && linearizedDimensions[i])
111       tileSizes.push_back(
112           getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
113   }
114 
115   return ExtractSliceFromCollapseHelper(
116       op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
117       linearizedDimensions, slicedDimensions, tileSizes);
118 }
119 
120 std::pair<Value, SmallVector<Range>>
emitLoopNestBody(OpBuilder & builder,Location loc,ValueRange tileInductionVars)121 tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
122     OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
123   // Create the helper class for forming the slice parameters.
124   const SmallVector<ReassociationIndices> reassociationIndices =
125       collapseShapeOp.getReassociationIndices();
126   SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
127                                  collapseShapeOutputShape, sliceParams);
128 
129   // Get the indices of the tiled dims (linearized by the collapse_shape
130   // and sliced by the extract_slice) invert the index spaces
131   // transformations.
132   SmallVector<ValueRange> multiIndices;
133   unsigned loopIdx = 0;
134   for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
135     if (linearizedDimensions[i] && slicedDimensions[i]) {
136       DimAndIndex tb =
137           invertSliceIndexing(builder, loc, sliceParams,
138                               std::make_tuple(i, tileInductionVars[loopIdx++]));
139       multiIndices.push_back(invertCollapseShapeIndexing(
140           builder, loc, reassociationIndices, collapseShapeInputShape, tb));
141     }
142   }
143 
144   SmallVector<Range> extractParams =
145       helper.getExtractSliceParams(builder.getContext(), multiIndices);
146 
147   Value subTileResult = builder.create<tensor::ExtractSliceOp>(
148       loc, collapseShapeOp.getSrc(), extractParams);
149 
150   SmallVector<Range> insertParams =
151       helper.getInsertSliceParams(builder.getContext(), tileInductionVars);
152 
153   // Collapse the dimensions of the source slice back down.
154   Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
155       loc, subTileResult, reassociationIndices);
156   return std::make_pair(collapsedResult, insertParams);
157 }
158 
159 FailureOr<Operation *>
simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op,RewriterBase & rewriter)160 tensor::simplifyCollapseShapeWithRankReducingExtractSlice(
161     tensor::CollapseShapeOp op, RewriterBase &rewriter) {
162   SmallVector<ReassociationIndices> reassociationIndices =
163       op.getReassociationIndices();
164   RankedTensorType sourceType = op.getSrcType();
165   FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
166       getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
167                                                         reassociationIndices);
168   if (failed(info))
169     return failure();
170 
171   // Create the rank-reducing extract slice op.
172   auto zero = rewriter.getIndexAttr(0);
173   auto one = rewriter.getIndexAttr(1);
174   SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
175   SmallVector<OpFoldResult> sizes =
176       tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc());
177   SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
178   auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
179       op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
180 
181   if (!info->newReassociationIndices.has_value()) {
182     rewriter.replaceOp(op, sliceOp.getResult());
183     return sliceOp.getOperation();
184   }
185 
186   return rewriter
187       .replaceOpWithNewOp<tensor::CollapseShapeOp>(
188           op, sliceOp.getResult(), *info->newReassociationIndices)
189       .getOperation();
190 }
191