xref: /llvm-project/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (revision da8778e499d8049ac68c2e152941a38ff2bc9fb2)
1 //===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
10 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
11 
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Utils/IndexingUtils.h"
16 #include "mlir/Dialect/Vector/IR/VectorOps.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/Support/LLVM.h"
19 
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 
25 // Forward declarations.
26 class AffineMap;
27 class Block;
28 class Location;
29 class OpBuilder;
30 class Operation;
31 class ShapedType;
32 class Value;
33 class VectorType;
34 class VectorTransferOpInterface;
35 
36 namespace affine {
37 class AffineApplyOp;
38 class AffineForOp;
39 } // namespace affine
40 
41 namespace vector {
42 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
43 /// the type of `source`.
44 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
45 
46 /// Returns two dims that are greater than one if the transposition is applied
47 /// on a 2D slice. Otherwise, returns a failure.
48 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
49 
50 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
51 ///
52 /// Only the N = vectorType.getRank() trailing dims of `memrefType` are
53 /// checked (the other dims are not relevant). Note that for `vectorType` to be
54 /// a contiguous slice of `memrefType`, the trailing dims of the latter have
55 /// to be contiguous - this is checked by looking at the corresponding strides.
56 ///
57 /// There might be some restriction on the leading dim of `VectorType`:
58 ///
59 /// Case 1. If all the trailing dims of `vectorType` match the trailing dims
60 ///         of `memrefType` then the leading dim of `vectorType` can be
61 ///         arbitrary.
62 ///
63 ///        Ex. 1.1 contiguous slice, perfect match
64 ///          vector<4x3x2xi32> from memref<5x4x3x2xi32>
65 ///        Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
66 ///          vector<2x3x2xi32> from memref<5x4x3x2xi32>
67 ///
68 /// Case 2. If an "internal" dim of `vectorType` does not match the
69 ///         corresponding trailing dim in `memrefType` then the remaining
70 ///         leading dims of `vectorType` have to be 1 (the first non-matching
71 ///         dim can be arbitrary).
72 ///
73 ///        Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
74 ///          vector<2x2x2xi32> from memref<5x4x3x2xi32>
75 ///        Ex. 2.2  contiguous slice, 2 != 3 and the leading dim == <1>
76 ///          vector<1x2x2xi32> from memref<5x4x3x2xi32>
77 ///        Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
78 ///          vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
79 ///        Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
80 ///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
81 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
82 
83 /// Returns an iterator for all positions in the leading dimensions of `vType`
84 /// up to the `targetRank`. If any leading dimension before the `targetRank` is
85 /// scalable (so cannot be unrolled), it will return an iterator for positions
86 /// up to the first scalable dimension.
87 ///
88 /// If no leading dimensions can be unrolled an empty optional will be returned.
89 ///
90 /// Examples:
91 ///
92 ///   For vType = vector<2x3x4> and targetRank = 1
93 ///
94 ///   The resulting iterator will yield:
95 ///     [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
96 ///
97 ///   For vType = vector<3x[4]x5> and targetRank = 0
98 ///
99 ///   The scalable dimension blocks unrolling so the iterator yields only:
100 ///     [0], [1], [2]
101 ///
102 std::optional<StaticTileOffsetRange>
103 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
104 
105 /// Returns a functor (int64_t -> Value) which returns a constant vscale
106 /// multiple.
107 ///
108 /// Example:
109 /// ```c++
110 /// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
111 /// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
112 /// ```
113 inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
114   Value vscale = nullptr;
115   return [loc, vscale, &rewriter](int64_t multiplier) mutable {
116     if (!vscale)
117       vscale = rewriter.create<vector::VectorScaleOp>(loc);
118     return rewriter.create<arith::MulIOp>(
119         loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
120   };
121 }
122 
123 /// Returns a range over the dims (size and scalability) of a VectorType.
124 inline auto getDims(VectorType vType) {
125   return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
126 }
127 
128 /// A wrapper for getMixedSizes for vector.transfer_read and
129 /// vector.transfer_write Ops (for source and destination, respectively).
130 ///
131 /// Tensor and MemRef types implement their own, very similar version of
132 /// getMixedSizes. This method will call the appropriate version (depending on
133 /// `hasTensorSemantics`). It will also automatically extract the operand for
134 /// which to call it on (source for "read" and destination for "write" ops).
135 SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
136                                             Operation *xfer,
137                                             RewriterBase &rewriter);
138 
139 /// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
140 /// masked (i.e. inside `vector.mask` Op region). In particular:
141 ///   1. Matches `SourceOp` operation, Op.
142 ///   2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
143 ///     insertion point to avoid inserting new ops into the `vector.mask` Op
144 ///     region (which only allows one Op).
145 ///   2.2 If Op is not masked, this step is skipped.
146 ///   3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
147 ///     found in step 2.1.
148 ///
149 /// This wrapper frees patterns from re-implementing the logic to update the
150 /// insertion point when a maskable Op is masked. Such patterns are still
151 /// responsible for providing an updated ("rewritten") version of:
152 ///   a. the source Op when mask _is not_ present,
153 ///   b. the source Op and the masking Op when mask _is_ present.
154 /// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
155 /// the return value will depend on the case above.
156 template <class SourceOp>
157 struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
158   using OpRewritePattern<SourceOp>::OpRewritePattern;
159 
160 private:
161   LogicalResult matchAndRewrite(SourceOp sourceOp,
162                                 PatternRewriter &rewriter) const final {
163     auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
164     if (!maskableOp)
165       return failure();
166 
167     Operation *rootOp = sourceOp;
168 
169     // If this Op is masked, update the insertion point to avoid inserting into
170     // the vector.mask Op region.
171     OpBuilder::InsertionGuard guard(rewriter);
172     MaskingOpInterface maskOp;
173     if (maskableOp.isMasked()) {
174       maskOp = maskableOp.getMaskingOp();
175       rewriter.setInsertionPoint(maskOp);
176       rootOp = maskOp;
177     }
178 
179     FailureOr<Value> newOp =
180         matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
181     if (failed(newOp))
182       return failure();
183 
184     // Rewriting succeeded but there are no values to replace.
185     if (rootOp->getNumResults() == 0) {
186       rewriter.eraseOp(rootOp);
187     } else {
188       assert(*newOp != Value() &&
189              "Cannot replace an op's use with an empty value.");
190       rewriter.replaceOp(rootOp, *newOp);
191     }
192     return success();
193   }
194 
195 public:
196   // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
197   // latter is present, returns a replacement for `maskingOp`. Otherwise,
198   // returns a replacement for `sourceOp`.
199   virtual FailureOr<Value>
200   matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
201                             PatternRewriter &rewriter) const = 0;
202 };
203 
204 /// Returns true if the input Vector type can be linearized.
205 ///
206 /// Linearization is meant in the sense of flattening vectors, e.g.:
207 ///   * vector<NxMxKxi32> -> vector<N*M*Kxi32>
208 /// In this sense, Vectors that are either:
209 ///   * already linearized, or
210 ///   * contain more than 1 scalable dimensions,
211 /// are not linearizable.
212 bool isLinearizableVector(VectorType type);
213 
214 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
215 /// vector type for the read is not the same as the type of `source`, then a
216 /// mask is created on the read, if use of mask is specified or the bounds on a
217 /// dimension are different.
218 ///
219 /// `useInBoundsInsteadOfMasking` if false, the inBoundsVal values are set
220 /// properly, based on
221 ///   the rank dimensions of the source and destination tensors. And that is
222 ///   what determines if masking is done.
223 ///
224 /// Note that the internal `vector::TransferReadOp` always read at indices zero
225 /// for each dimension of the passed in tensor.
226 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
227                              ArrayRef<int64_t> readShape, Value padValue,
228                              bool useInBoundsInsteadOfMasking);
229 
230 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
231 /// given `shape`, i.e., it meets:
232 ///   1. The numbers of elements in both array are equal.
233 ///   2. `inputVectorSizes` does not have dynamic dimensions.
234 ///   3. All the values in `inputVectorSizes` are greater than or equal to
235 ///      static sizes in `shape`.
236 LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
237                                        ArrayRef<int64_t> inputVectorSizes);
238 } // namespace vector
239 
240 /// Constructs a permutation map of invariant memref indices to vector
241 /// dimension.
242 ///
243 /// If no index is found to be invariant, 0 is added to the permutation_map and
244 /// corresponds to a vector broadcast along that dimension.
245 ///
246 /// The implementation uses the knowledge of the mapping of loops to
247 /// vector dimension. `loopToVectorDim` carries this information as a map with:
248 ///   - keys representing "vectorized enclosing loops";
249 ///   - values representing the corresponding vector dimension.
250 /// Note that loopToVectorDim is a whole function map from which only enclosing
251 /// loop information is extracted.
252 ///
253 /// Prerequisites: `indices` belong to a vectorizable load or store operation
254 /// (i.e. at most one invariant index along each AffineForOp of
255 /// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
256 /// load or store operation.
257 ///
258 /// Example 1:
259 /// The following MLIR snippet:
260 ///
261 /// ```mlir
262 ///    affine.for %i3 = 0 to %0 {
263 ///      affine.for %i4 = 0 to %1 {
264 ///        affine.for %i5 = 0 to %2 {
265 ///          %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
266 ///    }}}
267 /// ```
268 ///
269 /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
270 ///
271 /// ```mlir
272 ///    affine.for %i3 = 0 to %0 step 32 {
273 ///      affine.for %i4 = 0 to %1 {
274 ///        affine.for %i5 = 0 to %2 step 256 {
275 ///          %4 = vector.transfer_read %arg0, %i4, %i5, %i3
276 ///               {permutation_map: (d0, d1, d2) -> (d2, d1)} :
277 ///               (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
278 ///    }}}
279 /// ```
280 ///
281 /// Meaning that vector.transfer_read will be responsible for reading the slice:
282 /// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
283 ///
284 /// Example 2:
285 /// The following MLIR snippet:
286 ///
287 /// ```mlir
288 ///    %cst0 = arith.constant 0 : index
289 ///    affine.for %i0 = 0 to %0 {
290 ///      %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
291 ///    }
292 /// ```
293 ///
294 /// may vectorize with {permutation_map: (d0) -> (0)} into:
295 ///
296 /// ```mlir
297 ///    affine.for %i0 = 0 to %0 step 128 {
298 ///      %3 = vector.transfer_read %arg0, %c0_0, %c0_0
299 ///           {permutation_map: (d0, d1) -> (0)} :
300 ///           (memref<?x?xf32>, index, index) -> vector<128xf32>
301 ///    }
302 /// ````
303 ///
304 /// Meaning that vector.transfer_read will be responsible of reading the slice
305 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
306 ///
307 AffineMap
308 makePermutationMap(Block *insertPoint, ArrayRef<Value> indices,
309                    const DenseMap<Operation *, unsigned> &loopToVectorDim);
310 AffineMap
311 makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices,
312                    const DenseMap<Operation *, unsigned> &loopToVectorDim);
313 
314 namespace matcher {
315 
316 /// Matches vector.transfer_read, vector.transfer_write and ops that return a
317 /// vector type that is a multiple of the sub-vector type. This allows passing
318 /// over other smaller vector types in the function and avoids interfering with
319 /// operations on those.
320 /// This is a first approximation, it can easily be extended in the future.
321 /// TODO: this could all be much simpler if we added a bit that a vector type to
322 /// mark that a vector is a strict super-vector but it still does not warrant
323 /// adding even 1 extra bit in the IR for now.
324 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
325 
326 } // namespace matcher
327 } // namespace mlir
328 
329 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
330