xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision e7026aba004934cad5487256601af7690757d09f)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Support/MathExtras.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28   assert((tp.getRank() > 1) && "unlowerable vector type");
29   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
30 }
31 
32 // Helper to reduce vector type by *all* but one rank at back.
33 static VectorType reducedVectorTypeBack(VectorType tp) {
34   assert((tp.getRank() > 1) && "unlowerable vector type");
35   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
36 }
37 
38 // Helper that picks the proper sequence for inserting.
39 static Value insertOne(ConversionPatternRewriter &rewriter,
40                        LLVMTypeConverter &typeConverter, Location loc,
41                        Value val1, Value val2, Type llvmType, int64_t rank,
42                        int64_t pos) {
43   assert(rank > 0 && "0-D vector corner case should have been handled already");
44   if (rank == 1) {
45     auto idxType = rewriter.getIndexType();
46     auto constant = rewriter.create<LLVM::ConstantOp>(
47         loc, typeConverter.convertType(idxType),
48         rewriter.getIntegerAttr(idxType, pos));
49     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
50                                                   constant);
51   }
52   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
53                                               rewriter.getI64ArrayAttr(pos));
54 }
55 
56 // Helper that picks the proper sequence for extracting.
57 static Value extractOne(ConversionPatternRewriter &rewriter,
58                         LLVMTypeConverter &typeConverter, Location loc,
59                         Value val, Type llvmType, int64_t rank, int64_t pos) {
60   assert(rank > 0 && "0-D vector corner case should have been handled already");
61   if (rank == 1) {
62     auto idxType = rewriter.getIndexType();
63     auto constant = rewriter.create<LLVM::ConstantOp>(
64         loc, typeConverter.convertType(idxType),
65         rewriter.getIntegerAttr(idxType, pos));
66     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
67                                                    constant);
68   }
69   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
70                                                rewriter.getI64ArrayAttr(pos));
71 }
72 
73 // Helper that returns data layout alignment of a memref.
74 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
75                                  MemRefType memrefType, unsigned &align) {
76   Type elementTy = typeConverter.convertType(memrefType.getElementType());
77   if (!elementTy)
78     return failure();
79 
80   // TODO: this should use the MLIR data layout when it becomes available and
81   // stop depending on translation.
82   llvm::LLVMContext llvmContext;
83   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
84               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
85   return success();
86 }
87 
88 // Return the minimal alignment value that satisfies all the AssumeAlignment
89 // uses of `value`. If no such uses exist, return 1.
90 static unsigned getAssumedAlignment(Value value) {
91   unsigned align = 1;
92   for (auto &u : value.getUses()) {
93     Operation *owner = u.getOwner();
94     if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
95       align = mlir::lcm(align, op.alignment());
96   }
97   return align;
98 }
99 
100 // Helper that returns data layout alignment of a memref associated with a
101 // load, store, scatter, or gather op, including additional information from
102 // assume_alignment calls on the source of the transfer
103 template <class OpAdaptor>
104 LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
105                                    OpAdaptor op, unsigned &align) {
106   if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
107     return failure();
108   align = std::max(align, getAssumedAlignment(op.base()));
109   return success();
110 }
111 
112 // Add an index vector component to a base pointer. This almost always succeeds
113 // unless the last stride is non-unit or the memory space is not zero.
114 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
115                                     Location loc, Value memref, Value base,
116                                     Value index, MemRefType memRefType,
117                                     VectorType vType, Value &ptrs) {
118   int64_t offset;
119   SmallVector<int64_t, 4> strides;
120   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
121   if (failed(successStrides) || strides.back() != 1 ||
122       memRefType.getMemorySpaceAsInt() != 0)
123     return failure();
124   auto pType = MemRefDescriptor(memref).getElementPtrType();
125   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
126   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
127   return success();
128 }
129 
130 // Casts a strided element pointer to a vector pointer.  The vector pointer
131 // will be in the same address space as the incoming memref type.
132 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
133                          Value ptr, MemRefType memRefType, Type vt) {
134   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
135   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
136 }
137 
138 namespace {
139 
140 /// Conversion pattern for a vector.bitcast.
141 class VectorBitCastOpConversion
142     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
143 public:
144   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
145 
146   LogicalResult
147   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
148                   ConversionPatternRewriter &rewriter) const override {
149     // Only 1-D vectors can be lowered to LLVM.
150     VectorType resultTy = bitCastOp.getType();
151     if (resultTy.getRank() != 1)
152       return failure();
153     Type newResultTy = typeConverter->convertType(resultTy);
154     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
155                                                  adaptor.getOperands()[0]);
156     return success();
157   }
158 };
159 
160 /// Conversion pattern for a vector.matrix_multiply.
161 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
162 class VectorMatmulOpConversion
163     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
164 public:
165   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
166 
167   LogicalResult
168   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
169                   ConversionPatternRewriter &rewriter) const override {
170     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
171         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
172         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
173         matmulOp.lhs_columns(), matmulOp.rhs_columns());
174     return success();
175   }
176 };
177 
178 /// Conversion pattern for a vector.flat_transpose.
179 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
180 class VectorFlatTransposeOpConversion
181     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
182 public:
183   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
184 
185   LogicalResult
186   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
187                   ConversionPatternRewriter &rewriter) const override {
188     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
189         transOp, typeConverter->convertType(transOp.res().getType()),
190         adaptor.matrix(), transOp.rows(), transOp.columns());
191     return success();
192   }
193 };
194 
195 /// Overloaded utility that replaces a vector.load, vector.store,
196 /// vector.maskedload and vector.maskedstore with their respective LLVM
197 /// couterparts.
198 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
199                                  vector::LoadOpAdaptor adaptor,
200                                  VectorType vectorTy, Value ptr, unsigned align,
201                                  ConversionPatternRewriter &rewriter) {
202   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
203 }
204 
205 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
206                                  vector::MaskedLoadOpAdaptor adaptor,
207                                  VectorType vectorTy, Value ptr, unsigned align,
208                                  ConversionPatternRewriter &rewriter) {
209   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
210       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
211 }
212 
213 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
214                                  vector::StoreOpAdaptor adaptor,
215                                  VectorType vectorTy, Value ptr, unsigned align,
216                                  ConversionPatternRewriter &rewriter) {
217   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
218                                              ptr, align);
219 }
220 
221 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
222                                  vector::MaskedStoreOpAdaptor adaptor,
223                                  VectorType vectorTy, Value ptr, unsigned align,
224                                  ConversionPatternRewriter &rewriter) {
225   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
226       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
227 }
228 
229 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
230 /// vector.maskedstore.
231 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
232 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
233 public:
234   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
235 
236   LogicalResult
237   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
238                   typename LoadOrStoreOp::Adaptor adaptor,
239                   ConversionPatternRewriter &rewriter) const override {
240     // Only 1-D vectors can be lowered to LLVM.
241     VectorType vectorTy = loadOrStoreOp.getVectorType();
242     if (vectorTy.getRank() > 1)
243       return failure();
244 
245     auto loc = loadOrStoreOp->getLoc();
246     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
247 
248     // Resolve alignment.
249     unsigned align;
250     if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
251                                     align)))
252       return failure();
253 
254     // Resolve address.
255     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
256                      .template cast<VectorType>();
257     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
258                                                adaptor.indices(), rewriter);
259     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
260 
261     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
262     return success();
263   }
264 };
265 
266 /// Conversion pattern for a vector.gather.
267 class VectorGatherOpConversion
268     : public ConvertOpToLLVMPattern<vector::GatherOp> {
269 public:
270   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
271 
272   LogicalResult
273   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
274                   ConversionPatternRewriter &rewriter) const override {
275     auto loc = gather->getLoc();
276     MemRefType memRefType = gather.getMemRefType();
277 
278     // Resolve alignment.
279     unsigned align;
280     if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
281       return failure();
282 
283     // Resolve address.
284     Value ptrs;
285     VectorType vType = gather.getVectorType();
286     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
287                                      adaptor.indices(), rewriter);
288     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
289                               adaptor.index_vec(), memRefType, vType, ptrs)))
290       return failure();
291 
292     // Replace with the gather intrinsic.
293     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
294         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
295         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
296     return success();
297   }
298 };
299 
300 /// Conversion pattern for a vector.scatter.
301 class VectorScatterOpConversion
302     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
303 public:
304   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
305 
306   LogicalResult
307   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto loc = scatter->getLoc();
310     MemRefType memRefType = scatter.getMemRefType();
311 
312     // Resolve alignment.
313     unsigned align;
314     if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
315       return failure();
316 
317     // Resolve address.
318     Value ptrs;
319     VectorType vType = scatter.getVectorType();
320     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
321                                      adaptor.indices(), rewriter);
322     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
323                               adaptor.index_vec(), memRefType, vType, ptrs)))
324       return failure();
325 
326     // Replace with the scatter intrinsic.
327     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
328         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
329         rewriter.getI32IntegerAttr(align));
330     return success();
331   }
332 };
333 
334 /// Conversion pattern for a vector.expandload.
335 class VectorExpandLoadOpConversion
336     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
337 public:
338   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
339 
340   LogicalResult
341   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
342                   ConversionPatternRewriter &rewriter) const override {
343     auto loc = expand->getLoc();
344     MemRefType memRefType = expand.getMemRefType();
345 
346     // Resolve address.
347     auto vtype = typeConverter->convertType(expand.getVectorType());
348     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
349                                      adaptor.indices(), rewriter);
350 
351     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
352         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
353     return success();
354   }
355 };
356 
357 /// Conversion pattern for a vector.compressstore.
358 class VectorCompressStoreOpConversion
359     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
360 public:
361   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
362 
363   LogicalResult
364   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
365                   ConversionPatternRewriter &rewriter) const override {
366     auto loc = compress->getLoc();
367     MemRefType memRefType = compress.getMemRefType();
368 
369     // Resolve address.
370     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
371                                      adaptor.indices(), rewriter);
372 
373     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
374         compress, adaptor.valueToStore(), ptr, adaptor.mask());
375     return success();
376   }
377 };
378 
379 /// Conversion pattern for all vector reductions.
380 class VectorReductionOpConversion
381     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
382 public:
383   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
384                                        bool reassociateFPRed)
385       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
386         reassociateFPReductions(reassociateFPRed) {}
387 
388   LogicalResult
389   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
390                   ConversionPatternRewriter &rewriter) const override {
391     auto kind = reductionOp.kind();
392     Type eltType = reductionOp.dest().getType();
393     Type llvmType = typeConverter->convertType(eltType);
394     Value operand = adaptor.getOperands()[0];
395     if (eltType.isIntOrIndex()) {
396       // Integer reductions: add/mul/min/max/and/or/xor.
397       if (kind == "add")
398         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
399                                                              llvmType, operand);
400       else if (kind == "mul")
401         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
402                                                              llvmType, operand);
403       else if (kind == "minui")
404         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
405             reductionOp, llvmType, operand);
406       else if (kind == "minsi")
407         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
408             reductionOp, llvmType, operand);
409       else if (kind == "maxui")
410         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
411             reductionOp, llvmType, operand);
412       else if (kind == "maxsi")
413         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
414             reductionOp, llvmType, operand);
415       else if (kind == "and")
416         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
417                                                              llvmType, operand);
418       else if (kind == "or")
419         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
420                                                             llvmType, operand);
421       else if (kind == "xor")
422         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
423                                                              llvmType, operand);
424       else
425         return failure();
426       return success();
427     }
428 
429     if (!eltType.isa<FloatType>())
430       return failure();
431 
432     // Floating-point reductions: add/mul/min/max
433     if (kind == "add") {
434       // Optional accumulator (or zero).
435       Value acc = adaptor.getOperands().size() > 1
436                       ? adaptor.getOperands()[1]
437                       : rewriter.create<LLVM::ConstantOp>(
438                             reductionOp->getLoc(), llvmType,
439                             rewriter.getZeroAttr(eltType));
440       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
441           reductionOp, llvmType, acc, operand,
442           rewriter.getBoolAttr(reassociateFPReductions));
443     } else if (kind == "mul") {
444       // Optional accumulator (or one).
445       Value acc = adaptor.getOperands().size() > 1
446                       ? adaptor.getOperands()[1]
447                       : rewriter.create<LLVM::ConstantOp>(
448                             reductionOp->getLoc(), llvmType,
449                             rewriter.getFloatAttr(eltType, 1.0));
450       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
451           reductionOp, llvmType, acc, operand,
452           rewriter.getBoolAttr(reassociateFPReductions));
453     } else if (kind == "minf")
454       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
455       // NaNs/-0.0/+0.0 in the same way.
456       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
457                                                             llvmType, operand);
458     else if (kind == "maxf")
459       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
460       // NaNs/-0.0/+0.0 in the same way.
461       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
462                                                             llvmType, operand);
463     else
464       return failure();
465     return success();
466   }
467 
468 private:
469   const bool reassociateFPReductions;
470 };
471 
472 class VectorShuffleOpConversion
473     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
474 public:
475   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
476 
477   LogicalResult
478   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
479                   ConversionPatternRewriter &rewriter) const override {
480     auto loc = shuffleOp->getLoc();
481     auto v1Type = shuffleOp.getV1VectorType();
482     auto v2Type = shuffleOp.getV2VectorType();
483     auto vectorType = shuffleOp.getVectorType();
484     Type llvmType = typeConverter->convertType(vectorType);
485     auto maskArrayAttr = shuffleOp.mask();
486 
487     // Bail if result type cannot be lowered.
488     if (!llvmType)
489       return failure();
490 
491     // Get rank and dimension sizes.
492     int64_t rank = vectorType.getRank();
493     assert(v1Type.getRank() == rank);
494     assert(v2Type.getRank() == rank);
495     int64_t v1Dim = v1Type.getDimSize(0);
496 
497     // For rank 1, where both operands have *exactly* the same vector type,
498     // there is direct shuffle support in LLVM. Use it!
499     if (rank == 1 && v1Type == v2Type) {
500       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
501           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
502       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
503       return success();
504     }
505 
506     // For all other cases, insert the individual values individually.
507     Type eltType;
508     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
509       eltType = arrayType.getElementType();
510     else
511       eltType = llvmType.cast<VectorType>().getElementType();
512     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
513     int64_t insPos = 0;
514     for (auto en : llvm::enumerate(maskArrayAttr)) {
515       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
516       Value value = adaptor.v1();
517       if (extPos >= v1Dim) {
518         extPos -= v1Dim;
519         value = adaptor.v2();
520       }
521       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
522                                  eltType, rank, extPos);
523       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
524                          llvmType, rank, insPos++);
525     }
526     rewriter.replaceOp(shuffleOp, insert);
527     return success();
528   }
529 };
530 
531 class VectorExtractElementOpConversion
532     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
533 public:
534   using ConvertOpToLLVMPattern<
535       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
536 
537   LogicalResult
538   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
539                   ConversionPatternRewriter &rewriter) const override {
540     auto vectorType = extractEltOp.getVectorType();
541     auto llvmType = typeConverter->convertType(vectorType.getElementType());
542 
543     // Bail if result type cannot be lowered.
544     if (!llvmType)
545       return failure();
546 
547     if (vectorType.getRank() == 0) {
548       Location loc = extractEltOp.getLoc();
549       auto idxType = rewriter.getIndexType();
550       auto zero = rewriter.create<LLVM::ConstantOp>(
551           loc, typeConverter->convertType(idxType),
552           rewriter.getIntegerAttr(idxType, 0));
553       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
554           extractEltOp, llvmType, adaptor.vector(), zero);
555       return success();
556     }
557 
558     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
559         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
560     return success();
561   }
562 };
563 
564 class VectorExtractOpConversion
565     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
566 public:
567   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
568 
569   LogicalResult
570   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
571                   ConversionPatternRewriter &rewriter) const override {
572     auto loc = extractOp->getLoc();
573     auto vectorType = extractOp.getVectorType();
574     auto resultType = extractOp.getResult().getType();
575     auto llvmResultType = typeConverter->convertType(resultType);
576     auto positionArrayAttr = extractOp.position();
577 
578     // Bail if result type cannot be lowered.
579     if (!llvmResultType)
580       return failure();
581 
582     // Extract entire vector. Should be handled by folder, but just to be safe.
583     if (positionArrayAttr.empty()) {
584       rewriter.replaceOp(extractOp, adaptor.vector());
585       return success();
586     }
587 
588     // One-shot extraction of vector from array (only requires extractvalue).
589     if (resultType.isa<VectorType>()) {
590       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
591           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
592       rewriter.replaceOp(extractOp, extracted);
593       return success();
594     }
595 
596     // Potential extraction of 1-D vector from array.
597     auto *context = extractOp->getContext();
598     Value extracted = adaptor.vector();
599     auto positionAttrs = positionArrayAttr.getValue();
600     if (positionAttrs.size() > 1) {
601       auto oneDVectorType = reducedVectorTypeBack(vectorType);
602       auto nMinusOnePositionAttrs =
603           ArrayAttr::get(context, positionAttrs.drop_back());
604       extracted = rewriter.create<LLVM::ExtractValueOp>(
605           loc, typeConverter->convertType(oneDVectorType), extracted,
606           nMinusOnePositionAttrs);
607     }
608 
609     // Remaining extraction of element from 1-D LLVM vector
610     auto position = positionAttrs.back().cast<IntegerAttr>();
611     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
612     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
613     extracted =
614         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
615     rewriter.replaceOp(extractOp, extracted);
616 
617     return success();
618   }
619 };
620 
621 /// Conversion pattern that turns a vector.fma on a 1-D vector
622 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
623 /// This does not match vectors of n >= 2 rank.
624 ///
625 /// Example:
626 /// ```
627 ///  vector.fma %a, %a, %a : vector<8xf32>
628 /// ```
629 /// is converted to:
630 /// ```
631 ///  llvm.intr.fmuladd %va, %va, %va:
632 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
633 ///    -> !llvm."<8 x f32>">
634 /// ```
635 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
636 public:
637   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
638 
639   LogicalResult
640   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
641                   ConversionPatternRewriter &rewriter) const override {
642     VectorType vType = fmaOp.getVectorType();
643     if (vType.getRank() != 1)
644       return failure();
645     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
646                                                  adaptor.rhs(), adaptor.acc());
647     return success();
648   }
649 };
650 
651 class VectorInsertElementOpConversion
652     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
653 public:
654   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
655 
656   LogicalResult
657   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
658                   ConversionPatternRewriter &rewriter) const override {
659     auto vectorType = insertEltOp.getDestVectorType();
660     auto llvmType = typeConverter->convertType(vectorType);
661 
662     // Bail if result type cannot be lowered.
663     if (!llvmType)
664       return failure();
665 
666     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
667         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
668         adaptor.position());
669     return success();
670   }
671 };
672 
673 class VectorInsertOpConversion
674     : public ConvertOpToLLVMPattern<vector::InsertOp> {
675 public:
676   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
677 
678   LogicalResult
679   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
680                   ConversionPatternRewriter &rewriter) const override {
681     auto loc = insertOp->getLoc();
682     auto sourceType = insertOp.getSourceType();
683     auto destVectorType = insertOp.getDestVectorType();
684     auto llvmResultType = typeConverter->convertType(destVectorType);
685     auto positionArrayAttr = insertOp.position();
686 
687     // Bail if result type cannot be lowered.
688     if (!llvmResultType)
689       return failure();
690 
691     // Overwrite entire vector with value. Should be handled by folder, but
692     // just to be safe.
693     if (positionArrayAttr.empty()) {
694       rewriter.replaceOp(insertOp, adaptor.source());
695       return success();
696     }
697 
698     // One-shot insertion of a vector into an array (only requires insertvalue).
699     if (sourceType.isa<VectorType>()) {
700       Value inserted = rewriter.create<LLVM::InsertValueOp>(
701           loc, llvmResultType, adaptor.dest(), adaptor.source(),
702           positionArrayAttr);
703       rewriter.replaceOp(insertOp, inserted);
704       return success();
705     }
706 
707     // Potential extraction of 1-D vector from array.
708     auto *context = insertOp->getContext();
709     Value extracted = adaptor.dest();
710     auto positionAttrs = positionArrayAttr.getValue();
711     auto position = positionAttrs.back().cast<IntegerAttr>();
712     auto oneDVectorType = destVectorType;
713     if (positionAttrs.size() > 1) {
714       oneDVectorType = reducedVectorTypeBack(destVectorType);
715       auto nMinusOnePositionAttrs =
716           ArrayAttr::get(context, positionAttrs.drop_back());
717       extracted = rewriter.create<LLVM::ExtractValueOp>(
718           loc, typeConverter->convertType(oneDVectorType), extracted,
719           nMinusOnePositionAttrs);
720     }
721 
722     // Insertion of an element into a 1-D LLVM vector.
723     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
724     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
725     Value inserted = rewriter.create<LLVM::InsertElementOp>(
726         loc, typeConverter->convertType(oneDVectorType), extracted,
727         adaptor.source(), constant);
728 
729     // Potential insertion of resulting 1-D vector into array.
730     if (positionAttrs.size() > 1) {
731       auto nMinusOnePositionAttrs =
732           ArrayAttr::get(context, positionAttrs.drop_back());
733       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
734                                                       adaptor.dest(), inserted,
735                                                       nMinusOnePositionAttrs);
736     }
737 
738     rewriter.replaceOp(insertOp, inserted);
739     return success();
740   }
741 };
742 
743 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
744 ///
745 /// Example:
746 /// ```
747 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
748 /// ```
749 /// is rewritten into:
750 /// ```
751 ///  %r = splat %f0: vector<2x4xf32>
752 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
753 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
754 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
755 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
756 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
757 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
758 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
759 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
760 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
761 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
762 ///  // %r3 holds the final value.
763 /// ```
764 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
765 public:
766   using OpRewritePattern<FMAOp>::OpRewritePattern;
767 
768   void initialize() {
769     // This pattern recursively unpacks one dimension at a time. The recursion
770     // bounded as the rank is strictly decreasing.
771     setHasBoundedRewriteRecursion();
772   }
773 
774   LogicalResult matchAndRewrite(FMAOp op,
775                                 PatternRewriter &rewriter) const override {
776     auto vType = op.getVectorType();
777     if (vType.getRank() < 2)
778       return failure();
779 
780     auto loc = op.getLoc();
781     auto elemType = vType.getElementType();
782     Value zero = rewriter.create<arith::ConstantOp>(
783         loc, elemType, rewriter.getZeroAttr(elemType));
784     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
785     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
786       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
787       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
788       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
789       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
790       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
791     }
792     rewriter.replaceOp(op, desc);
793     return success();
794   }
795 };
796 
797 /// Returns the strides if the memory underlying `memRefType` has a contiguous
798 /// static layout.
799 static llvm::Optional<SmallVector<int64_t, 4>>
800 computeContiguousStrides(MemRefType memRefType) {
801   int64_t offset;
802   SmallVector<int64_t, 4> strides;
803   if (failed(getStridesAndOffset(memRefType, strides, offset)))
804     return None;
805   if (!strides.empty() && strides.back() != 1)
806     return None;
807   // If no layout or identity layout, this is contiguous by definition.
808   if (memRefType.getLayout().isIdentity())
809     return strides;
810 
811   // Otherwise, we must determine contiguity form shapes. This can only ever
812   // work in static cases because MemRefType is underspecified to represent
813   // contiguous dynamic shapes in other ways than with just empty/identity
814   // layout.
815   auto sizes = memRefType.getShape();
816   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
817     if (ShapedType::isDynamic(sizes[index + 1]) ||
818         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
819         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
820       return None;
821     if (strides[index] != strides[index + 1] * sizes[index + 1])
822       return None;
823   }
824   return strides;
825 }
826 
827 class VectorTypeCastOpConversion
828     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
829 public:
830   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
831 
832   LogicalResult
833   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
834                   ConversionPatternRewriter &rewriter) const override {
835     auto loc = castOp->getLoc();
836     MemRefType sourceMemRefType =
837         castOp.getOperand().getType().cast<MemRefType>();
838     MemRefType targetMemRefType = castOp.getType();
839 
840     // Only static shape casts supported atm.
841     if (!sourceMemRefType.hasStaticShape() ||
842         !targetMemRefType.hasStaticShape())
843       return failure();
844 
845     auto llvmSourceDescriptorTy =
846         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
847     if (!llvmSourceDescriptorTy)
848       return failure();
849     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
850 
851     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
852                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
853     if (!llvmTargetDescriptorTy)
854       return failure();
855 
856     // Only contiguous source buffers supported atm.
857     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
858     if (!sourceStrides)
859       return failure();
860     auto targetStrides = computeContiguousStrides(targetMemRefType);
861     if (!targetStrides)
862       return failure();
863     // Only support static strides for now, regardless of contiguity.
864     if (llvm::any_of(*targetStrides, [](int64_t stride) {
865           return ShapedType::isDynamicStrideOrOffset(stride);
866         }))
867       return failure();
868 
869     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
870 
871     // Create descriptor.
872     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
873     Type llvmTargetElementTy = desc.getElementPtrType();
874     // Set allocated ptr.
875     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
876     allocated =
877         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
878     desc.setAllocatedPtr(rewriter, loc, allocated);
879     // Set aligned ptr.
880     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
881     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
882     desc.setAlignedPtr(rewriter, loc, ptr);
883     // Fill offset 0.
884     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
885     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
886     desc.setOffset(rewriter, loc, zero);
887 
888     // Fill size and stride descriptors in memref.
889     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
890       int64_t index = indexedSize.index();
891       auto sizeAttr =
892           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
893       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
894       desc.setSize(rewriter, loc, index, size);
895       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
896                                                 (*targetStrides)[index]);
897       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
898       desc.setStride(rewriter, loc, index, stride);
899     }
900 
901     rewriter.replaceOp(castOp, {desc});
902     return success();
903   }
904 };
905 
906 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
907 public:
908   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
909 
910   // Proof-of-concept lowering implementation that relies on a small
911   // runtime support library, which only needs to provide a few
912   // printing methods (single value for all data types, opening/closing
913   // bracket, comma, newline). The lowering fully unrolls a vector
914   // in terms of these elementary printing operations. The advantage
915   // of this approach is that the library can remain unaware of all
916   // low-level implementation details of vectors while still supporting
917   // output of any shaped and dimensioned vector. Due to full unrolling,
918   // this approach is less suited for very large vectors though.
919   //
920   // TODO: rely solely on libc in future? something else?
921   //
922   LogicalResult
923   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
924                   ConversionPatternRewriter &rewriter) const override {
925     Type printType = printOp.getPrintType();
926 
927     if (typeConverter->convertType(printType) == nullptr)
928       return failure();
929 
930     // Make sure element type has runtime support.
931     PrintConversion conversion = PrintConversion::None;
932     VectorType vectorType = printType.dyn_cast<VectorType>();
933     Type eltType = vectorType ? vectorType.getElementType() : printType;
934     Operation *printer;
935     if (eltType.isF32()) {
936       printer =
937           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
938     } else if (eltType.isF64()) {
939       printer =
940           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
941     } else if (eltType.isIndex()) {
942       printer =
943           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
944     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
945       // Integers need a zero or sign extension on the operand
946       // (depending on the source type) as well as a signed or
947       // unsigned print method. Up to 64-bit is supported.
948       unsigned width = intTy.getWidth();
949       if (intTy.isUnsigned()) {
950         if (width <= 64) {
951           if (width < 64)
952             conversion = PrintConversion::ZeroExt64;
953           printer = LLVM::lookupOrCreatePrintU64Fn(
954               printOp->getParentOfType<ModuleOp>());
955         } else {
956           return failure();
957         }
958       } else {
959         assert(intTy.isSignless() || intTy.isSigned());
960         if (width <= 64) {
961           // Note that we *always* zero extend booleans (1-bit integers),
962           // so that true/false is printed as 1/0 rather than -1/0.
963           if (width == 1)
964             conversion = PrintConversion::ZeroExt64;
965           else if (width < 64)
966             conversion = PrintConversion::SignExt64;
967           printer = LLVM::lookupOrCreatePrintI64Fn(
968               printOp->getParentOfType<ModuleOp>());
969         } else {
970           return failure();
971         }
972       }
973     } else {
974       return failure();
975     }
976 
977     // Unroll vector into elementary print calls.
978     int64_t rank = vectorType ? vectorType.getRank() : 0;
979     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
980               conversion);
981     emitCall(rewriter, printOp->getLoc(),
982              LLVM::lookupOrCreatePrintNewlineFn(
983                  printOp->getParentOfType<ModuleOp>()));
984     rewriter.eraseOp(printOp);
985     return success();
986   }
987 
988 private:
989   enum class PrintConversion {
990     // clang-format off
991     None,
992     ZeroExt64,
993     SignExt64
994     // clang-format on
995   };
996 
997   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
998                  Value value, VectorType vectorType, Operation *printer,
999                  int64_t rank, PrintConversion conversion) const {
1000     Location loc = op->getLoc();
1001     if (rank == 0) {
1002       switch (conversion) {
1003       case PrintConversion::ZeroExt64:
1004         value = rewriter.create<arith::ExtUIOp>(
1005             loc, value, IntegerType::get(rewriter.getContext(), 64));
1006         break;
1007       case PrintConversion::SignExt64:
1008         value = rewriter.create<arith::ExtSIOp>(
1009             loc, value, IntegerType::get(rewriter.getContext(), 64));
1010         break;
1011       case PrintConversion::None:
1012         break;
1013       }
1014       emitCall(rewriter, loc, printer, value);
1015       return;
1016     }
1017 
1018     emitCall(rewriter, loc,
1019              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1020     Operation *printComma =
1021         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1022     int64_t dim = vectorType.getDimSize(0);
1023     for (int64_t d = 0; d < dim; ++d) {
1024       auto reducedType =
1025           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1026       auto llvmType = typeConverter->convertType(
1027           rank > 1 ? reducedType : vectorType.getElementType());
1028       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1029                                    llvmType, rank, d);
1030       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1031                 conversion);
1032       if (d != dim - 1)
1033         emitCall(rewriter, loc, printComma);
1034     }
1035     emitCall(rewriter, loc,
1036              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1037   }
1038 
1039   // Helper to emit a call.
1040   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1041                        Operation *ref, ValueRange params = ValueRange()) {
1042     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1043                                   params);
1044   }
1045 };
1046 
1047 } // namespace
1048 
1049 /// Populate the given list with patterns that convert from Vector to LLVM.
1050 void mlir::populateVectorToLLVMConversionPatterns(
1051     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1052     bool reassociateFPReductions) {
1053   MLIRContext *ctx = converter.getDialect()->getContext();
1054   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1055   populateVectorInsertExtractStridedSliceTransforms(patterns);
1056   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1057   patterns
1058       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1059            VectorExtractElementOpConversion, VectorExtractOpConversion,
1060            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1061            VectorInsertOpConversion, VectorPrintOpConversion,
1062            VectorTypeCastOpConversion,
1063            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1064            VectorLoadStoreConversion<vector::MaskedLoadOp,
1065                                      vector::MaskedLoadOpAdaptor>,
1066            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1067            VectorLoadStoreConversion<vector::MaskedStoreOp,
1068                                      vector::MaskedStoreOpAdaptor>,
1069            VectorGatherOpConversion, VectorScatterOpConversion,
1070            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
1071           converter);
1072   // Transfer ops with rank > 1 are handled by VectorToSCF.
1073   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1074 }
1075 
1076 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1077     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1078   patterns.add<VectorMatmulOpConversion>(converter);
1079   patterns.add<VectorFlatTransposeOpConversion>(converter);
1080 }
1081