1 //===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H 10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H 11 12 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 13 14 namespace mlir { 15 class RewritePatternSet; 16 17 namespace vector { 18 19 //===----------------------------------------------------------------------===// 20 // Lowering pattern populate functions 21 //===----------------------------------------------------------------------===// 22 23 /// Populate the pattern set with the following patterns: 24 /// 25 /// [OuterProductOpLowering] 26 /// Progressively lower a `vector.outerproduct` to linearized 27 /// `vector.extract` + `vector.fma` + `vector.insert`. 28 /// 29 /// [ContractionOpLowering] 30 /// Progressive lowering of ContractionOp. 31 /// One: 32 /// %x = vector.contract with at least one free/batch dimension 33 /// is replaced by: 34 /// %a = vector.contract with one less free/batch dimension 35 /// %b = vector.contract with one less free/batch dimension 36 /// 37 /// [ContractionOpToMatmulOpLowering] 38 /// Progressively lower a `vector.contract` with row-major matmul semantics to 39 /// linearized `vector.shape_cast` + `vector.matmul` on the way to 40 /// `llvm.matrix.multiply`. 41 /// 42 /// [ContractionOpToDotLowering] 43 /// Progressively lower a `vector.contract` with row-major matmul semantics to 44 /// linearized `vector.extract` + `vector.reduce` + `vector.insert`. 45 /// 46 /// [ContractionOpToOuterProductOpLowering] 47 /// Progressively lower a `vector.contract` with row-major matmul semantics to 48 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`. 49 void populateVectorContractLoweringPatterns( 50 RewritePatternSet &patterns, VectorTransformsOptions options, 51 PatternBenefit benefit = 1, bool disableOuterProductLowering = false); 52 53 /// Populate the pattern set with the following patterns: 54 /// 55 /// [OuterProductOpLowering] 56 /// Progressively lower a `vector.outerproduct` to linearized 57 /// `vector.extract` + `vector.fma` + `vector.insert`. 58 void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, 59 PatternBenefit benefit = 1); 60 61 /// Collect a set of patterns to convert vector.multi_reduction op into 62 /// a sequence of vector.reduction ops. The patterns comprise: 63 /// 64 /// [InnerOuterDimReductionConversion] 65 /// Rewrites vector.multi_reduction such that all reduction dimensions are 66 /// either innermost or outermost, by adding the proper vector.transpose 67 /// operations. 68 /// 69 /// [ReduceMultiDimReductionRank] 70 /// Once in innermost or outermost reduction 71 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, 72 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand 73 /// back. 74 /// 75 /// [TwoDimMultiReductionToElementWise] 76 /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction 77 /// dimension, unroll the outer dimension to obtain a sequence of 1-D vector 78 /// ops. This also has an opportunity for tree-reduction (in the future). 79 /// 80 /// [TwoDimMultiReductionToReduction] 81 /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction 82 /// dimension, unroll the outer dimension to obtain a sequence of extract + 83 /// vector.reduction + insert. This can further lower to horizontal reduction 84 /// ops. 85 /// 86 /// [OneDimMultiReductionToTwoDim] 87 /// For cases that reduce to 1-D vector<k> reduction (and are thus missing 88 /// either a parallel or a reduction), we lift them back up to 2-D with a simple 89 /// vector.shape_cast to vector<1xk> so that the other patterns can kick in, 90 /// thus fully exiting out of the vector.multi_reduction abstraction. 91 void populateVectorMultiReductionLoweringPatterns( 92 RewritePatternSet &patterns, VectorMultiReductionLowering options, 93 PatternBenefit benefit = 1); 94 95 /// Populate the pattern set with the following patterns: 96 /// 97 /// [TransferReadToVectorLoadLowering] 98 /// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D 99 /// BroadcastOp until dim 1. 100 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, 101 PatternBenefit benefit = 1); 102 103 /// Populate the pattern set with the following patterns: 104 /// 105 /// [CreateMaskOp] 106 /// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1. 107 /// 108 /// [ConstantMaskOp] 109 /// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until 110 /// dim 1. 111 void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, 112 PatternBenefit benefit = 1); 113 114 /// Collects patterns that lower scalar vector transfer ops to memref loads and 115 /// stores when beneficial. If `allowMultipleUses` is set to true, the patterns 116 /// are applied to vector transfer reads with any number of uses. Otherwise, 117 /// only vector transfer reads with a single use will be lowered. 118 void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, 119 PatternBenefit benefit, 120 bool allowMultipleUses); 121 122 /// Populate the pattern set with the following patterns: 123 /// 124 /// [ShapeCastOp2DDownCastRewritePattern] 125 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 126 /// vectors progressively. 127 /// 128 /// [ShapeCastOp2DUpCastRewritePattern] 129 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 130 /// vectors progressively. 131 /// 132 /// [ShapeCastOpRewritePattern] 133 /// Reference lowering to fully unrolled sequences of single element ExtractOp + 134 /// InsertOp. Note that applying this pattern can almost always be considered a 135 /// performance bug. 136 void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, 137 PatternBenefit benefit = 1); 138 139 /// Populate the pattern set with the following patterns: 140 /// 141 /// [TransposeOpLowering] 142 /// 143 /// [TransposeOp2DToShuffleLowering] 144 /// 145 void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, 146 VectorTransformsOptions options, 147 PatternBenefit benefit = 1); 148 149 /// Populate the pattern set with the following patterns: 150 /// 151 /// [TransferReadToVectorLoadLowering] 152 /// Progressive lowering of transfer_read.This pattern supports lowering of 153 /// `vector.transfer_read` to a combination of `vector.load` and 154 /// `vector.broadcast` 155 /// 156 /// [TransferWriteToVectorStoreLowering] 157 /// Progressive lowering of transfer_write. This pattern supports lowering of 158 /// `vector.transfer_write` to `vector.store` 159 /// 160 /// [VectorLoadToMemrefLoadLowering] 161 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 162 /// 163 /// [VectorStoreToMemrefStoreLowering] 164 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 165 /// 166 /// These patterns lower transfer ops to simpler ops like `vector.load`, 167 /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank 168 /// of a most `maxTransferRank` are lowered. This is useful when combined with 169 /// VectorToSCF, which reduces the rank of vector transfer ops. 170 void populateVectorTransferLoweringPatterns( 171 RewritePatternSet &patterns, 172 std::optional<unsigned> maxTransferRank = std::nullopt, 173 PatternBenefit benefit = 1); 174 175 /// Collect a set of transfer read/write lowering patterns that simplify the 176 /// permutation map (e.g., converting it to a minor identity map) by inserting 177 /// broadcasts and transposes. More specifically: 178 /// 179 /// [TransferReadPermutationLowering] 180 /// Lower transfer_read op with permutation into a transfer_read with a 181 /// permutation map composed of leading zeros followed by a minor identity + 182 /// vector.transpose op. 183 /// Ex: 184 /// vector.transfer_read ... 185 /// permutation_map: (d0, d1, d2) -> (0, d1) 186 /// into: 187 /// %v = vector.transfer_read ... 188 /// permutation_map: (d0, d1, d2) -> (d1, 0) 189 /// vector.transpose %v, [1, 0] 190 /// 191 /// vector.transfer_read ... 192 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 193 /// into: 194 /// %v = vector.transfer_read ... 195 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 196 /// vector.transpose %v, [0, 1, 3, 2, 4] 197 /// Note that an alternative is to transform it to linalg.transpose + 198 /// vector.transfer_read to do the transpose in memory instead. 199 /// 200 /// [TransferWritePermutationLowering] 201 /// Lower transfer_write op with permutation into a transfer_write with a 202 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 203 /// Ex: 204 /// vector.transfer_write %v ... 205 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 206 /// into: 207 /// %tmp = vector.transpose %v, [2, 0, 1] 208 /// vector.transfer_write %tmp ... 209 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 210 /// 211 /// vector.transfer_write %v ... 212 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 213 /// into: 214 /// %tmp = vector.transpose %v, [1, 0] 215 /// %v = vector.transfer_write %tmp ... 216 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 217 /// 218 /// [TransferOpReduceRank] 219 /// Lower transfer_read op with broadcast in the leading dimensions into 220 /// transfer_read of lower rank + vector.broadcast. 221 /// Ex: vector.transfer_read ... 222 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 223 /// into: 224 /// %v = vector.transfer_read ... 225 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 226 /// vector.broadcast %v 227 void populateVectorTransferPermutationMapLoweringPatterns( 228 RewritePatternSet &patterns, PatternBenefit benefit = 1); 229 230 /// Populate the pattern set with the following patterns: 231 /// 232 /// [ScanToArithOps] 233 /// Convert vector.scan op into arith ops and vector.insert_strided_slice / 234 /// vector.extract_strided_slice. 235 void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, 236 PatternBenefit benefit = 1); 237 238 /// Populate the pattern set with the following patterns: 239 /// 240 /// [StepToArithConstantOp] 241 /// Convert vector.step op into arith ops if not using scalable vectors 242 void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, 243 PatternBenefit benefit = 1); 244 245 /// Populate the pattern set with the following patterns: 246 /// 247 /// [FlattenGather] 248 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the 249 /// outermost dimension. 250 /// 251 /// [Gather1DToConditionalLoads] 252 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or 253 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these 254 /// loads/extracts are made conditional using `scf.if` ops. 255 void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, 256 PatternBenefit benefit = 1); 257 258 /// Populates instances of `MaskOpRewritePattern` to lower masked operations 259 /// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and 260 /// not its nested `MaskableOpInterface`. 261 void populateVectorMaskLoweringPatternsForSideEffectingOps( 262 RewritePatternSet &patterns); 263 264 /// Populate the pattern set with the following patterns: 265 /// 266 /// [VectorMaskedLoadOpConverter] 267 /// Turns vector.maskedload to scf.if + memref.load 268 /// 269 /// [VectorMaskedStoreOpConverter] 270 /// Turns vector.maskedstore to scf.if + memref.store 271 void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, 272 PatternBenefit benefit = 1); 273 274 /// Populate the pattern set with the following patterns: 275 /// 276 /// [UnrollInterleaveOp] 277 /// A one-shot unrolling of InterleaveOp to (one or more) ExtractOp + 278 /// InterleaveOp (of `targetRank`) + InsertOp. 279 void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, 280 int64_t targetRank = 1, 281 PatternBenefit benefit = 1); 282 283 void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, 284 PatternBenefit benefit = 1); 285 286 /// Populates the pattern set with the following patterns: 287 /// 288 /// [UnrollBitCastOp] 289 /// A one-shot unrolling of BitCastOp to (one or more) ExtractOp + 290 /// BitCastOp (of `targetRank`) + InsertOp. 291 void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, 292 int64_t targetRank = 1, 293 PatternBenefit benefit = 1); 294 295 /// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where 296 /// n > 1. 297 void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); 298 299 } // namespace vector 300 } // namespace mlir 301 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H 302