1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H 10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H 11 12 #include <optional> 13 #include <utility> 14 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc" 20 21 namespace mlir { 22 class ConversionTarget; 23 class RewritePatternSet; 24 class TypeConverter; 25 26 namespace arith { 27 class AndIOp; 28 class NarrowTypeEmulationConverter; 29 class TruncIOp; 30 } // namespace arith 31 32 namespace vector { 33 struct VectorTransformsOptions; 34 35 /// Options that control the vector unrolling. 36 struct UnrollVectorOptions { 37 using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>; 38 /// Callback function that indicates whether vector unrolling should be 39 /// attempted on the operation. 40 FilterConstraintFnType filterConstraint = nullptr; 41 UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) { 42 filterConstraint = std::move(constraint); 43 return *this; 44 } 45 46 using NativeShapeFnType = 47 std::function<std::optional<SmallVector<int64_t>>(Operation *op)>; 48 /// Function that returns the shape of the vector to unroll to for a given 49 /// operation. The unrolling is aborted if the function returns 50 /// `std::nullopt`. 51 NativeShapeFnType nativeShape = nullptr; 52 UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { 53 nativeShape = std::move(fn); 54 return *this; 55 } 56 57 /// Set the native shape to use for unrolling. 58 UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) { 59 SmallVector<int64_t> tsShape(shape); 60 nativeShape = [=](Operation *) -> std::optional<SmallVector<int64_t>> { 61 return tsShape; 62 }; 63 return *this; 64 } 65 66 /// Function that returns the traversal order (in terms of "for loop order", 67 /// i.e. slowest varying dimension to fastest varying dimension) that should 68 /// be used when unrolling the given operation into units of the native vector 69 /// size. 70 using UnrollTraversalOrderFnType = 71 std::function<std::optional<SmallVector<int64_t>>(Operation *op)>; 72 UnrollTraversalOrderFnType traversalOrderCallback = nullptr; 73 UnrollVectorOptions & 74 setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) { 75 traversalOrderCallback = std::move(traversalOrderFn); 76 return *this; 77 } 78 }; 79 80 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul 81 /// semantics to a contraction with MMT semantics (matrix matrix multiplication 82 /// with the RHS transposed). This specific form is meant to have the vector 83 /// operands are organized such that the reduction dimension is contiguous. 84 /// Example: 85 /// ``` 86 /// vector.contract {indexing_maps = [affine_map<(m, n, k) -> (m, k)>, 87 /// affine_map<(m, n, k) -> (n, k)>, 88 /// affine_map<(m, n, k) -> (m, n)>], 89 /// iterator_types = ["parallel", "parallel", "reduction"], 90 /// kind = #vector.kind<add>} %a, %b, %c : ... 91 /// ``` 92 /// 93 /// The `constraint` predicate is used to decide which `vector.contraction` ops 94 /// to filter out. 95 void populateVectorContractCanonicalizeMatmulToMMT( 96 RewritePatternSet &patterns, 97 std::function<LogicalResult(vector::ContractionOp)> constraint = 98 [](vector::ContractionOp) { return success(); }, 99 PatternBenefit = 1); 100 101 /// Collect patterns to convert reduction op to vector.contract and fold 102 /// transpose/broadcast ops into the contract. 103 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, 104 PatternBenefit benefit = 1); 105 106 /// Populate `patterns` with the following patterns. 107 /// 108 /// - VectorTransferFullPartialRewriter 109 /// 110 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 111 /// masking) fast path and a slow path. 112 /// 113 /// Example (a 2-D vector.transfer_read): 114 /// ``` 115 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 116 /// ``` 117 /// is transformed into: 118 /// ``` 119 /// %1:3 = scf.if (%inBounds) { 120 /// // fast path, direct cast 121 /// memref.cast %A: memref<A...> to compatibleMemRefType 122 /// scf.yield %view : compatibleMemRefType, index, index 123 /// } else { 124 /// // slow path, not in-bounds vector.transfer or linalg.copy. 125 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 126 /// scf.yield %4 : compatibleMemRefType, index, index 127 // } 128 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 129 /// ``` 130 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 131 /// 132 /// Preconditions: 133 /// 1. `xferOp.permutation_map()` must be a minor identity map 134 /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` 135 /// must be equal. This will be relaxed in the future but requires 136 /// rank-reducing subviews. 137 void populateVectorTransferFullPartialPatterns( 138 RewritePatternSet &patterns, const VectorTransformsOptions &options); 139 140 /// Collect a set of patterns to reduce the rank of the operands of vector 141 /// transfer ops to operate on the largest contigious vector. 142 /// These patterns are useful when lowering to dialects with 1d vector type 143 /// such as llvm and it will result fewer memory reads. 144 void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 145 RewritePatternSet &patterns, PatternBenefit benefit = 1); 146 147 /// Patterns that remove redundant Vector Ops by re-ordering them with 148 /// e.g. elementwise Ops: 149 /// ``` 150 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 151 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 152 /// %r = arith.addf %at, %bt : vector<2x4xf32> 153 /// ``` 154 /// gets converted to: 155 /// ``` 156 /// %0 = arith.addf %a, %b : vector<4x2xf32> 157 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> 158 /// ``` 159 /// At the moment, these patterns are limited to vector.broadcast and 160 /// vector.transpose. 161 void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, 162 PatternBenefit benefit = 1); 163 164 /// Patterns that fold chained vector reductions. These patterns assume that 165 /// elementwise operations (e.g., `arith.addf` with vector operands) are 166 /// cheaper than vector reduction. 167 /// Note that these patterns change the order of reduction which may not always 168 /// produce bit-identical results on some floating point inputs. 169 /// 170 /// Example: 171 /// ``` 172 /// %a = vector.reduction <add> %x, %acc 173 /// %b = vector.reduction <add> %y, %a 174 /// ``` 175 /// is transformed into: 176 /// ``` 177 /// %a = arith.addf %x, %y 178 /// %b = vector.reduction <add> %a, %acc 179 /// ``` 180 void populateChainedVectorReductionFoldingPatterns(RewritePatternSet &patterns, 181 PatternBenefit benefit = 1); 182 183 /// Patterns to break down vector reductions into a series of arith reductions 184 /// over vector elements. This is intended to be simplify code with reductions 185 /// over small vector types and avoid more specialized reduction lowering when 186 /// possible. 187 /// 188 /// Example: 189 /// ``` 190 /// %a = vector.reduction <add> %x : vector<2xf32> into f32 191 /// ``` 192 /// is transformed into: 193 /// ``` 194 /// %y = vector.extract %x[0] : f32 from vector<2xf32> 195 /// %z = vector.extract %x[1] : f32 from vector<2xf32> 196 /// %a = arith.addf %y, %z : f32 197 /// ``` 198 void populateBreakDownVectorReductionPatterns( 199 RewritePatternSet &patterns, unsigned maxNumElementsToExtract = 2, 200 PatternBenefit benefit = 1); 201 202 /// Populate `patterns` with the following patterns. 203 /// 204 /// [DecomposeDifferentRankInsertStridedSlice] 205 /// ========================================== 206 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 207 /// have different ranks. 208 /// 209 /// When ranks are different, InsertStridedSlice needs to extract a properly 210 /// ranked vector from the destination vector into which to insert. This pattern 211 /// only takes care of this extraction part and forwards the rest to 212 /// [VectorInsertStridedSliceOpSameRankRewritePattern]. 213 /// 214 /// For a k-D source and n-D destination vector (k < n), we emit: 215 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to 216 /// insert the k-D source. 217 /// 2. k-D -> (n-1)-D InsertStridedSlice op 218 /// 3. InsertOp that is the reverse of 1. 219 /// 220 /// [DecomposeNDExtractStridedSlice] 221 /// ================================ 222 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower 223 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. 224 void populateVectorInsertExtractStridedSliceDecompositionPatterns( 225 RewritePatternSet &patterns, PatternBenefit benefit = 1); 226 227 /// Populate `patterns` with a pattern to breaks down 1-D extract_strided_slice 228 /// ops into a chain of Extract ops to extract each element from the source, and 229 /// then a chain of Insert ops to insert to the target vector. 230 /// 231 /// If `controlFn` is not nullptr, the pattern will only be invoked on ops that 232 /// `controlFn` returns true. Otherwise runs on ops. 233 void populateVectorExtractStridedSliceToExtractInsertChainPatterns( 234 RewritePatternSet &patterns, 235 std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr, 236 PatternBenefit benefit = 1); 237 238 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops 239 /// based on the destination vector shape. Bitcasts from a lower bitwidth 240 /// element type to a higher bitwidth one are extracted from the lower bitwidth 241 /// based on the native destination vector shape and inserted based on the ratio 242 /// of the bitwidths. 243 /// 244 /// This acts as a last resort way to break down vector.bitcast ops to smaller 245 /// vector sizes. Because this pattern composes until it is bitcasting to a 246 /// single element of the higher bitwidth, the is an optional control function. 247 /// If `controlFn` is not nullptr, the pattern will only apply to ops where 248 /// `controlFn` returns true, otherwise applies to all bitcast ops. 249 void populateBreakDownVectorBitCastOpPatterns( 250 RewritePatternSet &patterns, 251 std::function<bool(BitCastOp)> controlFn = nullptr, 252 PatternBenefit benefit = 1); 253 254 /// Populate `patterns` with the following patterns. 255 /// 256 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns(); 257 /// 258 /// [ConvertSameRankInsertStridedSliceIntoShuffle] 259 /// ============================================== 260 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 261 /// have the same rank. For each outermost index in the slice: 262 /// begin end stride 263 /// [offset : offset+size*stride : stride] 264 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. 265 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D 266 /// 3. the destination subvector is inserted back in the proper place 267 /// 3. InsertOp that is the reverse of 1. 268 /// 269 /// [Convert1DExtractStridedSliceIntoShuffle] 270 /// ========================================= 271 /// For such cases, we can lower it to a ShuffleOp. 272 void populateVectorInsertExtractStridedSliceTransforms( 273 RewritePatternSet &patterns, PatternBenefit benefit = 1); 274 275 /// Collect a set of pattern to unroll vector operations to a smaller shapes. 276 /// `options` structure controls which operations are unrolled and the target 277 /// shape. 278 /// `op` is unrolled to the `targetShape` as follows, for each of its operands: 279 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances 280 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is 281 /// assumed the unrolling factors divide the vector sizes. 282 /// 2. ExtractStridedSlice are created to break-up the vector operands. 283 /// 3. the original op is cloned `numUnrolledInstances` times, once for each 284 /// result. 285 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the 286 /// original vectore shape. 287 /// 288 /// Example: 289 /// 290 /// opA(operand0, operand1) // numUnrolledInstances = 3 291 /// 292 /// operand0 operand1 293 /// | | 294 /// fork fork 295 /// <----------gather all fork ops ---------> 296 /// /|\ /|\ 297 /// f00 f01 f02 f10 f11 f12 298 /// <---------- clone op 3 times ---------> 299 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) 300 /// \ | / 301 /// <-------------------- join -------------------------> 302 /// 303 /// Other local patterns then kick in iteratively (including DCE) and compose 304 /// to combine the ExtractStridedSlice/InsertStridedSlice. 305 void populateVectorUnrollPatterns(RewritePatternSet &patterns, 306 const UnrollVectorOptions &options, 307 PatternBenefit benefit = 1); 308 309 /// Collect a set of vector.shape_cast folding patterns. 310 void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, 311 PatternBenefit benefit = 1); 312 313 /// Collect a set of leading one dimension removal patterns. 314 /// 315 /// These patterns insert vector.shape_cast to remove leading one dimensions 316 /// to expose more canonical forms of read/write/insert/extract operations. 317 /// With them, there are more chances that we can cancel out extract-insert 318 /// pairs or forward write-read pairs. 319 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, 320 PatternBenefit benefit = 1); 321 322 /// Collect a set of one dimension removal patterns. 323 /// 324 /// These patterns insert rank-reducing memref.subview ops to remove one 325 /// dimensions. With them, there are more chances that we can avoid 326 /// potentially expensive vector.shape_cast operations. 327 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, 328 PatternBenefit benefit = 1); 329 330 /// Collect a set of patterns that use vector.shape_cast to help fold unit dims. 331 /// 332 /// These patterns use vector.shape_cast to remove unit dims from e.g. 333 /// arithmetic operations on Vectors. The newly inserted shape_casts will either 334 /// cancel each other out or will be folded away when combined with other 335 /// patterns. 336 void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns, 337 PatternBenefit benefit = 1); 338 339 /// Collect a set of patterns to flatten n-D vector transfers on contiguous 340 /// memref. 341 /// 342 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns 343 /// to transform multiple small n-D transfers into a larger 1-D transfer where 344 /// the memref contiguity properties allow it. 345 /// 346 /// Flattening is only applied if the bitwidth of the trailing vector dimension 347 /// is smaller or equal to `targetVectorBitwidth`. 348 void populateFlattenVectorTransferPatterns( 349 RewritePatternSet &patterns, 350 unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(), 351 PatternBenefit benefit = 1); 352 353 /// Collect a set of patterns that bubble up/down bitcast ops. 354 /// 355 /// These patterns move vector.bitcast ops to be before insert ops or after 356 /// extract ops where suitable. With them, bitcast will happen on smaller 357 /// vectors and there are more chances to share extract/insert ops. 358 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, 359 PatternBenefit benefit = 1); 360 361 /// These patterns materialize masks for various vector ops such as transfers. 362 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, 363 bool force32BitVectorIndices, 364 PatternBenefit benefit = 1); 365 366 /// Appends patterns for emulating vector operations over narrow types with ops 367 /// over wider types. 368 void populateVectorNarrowTypeEmulationPatterns( 369 const arith::NarrowTypeEmulationConverter &typeConverter, 370 RewritePatternSet &patterns); 371 372 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of 373 /// vector operations comprising `shuffle` and `bitwise` ops. 374 /// Warning: these patterns currently only work for little endian targets. 375 FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter, 376 vector::BitCastOp bitCastOp, 377 arith::TruncIOp truncOp, 378 vector::BroadcastOp maybeBroadcastOp); 379 380 /// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of 381 /// vector operations comprising `shuffle` and `bitwise` ops. 382 /// Warning: these patterns currently only work for little endian targets. 383 FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp, 384 vector::BitCastOp bitCastOp, 385 vector::BroadcastOp maybeBroadcastOp); 386 387 /// Appends patterns for rewriting vector operations over narrow types with 388 /// ops over wider types. 389 /// Warning: these patterns currently only work for little endian targets. 390 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns, 391 PatternBenefit benefit = 1); 392 393 /// Appends patterns for emulating a sub-byte vector transpose. 394 void populateVectorTransposeNarrowTypeRewritePatterns( 395 RewritePatternSet &patterns, PatternBenefit benefit = 1); 396 397 /// Populates patterns for ND vectors (N >= 2) linearization and sets up the 398 /// provided ConversionTarget with the appropriate legality configuration for 399 /// the ops to get converted properly. 400 void populateVectorLinearizeTypeConversionsAndLegality( 401 TypeConverter &typeConverter, RewritePatternSet &patterns, 402 ConversionTarget &target, unsigned targetBitWidth); 403 404 /// Populates patterns for linearizing ND (N >= 2) vector operations to 1D 405 /// vector shuffle operations. 406 void populateVectorLinearizeShuffleLikeOpsPatterns( 407 const TypeConverter &typeConverter, RewritePatternSet &patterns, 408 ConversionTarget &target, unsigned targetBitWidth); 409 410 } // namespace vector 411 } // namespace mlir 412 413 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H 414