xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 99ef9eebad51fbb5f73ffe747a529ea189f336b7)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Support/MathExtras.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28   assert((tp.getRank() > 1) && "unlowerable vector type");
29   unsigned numScalableDims = tp.getNumScalableDims();
30   if (tp.getShape().size() == numScalableDims)
31     --numScalableDims;
32   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
33                          numScalableDims);
34 }
35 
36 // Helper to reduce vector type by *all* but one rank at back.
37 static VectorType reducedVectorTypeBack(VectorType tp) {
38   assert((tp.getRank() > 1) && "unlowerable vector type");
39   unsigned numScalableDims = tp.getNumScalableDims();
40   if (numScalableDims > 0)
41     --numScalableDims;
42   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
43                          numScalableDims);
44 }
45 
46 // Helper that picks the proper sequence for inserting.
47 static Value insertOne(ConversionPatternRewriter &rewriter,
48                        LLVMTypeConverter &typeConverter, Location loc,
49                        Value val1, Value val2, Type llvmType, int64_t rank,
50                        int64_t pos) {
51   assert(rank > 0 && "0-D vector corner case should have been handled already");
52   if (rank == 1) {
53     auto idxType = rewriter.getIndexType();
54     auto constant = rewriter.create<LLVM::ConstantOp>(
55         loc, typeConverter.convertType(idxType),
56         rewriter.getIntegerAttr(idxType, pos));
57     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58                                                   constant);
59   }
60   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
61                                               rewriter.getI64ArrayAttr(pos));
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank <= 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
77                                                rewriter.getI64ArrayAttr(pos));
78 }
79 
80 // Helper that returns data layout alignment of a memref.
81 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
82                                  MemRefType memrefType, unsigned &align) {
83   Type elementTy = typeConverter.convertType(memrefType.getElementType());
84   if (!elementTy)
85     return failure();
86 
87   // TODO: this should use the MLIR data layout when it becomes available and
88   // stop depending on translation.
89   llvm::LLVMContext llvmContext;
90   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
91               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
92   return success();
93 }
94 
95 // Add an index vector component to a base pointer. This almost always succeeds
96 // unless the last stride is non-unit or the memory space is not zero.
97 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
98                                     Location loc, Value memref, Value base,
99                                     Value index, MemRefType memRefType,
100                                     VectorType vType, Value &ptrs) {
101   int64_t offset;
102   SmallVector<int64_t, 4> strides;
103   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
104   if (failed(successStrides) || strides.back() != 1 ||
105       memRefType.getMemorySpaceAsInt() != 0)
106     return failure();
107   auto pType = MemRefDescriptor(memref).getElementPtrType();
108   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
109   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
110   return success();
111 }
112 
113 // Casts a strided element pointer to a vector pointer.  The vector pointer
114 // will be in the same address space as the incoming memref type.
115 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
116                          Value ptr, MemRefType memRefType, Type vt) {
117   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
118   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
119 }
120 
121 namespace {
122 
123 /// Trivial Vector to LLVM conversions
124 using VectorScaleOpConversion =
125     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
126 
127 /// Conversion pattern for a vector.bitcast.
128 class VectorBitCastOpConversion
129     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
130 public:
131   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
132 
133   LogicalResult
134   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
135                   ConversionPatternRewriter &rewriter) const override {
136     // Only 0-D and 1-D vectors can be lowered to LLVM.
137     VectorType resultTy = bitCastOp.getResultVectorType();
138     if (resultTy.getRank() > 1)
139       return failure();
140     Type newResultTy = typeConverter->convertType(resultTy);
141     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
142                                                  adaptor.getOperands()[0]);
143     return success();
144   }
145 };
146 
147 /// Conversion pattern for a vector.matrix_multiply.
148 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
149 class VectorMatmulOpConversion
150     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
151 public:
152   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
153 
154   LogicalResult
155   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
156                   ConversionPatternRewriter &rewriter) const override {
157     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
158         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
159         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
160         matmulOp.lhs_columns(), matmulOp.rhs_columns());
161     return success();
162   }
163 };
164 
165 /// Conversion pattern for a vector.flat_transpose.
166 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
167 class VectorFlatTransposeOpConversion
168     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
169 public:
170   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
171 
172   LogicalResult
173   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
174                   ConversionPatternRewriter &rewriter) const override {
175     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
176         transOp, typeConverter->convertType(transOp.res().getType()),
177         adaptor.matrix(), transOp.rows(), transOp.columns());
178     return success();
179   }
180 };
181 
182 /// Overloaded utility that replaces a vector.load, vector.store,
183 /// vector.maskedload and vector.maskedstore with their respective LLVM
184 /// couterparts.
185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
186                                  vector::LoadOpAdaptor adaptor,
187                                  VectorType vectorTy, Value ptr, unsigned align,
188                                  ConversionPatternRewriter &rewriter) {
189   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
190 }
191 
192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
193                                  vector::MaskedLoadOpAdaptor adaptor,
194                                  VectorType vectorTy, Value ptr, unsigned align,
195                                  ConversionPatternRewriter &rewriter) {
196   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
197       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
198 }
199 
200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
201                                  vector::StoreOpAdaptor adaptor,
202                                  VectorType vectorTy, Value ptr, unsigned align,
203                                  ConversionPatternRewriter &rewriter) {
204   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
205                                              ptr, align);
206 }
207 
208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
209                                  vector::MaskedStoreOpAdaptor adaptor,
210                                  VectorType vectorTy, Value ptr, unsigned align,
211                                  ConversionPatternRewriter &rewriter) {
212   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
213       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
214 }
215 
216 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
217 /// vector.maskedstore.
218 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
219 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
220 public:
221   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
222 
223   LogicalResult
224   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
225                   typename LoadOrStoreOp::Adaptor adaptor,
226                   ConversionPatternRewriter &rewriter) const override {
227     // Only 1-D vectors can be lowered to LLVM.
228     VectorType vectorTy = loadOrStoreOp.getVectorType();
229     if (vectorTy.getRank() > 1)
230       return failure();
231 
232     auto loc = loadOrStoreOp->getLoc();
233     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
234 
235     // Resolve alignment.
236     unsigned align;
237     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
238       return failure();
239 
240     // Resolve address.
241     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
242                      .template cast<VectorType>();
243     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
244                                                adaptor.indices(), rewriter);
245     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
246 
247     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
248     return success();
249   }
250 };
251 
252 /// Conversion pattern for a vector.gather.
253 class VectorGatherOpConversion
254     : public ConvertOpToLLVMPattern<vector::GatherOp> {
255 public:
256   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
257 
258   LogicalResult
259   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
260                   ConversionPatternRewriter &rewriter) const override {
261     auto loc = gather->getLoc();
262     MemRefType memRefType = gather.getMemRefType();
263 
264     // Resolve alignment.
265     unsigned align;
266     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
267       return failure();
268 
269     // Resolve address.
270     Value ptrs;
271     VectorType vType = gather.getVectorType();
272     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
273                                      adaptor.indices(), rewriter);
274     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
275                               adaptor.index_vec(), memRefType, vType, ptrs)))
276       return failure();
277 
278     // Replace with the gather intrinsic.
279     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
280         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
281         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
282     return success();
283   }
284 };
285 
286 /// Conversion pattern for a vector.scatter.
287 class VectorScatterOpConversion
288     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
289 public:
290   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
291 
292   LogicalResult
293   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
294                   ConversionPatternRewriter &rewriter) const override {
295     auto loc = scatter->getLoc();
296     MemRefType memRefType = scatter.getMemRefType();
297 
298     // Resolve alignment.
299     unsigned align;
300     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
301       return failure();
302 
303     // Resolve address.
304     Value ptrs;
305     VectorType vType = scatter.getVectorType();
306     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
307                                      adaptor.indices(), rewriter);
308     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
309                               adaptor.index_vec(), memRefType, vType, ptrs)))
310       return failure();
311 
312     // Replace with the scatter intrinsic.
313     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
314         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
315         rewriter.getI32IntegerAttr(align));
316     return success();
317   }
318 };
319 
320 /// Conversion pattern for a vector.expandload.
321 class VectorExpandLoadOpConversion
322     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
323 public:
324   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
325 
326   LogicalResult
327   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329     auto loc = expand->getLoc();
330     MemRefType memRefType = expand.getMemRefType();
331 
332     // Resolve address.
333     auto vtype = typeConverter->convertType(expand.getVectorType());
334     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
335                                      adaptor.indices(), rewriter);
336 
337     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
338         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
339     return success();
340   }
341 };
342 
343 /// Conversion pattern for a vector.compressstore.
344 class VectorCompressStoreOpConversion
345     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
346 public:
347   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
348 
349   LogicalResult
350   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
351                   ConversionPatternRewriter &rewriter) const override {
352     auto loc = compress->getLoc();
353     MemRefType memRefType = compress.getMemRefType();
354 
355     // Resolve address.
356     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
357                                      adaptor.indices(), rewriter);
358 
359     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
360         compress, adaptor.valueToStore(), ptr, adaptor.mask());
361     return success();
362   }
363 };
364 
365 /// Conversion pattern for all vector reductions.
366 class VectorReductionOpConversion
367     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
368 public:
369   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
370                                        bool reassociateFPRed)
371       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
372         reassociateFPReductions(reassociateFPRed) {}
373 
374   LogicalResult
375   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
376                   ConversionPatternRewriter &rewriter) const override {
377     auto kind = reductionOp.kind();
378     Type eltType = reductionOp.dest().getType();
379     Type llvmType = typeConverter->convertType(eltType);
380     Value operand = adaptor.getOperands()[0];
381     if (eltType.isIntOrIndex()) {
382       // Integer reductions: add/mul/min/max/and/or/xor.
383       if (kind == "add")
384         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
385                                                              llvmType, operand);
386       else if (kind == "mul")
387         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
388                                                              llvmType, operand);
389       else if (kind == "minui")
390         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
391             reductionOp, llvmType, operand);
392       else if (kind == "minsi")
393         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
394             reductionOp, llvmType, operand);
395       else if (kind == "maxui")
396         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
397             reductionOp, llvmType, operand);
398       else if (kind == "maxsi")
399         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
400             reductionOp, llvmType, operand);
401       else if (kind == "and")
402         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
403                                                              llvmType, operand);
404       else if (kind == "or")
405         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
406                                                             llvmType, operand);
407       else if (kind == "xor")
408         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
409                                                              llvmType, operand);
410       else
411         return failure();
412       return success();
413     }
414 
415     if (!eltType.isa<FloatType>())
416       return failure();
417 
418     // Floating-point reductions: add/mul/min/max
419     if (kind == "add") {
420       // Optional accumulator (or zero).
421       Value acc = adaptor.getOperands().size() > 1
422                       ? adaptor.getOperands()[1]
423                       : rewriter.create<LLVM::ConstantOp>(
424                             reductionOp->getLoc(), llvmType,
425                             rewriter.getZeroAttr(eltType));
426       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
427           reductionOp, llvmType, acc, operand,
428           rewriter.getBoolAttr(reassociateFPReductions));
429     } else if (kind == "mul") {
430       // Optional accumulator (or one).
431       Value acc = adaptor.getOperands().size() > 1
432                       ? adaptor.getOperands()[1]
433                       : rewriter.create<LLVM::ConstantOp>(
434                             reductionOp->getLoc(), llvmType,
435                             rewriter.getFloatAttr(eltType, 1.0));
436       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
437           reductionOp, llvmType, acc, operand,
438           rewriter.getBoolAttr(reassociateFPReductions));
439     } else if (kind == "minf")
440       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
441       // NaNs/-0.0/+0.0 in the same way.
442       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
443                                                             llvmType, operand);
444     else if (kind == "maxf")
445       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
446       // NaNs/-0.0/+0.0 in the same way.
447       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
448                                                             llvmType, operand);
449     else
450       return failure();
451     return success();
452   }
453 
454 private:
455   const bool reassociateFPReductions;
456 };
457 
458 class VectorShuffleOpConversion
459     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
460 public:
461   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
462 
463   LogicalResult
464   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
465                   ConversionPatternRewriter &rewriter) const override {
466     auto loc = shuffleOp->getLoc();
467     auto v1Type = shuffleOp.getV1VectorType();
468     auto v2Type = shuffleOp.getV2VectorType();
469     auto vectorType = shuffleOp.getVectorType();
470     Type llvmType = typeConverter->convertType(vectorType);
471     auto maskArrayAttr = shuffleOp.mask();
472 
473     // Bail if result type cannot be lowered.
474     if (!llvmType)
475       return failure();
476 
477     // Get rank and dimension sizes.
478     int64_t rank = vectorType.getRank();
479     assert(v1Type.getRank() == rank);
480     assert(v2Type.getRank() == rank);
481     int64_t v1Dim = v1Type.getDimSize(0);
482 
483     // For rank 1, where both operands have *exactly* the same vector type,
484     // there is direct shuffle support in LLVM. Use it!
485     if (rank == 1 && v1Type == v2Type) {
486       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
487           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
488       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
489       return success();
490     }
491 
492     // For all other cases, insert the individual values individually.
493     Type eltType;
494     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
495       eltType = arrayType.getElementType();
496     else
497       eltType = llvmType.cast<VectorType>().getElementType();
498     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
499     int64_t insPos = 0;
500     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
501       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
502       Value value = adaptor.v1();
503       if (extPos >= v1Dim) {
504         extPos -= v1Dim;
505         value = adaptor.v2();
506       }
507       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
508                                  eltType, rank, extPos);
509       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
510                          llvmType, rank, insPos++);
511     }
512     rewriter.replaceOp(shuffleOp, insert);
513     return success();
514   }
515 };
516 
517 class VectorExtractElementOpConversion
518     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
519 public:
520   using ConvertOpToLLVMPattern<
521       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
522 
523   LogicalResult
524   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
525                   ConversionPatternRewriter &rewriter) const override {
526     auto vectorType = extractEltOp.getVectorType();
527     auto llvmType = typeConverter->convertType(vectorType.getElementType());
528 
529     // Bail if result type cannot be lowered.
530     if (!llvmType)
531       return failure();
532 
533     if (vectorType.getRank() == 0) {
534       Location loc = extractEltOp.getLoc();
535       auto idxType = rewriter.getIndexType();
536       auto zero = rewriter.create<LLVM::ConstantOp>(
537           loc, typeConverter->convertType(idxType),
538           rewriter.getIntegerAttr(idxType, 0));
539       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
540           extractEltOp, llvmType, adaptor.vector(), zero);
541       return success();
542     }
543 
544     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
545         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
546     return success();
547   }
548 };
549 
550 class VectorExtractOpConversion
551     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
552 public:
553   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
554 
555   LogicalResult
556   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
557                   ConversionPatternRewriter &rewriter) const override {
558     auto loc = extractOp->getLoc();
559     auto vectorType = extractOp.getVectorType();
560     auto resultType = extractOp.getResult().getType();
561     auto llvmResultType = typeConverter->convertType(resultType);
562     auto positionArrayAttr = extractOp.position();
563 
564     // Bail if result type cannot be lowered.
565     if (!llvmResultType)
566       return failure();
567 
568     // Extract entire vector. Should be handled by folder, but just to be safe.
569     if (positionArrayAttr.empty()) {
570       rewriter.replaceOp(extractOp, adaptor.vector());
571       return success();
572     }
573 
574     // One-shot extraction of vector from array (only requires extractvalue).
575     if (resultType.isa<VectorType>()) {
576       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
577           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
578       rewriter.replaceOp(extractOp, extracted);
579       return success();
580     }
581 
582     // Potential extraction of 1-D vector from array.
583     auto *context = extractOp->getContext();
584     Value extracted = adaptor.vector();
585     auto positionAttrs = positionArrayAttr.getValue();
586     if (positionAttrs.size() > 1) {
587       auto oneDVectorType = reducedVectorTypeBack(vectorType);
588       auto nMinusOnePositionAttrs =
589           ArrayAttr::get(context, positionAttrs.drop_back());
590       extracted = rewriter.create<LLVM::ExtractValueOp>(
591           loc, typeConverter->convertType(oneDVectorType), extracted,
592           nMinusOnePositionAttrs);
593     }
594 
595     // Remaining extraction of element from 1-D LLVM vector
596     auto position = positionAttrs.back().cast<IntegerAttr>();
597     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
598     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
599     extracted =
600         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
601     rewriter.replaceOp(extractOp, extracted);
602 
603     return success();
604   }
605 };
606 
607 /// Conversion pattern that turns a vector.fma on a 1-D vector
608 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
609 /// This does not match vectors of n >= 2 rank.
610 ///
611 /// Example:
612 /// ```
613 ///  vector.fma %a, %a, %a : vector<8xf32>
614 /// ```
615 /// is converted to:
616 /// ```
617 ///  llvm.intr.fmuladd %va, %va, %va:
618 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
619 ///    -> !llvm."<8 x f32>">
620 /// ```
621 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
622 public:
623   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
624 
625   LogicalResult
626   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
627                   ConversionPatternRewriter &rewriter) const override {
628     VectorType vType = fmaOp.getVectorType();
629     if (vType.getRank() != 1)
630       return failure();
631     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
632                                                  adaptor.rhs(), adaptor.acc());
633     return success();
634   }
635 };
636 
637 class VectorInsertElementOpConversion
638     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
639 public:
640   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
641 
642   LogicalResult
643   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
644                   ConversionPatternRewriter &rewriter) const override {
645     auto vectorType = insertEltOp.getDestVectorType();
646     auto llvmType = typeConverter->convertType(vectorType);
647 
648     // Bail if result type cannot be lowered.
649     if (!llvmType)
650       return failure();
651 
652     if (vectorType.getRank() == 0) {
653       Location loc = insertEltOp.getLoc();
654       auto idxType = rewriter.getIndexType();
655       auto zero = rewriter.create<LLVM::ConstantOp>(
656           loc, typeConverter->convertType(idxType),
657           rewriter.getIntegerAttr(idxType, 0));
658       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
659           insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
660       return success();
661     }
662 
663     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
664         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
665         adaptor.position());
666     return success();
667   }
668 };
669 
670 class VectorInsertOpConversion
671     : public ConvertOpToLLVMPattern<vector::InsertOp> {
672 public:
673   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
674 
675   LogicalResult
676   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
677                   ConversionPatternRewriter &rewriter) const override {
678     auto loc = insertOp->getLoc();
679     auto sourceType = insertOp.getSourceType();
680     auto destVectorType = insertOp.getDestVectorType();
681     auto llvmResultType = typeConverter->convertType(destVectorType);
682     auto positionArrayAttr = insertOp.position();
683 
684     // Bail if result type cannot be lowered.
685     if (!llvmResultType)
686       return failure();
687 
688     // Overwrite entire vector with value. Should be handled by folder, but
689     // just to be safe.
690     if (positionArrayAttr.empty()) {
691       rewriter.replaceOp(insertOp, adaptor.source());
692       return success();
693     }
694 
695     // One-shot insertion of a vector into an array (only requires insertvalue).
696     if (sourceType.isa<VectorType>()) {
697       Value inserted = rewriter.create<LLVM::InsertValueOp>(
698           loc, llvmResultType, adaptor.dest(), adaptor.source(),
699           positionArrayAttr);
700       rewriter.replaceOp(insertOp, inserted);
701       return success();
702     }
703 
704     // Potential extraction of 1-D vector from array.
705     auto *context = insertOp->getContext();
706     Value extracted = adaptor.dest();
707     auto positionAttrs = positionArrayAttr.getValue();
708     auto position = positionAttrs.back().cast<IntegerAttr>();
709     auto oneDVectorType = destVectorType;
710     if (positionAttrs.size() > 1) {
711       oneDVectorType = reducedVectorTypeBack(destVectorType);
712       auto nMinusOnePositionAttrs =
713           ArrayAttr::get(context, positionAttrs.drop_back());
714       extracted = rewriter.create<LLVM::ExtractValueOp>(
715           loc, typeConverter->convertType(oneDVectorType), extracted,
716           nMinusOnePositionAttrs);
717     }
718 
719     // Insertion of an element into a 1-D LLVM vector.
720     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
721     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
722     Value inserted = rewriter.create<LLVM::InsertElementOp>(
723         loc, typeConverter->convertType(oneDVectorType), extracted,
724         adaptor.source(), constant);
725 
726     // Potential insertion of resulting 1-D vector into array.
727     if (positionAttrs.size() > 1) {
728       auto nMinusOnePositionAttrs =
729           ArrayAttr::get(context, positionAttrs.drop_back());
730       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
731                                                       adaptor.dest(), inserted,
732                                                       nMinusOnePositionAttrs);
733     }
734 
735     rewriter.replaceOp(insertOp, inserted);
736     return success();
737   }
738 };
739 
740 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
741 ///
742 /// Example:
743 /// ```
744 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
745 /// ```
746 /// is rewritten into:
747 /// ```
748 ///  %r = splat %f0: vector<2x4xf32>
749 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
750 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
751 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
752 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
753 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
754 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
755 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
756 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
757 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
758 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
759 ///  // %r3 holds the final value.
760 /// ```
761 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
762 public:
763   using OpRewritePattern<FMAOp>::OpRewritePattern;
764 
765   void initialize() {
766     // This pattern recursively unpacks one dimension at a time. The recursion
767     // bounded as the rank is strictly decreasing.
768     setHasBoundedRewriteRecursion();
769   }
770 
771   LogicalResult matchAndRewrite(FMAOp op,
772                                 PatternRewriter &rewriter) const override {
773     auto vType = op.getVectorType();
774     if (vType.getRank() < 2)
775       return failure();
776 
777     auto loc = op.getLoc();
778     auto elemType = vType.getElementType();
779     Value zero = rewriter.create<arith::ConstantOp>(
780         loc, elemType, rewriter.getZeroAttr(elemType));
781     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
782     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
783       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
784       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
785       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
786       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
787       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
788     }
789     rewriter.replaceOp(op, desc);
790     return success();
791   }
792 };
793 
794 /// Returns the strides if the memory underlying `memRefType` has a contiguous
795 /// static layout.
796 static llvm::Optional<SmallVector<int64_t, 4>>
797 computeContiguousStrides(MemRefType memRefType) {
798   int64_t offset;
799   SmallVector<int64_t, 4> strides;
800   if (failed(getStridesAndOffset(memRefType, strides, offset)))
801     return None;
802   if (!strides.empty() && strides.back() != 1)
803     return None;
804   // If no layout or identity layout, this is contiguous by definition.
805   if (memRefType.getLayout().isIdentity())
806     return strides;
807 
808   // Otherwise, we must determine contiguity form shapes. This can only ever
809   // work in static cases because MemRefType is underspecified to represent
810   // contiguous dynamic shapes in other ways than with just empty/identity
811   // layout.
812   auto sizes = memRefType.getShape();
813   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
814     if (ShapedType::isDynamic(sizes[index + 1]) ||
815         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
816         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
817       return None;
818     if (strides[index] != strides[index + 1] * sizes[index + 1])
819       return None;
820   }
821   return strides;
822 }
823 
824 class VectorTypeCastOpConversion
825     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
826 public:
827   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
828 
829   LogicalResult
830   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
831                   ConversionPatternRewriter &rewriter) const override {
832     auto loc = castOp->getLoc();
833     MemRefType sourceMemRefType =
834         castOp.getOperand().getType().cast<MemRefType>();
835     MemRefType targetMemRefType = castOp.getType();
836 
837     // Only static shape casts supported atm.
838     if (!sourceMemRefType.hasStaticShape() ||
839         !targetMemRefType.hasStaticShape())
840       return failure();
841 
842     auto llvmSourceDescriptorTy =
843         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
844     if (!llvmSourceDescriptorTy)
845       return failure();
846     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
847 
848     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
849                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
850     if (!llvmTargetDescriptorTy)
851       return failure();
852 
853     // Only contiguous source buffers supported atm.
854     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
855     if (!sourceStrides)
856       return failure();
857     auto targetStrides = computeContiguousStrides(targetMemRefType);
858     if (!targetStrides)
859       return failure();
860     // Only support static strides for now, regardless of contiguity.
861     if (llvm::any_of(*targetStrides, [](int64_t stride) {
862           return ShapedType::isDynamicStrideOrOffset(stride);
863         }))
864       return failure();
865 
866     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
867 
868     // Create descriptor.
869     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
870     Type llvmTargetElementTy = desc.getElementPtrType();
871     // Set allocated ptr.
872     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
873     allocated =
874         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
875     desc.setAllocatedPtr(rewriter, loc, allocated);
876     // Set aligned ptr.
877     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
878     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
879     desc.setAlignedPtr(rewriter, loc, ptr);
880     // Fill offset 0.
881     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
882     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
883     desc.setOffset(rewriter, loc, zero);
884 
885     // Fill size and stride descriptors in memref.
886     for (const auto &indexedSize :
887          llvm::enumerate(targetMemRefType.getShape())) {
888       int64_t index = indexedSize.index();
889       auto sizeAttr =
890           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
891       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
892       desc.setSize(rewriter, loc, index, size);
893       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
894                                                 (*targetStrides)[index]);
895       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
896       desc.setStride(rewriter, loc, index, stride);
897     }
898 
899     rewriter.replaceOp(castOp, {desc});
900     return success();
901   }
902 };
903 
904 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
905 public:
906   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
907 
908   // Proof-of-concept lowering implementation that relies on a small
909   // runtime support library, which only needs to provide a few
910   // printing methods (single value for all data types, opening/closing
911   // bracket, comma, newline). The lowering fully unrolls a vector
912   // in terms of these elementary printing operations. The advantage
913   // of this approach is that the library can remain unaware of all
914   // low-level implementation details of vectors while still supporting
915   // output of any shaped and dimensioned vector. Due to full unrolling,
916   // this approach is less suited for very large vectors though.
917   //
918   // TODO: rely solely on libc in future? something else?
919   //
920   LogicalResult
921   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
922                   ConversionPatternRewriter &rewriter) const override {
923     Type printType = printOp.getPrintType();
924 
925     if (typeConverter->convertType(printType) == nullptr)
926       return failure();
927 
928     // Make sure element type has runtime support.
929     PrintConversion conversion = PrintConversion::None;
930     VectorType vectorType = printType.dyn_cast<VectorType>();
931     Type eltType = vectorType ? vectorType.getElementType() : printType;
932     Operation *printer;
933     if (eltType.isF32()) {
934       printer =
935           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
936     } else if (eltType.isF64()) {
937       printer =
938           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
939     } else if (eltType.isIndex()) {
940       printer =
941           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
942     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
943       // Integers need a zero or sign extension on the operand
944       // (depending on the source type) as well as a signed or
945       // unsigned print method. Up to 64-bit is supported.
946       unsigned width = intTy.getWidth();
947       if (intTy.isUnsigned()) {
948         if (width <= 64) {
949           if (width < 64)
950             conversion = PrintConversion::ZeroExt64;
951           printer = LLVM::lookupOrCreatePrintU64Fn(
952               printOp->getParentOfType<ModuleOp>());
953         } else {
954           return failure();
955         }
956       } else {
957         assert(intTy.isSignless() || intTy.isSigned());
958         if (width <= 64) {
959           // Note that we *always* zero extend booleans (1-bit integers),
960           // so that true/false is printed as 1/0 rather than -1/0.
961           if (width == 1)
962             conversion = PrintConversion::ZeroExt64;
963           else if (width < 64)
964             conversion = PrintConversion::SignExt64;
965           printer = LLVM::lookupOrCreatePrintI64Fn(
966               printOp->getParentOfType<ModuleOp>());
967         } else {
968           return failure();
969         }
970       }
971     } else {
972       return failure();
973     }
974 
975     // Unroll vector into elementary print calls.
976     int64_t rank = vectorType ? vectorType.getRank() : 0;
977     Type type = vectorType ? vectorType : eltType;
978     emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
979               conversion);
980     emitCall(rewriter, printOp->getLoc(),
981              LLVM::lookupOrCreatePrintNewlineFn(
982                  printOp->getParentOfType<ModuleOp>()));
983     rewriter.eraseOp(printOp);
984     return success();
985   }
986 
987 private:
988   enum class PrintConversion {
989     // clang-format off
990     None,
991     ZeroExt64,
992     SignExt64
993     // clang-format on
994   };
995 
996   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
997                  Value value, Type type, Operation *printer, int64_t rank,
998                  PrintConversion conversion) const {
999     VectorType vectorType = type.dyn_cast<VectorType>();
1000     Location loc = op->getLoc();
1001     if (!vectorType) {
1002       assert(rank == 0 && "The scalar case expects rank == 0");
1003       switch (conversion) {
1004       case PrintConversion::ZeroExt64:
1005         value = rewriter.create<arith::ExtUIOp>(
1006             loc, value, IntegerType::get(rewriter.getContext(), 64));
1007         break;
1008       case PrintConversion::SignExt64:
1009         value = rewriter.create<arith::ExtSIOp>(
1010             loc, value, IntegerType::get(rewriter.getContext(), 64));
1011         break;
1012       case PrintConversion::None:
1013         break;
1014       }
1015       emitCall(rewriter, loc, printer, value);
1016       return;
1017     }
1018 
1019     emitCall(rewriter, loc,
1020              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1021     Operation *printComma =
1022         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1023 
1024     if (rank <= 1) {
1025       auto reducedType = vectorType.getElementType();
1026       auto llvmType = typeConverter->convertType(reducedType);
1027       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1028       for (int64_t d = 0; d < dim; ++d) {
1029         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1030                                      llvmType, /*rank=*/0, /*pos=*/d);
1031         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1032                   conversion);
1033         if (d != dim - 1)
1034           emitCall(rewriter, loc, printComma);
1035       }
1036       emitCall(
1037           rewriter, loc,
1038           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1039       return;
1040     }
1041 
1042     int64_t dim = vectorType.getDimSize(0);
1043     for (int64_t d = 0; d < dim; ++d) {
1044       auto reducedType = reducedVectorTypeFront(vectorType);
1045       auto llvmType = typeConverter->convertType(reducedType);
1046       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1047                                    llvmType, rank, d);
1048       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1049                 conversion);
1050       if (d != dim - 1)
1051         emitCall(rewriter, loc, printComma);
1052     }
1053     emitCall(rewriter, loc,
1054              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1055   }
1056 
1057   // Helper to emit a call.
1058   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1059                        Operation *ref, ValueRange params = ValueRange()) {
1060     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1061                                   params);
1062   }
1063 };
1064 
1065 } // namespace
1066 
1067 /// Populate the given list with patterns that convert from Vector to LLVM.
1068 void mlir::populateVectorToLLVMConversionPatterns(
1069     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1070     bool reassociateFPReductions) {
1071   MLIRContext *ctx = converter.getDialect()->getContext();
1072   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1073   populateVectorInsertExtractStridedSliceTransforms(patterns);
1074   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1075   patterns
1076       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1077            VectorExtractElementOpConversion, VectorExtractOpConversion,
1078            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1079            VectorInsertOpConversion, VectorPrintOpConversion,
1080            VectorTypeCastOpConversion, VectorScaleOpConversion,
1081            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1082            VectorLoadStoreConversion<vector::MaskedLoadOp,
1083                                      vector::MaskedLoadOpAdaptor>,
1084            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1085            VectorLoadStoreConversion<vector::MaskedStoreOp,
1086                                      vector::MaskedStoreOpAdaptor>,
1087            VectorGatherOpConversion, VectorScatterOpConversion,
1088            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
1089           converter);
1090   // Transfer ops with rank > 1 are handled by VectorToSCF.
1091   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1092 }
1093 
1094 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1095     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1096   patterns.add<VectorMatmulOpConversion>(converter);
1097   patterns.add<VectorFlatTransposeOpConversion>(converter);
1098 }
1099