xref: /llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h (revision c8b5d30f707757a4fe4d9d0bb01f762665f6942f)
1 //===- IndexingUtils.h - Helpers related to index computations --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities and common canonicalization patterns for
10 // reshape operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_UTILS_INDEXINGUTILS_H
15 #define MLIR_DIALECT_UTILS_INDEXINGUTILS_H
16 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Support/LLVM.h"
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/iterator.h"
22 #include <optional>
23 #include <utility>
24 
25 namespace mlir {
26 class ArrayAttr;
27 
28 //===----------------------------------------------------------------------===//
29 // Utils that operate on static integer values.
30 //===----------------------------------------------------------------------===//
31 
32 /// Given a set of sizes, return the suffix product.
33 ///
34 /// When applied to slicing, this is the calculation needed to derive the
35 /// strides (i.e. the number of linear indices to skip along the (k-1) most
36 /// minor dimensions to get the next k-slice).
37 ///
38 /// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
39 ///
40 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
41 ///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
42 ///
43 /// `sizes` elements are asserted to be non-negative.
44 ///
45 /// Return an empty vector if `sizes` is empty.
46 SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
47 inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
48   return computeSuffixProduct(sizes);
49 }
50 
51 /// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
52 ///
53 /// Return an empty vector if `v1` and `v2` are empty.
54 SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
55                                            ArrayRef<int64_t> v2);
56 
57 /// Self-explicit.
58 int64_t computeSum(ArrayRef<int64_t> basis);
59 
60 /// Self-explicit.
61 int64_t computeProduct(ArrayRef<int64_t> basis);
62 
63 /// Return the number of elements of basis (i.e. the max linear index).
64 /// Return `0` if `basis` is empty.
65 ///
66 /// `basis` elements are asserted to be non-negative.
67 ///
68 /// Return `0` if `basis` is empty.
69 inline int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
70   return computeProduct(basis);
71 }
72 
73 /// Return the linearized index of 'offsets' w.r.t. 'basis'.
74 ///
75 /// `basis` elements are asserted to be non-negative.
76 int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
77 
78 /// Given the strides together with a linear index in the dimension space,
79 /// return the vector-space offsets in each dimension for a de-linearized index.
80 /// `strides` elements are asserted to be non-negative.
81 ///
82 /// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
83 /// vector of int64_t
84 ///   `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
85 SmallVector<int64_t> delinearize(int64_t linearIndex,
86                                  ArrayRef<int64_t> strides);
87 
88 /// Return the multi-dimensional integral ratio of `subShape` to the trailing
89 /// dimensions of `shape`. This represents how many times `subShape` fits
90 /// within `shape`. If integral division is not possible, return std::nullopt.
91 /// The trailing `subShape.size()` entries of both shapes are assumed (and
92 /// enforced) to only contain non-negative values.
93 ///
94 /// Examples:
95 ///   - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
96 ///   - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has
97 ///   higher
98 ///     rank).
99 ///   - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
100 ///     derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
101 ///   - shapeRatio({42, 2, 11, 32}, {2, 5, 2}) returns std::nullopt  which is
102 ///     derived as {42(leading shape dim), 2/2, 11/5(not divisible), 32/2}.
103 std::optional<SmallVector<int64_t>>
104 computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape);
105 
106 //===----------------------------------------------------------------------===//
107 // Utils that operate on AffineExpr.
108 //===----------------------------------------------------------------------===//
109 
110 /// Given a set of sizes, return the suffix product.
111 ///
112 /// When applied to slicing, this is the calculation needed to derive the
113 /// strides (i.e. the number of linear indices to skip along the (k-1) most
114 /// minor dimensions to get the next k-slice).
115 ///
116 /// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
117 ///
118 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<AffineExpr>
119 ///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
120 ///
121 /// It is the caller's responsibility to pass proper AffineExpr kind that
122 /// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
123 /// by an AffineDimExpr).
124 ///
125 /// `sizes` elements are expected to bind to non-negative values.
126 ///
127 /// Return an empty vector if `sizes` is empty.
128 SmallVector<AffineExpr> computeSuffixProduct(ArrayRef<AffineExpr> sizes);
129 inline SmallVector<AffineExpr> computeStrides(ArrayRef<AffineExpr> sizes) {
130   return computeSuffixProduct(sizes);
131 }
132 
133 /// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
134 ///
135 /// It is the caller's responsibility to pass proper AffineExpr kind that
136 /// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
137 /// by an AffineDimExpr).
138 ///
139 /// Return an empty vector if `v1` and `v2` are empty.
140 SmallVector<AffineExpr> computeElementwiseMul(ArrayRef<AffineExpr> v1,
141                                               ArrayRef<AffineExpr> v2);
142 
143 /// Self-explicit.
144 AffineExpr computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
145 
146 /// Self-explicit.
147 AffineExpr computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
148 
149 /// Return the number of elements of basis (i.e. the max linear index).
150 /// Return `0` if `basis` is empty.
151 ///
152 /// It is the caller's responsibility to pass proper AffineExpr kind that
153 /// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
154 /// by an AffineDimExpr).
155 ///
156 /// `basis` elements are expected to bind to non-negative values.
157 ///
158 /// Return the `0` AffineConstantExpr if `basis` is empty.
159 inline AffineExpr computeMaxLinearIndex(MLIRContext *ctx,
160                                         ArrayRef<AffineExpr> basis) {
161   return computeProduct(ctx, basis);
162 }
163 
164 /// Return the linearized index of 'offsets' w.r.t. 'basis'.
165 ///
166 /// Assuming `offsets` is `[o0, .. on]` and `basis` is `[b0, .. bn]`, return the
167 /// AffineExpr `o0 * b0 + .. + on * bn`.
168 ///
169 /// It is the caller's responsibility to pass proper AffineExpr kind that result
170 /// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
171 /// AffineDimExpr).
172 ///
173 /// `basis` elements are expected to bind to non-negative values.
174 AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
175                      ArrayRef<AffineExpr> basis);
176 AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
177                      ArrayRef<int64_t> basis);
178 
179 /// Given the strides together with a linear index in the dimension space,
180 /// return the vector-space offsets in each dimension for a de-linearized index.
181 ///
182 /// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
183 /// vector of AffineExpr
184 ///   `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
185 ///
186 /// It is the caller's responsibility to pass proper AffineExpr kind that result
187 /// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
188 /// AffineDimExpr).
189 ///
190 /// `strides` elements are expected to bind to non-negative values.
191 SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
192                                     ArrayRef<AffineExpr> strides);
193 SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
194                                     ArrayRef<int64_t> strides);
195 
196 //===----------------------------------------------------------------------===//
197 // Permutation utils.
198 //===----------------------------------------------------------------------===//
199 
200 template <typename T>
201 SmallVector<T> applyPermutation(ArrayRef<T> input,
202                                 ArrayRef<int64_t> permutation) {
203   assert(input.size() == permutation.size() &&
204          "expected input rank to equal permutation rank");
205   assert(
206       llvm::all_of(permutation, [&](size_t s) { return s < input.size(); }) &&
207       "permutation must be within input bounds");
208   auto permutationRange = llvm::map_range(
209       llvm::seq<unsigned>(0, input.size()),
210       [&](int64_t idx) -> T { return input[permutation[idx]]; });
211   return llvm::to_vector(permutationRange);
212 }
213 
214 template <typename T>
215 SmallVector<T> applyPermutation(const SmallVectorImpl<T> &input,
216                                 ArrayRef<int64_t> permutation) {
217   return applyPermutation(ArrayRef(input), permutation);
218 }
219 
220 /// Apply the permutation defined by `permutation` to `inVec`.
221 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
222 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
223 /// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a',
224 /// 'b']`.
225 template <typename T, unsigned N>
226 void applyPermutationToVector(SmallVector<T, N> &inVec,
227                               ArrayRef<int64_t> permutation) {
228   inVec = applyPermutation(inVec, permutation);
229 }
230 
231 /// Helper method to apply to inverse a permutation.
232 SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
233 
234 /// Returns true if `permutation` is an identity permutation.
235 bool isIdentityPermutation(ArrayRef<int64_t> permutation);
236 
237 /// Method to check if an interchange vector is a permutation.
238 bool isPermutationVector(ArrayRef<int64_t> interchange);
239 
240 /// Return a permutation vector of size permSize that would result in moving
241 /// positions into desiredPositions.
242 ///
243 /// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
244 /// would result in a {4, 2, 0, 1, 3} permutation vector.
245 SmallVector<int64_t>
246 computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
247                          ArrayRef<int64_t> desiredPositions);
248 
249 /// Returns a permutation vector that drop the input dims in
250 /// dropPositions from inputPerm.
251 ///
252 /// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would
253 /// result in a {2, 0, 1} permutation vector.
254 SmallVector<int64_t> dropDims(ArrayRef<int64_t> inputPerm,
255                               ArrayRef<int64_t> dropPositions);
256 
257 /// Helper to return a subset of `arrayAttr` as a vector of int64_t.
258 // TODO: Port everything relevant to DenseArrayAttr and drop this util.
259 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
260                                     unsigned dropBack = 0);
261 
262 /// Compute linear index from provided strides and indices, assuming strided
263 /// layout.
264 /// Returns AffineExpr and list of values to apply to it, e.g.:
265 ///
266 /// auto &&[expr, values] = computeLinearIndex(...);
267 /// offset = affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
268 std::pair<AffineExpr, SmallVector<OpFoldResult>>
269 computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
270                    ArrayRef<OpFoldResult> indices);
271 std::pair<AffineExpr, SmallVector<OpFoldResult>>
272 computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
273                    ArrayRef<Value> indices);
274 
275 //===----------------------------------------------------------------------===//
276 // Utilities for decomposing larger shapes
277 //===----------------------------------------------------------------------===//
278 
279 namespace detail {
280 /// Encapsulates the set of parameters that are used to make tile offset
281 /// calculations in the TileOffsetRangeIterator.
282 class TileOffsetRangeImpl {
283 public:
284   TileOffsetRangeImpl(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
285                       ArrayRef<int64_t> loopOrder);
286 
287   int64_t getMaxLinearIndex() const { return maxLinearIndex; }
288 
289   SmallVector<int64_t> getStaticTileOffsets(int64_t linearIndex) const;
290 
291   SmallVector<AffineExpr> getDynamicTileOffsets(AffineExpr linearIndex) const;
292 
293   template <typename T>
294   SmallVector<T> getTileOffsets(T linearIndex) const {
295     if constexpr (std::is_same_v<T, int64_t>)
296       return getStaticTileOffsets(linearIndex);
297     else
298       return getDynamicTileOffsets(linearIndex);
299   }
300 
301   size_t getRank() const { return tileShape.size(); }
302 
303 private:
304   /// The sub-shape that divides the larger outer shape (which is provided to
305   /// the constructor).
306   SmallVector<int64_t> tileShape;
307   /// The inverse permutation to the `loopOrder` permutation provided in the
308   /// constructor.
309   SmallVector<int64_t> inverseLoopOrder;
310   /// The strides for the basis 'div(shape, tileShape)' permuted by `loopOrder`.
311   SmallVector<int64_t> sliceStrides;
312   /// The maximum linear index in the iteration space given by basis 'div(shape,
313   /// tileShape)'.
314   int64_t maxLinearIndex;
315 };
316 
317 /// The STL-style iterator implementation for StaticTileOffsetRange.
318 template <typename ElementType>
319 class TileOffsetRangeIterator
320     : public llvm::iterator_facade_base<TileOffsetRangeIterator<ElementType>,
321                                         std::forward_iterator_tag,
322                                         SmallVector<ElementType>> {
323 public:
324   TileOffsetRangeIterator(const TileOffsetRangeImpl &params, ElementType index)
325       : params(params), index(index) {}
326 
327   void operator++() { incrementIndex(1); }
328   TileOffsetRangeIterator operator++(int) {
329     const auto copy = *this;
330     ++*this;
331     return copy;
332   }
333 
334   bool operator==(const TileOffsetRangeIterator &other) const {
335     return index == other.index;
336   }
337   bool operator!=(const TileOffsetRangeIterator &other) const {
338     return index != other.index;
339   }
340 
341   SmallVector<ElementType> operator*() const {
342     return params.getTileOffsets(index);
343   }
344   void operator+=(int64_t offset) { incrementIndex(offset); }
345 
346 private:
347   void incrementIndex(int64_t offset) { index = index + offset; }
348   const TileOffsetRangeImpl params;
349   int64_t index;
350 };
351 } // namespace detail
352 
353 /// A range-style iterator that allows for iterating over the offsets of all
354 /// potential tiles of size `tileShape` within the larger shape `shape`, using
355 /// an ordering specified by `loopOrder`. The `loopOrder` specifies the order of
356 /// unrolling by numbering the dimensions in order from "outer most for loop"
357 /// (slowest changing) to "inner most for loop" (fastest changing).
358 ///
359 /// For example, for `shape = {10, 20, 30}`, `tileShape = {5, 10, 15}`, and
360 /// `loopOrder={2, 0, 1}`, the iterating over this range will yield offsets:
361 ///
362 /// ```
363 /// {0, 0,  0}, {0, 10,  0}, {5, 0,  0}, {5, 10,  0}, {0, 0, 15},
364 /// {0, 10, 15}, {5, 0, 15}, {0, 10, 15}, {5, 10, 15}
365 /// ```
366 ///
367 /// This is useful in contexts where a vector computation over a larger shape
368 /// needs to be unrolled to a set of operations on subsets of the original
369 /// operands, such as during the "vector unrolling" transformations.
370 ///
371 /// The size of `tileShape` must be less-than-or-equal-to the size of `shape`.a
372 /// If the rank of `tileShape` is smaller than `shape`, then `tileShape`
373 /// elements correspond to the trailing dimensions of `shape`, and the leading
374 /// dimensions are considered untiled and `tileShape` is effectively prepended
375 /// with the leading dims of `shape`.
376 class StaticTileOffsetRange {
377 public:
378   using IteratorTy = detail::TileOffsetRangeIterator<int64_t>;
379   using ParamsTy = detail::TileOffsetRangeImpl;
380 
381   StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
382                         ArrayRef<int64_t> loopOrder)
383       : params(shape, tileShape, loopOrder), beginValue(params, 0),
384         pastEndValue(params, params.getMaxLinearIndex()) {
385     assert(shape.size() >= tileShape.size());
386     assert(loopOrder.size() == shape.size());
387   }
388 
389   /// Create the range with identity loop order.
390   StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape)
391       : params(shape, tileShape,
392                llvm::to_vector(llvm::seq<int64_t>(0, shape.size()))),
393         beginValue(params, 0),
394         pastEndValue(params, params.getMaxLinearIndex()) {
395     assert(shape.size() >= tileShape.size());
396   }
397 
398   IteratorTy begin() const { return beginValue; }
399   IteratorTy end() const { return pastEndValue; }
400 
401   /// Returns the total number of tiles that fit in the larger shape.
402   size_t size() const { return params.getMaxLinearIndex(); }
403 
404   /// Returns rank of the iterator's shape.
405   size_t getRank() const { return params.getRank(); }
406 
407 private:
408   const ParamsTy params;
409   IteratorTy beginValue;
410   IteratorTy pastEndValue;
411 };
412 } // namespace mlir
413 
414 #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H
415