xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision a4830d14edbb2a21eb35f3d79d1f64bd09db8b1c)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Support/MathExtras.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28   assert((tp.getRank() > 1) && "unlowerable vector type");
29   unsigned numScalableDims = tp.getNumScalableDims();
30   if (tp.getShape().size() == numScalableDims)
31     --numScalableDims;
32   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
33                          numScalableDims);
34 }
35 
36 // Helper to reduce vector type by *all* but one rank at back.
37 static VectorType reducedVectorTypeBack(VectorType tp) {
38   assert((tp.getRank() > 1) && "unlowerable vector type");
39   unsigned numScalableDims = tp.getNumScalableDims();
40   if (numScalableDims > 0)
41     --numScalableDims;
42   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
43                          numScalableDims);
44 }
45 
46 // Helper that picks the proper sequence for inserting.
47 static Value insertOne(ConversionPatternRewriter &rewriter,
48                        LLVMTypeConverter &typeConverter, Location loc,
49                        Value val1, Value val2, Type llvmType, int64_t rank,
50                        int64_t pos) {
51   assert(rank > 0 && "0-D vector corner case should have been handled already");
52   if (rank == 1) {
53     auto idxType = rewriter.getIndexType();
54     auto constant = rewriter.create<LLVM::ConstantOp>(
55         loc, typeConverter.convertType(idxType),
56         rewriter.getIntegerAttr(idxType, pos));
57     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58                                                   constant);
59   }
60   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
61                                               rewriter.getI64ArrayAttr(pos));
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank <= 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
77                                                rewriter.getI64ArrayAttr(pos));
78 }
79 
80 // Helper that returns data layout alignment of a memref.
81 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
82                                  MemRefType memrefType, unsigned &align) {
83   Type elementTy = typeConverter.convertType(memrefType.getElementType());
84   if (!elementTy)
85     return failure();
86 
87   // TODO: this should use the MLIR data layout when it becomes available and
88   // stop depending on translation.
89   llvm::LLVMContext llvmContext;
90   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
91               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
92   return success();
93 }
94 
95 // Add an index vector component to a base pointer. This almost always succeeds
96 // unless the last stride is non-unit or the memory space is not zero.
97 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
98                                     Location loc, Value memref, Value base,
99                                     Value index, MemRefType memRefType,
100                                     VectorType vType, Value &ptrs) {
101   int64_t offset;
102   SmallVector<int64_t, 4> strides;
103   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
104   if (failed(successStrides) || strides.back() != 1 ||
105       memRefType.getMemorySpaceAsInt() != 0)
106     return failure();
107   auto pType = MemRefDescriptor(memref).getElementPtrType();
108   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
109   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
110   return success();
111 }
112 
113 // Casts a strided element pointer to a vector pointer.  The vector pointer
114 // will be in the same address space as the incoming memref type.
115 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
116                          Value ptr, MemRefType memRefType, Type vt) {
117   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
118   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
119 }
120 
121 namespace {
122 
123 /// Trivial Vector to LLVM conversions
124 using VectorScaleOpConversion =
125     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
126 
127 /// Conversion pattern for a vector.bitcast.
128 class VectorBitCastOpConversion
129     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
130 public:
131   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
132 
133   LogicalResult
134   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
135                   ConversionPatternRewriter &rewriter) const override {
136     // Only 0-D and 1-D vectors can be lowered to LLVM.
137     VectorType resultTy = bitCastOp.getResultVectorType();
138     if (resultTy.getRank() > 1)
139       return failure();
140     Type newResultTy = typeConverter->convertType(resultTy);
141     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
142                                                  adaptor.getOperands()[0]);
143     return success();
144   }
145 };
146 
147 /// Conversion pattern for a vector.matrix_multiply.
148 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
149 class VectorMatmulOpConversion
150     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
151 public:
152   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
153 
154   LogicalResult
155   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
156                   ConversionPatternRewriter &rewriter) const override {
157     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
158         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
159         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
160         matmulOp.lhs_columns(), matmulOp.rhs_columns());
161     return success();
162   }
163 };
164 
165 /// Conversion pattern for a vector.flat_transpose.
166 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
167 class VectorFlatTransposeOpConversion
168     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
169 public:
170   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
171 
172   LogicalResult
173   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
174                   ConversionPatternRewriter &rewriter) const override {
175     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
176         transOp, typeConverter->convertType(transOp.res().getType()),
177         adaptor.matrix(), transOp.rows(), transOp.columns());
178     return success();
179   }
180 };
181 
182 /// Overloaded utility that replaces a vector.load, vector.store,
183 /// vector.maskedload and vector.maskedstore with their respective LLVM
184 /// couterparts.
185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
186                                  vector::LoadOpAdaptor adaptor,
187                                  VectorType vectorTy, Value ptr, unsigned align,
188                                  ConversionPatternRewriter &rewriter) {
189   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
190 }
191 
192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
193                                  vector::MaskedLoadOpAdaptor adaptor,
194                                  VectorType vectorTy, Value ptr, unsigned align,
195                                  ConversionPatternRewriter &rewriter) {
196   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
197       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
198 }
199 
200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
201                                  vector::StoreOpAdaptor adaptor,
202                                  VectorType vectorTy, Value ptr, unsigned align,
203                                  ConversionPatternRewriter &rewriter) {
204   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
205                                              ptr, align);
206 }
207 
208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
209                                  vector::MaskedStoreOpAdaptor adaptor,
210                                  VectorType vectorTy, Value ptr, unsigned align,
211                                  ConversionPatternRewriter &rewriter) {
212   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
213       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
214 }
215 
216 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
217 /// vector.maskedstore.
218 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
219 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
220 public:
221   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
222 
223   LogicalResult
224   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
225                   typename LoadOrStoreOp::Adaptor adaptor,
226                   ConversionPatternRewriter &rewriter) const override {
227     // Only 1-D vectors can be lowered to LLVM.
228     VectorType vectorTy = loadOrStoreOp.getVectorType();
229     if (vectorTy.getRank() > 1)
230       return failure();
231 
232     auto loc = loadOrStoreOp->getLoc();
233     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
234 
235     // Resolve alignment.
236     unsigned align;
237     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
238       return failure();
239 
240     // Resolve address.
241     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
242                      .template cast<VectorType>();
243     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
244                                                adaptor.indices(), rewriter);
245     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
246 
247     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
248     return success();
249   }
250 };
251 
252 /// Conversion pattern for a vector.gather.
253 class VectorGatherOpConversion
254     : public ConvertOpToLLVMPattern<vector::GatherOp> {
255 public:
256   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
257 
258   LogicalResult
259   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
260                   ConversionPatternRewriter &rewriter) const override {
261     auto loc = gather->getLoc();
262     MemRefType memRefType = gather.getMemRefType();
263 
264     // Resolve alignment.
265     unsigned align;
266     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
267       return failure();
268 
269     // Resolve address.
270     Value ptrs;
271     VectorType vType = gather.getVectorType();
272     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
273                                      adaptor.indices(), rewriter);
274     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
275                               adaptor.index_vec(), memRefType, vType, ptrs)))
276       return failure();
277 
278     // Replace with the gather intrinsic.
279     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
280         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
281         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
282     return success();
283   }
284 };
285 
286 /// Conversion pattern for a vector.scatter.
287 class VectorScatterOpConversion
288     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
289 public:
290   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
291 
292   LogicalResult
293   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
294                   ConversionPatternRewriter &rewriter) const override {
295     auto loc = scatter->getLoc();
296     MemRefType memRefType = scatter.getMemRefType();
297 
298     // Resolve alignment.
299     unsigned align;
300     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
301       return failure();
302 
303     // Resolve address.
304     Value ptrs;
305     VectorType vType = scatter.getVectorType();
306     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
307                                      adaptor.indices(), rewriter);
308     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
309                               adaptor.index_vec(), memRefType, vType, ptrs)))
310       return failure();
311 
312     // Replace with the scatter intrinsic.
313     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
314         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
315         rewriter.getI32IntegerAttr(align));
316     return success();
317   }
318 };
319 
320 /// Conversion pattern for a vector.expandload.
321 class VectorExpandLoadOpConversion
322     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
323 public:
324   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
325 
326   LogicalResult
327   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329     auto loc = expand->getLoc();
330     MemRefType memRefType = expand.getMemRefType();
331 
332     // Resolve address.
333     auto vtype = typeConverter->convertType(expand.getVectorType());
334     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
335                                      adaptor.indices(), rewriter);
336 
337     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
338         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
339     return success();
340   }
341 };
342 
343 /// Conversion pattern for a vector.compressstore.
344 class VectorCompressStoreOpConversion
345     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
346 public:
347   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
348 
349   LogicalResult
350   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
351                   ConversionPatternRewriter &rewriter) const override {
352     auto loc = compress->getLoc();
353     MemRefType memRefType = compress.getMemRefType();
354 
355     // Resolve address.
356     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
357                                      adaptor.indices(), rewriter);
358 
359     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
360         compress, adaptor.valueToStore(), ptr, adaptor.mask());
361     return success();
362   }
363 };
364 
365 /// Conversion pattern for all vector reductions.
366 class VectorReductionOpConversion
367     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
368 public:
369   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
370                                        bool reassociateFPRed)
371       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
372         reassociateFPReductions(reassociateFPRed) {}
373 
374   LogicalResult
375   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
376                   ConversionPatternRewriter &rewriter) const override {
377     auto kind = reductionOp.kind();
378     Type eltType = reductionOp.dest().getType();
379     Type llvmType = typeConverter->convertType(eltType);
380     Value operand = adaptor.getOperands()[0];
381     if (eltType.isIntOrIndex()) {
382       // Integer reductions: add/mul/min/max/and/or/xor.
383       if (kind == "add")
384         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
385                                                              llvmType, operand);
386       else if (kind == "mul")
387         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
388                                                              llvmType, operand);
389       else if (kind == "minui")
390         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
391             reductionOp, llvmType, operand);
392       else if (kind == "minsi")
393         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
394             reductionOp, llvmType, operand);
395       else if (kind == "maxui")
396         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
397             reductionOp, llvmType, operand);
398       else if (kind == "maxsi")
399         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
400             reductionOp, llvmType, operand);
401       else if (kind == "and")
402         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
403                                                              llvmType, operand);
404       else if (kind == "or")
405         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
406                                                             llvmType, operand);
407       else if (kind == "xor")
408         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
409                                                              llvmType, operand);
410       else
411         return failure();
412       return success();
413     }
414 
415     if (!eltType.isa<FloatType>())
416       return failure();
417 
418     // Floating-point reductions: add/mul/min/max
419     if (kind == "add") {
420       // Optional accumulator (or zero).
421       Value acc = adaptor.getOperands().size() > 1
422                       ? adaptor.getOperands()[1]
423                       : rewriter.create<LLVM::ConstantOp>(
424                             reductionOp->getLoc(), llvmType,
425                             rewriter.getZeroAttr(eltType));
426       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
427           reductionOp, llvmType, acc, operand,
428           rewriter.getBoolAttr(reassociateFPReductions));
429     } else if (kind == "mul") {
430       // Optional accumulator (or one).
431       Value acc = adaptor.getOperands().size() > 1
432                       ? adaptor.getOperands()[1]
433                       : rewriter.create<LLVM::ConstantOp>(
434                             reductionOp->getLoc(), llvmType,
435                             rewriter.getFloatAttr(eltType, 1.0));
436       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
437           reductionOp, llvmType, acc, operand,
438           rewriter.getBoolAttr(reassociateFPReductions));
439     } else if (kind == "minf")
440       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
441       // NaNs/-0.0/+0.0 in the same way.
442       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
443                                                             llvmType, operand);
444     else if (kind == "maxf")
445       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
446       // NaNs/-0.0/+0.0 in the same way.
447       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
448                                                             llvmType, operand);
449     else
450       return failure();
451     return success();
452   }
453 
454 private:
455   const bool reassociateFPReductions;
456 };
457 
458 class VectorShuffleOpConversion
459     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
460 public:
461   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
462 
463   LogicalResult
464   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
465                   ConversionPatternRewriter &rewriter) const override {
466     auto loc = shuffleOp->getLoc();
467     auto v1Type = shuffleOp.getV1VectorType();
468     auto v2Type = shuffleOp.getV2VectorType();
469     auto vectorType = shuffleOp.getVectorType();
470     Type llvmType = typeConverter->convertType(vectorType);
471     auto maskArrayAttr = shuffleOp.mask();
472 
473     // Bail if result type cannot be lowered.
474     if (!llvmType)
475       return failure();
476 
477     // Get rank and dimension sizes.
478     int64_t rank = vectorType.getRank();
479     assert(v1Type.getRank() == rank);
480     assert(v2Type.getRank() == rank);
481     int64_t v1Dim = v1Type.getDimSize(0);
482 
483     // For rank 1, where both operands have *exactly* the same vector type,
484     // there is direct shuffle support in LLVM. Use it!
485     if (rank == 1 && v1Type == v2Type) {
486       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
487           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
488       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
489       return success();
490     }
491 
492     // For all other cases, insert the individual values individually.
493     Type eltType;
494     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
495       eltType = arrayType.getElementType();
496     else
497       eltType = llvmType.cast<VectorType>().getElementType();
498     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
499     int64_t insPos = 0;
500     for (auto en : llvm::enumerate(maskArrayAttr)) {
501       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
502       Value value = adaptor.v1();
503       if (extPos >= v1Dim) {
504         extPos -= v1Dim;
505         value = adaptor.v2();
506       }
507       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
508                                  eltType, rank, extPos);
509       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
510                          llvmType, rank, insPos++);
511     }
512     rewriter.replaceOp(shuffleOp, insert);
513     return success();
514   }
515 };
516 
517 class VectorExtractElementOpConversion
518     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
519 public:
520   using ConvertOpToLLVMPattern<
521       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
522 
523   LogicalResult
524   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
525                   ConversionPatternRewriter &rewriter) const override {
526     auto vectorType = extractEltOp.getVectorType();
527     auto llvmType = typeConverter->convertType(vectorType.getElementType());
528 
529     // Bail if result type cannot be lowered.
530     if (!llvmType)
531       return failure();
532 
533     if (vectorType.getRank() == 0) {
534       Location loc = extractEltOp.getLoc();
535       auto idxType = rewriter.getIndexType();
536       auto zero = rewriter.create<LLVM::ConstantOp>(
537           loc, typeConverter->convertType(idxType),
538           rewriter.getIntegerAttr(idxType, 0));
539       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
540           extractEltOp, llvmType, adaptor.vector(), zero);
541       return success();
542     }
543 
544     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
545         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
546     return success();
547   }
548 };
549 
550 class VectorExtractOpConversion
551     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
552 public:
553   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
554 
555   LogicalResult
556   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
557                   ConversionPatternRewriter &rewriter) const override {
558     auto loc = extractOp->getLoc();
559     auto vectorType = extractOp.getVectorType();
560     auto resultType = extractOp.getResult().getType();
561     auto llvmResultType = typeConverter->convertType(resultType);
562     auto positionArrayAttr = extractOp.position();
563 
564     // Bail if result type cannot be lowered.
565     if (!llvmResultType)
566       return failure();
567 
568     // Extract entire vector. Should be handled by folder, but just to be safe.
569     if (positionArrayAttr.empty()) {
570       rewriter.replaceOp(extractOp, adaptor.vector());
571       return success();
572     }
573 
574     // One-shot extraction of vector from array (only requires extractvalue).
575     if (resultType.isa<VectorType>()) {
576       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
577           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
578       rewriter.replaceOp(extractOp, extracted);
579       return success();
580     }
581 
582     // Potential extraction of 1-D vector from array.
583     auto *context = extractOp->getContext();
584     Value extracted = adaptor.vector();
585     auto positionAttrs = positionArrayAttr.getValue();
586     if (positionAttrs.size() > 1) {
587       auto oneDVectorType = reducedVectorTypeBack(vectorType);
588       auto nMinusOnePositionAttrs =
589           ArrayAttr::get(context, positionAttrs.drop_back());
590       extracted = rewriter.create<LLVM::ExtractValueOp>(
591           loc, typeConverter->convertType(oneDVectorType), extracted,
592           nMinusOnePositionAttrs);
593     }
594 
595     // Remaining extraction of element from 1-D LLVM vector
596     auto position = positionAttrs.back().cast<IntegerAttr>();
597     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
598     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
599     extracted =
600         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
601     rewriter.replaceOp(extractOp, extracted);
602 
603     return success();
604   }
605 };
606 
607 /// Conversion pattern that turns a vector.fma on a 1-D vector
608 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
609 /// This does not match vectors of n >= 2 rank.
610 ///
611 /// Example:
612 /// ```
613 ///  vector.fma %a, %a, %a : vector<8xf32>
614 /// ```
615 /// is converted to:
616 /// ```
617 ///  llvm.intr.fmuladd %va, %va, %va:
618 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
619 ///    -> !llvm."<8 x f32>">
620 /// ```
621 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
622 public:
623   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
624 
625   LogicalResult
626   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
627                   ConversionPatternRewriter &rewriter) const override {
628     VectorType vType = fmaOp.getVectorType();
629     if (vType.getRank() != 1)
630       return failure();
631     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
632                                                  adaptor.rhs(), adaptor.acc());
633     return success();
634   }
635 };
636 
637 class VectorInsertElementOpConversion
638     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
639 public:
640   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
641 
642   LogicalResult
643   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
644                   ConversionPatternRewriter &rewriter) const override {
645     auto vectorType = insertEltOp.getDestVectorType();
646     auto llvmType = typeConverter->convertType(vectorType);
647 
648     // Bail if result type cannot be lowered.
649     if (!llvmType)
650       return failure();
651 
652     if (vectorType.getRank() == 0) {
653       Location loc = insertEltOp.getLoc();
654       auto idxType = rewriter.getIndexType();
655       auto zero = rewriter.create<LLVM::ConstantOp>(
656           loc, typeConverter->convertType(idxType),
657           rewriter.getIntegerAttr(idxType, 0));
658       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
659           insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
660       return success();
661     }
662 
663     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
664         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
665         adaptor.position());
666     return success();
667   }
668 };
669 
670 class VectorInsertOpConversion
671     : public ConvertOpToLLVMPattern<vector::InsertOp> {
672 public:
673   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
674 
675   LogicalResult
676   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
677                   ConversionPatternRewriter &rewriter) const override {
678     auto loc = insertOp->getLoc();
679     auto sourceType = insertOp.getSourceType();
680     auto destVectorType = insertOp.getDestVectorType();
681     auto llvmResultType = typeConverter->convertType(destVectorType);
682     auto positionArrayAttr = insertOp.position();
683 
684     // Bail if result type cannot be lowered.
685     if (!llvmResultType)
686       return failure();
687 
688     // Overwrite entire vector with value. Should be handled by folder, but
689     // just to be safe.
690     if (positionArrayAttr.empty()) {
691       rewriter.replaceOp(insertOp, adaptor.source());
692       return success();
693     }
694 
695     // One-shot insertion of a vector into an array (only requires insertvalue).
696     if (sourceType.isa<VectorType>()) {
697       Value inserted = rewriter.create<LLVM::InsertValueOp>(
698           loc, llvmResultType, adaptor.dest(), adaptor.source(),
699           positionArrayAttr);
700       rewriter.replaceOp(insertOp, inserted);
701       return success();
702     }
703 
704     // Potential extraction of 1-D vector from array.
705     auto *context = insertOp->getContext();
706     Value extracted = adaptor.dest();
707     auto positionAttrs = positionArrayAttr.getValue();
708     auto position = positionAttrs.back().cast<IntegerAttr>();
709     auto oneDVectorType = destVectorType;
710     if (positionAttrs.size() > 1) {
711       oneDVectorType = reducedVectorTypeBack(destVectorType);
712       auto nMinusOnePositionAttrs =
713           ArrayAttr::get(context, positionAttrs.drop_back());
714       extracted = rewriter.create<LLVM::ExtractValueOp>(
715           loc, typeConverter->convertType(oneDVectorType), extracted,
716           nMinusOnePositionAttrs);
717     }
718 
719     // Insertion of an element into a 1-D LLVM vector.
720     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
721     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
722     Value inserted = rewriter.create<LLVM::InsertElementOp>(
723         loc, typeConverter->convertType(oneDVectorType), extracted,
724         adaptor.source(), constant);
725 
726     // Potential insertion of resulting 1-D vector into array.
727     if (positionAttrs.size() > 1) {
728       auto nMinusOnePositionAttrs =
729           ArrayAttr::get(context, positionAttrs.drop_back());
730       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
731                                                       adaptor.dest(), inserted,
732                                                       nMinusOnePositionAttrs);
733     }
734 
735     rewriter.replaceOp(insertOp, inserted);
736     return success();
737   }
738 };
739 
740 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
741 ///
742 /// Example:
743 /// ```
744 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
745 /// ```
746 /// is rewritten into:
747 /// ```
748 ///  %r = splat %f0: vector<2x4xf32>
749 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
750 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
751 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
752 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
753 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
754 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
755 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
756 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
757 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
758 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
759 ///  // %r3 holds the final value.
760 /// ```
761 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
762 public:
763   using OpRewritePattern<FMAOp>::OpRewritePattern;
764 
765   void initialize() {
766     // This pattern recursively unpacks one dimension at a time. The recursion
767     // bounded as the rank is strictly decreasing.
768     setHasBoundedRewriteRecursion();
769   }
770 
771   LogicalResult matchAndRewrite(FMAOp op,
772                                 PatternRewriter &rewriter) const override {
773     auto vType = op.getVectorType();
774     if (vType.getRank() < 2)
775       return failure();
776 
777     auto loc = op.getLoc();
778     auto elemType = vType.getElementType();
779     Value zero = rewriter.create<arith::ConstantOp>(
780         loc, elemType, rewriter.getZeroAttr(elemType));
781     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
782     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
783       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
784       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
785       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
786       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
787       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
788     }
789     rewriter.replaceOp(op, desc);
790     return success();
791   }
792 };
793 
794 /// Returns the strides if the memory underlying `memRefType` has a contiguous
795 /// static layout.
796 static llvm::Optional<SmallVector<int64_t, 4>>
797 computeContiguousStrides(MemRefType memRefType) {
798   int64_t offset;
799   SmallVector<int64_t, 4> strides;
800   if (failed(getStridesAndOffset(memRefType, strides, offset)))
801     return None;
802   if (!strides.empty() && strides.back() != 1)
803     return None;
804   // If no layout or identity layout, this is contiguous by definition.
805   if (memRefType.getLayout().isIdentity())
806     return strides;
807 
808   // Otherwise, we must determine contiguity form shapes. This can only ever
809   // work in static cases because MemRefType is underspecified to represent
810   // contiguous dynamic shapes in other ways than with just empty/identity
811   // layout.
812   auto sizes = memRefType.getShape();
813   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
814     if (ShapedType::isDynamic(sizes[index + 1]) ||
815         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
816         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
817       return None;
818     if (strides[index] != strides[index + 1] * sizes[index + 1])
819       return None;
820   }
821   return strides;
822 }
823 
824 class VectorTypeCastOpConversion
825     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
826 public:
827   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
828 
829   LogicalResult
830   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
831                   ConversionPatternRewriter &rewriter) const override {
832     auto loc = castOp->getLoc();
833     MemRefType sourceMemRefType =
834         castOp.getOperand().getType().cast<MemRefType>();
835     MemRefType targetMemRefType = castOp.getType();
836 
837     // Only static shape casts supported atm.
838     if (!sourceMemRefType.hasStaticShape() ||
839         !targetMemRefType.hasStaticShape())
840       return failure();
841 
842     auto llvmSourceDescriptorTy =
843         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
844     if (!llvmSourceDescriptorTy)
845       return failure();
846     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
847 
848     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
849                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
850     if (!llvmTargetDescriptorTy)
851       return failure();
852 
853     // Only contiguous source buffers supported atm.
854     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
855     if (!sourceStrides)
856       return failure();
857     auto targetStrides = computeContiguousStrides(targetMemRefType);
858     if (!targetStrides)
859       return failure();
860     // Only support static strides for now, regardless of contiguity.
861     if (llvm::any_of(*targetStrides, [](int64_t stride) {
862           return ShapedType::isDynamicStrideOrOffset(stride);
863         }))
864       return failure();
865 
866     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
867 
868     // Create descriptor.
869     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
870     Type llvmTargetElementTy = desc.getElementPtrType();
871     // Set allocated ptr.
872     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
873     allocated =
874         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
875     desc.setAllocatedPtr(rewriter, loc, allocated);
876     // Set aligned ptr.
877     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
878     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
879     desc.setAlignedPtr(rewriter, loc, ptr);
880     // Fill offset 0.
881     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
882     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
883     desc.setOffset(rewriter, loc, zero);
884 
885     // Fill size and stride descriptors in memref.
886     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
887       int64_t index = indexedSize.index();
888       auto sizeAttr =
889           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
890       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
891       desc.setSize(rewriter, loc, index, size);
892       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
893                                                 (*targetStrides)[index]);
894       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
895       desc.setStride(rewriter, loc, index, stride);
896     }
897 
898     rewriter.replaceOp(castOp, {desc});
899     return success();
900   }
901 };
902 
903 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
904 public:
905   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
906 
907   // Proof-of-concept lowering implementation that relies on a small
908   // runtime support library, which only needs to provide a few
909   // printing methods (single value for all data types, opening/closing
910   // bracket, comma, newline). The lowering fully unrolls a vector
911   // in terms of these elementary printing operations. The advantage
912   // of this approach is that the library can remain unaware of all
913   // low-level implementation details of vectors while still supporting
914   // output of any shaped and dimensioned vector. Due to full unrolling,
915   // this approach is less suited for very large vectors though.
916   //
917   // TODO: rely solely on libc in future? something else?
918   //
919   LogicalResult
920   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
921                   ConversionPatternRewriter &rewriter) const override {
922     Type printType = printOp.getPrintType();
923 
924     if (typeConverter->convertType(printType) == nullptr)
925       return failure();
926 
927     // Make sure element type has runtime support.
928     PrintConversion conversion = PrintConversion::None;
929     VectorType vectorType = printType.dyn_cast<VectorType>();
930     Type eltType = vectorType ? vectorType.getElementType() : printType;
931     Operation *printer;
932     if (eltType.isF32()) {
933       printer =
934           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
935     } else if (eltType.isF64()) {
936       printer =
937           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
938     } else if (eltType.isIndex()) {
939       printer =
940           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
941     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
942       // Integers need a zero or sign extension on the operand
943       // (depending on the source type) as well as a signed or
944       // unsigned print method. Up to 64-bit is supported.
945       unsigned width = intTy.getWidth();
946       if (intTy.isUnsigned()) {
947         if (width <= 64) {
948           if (width < 64)
949             conversion = PrintConversion::ZeroExt64;
950           printer = LLVM::lookupOrCreatePrintU64Fn(
951               printOp->getParentOfType<ModuleOp>());
952         } else {
953           return failure();
954         }
955       } else {
956         assert(intTy.isSignless() || intTy.isSigned());
957         if (width <= 64) {
958           // Note that we *always* zero extend booleans (1-bit integers),
959           // so that true/false is printed as 1/0 rather than -1/0.
960           if (width == 1)
961             conversion = PrintConversion::ZeroExt64;
962           else if (width < 64)
963             conversion = PrintConversion::SignExt64;
964           printer = LLVM::lookupOrCreatePrintI64Fn(
965               printOp->getParentOfType<ModuleOp>());
966         } else {
967           return failure();
968         }
969       }
970     } else {
971       return failure();
972     }
973 
974     // Unroll vector into elementary print calls.
975     int64_t rank = vectorType ? vectorType.getRank() : 0;
976     Type type = vectorType ? vectorType : eltType;
977     emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
978               conversion);
979     emitCall(rewriter, printOp->getLoc(),
980              LLVM::lookupOrCreatePrintNewlineFn(
981                  printOp->getParentOfType<ModuleOp>()));
982     rewriter.eraseOp(printOp);
983     return success();
984   }
985 
986 private:
987   enum class PrintConversion {
988     // clang-format off
989     None,
990     ZeroExt64,
991     SignExt64
992     // clang-format on
993   };
994 
995   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
996                  Value value, Type type, Operation *printer, int64_t rank,
997                  PrintConversion conversion) const {
998     VectorType vectorType = type.dyn_cast<VectorType>();
999     Location loc = op->getLoc();
1000     if (!vectorType) {
1001       assert(rank == 0 && "The scalar case expects rank == 0");
1002       switch (conversion) {
1003       case PrintConversion::ZeroExt64:
1004         value = rewriter.create<arith::ExtUIOp>(
1005             loc, value, IntegerType::get(rewriter.getContext(), 64));
1006         break;
1007       case PrintConversion::SignExt64:
1008         value = rewriter.create<arith::ExtSIOp>(
1009             loc, value, IntegerType::get(rewriter.getContext(), 64));
1010         break;
1011       case PrintConversion::None:
1012         break;
1013       }
1014       emitCall(rewriter, loc, printer, value);
1015       return;
1016     }
1017 
1018     emitCall(rewriter, loc,
1019              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1020     Operation *printComma =
1021         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1022 
1023     if (rank <= 1) {
1024       auto reducedType = vectorType.getElementType();
1025       auto llvmType = typeConverter->convertType(reducedType);
1026       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1027       for (int64_t d = 0; d < dim; ++d) {
1028         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1029                                      llvmType, /*rank=*/0, /*pos=*/d);
1030         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1031                   conversion);
1032         if (d != dim - 1)
1033           emitCall(rewriter, loc, printComma);
1034       }
1035       emitCall(
1036           rewriter, loc,
1037           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1038       return;
1039     }
1040 
1041     int64_t dim = vectorType.getDimSize(0);
1042     for (int64_t d = 0; d < dim; ++d) {
1043       auto reducedType = reducedVectorTypeFront(vectorType);
1044       auto llvmType = typeConverter->convertType(reducedType);
1045       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1046                                    llvmType, rank, d);
1047       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1048                 conversion);
1049       if (d != dim - 1)
1050         emitCall(rewriter, loc, printComma);
1051     }
1052     emitCall(rewriter, loc,
1053              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1054   }
1055 
1056   // Helper to emit a call.
1057   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1058                        Operation *ref, ValueRange params = ValueRange()) {
1059     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1060                                   params);
1061   }
1062 };
1063 
1064 } // namespace
1065 
1066 /// Populate the given list with patterns that convert from Vector to LLVM.
1067 void mlir::populateVectorToLLVMConversionPatterns(
1068     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1069     bool reassociateFPReductions) {
1070   MLIRContext *ctx = converter.getDialect()->getContext();
1071   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1072   populateVectorInsertExtractStridedSliceTransforms(patterns);
1073   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1074   patterns
1075       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1076            VectorExtractElementOpConversion, VectorExtractOpConversion,
1077            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1078            VectorInsertOpConversion, VectorPrintOpConversion,
1079            VectorTypeCastOpConversion, VectorScaleOpConversion,
1080            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1081            VectorLoadStoreConversion<vector::MaskedLoadOp,
1082                                      vector::MaskedLoadOpAdaptor>,
1083            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1084            VectorLoadStoreConversion<vector::MaskedStoreOp,
1085                                      vector::MaskedStoreOpAdaptor>,
1086            VectorGatherOpConversion, VectorScatterOpConversion,
1087            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
1088           converter);
1089   // Transfer ops with rank > 1 are handled by VectorToSCF.
1090   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1091 }
1092 
1093 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1094     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1095   patterns.add<VectorMatmulOpConversion>(converter);
1096   patterns.add<VectorFlatTransposeOpConversion>(converter);
1097 }
1098