xref: /llvm-project/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (revision 2f15d7e43e17f72839861bfe3a5c466c325bc04d)
1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13 
14 #include <numeric>
15 #include <optional>
16 
17 using namespace mlir;
18 
19 std::optional<SmallVector<ReassociationIndices>>
20 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
21                                         ShapedType targetType) {
22   if (sourceType.getRank() > targetType.getRank())
23     return getReassociationIndicesForCollapse(sourceType.getShape(),
24                                               targetType.getShape());
25   if (sourceType.getRank() < targetType.getRank())
26     return getReassociationIndicesForCollapse(targetType.getShape(),
27                                               sourceType.getShape());
28   return std::nullopt;
29 }
30 
31 std::optional<SmallVector<ReassociationIndices>>
32 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
33                                          ArrayRef<int64_t> targetShape) {
34   if (sourceShape.size() <= targetShape.size())
35     return std::nullopt;
36   unsigned sourceDim = 0;
37   SmallVector<ReassociationIndices> reassociationMap;
38   reassociationMap.reserve(targetShape.size());
39 
40   ReassociationIndices currIndices;
41   int64_t prodOfCollapsedDims = 1;
42   while (sourceDim < sourceShape.size()) {
43     unsigned targetDim = reassociationMap.size();
44     // If we have mapped all the target dimensions stop and handle the remaining
45     // tail of size-1 dimensions explicitly.
46     if (targetDim == targetShape.size())
47       break;
48 
49     int64_t currTargetShape = targetShape[targetDim];
50     while (sourceDim < (sourceShape.size() - 1) &&
51            sourceShape[sourceDim] != ShapedType::kDynamic &&
52            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53       prodOfCollapsedDims *= sourceShape[sourceDim];
54       currIndices.push_back(sourceDim++);
55     }
56 
57     // If the current expanded dimension is dynamic, then the collapsed
58     // dimensions should also be dynamic and product of all previous unprocessed
59     // dimensions of the expanded shape should be 1.
60     if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61         (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
62       return std::nullopt;
63 
64     // If the collapsed dim is dynamic, the current expanded dim should also
65     // be dynamic.
66     if (currTargetShape == ShapedType::kDynamic &&
67         sourceShape[sourceDim] != ShapedType::kDynamic)
68       return std::nullopt;
69 
70     // For static shapes, if the product of dimensions of the expanded shape
71     // should match the collapsed dimension shape.
72     if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73       return std::nullopt;
74 
75     currIndices.push_back(sourceDim++);
76     reassociationMap.emplace_back(ReassociationIndices{});
77     std::swap(reassociationMap.back(), currIndices);
78     prodOfCollapsedDims = 1;
79   }
80   // All the dimensions in the target must have been processed.
81   if (reassociationMap.size() != targetShape.size())
82     return std::nullopt;
83   // Process any remaining entries in the source shape. They all need to be
84   // 1 or dynamic.
85   for (; sourceDim < sourceShape.size(); sourceDim++) {
86     if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87         sourceShape[sourceDim] != 1)
88       return std::nullopt;
89     // The map is empty when the target type is a scalar.
90     if (!reassociationMap.empty())
91       reassociationMap.back().push_back(sourceDim);
92   }
93   return reassociationMap;
94 }
95 
96 std::optional<SmallVector<ReassociationIndices>>
97 mlir::composeReassociationIndices(
98     ArrayRef<ReassociationIndices> producerReassociations,
99     ArrayRef<ReassociationIndices> consumerReassociations,
100     MLIRContext *context) {
101   SmallVector<ReassociationIndices> composedIndices;
102   // Make the producer the larger sized vector. If they are of same size, the
103   // resulting reshape is not a supported reshape op.
104   if (producerReassociations.size() == consumerReassociations.size())
105     return std::nullopt;
106   if (producerReassociations.size() < consumerReassociations.size())
107     std::swap(producerReassociations, consumerReassociations);
108 
109   // Handle the corner case of the result being a rank 0 shaped type. Return an
110   // empty reassociation.
111   if (consumerReassociations.empty())
112     return composedIndices;
113 
114   size_t consumerDims = std::accumulate(
115       consumerReassociations.begin(), consumerReassociations.end(), 0,
116       [](size_t all, ReassociationIndicesRef indices) {
117         return all + indices.size();
118       });
119   if (producerReassociations.size() != consumerDims)
120     return std::nullopt;
121 
122   for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
123     ReassociationIndices reassociations;
124     for (int64_t consumerIndex : consumerIndices) {
125       llvm::append_range(reassociations, producerReassociations[consumerIndex]);
126     }
127     composedIndices.push_back(std::move(reassociations));
128   }
129   return composedIndices;
130 }
131 
132 SmallVector<SmallVector<AffineExpr, 2>, 2>
133 mlir::convertReassociationIndicesToExprs(
134     MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
135   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
136   for (const auto &indices : reassociationIndices) {
137     SmallVector<AffineExpr, 2> reassociationMap;
138     reassociationMap.reserve(indices.size());
139     for (int64_t index : indices)
140       reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
141     reassociationMaps.push_back(std::move(reassociationMap));
142   }
143   return reassociationMaps;
144 }
145 
146 template <typename AffineExprTy>
147 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
148   unsigned pos = 0;
149   for (const auto &exprs : exprArrays) {
150     for (auto expr : exprs) {
151       expr.walk([&pos](AffineExpr e) {
152         if (auto d = dyn_cast<AffineExprTy>(e))
153           pos = std::max(pos, d.getPosition());
154       });
155     }
156   }
157   return pos;
158 }
159 
160 ArrayAttr mlir::getReassociationIndicesAttribute(
161     OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
162   SmallVector<Attribute, 4> reassociationAttr =
163       llvm::to_vector<4>(llvm::map_range(
164           reassociation, [&](const ReassociationIndices &indices) -> Attribute {
165             return cast<Attribute>(b.getI64ArrayAttr(indices));
166           }));
167   return b.getArrayAttr(reassociationAttr);
168 }
169 
170 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
171     ArrayRef<ReassociationExprs> reassociationExprs) {
172   SmallVector<ReassociationIndices, 2> reassociationIndices;
173   for (const auto &exprs : reassociationExprs) {
174     ReassociationIndices indices;
175     indices.reserve(exprs.size());
176     for (const auto &expr : exprs)
177       indices.push_back(cast<AffineDimExpr>(expr).getPosition());
178     reassociationIndices.push_back(indices);
179   }
180   return reassociationIndices;
181 }
182 
183 SmallVector<AffineMap, 4>
184 mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
185   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
186   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
187          "Expected symbol-less expressions");
188   SmallVector<AffineMap, 4> maps;
189   maps.reserve(reassociation.size());
190   for (const auto &exprs : reassociation) {
191     assert(!exprs.empty());
192     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
193   }
194   return maps;
195 }
196 
197 bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
198                                 int *invalidIndex) {
199   if (reassociation.empty())
200     return true;
201   unsigned nDims = reassociation[0].getNumDims();
202   unsigned nextExpectedDim = 0;
203   for (const auto &it : llvm::enumerate(reassociation)) {
204     auto m = it.value();
205     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
206       if (invalidIndex)
207         *invalidIndex = it.index();
208       return false;
209     }
210     for (auto e : m.getResults()) {
211       auto d = dyn_cast<AffineDimExpr>(e);
212       if (!d || d.getPosition() != nextExpectedDim++) {
213         if (invalidIndex)
214           *invalidIndex = it.index();
215         return false;
216       }
217     }
218   }
219   if (nextExpectedDim != nDims) {
220     if (invalidIndex)
221       *invalidIndex = reassociation.size() - 1;
222     return false;
223   }
224   return true;
225 }
226 
227 LogicalResult mlir::reshapeLikeShapesAreCompatible(
228     function_ref<LogicalResult(const Twine &)> emitError,
229     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
230     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
231   unsigned expandedDimStart = 0;
232   for (const auto &map : llvm::enumerate(reassociationMaps)) {
233     bool foundDynamicShape = false;
234     int64_t linearizedStaticShape = 1;
235 
236     for (const auto &dim : llvm::enumerate(
237              expandedShape.slice(expandedDimStart, map.value().size()))) {
238       if (ShapedType::isDynamic(dim.value()))
239         foundDynamicShape = true;
240       else
241         linearizedStaticShape *= dim.value();
242     }
243     if (foundDynamicShape) {
244       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
245         return emitError(
246             "expected dimension " + Twine(map.index()) +
247             " of collapsed type to be dynamic since one or more of the "
248             "corresponding dimensions in the expanded type is dynamic");
249       }
250     } else {
251       if (collapsedShape[map.index()] != linearizedStaticShape) {
252         return emitError("expected dimension " + Twine(map.index()) +
253                          " of collapsed type to be static value of " +
254                          Twine(linearizedStaticShape));
255       }
256     }
257     expandedDimStart += map.value().size();
258   }
259   return success();
260 }
261 
262 bool mlir::hasNonIdentityLayout(Type type) {
263   if (auto memrefType = dyn_cast<MemRefType>(type))
264     return !memrefType.getLayout().isIdentity();
265   return false;
266 }
267 
268 llvm::SmallBitVector
269 mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
270                           ArrayRef<Range> sliceParams) {
271   assert(sliceParams.size() == sliceInputShape.size() &&
272          "only supports non rank-reducing case");
273   llvm::SmallBitVector mask(sliceInputShape.size());
274   unsigned idx = 0;
275   for (const auto &[offset, size, stride] : sliceParams) {
276     std::optional<int64_t> offsetConst = getConstantIntValue(offset);
277     std::optional<int64_t> strideConst = getConstantIntValue(stride);
278     mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
279                 (!strideConst || *strideConst != 1) ||
280                 (!offsetConst || *offsetConst != 0);
281     idx++;
282   }
283   return mask;
284 }
285 
286 llvm::SmallBitVector mlir::getLinearizedDimensions(
287     ArrayRef<ReassociationIndices> reassociationIndices) {
288   llvm::SmallBitVector result(reassociationIndices.size());
289   for (const auto &it : llvm::enumerate(reassociationIndices))
290     result[it.index()] = it.value().size() > 1;
291   return result;
292 }
293 
294 SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
295     MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
296   unsigned loopIdx = 0;
297   auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
298   auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
299   SmallVector<Range> offsetsSizesAndStrides;
300   offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
301   for (const auto &it : llvm::enumerate(reassociationIndices)) {
302     // Case 1: Linearized dimensions that have also been sliced. These
303     // are size of 1 because we are iterating over these dimensions. The
304     // offsets are exactly the de-linearized multi-indices.
305     if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
306       llvm::append_range(
307           offsetsSizesAndStrides,
308           llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
309             return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
310           }));
311       continue;
312     }
313 
314     // Case 2: One or possibly multiple combined input dimensions, but we
315     // have proven that these are not sliced. In this case we just take
316     // the full extent of each dimension in the reassociation list.
317     if (linearizedDimensions[it.index()]) {
318       llvm::append_range(
319           offsetsSizesAndStrides,
320           llvm::map_range(it.value(), [&](int64_t idx) -> Range {
321             return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
322           }));
323       continue;
324     }
325 
326     // Case 3: A single index, but it may be sliced.
327     offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
328   }
329   return offsetsSizesAndStrides;
330 }
331 
332 SmallVector<Range>
333 SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
334                                               ValueRange tileIndices) {
335   auto one = IntegerAttr::get(IndexType::get(ctx), 1);
336   auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
337   SmallVector<Range> insertParams;
338   insertParams.reserve(linearizedDimensions.size());
339   unsigned loopIdx = 0;
340   for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
341     if (linearizedDimensions[i] && slicedDimensions[i]) {
342       insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
343       continue;
344     }
345     insertParams.push_back(Range{zero, sliceParams[i].size, one});
346   }
347   return insertParams;
348 }
349 
350 /// Returns the index of the only non-unit dimension among `indices` of `shape`,
351 /// if such a dimension exists and `indices` has more than one element.
352 /// Otherwise, return std::nullopt.
353 static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
354                                                   ArrayRef<int64_t> shape) {
355   // Return false if more than one of the dimensions in this group are not 1.
356   std::optional<int64_t> dimIndex;
357   if (indices.size() < 2)
358     return std::nullopt;
359   for (int64_t idx : indices) {
360     if (shape[idx] != 1) {
361       if (dimIndex != std::nullopt)
362         return std::nullopt;
363       dimIndex = idx;
364     }
365   }
366   return dimIndex;
367 }
368 
369 // For each segment in the reassociation indices, check whether we can
370 // simplify that segment with a rank-reducing extract slice. We can do this if
371 // all but (exactly) one of the corresponding source dims is 1.
372 static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments(
373     RankedTensorType sourceType,
374     ArrayRef<ReassociationIndices> reassociationIndices) {
375   SmallVector<std::optional<int64_t>> trivialSegments;
376   for (const auto &indices : reassociationIndices)
377     trivialSegments.push_back(
378         getUniqueNonUnitDim(indices, sourceType.getShape()));
379   return trivialSegments;
380 }
381 
382 /// Returns true if any of the segments of the reassociation indices for a
383 /// collapsing reshape can be simplified using a rank-reducing slice.
384 static FailureOr<SmallVector<std::optional<int64_t>>>
385 canCollapseShapeBeSimplifiedByRankReducingSlice(
386     RankedTensorType sourceType,
387     ArrayRef<ReassociationIndices> reassociationIndices) {
388   SmallVector<std::optional<int64_t>> trivialSegments =
389       getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
390   if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {
391         return idx.has_value();
392       }))
393     return failure();
394   return trivialSegments;
395 }
396 
397 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
398 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
399     RankedTensorType sourceType,
400     ArrayRef<ReassociationIndices> reassociationIndices) {
401   FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
402       canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType,
403                                                       reassociationIndices);
404   if (failed(trivialSegments))
405     return failure();
406 
407   // Create the expected result shape of the rank-reducing slice.
408   SmallVector<int64_t> sliceShape;
409   for (const auto &[nonUnitDim, indices] :
410        llvm::zip(*trivialSegments, reassociationIndices)) {
411     if (nonUnitDim) {
412       sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
413       continue;
414     }
415     llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
416                          return sourceType.getDimSize(idx);
417                        }));
418   }
419   auto sliceType =
420       RankedTensorType::get(sliceShape, sourceType.getElementType());
421 
422   // If the rank-reducing slice simplified every segment, then we are done.
423   if (sliceShape.size() == reassociationIndices.size())
424     return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
425                                                             std::nullopt};
426 
427   // Otherwise, we need to create a new collapse_shape op for the segments that
428   // weren't covered by the slice. By design, the new reassociation indices has
429   // the same number of groups as the old reassociation indices.
430   SmallVector<ReassociationIndices> newReassociationIndices;
431   SmallVector<int64_t, 2> reassociation;
432   int64_t groupIdx = 0;
433   for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
434     reassociation.push_back(dimIdx);
435     if ((*trivialSegments)[groupIdx] ||
436         reassociation.size() == reassociationIndices[groupIdx].size()) {
437       newReassociationIndices.push_back(reassociation);
438       reassociation.clear();
439       groupIdx++;
440     }
441   }
442 
443   return CollapseShapeRankReducingSliceSimplificationInfo{
444       sliceType, newReassociationIndices};
445 }
446 
447 PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
448                                              ArrayRef<int64_t> innerDimPos) {
449   PackingMetadata res;
450   res.insertPositions.reserve(innerDimPos.size());
451   // The pack insert position is the position + the number of previously
452   // inserted positions + offset.
453   // The offset controls whether the packing dimension is the first or last.
454   //
455   // Example
456   // =======
457   // Consider packing from a hypothetical ABCD layout to ABCDba whose
458   // pack.inner_dims is [1, 0]. The first step consists in undoing the
459   // permutation and producing AaBbCD. This is achieved purely by computing the
460   // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
461   // possibility, is to produce insert positions [2, 0], this would result in an
462   // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
463   // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
464   // The latter is what we expect from packing.
465   int64_t offset = 1;
466   for (int64_t pos : innerDimPos) {
467     int64_t numInsertedBefore = llvm::count_if(
468         innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
469     res.insertPositions.push_back(pos + numInsertedBefore + offset);
470   }
471 
472   DenseSet<int64_t> posSet(res.insertPositions.begin(),
473                            res.insertPositions.end());
474   res.reassociations.reserve(packedRank);
475   for (int64_t i = 1; i <= packedRank; ++i) {
476     res.outerPositions.push_back(i - 1);
477     if (!posSet.contains(i)) {
478       res.reassociations.push_back(ReassociationIndices{i - 1});
479       continue;
480     }
481     res.reassociations.push_back(ReassociationIndices{i - 1, i});
482     ++i;
483   }
484   return res;
485 }
486