xref: /llvm-project/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (revision a9ebdbb5ac7de7a028f6060b789196a43aea7580)
1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <optional>
13 #include <utility>
14 
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
20 
21 namespace mlir {
22 class ConversionTarget;
23 class RewritePatternSet;
24 class TypeConverter;
25 
26 namespace arith {
27 class AndIOp;
28 class NarrowTypeEmulationConverter;
29 class TruncIOp;
30 } // namespace arith
31 
32 namespace vector {
33 struct VectorTransformsOptions;
34 
35 /// Options that control the vector unrolling.
36 struct UnrollVectorOptions {
37   using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
38   /// Callback function that indicates whether vector unrolling should be
39   /// attempted on the operation.
40   FilterConstraintFnType filterConstraint = nullptr;
41   UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
42     filterConstraint = std::move(constraint);
43     return *this;
44   }
45 
46   using NativeShapeFnType =
47       std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
48   /// Function that returns the shape of the vector to unroll to for a given
49   /// operation. The unrolling is aborted if the function returns
50   /// `std::nullopt`.
51   NativeShapeFnType nativeShape = nullptr;
52   UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
53     nativeShape = std::move(fn);
54     return *this;
55   }
56 
57   /// Set the native shape to use for unrolling.
58   UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
59     SmallVector<int64_t> tsShape(shape);
60     nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> {
61       return tsShape;
62     };
63     return *this;
64   }
65 
66   /// Function that returns the traversal order (in terms of "for loop order",
67   /// i.e. slowest varying dimension to fastest varying dimension) that should
68   /// be used when unrolling the given operation into units of the native vector
69   /// size.
70   using UnrollTraversalOrderFnType =
71       std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
72   UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
73   UnrollVectorOptions &
74   setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
75     traversalOrderCallback = std::move(traversalOrderFn);
76     return *this;
77   }
78 };
79 
80 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
81 /// semantics to a contraction with MMT semantics (matrix matrix multiplication
82 /// with the RHS transposed). This specific form is meant to have the vector
83 /// operands are organized such that the reduction dimension is contiguous.
84 /// Example:
85 /// ```
86 /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
87 ///                                   affine_map<(m, n, k) -> (n, k)>,
88 ///                                   affine_map<(m, n, k) -> (m, n)>],
89 ///                  iterator_types = ["parallel", "parallel", "reduction"],
90 ///                  kind = #vector.kind<add>} %a, %b, %c : ...
91 /// ```
92 ///
93 ///  The `constraint` predicate is used to decide which `vector.contraction` ops
94 ///  to filter out.
95 void populateVectorContractCanonicalizeMatmulToMMT(
96     RewritePatternSet &patterns,
97     std::function<LogicalResult(vector::ContractionOp)> constraint =
98         [](vector::ContractionOp) { return success(); },
99     PatternBenefit = 1);
100 
101 /// Collect patterns to convert reduction op to vector.contract and fold
102 /// transpose/broadcast ops into the contract.
103 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns,
104                                                PatternBenefit benefit = 1);
105 
106 /// Populate `patterns` with the following patterns.
107 ///
108 ///   - VectorTransferFullPartialRewriter
109 ///
110 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
111 /// masking) fast path and a slow path.
112 ///
113 /// Example (a 2-D vector.transfer_read):
114 /// ```
115 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
116 /// ```
117 /// is transformed into:
118 /// ```
119 ///    %1:3 = scf.if (%inBounds) {
120 ///      // fast path, direct cast
121 ///      memref.cast %A: memref<A...> to compatibleMemRefType
122 ///      scf.yield %view : compatibleMemRefType, index, index
123 ///    } else {
124 ///      // slow path, not in-bounds vector.transfer or linalg.copy.
125 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
126 ///      scf.yield %4 : compatibleMemRefType, index, index
127 //     }
128 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
129 /// ```
130 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
131 ///
132 /// Preconditions:
133 ///  1. `xferOp.permutation_map()` must be a minor identity map
134 ///  2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
135 ///  must be equal. This will be relaxed in the future but requires
136 ///  rank-reducing subviews.
137 void populateVectorTransferFullPartialPatterns(
138     RewritePatternSet &patterns, const VectorTransformsOptions &options);
139 
140 /// Collect a set of patterns to reduce the rank of the operands of vector
141 /// transfer ops to operate on the largest contigious vector.
142 /// These patterns are useful when lowering to dialects with 1d vector type
143 /// such as llvm and it will result fewer memory reads.
144 void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
145     RewritePatternSet &patterns, PatternBenefit benefit = 1);
146 
147 /// Patterns that remove redundant Vector Ops by re-ordering them with
148 /// e.g. elementwise Ops:
149 /// ```
150 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
151 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
152 /// %r = arith.addf %at, %bt : vector<2x4xf32>
153 /// ```
154 /// gets converted to:
155 /// ```
156 /// %0 = arith.addf %a, %b : vector<4x2xf32>
157 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
158 /// ```
159 /// At the moment, these patterns are limited to vector.broadcast and
160 /// vector.transpose.
161 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
162                                    PatternBenefit benefit = 1);
163 
164 /// Patterns that fold chained vector reductions. These patterns assume that
165 /// elementwise operations (e.g., `arith.addf` with vector operands) are
166 /// cheaper than vector reduction.
167 /// Note that these patterns change the order of reduction which may not always
168 /// produce bit-identical results on some floating point inputs.
169 ///
170 /// Example:
171 /// ```
172 /// %a = vector.reduction <add> %x, %acc
173 /// %b = vector.reduction <add> %y, %a
174 /// ```
175 /// is transformed into:
176 /// ```
177 /// %a = arith.addf %x, %y
178 /// %b = vector.reduction <add> %a, %acc
179 /// ```
180 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns,
181                                                    PatternBenefit benefit = 1);
182 
183 /// Patterns to break down vector reductions into a series of arith reductions
184 /// over vector elements. This is intended to be simplify code with reductions
185 /// over small vector types and avoid more specialized reduction lowering when
186 /// possible.
187 ///
188 /// Example:
189 /// ```
190 /// %a = vector.reduction <add> %x : vector<2xf32> into f32
191 /// ```
192 /// is transformed into:
193 /// ```
194 /// %y = vector.extract %x[0] : f32 from vector<2xf32>
195 /// %z = vector.extract %x[1] : f32 from vector<2xf32>
196 /// %a = arith.addf %y, %z : f32
197 /// ```
198 void populateBreakDownVectorReductionPatterns(
199     RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2,
200     PatternBenefit benefit = 1);
201 
202 /// Populate `patterns` with the following patterns.
203 ///
204 /// [DecomposeDifferentRankInsertStridedSlice]
205 /// ==========================================
206 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
207 /// have different ranks.
208 ///
209 /// When ranks are different, InsertStridedSlice needs to extract a properly
210 /// ranked vector from the destination vector into which to insert. This pattern
211 /// only takes care of this extraction part and forwards the rest to
212 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
213 ///
214 /// For a k-D source and n-D destination vector (k < n), we emit:
215 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
216 ///      insert the k-D source.
217 ///   2. k-D -> (n-1)-D InsertStridedSlice op
218 ///   3. InsertOp that is the reverse of 1.
219 ///
220 /// [DecomposeNDExtractStridedSlice]
221 /// ================================
222 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
223 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
224 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
225     RewritePatternSet &patterns, PatternBenefit benefit = 1);
226 
227 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice
228 /// ops into a chain of Extract ops to extract each element from the source, and
229 /// then a chain of Insert ops to insert to the target vector.
230 ///
231 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that
232 /// `controlFn` returns true. Otherwise runs on ops.
233 void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
234     RewritePatternSet &patterns,
235     std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
236     PatternBenefit benefit = 1);
237 
238 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
239 /// based on the destination vector shape. Bitcasts from a lower bitwidth
240 /// element type to a higher bitwidth one are extracted from the lower bitwidth
241 /// based on the native destination vector shape and inserted based on the ratio
242 /// of the bitwidths.
243 ///
244 /// This acts as a last resort way to break down vector.bitcast ops to smaller
245 /// vector sizes. Because this pattern composes until it is bitcasting to a
246 /// single element of the higher bitwidth, the is an optional control function.
247 /// If `controlFn` is not nullptr, the pattern will only apply to ops where
248 /// `controlFn` returns true, otherwise applies to all bitcast ops.
249 void populateBreakDownVectorBitCastOpPatterns(
250     RewritePatternSet &patterns,
251     std::function<bool(BitCastOp)> controlFn = nullptr,
252     PatternBenefit benefit = 1);
253 
254 /// Populate `patterns` with the following patterns.
255 ///
256 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
257 ///
258 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
259 /// ==============================================
260 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
261 /// have the same rank. For each outermost index in the slice:
262 ///   begin    end             stride
263 /// [offset : offset+size*stride : stride]
264 ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
265 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
266 ///   3. the destination subvector is inserted back in the proper place
267 ///   3. InsertOp that is the reverse of 1.
268 ///
269 /// [Convert1DExtractStridedSliceIntoShuffle]
270 /// =========================================
271 /// For such cases, we can lower it to a ShuffleOp.
272 void populateVectorInsertExtractStridedSliceTransforms(
273     RewritePatternSet &patterns, PatternBenefit benefit = 1);
274 
275 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
276 /// `options` structure controls which operations are unrolled and the target
277 /// shape.
278 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
279 ///   1. the unrolled type `unrolledVectorType` and number of unrolled instances
280 ///   `numUnrolledInstances` are computed from the `targetShape`. For now it is
281 ///   assumed the unrolling factors divide the vector sizes.
282 ///   2. ExtractStridedSlice are created to break-up the vector operands.
283 ///   3. the original op is cloned `numUnrolledInstances` times, once for each
284 ///   result.
285 ///   4. InsertStridedSlice are inserted to re-assemble the slices into the
286 ///   original vectore shape.
287 ///
288 /// Example:
289 ///
290 ///    opA(operand0, operand1)  // numUnrolledInstances = 3
291 ///
292 ///            operand0                   operand1
293 ///               |                          |
294 ///             fork                       fork
295 ///        <----------gather all fork ops --------->
296 ///              /|\                        /|\
297 ///          f00 f01 f02                f10 f11 f12
298 ///        <---------- clone op 3 times --------->
299 ///          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
300 ///                 \            |            /
301 ///      <-------------------- join ------------------------->
302 ///
303 /// Other local patterns then kick in iteratively (including DCE) and compose
304 /// to combine the ExtractStridedSlice/InsertStridedSlice.
305 void populateVectorUnrollPatterns(RewritePatternSet &patterns,
306                                   const UnrollVectorOptions &options,
307                                   PatternBenefit benefit = 1);
308 
309 /// Collect a set of vector.shape_cast folding patterns.
310 void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
311                                       PatternBenefit benefit = 1);
312 
313 /// Collect a set of leading one dimension removal patterns.
314 ///
315 /// These patterns insert vector.shape_cast to remove leading one dimensions
316 /// to expose more canonical forms of read/write/insert/extract operations.
317 /// With them, there are more chances that we can cancel out extract-insert
318 /// pairs or forward write-read pairs.
319 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns,
320                                                  PatternBenefit benefit = 1);
321 
322 /// Collect a set of one dimension removal patterns.
323 ///
324 /// These patterns insert rank-reducing memref.subview ops to remove one
325 /// dimensions. With them, there are more chances that we can avoid
326 /// potentially expensive vector.shape_cast operations.
327 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns,
328                                                 PatternBenefit benefit = 1);
329 
330 /// Collect a set of patterns that use vector.shape_cast to help fold unit dims.
331 ///
332 /// These patterns use vector.shape_cast to remove unit dims from e.g.
333 /// arithmetic operations on Vectors. The newly inserted shape_casts will either
334 /// cancel each other out or will be folded away when combined with other
335 /// patterns.
336 void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
337                                               PatternBenefit benefit = 1);
338 
339 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
340 /// memref.
341 ///
342 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
343 /// to transform multiple small n-D transfers into a larger 1-D transfer where
344 /// the memref contiguity properties allow it.
345 ///
346 /// Flattening is only applied if the bitwidth of the trailing vector dimension
347 /// is smaller or equal to `targetVectorBitwidth`.
348 void populateFlattenVectorTransferPatterns(
349     RewritePatternSet &patterns,
350     unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
351     PatternBenefit benefit = 1);
352 
353 /// Collect a set of patterns that bubble up/down bitcast ops.
354 ///
355 /// These patterns move vector.bitcast ops to be before insert ops or after
356 /// extract ops where suitable. With them, bitcast will happen on smaller
357 /// vectors and there are more chances to share extract/insert ops.
358 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns,
359                                            PatternBenefit benefit = 1);
360 
361 /// These patterns materialize masks for various vector ops such as transfers.
362 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
363                                                bool force32BitVectorIndices,
364                                                PatternBenefit benefit = 1);
365 
366 /// Appends patterns for emulating vector operations over narrow types with ops
367 /// over wider types.
368 void populateVectorNarrowTypeEmulationPatterns(
369     const arith::NarrowTypeEmulationConverter &typeConverter,
370     RewritePatternSet &patterns);
371 
372 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373 /// vector operations comprising `shuffle` and `bitwise` ops.
374 /// Warning: these patterns currently only work for little endian targets.
375 FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
376                                         vector::BitCastOp bitCastOp,
377                                         arith::TruncIOp truncOp,
378                                         vector::BroadcastOp maybeBroadcastOp);
379 
380 /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
381 /// vector operations comprising `shuffle` and `bitwise` ops.
382 /// Warning: these patterns currently only work for little endian targets.
383 FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
384                                      vector::BitCastOp bitCastOp,
385                                      vector::BroadcastOp maybeBroadcastOp);
386 
387 /// Appends patterns for rewriting vector operations over narrow types with
388 /// ops over wider types.
389 /// Warning: these patterns currently only work for little endian targets.
390 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
391                                              PatternBenefit benefit = 1);
392 
393 /// Appends patterns for emulating a sub-byte vector transpose.
394 void populateVectorTransposeNarrowTypeRewritePatterns(
395     RewritePatternSet &patterns, PatternBenefit benefit = 1);
396 
397 /// Populates patterns for ND vectors (N >= 2) linearization and sets up the
398 /// provided ConversionTarget with the appropriate legality configuration for
399 /// the ops to get converted properly.
400 void populateVectorLinearizeTypeConversionsAndLegality(
401     TypeConverter &typeConverter, RewritePatternSet &patterns,
402     ConversionTarget &target, unsigned targetBitWidth);
403 
404 /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
405 /// vector shuffle operations.
406 void populateVectorLinearizeShuffleLikeOpsPatterns(
407     const TypeConverter &typeConverter, RewritePatternSet &patterns,
408     ConversionTarget &target, unsigned targetBitWidth);
409 
410 } // namespace vector
411 } // namespace mlir
412 
413 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
414