xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision cc311a155aa9e3d7ba67ec6d65948952a314c307)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Support/MathExtras.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28   assert((tp.getRank() > 1) && "unlowerable vector type");
29   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
30 }
31 
32 // Helper to reduce vector type by *all* but one rank at back.
33 static VectorType reducedVectorTypeBack(VectorType tp) {
34   assert((tp.getRank() > 1) && "unlowerable vector type");
35   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
36 }
37 
38 // Helper that picks the proper sequence for inserting.
39 static Value insertOne(ConversionPatternRewriter &rewriter,
40                        LLVMTypeConverter &typeConverter, Location loc,
41                        Value val1, Value val2, Type llvmType, int64_t rank,
42                        int64_t pos) {
43   assert(rank > 0 && "0-D vector corner case should have been handled already");
44   if (rank == 1) {
45     auto idxType = rewriter.getIndexType();
46     auto constant = rewriter.create<LLVM::ConstantOp>(
47         loc, typeConverter.convertType(idxType),
48         rewriter.getIntegerAttr(idxType, pos));
49     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
50                                                   constant);
51   }
52   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
53                                               rewriter.getI64ArrayAttr(pos));
54 }
55 
56 // Helper that picks the proper sequence for extracting.
57 static Value extractOne(ConversionPatternRewriter &rewriter,
58                         LLVMTypeConverter &typeConverter, Location loc,
59                         Value val, Type llvmType, int64_t rank, int64_t pos) {
60   if (rank <= 1) {
61     auto idxType = rewriter.getIndexType();
62     auto constant = rewriter.create<LLVM::ConstantOp>(
63         loc, typeConverter.convertType(idxType),
64         rewriter.getIntegerAttr(idxType, pos));
65     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
66                                                    constant);
67   }
68   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
69                                                rewriter.getI64ArrayAttr(pos));
70 }
71 
72 // Helper that returns data layout alignment of a memref.
73 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
74                                  MemRefType memrefType, unsigned &align) {
75   Type elementTy = typeConverter.convertType(memrefType.getElementType());
76   if (!elementTy)
77     return failure();
78 
79   // TODO: this should use the MLIR data layout when it becomes available and
80   // stop depending on translation.
81   llvm::LLVMContext llvmContext;
82   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
83               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
84   return success();
85 }
86 
87 // Return the minimal alignment value that satisfies all the AssumeAlignment
88 // uses of `value`. If no such uses exist, return 1.
89 static unsigned getAssumedAlignment(Value value) {
90   unsigned align = 1;
91   for (auto &u : value.getUses()) {
92     Operation *owner = u.getOwner();
93     if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
94       align = mlir::lcm(align, op.alignment());
95   }
96   return align;
97 }
98 
99 // Helper that returns data layout alignment of a memref associated with a
100 // load, store, scatter, or gather op, including additional information from
101 // assume_alignment calls on the source of the transfer
102 template <class OpAdaptor>
103 LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
104                                    OpAdaptor op, unsigned &align) {
105   if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
106     return failure();
107   align = std::max(align, getAssumedAlignment(op.base()));
108   return success();
109 }
110 
111 // Add an index vector component to a base pointer. This almost always succeeds
112 // unless the last stride is non-unit or the memory space is not zero.
113 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
114                                     Location loc, Value memref, Value base,
115                                     Value index, MemRefType memRefType,
116                                     VectorType vType, Value &ptrs) {
117   int64_t offset;
118   SmallVector<int64_t, 4> strides;
119   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
120   if (failed(successStrides) || strides.back() != 1 ||
121       memRefType.getMemorySpaceAsInt() != 0)
122     return failure();
123   auto pType = MemRefDescriptor(memref).getElementPtrType();
124   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
125   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
126   return success();
127 }
128 
129 // Casts a strided element pointer to a vector pointer.  The vector pointer
130 // will be in the same address space as the incoming memref type.
131 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
132                          Value ptr, MemRefType memRefType, Type vt) {
133   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
134   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
135 }
136 
137 namespace {
138 
139 /// Conversion pattern for a vector.bitcast.
140 class VectorBitCastOpConversion
141     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
142 public:
143   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
144 
145   LogicalResult
146   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
147                   ConversionPatternRewriter &rewriter) const override {
148     // Only 1-D vectors can be lowered to LLVM.
149     VectorType resultTy = bitCastOp.getType();
150     if (resultTy.getRank() != 1)
151       return failure();
152     Type newResultTy = typeConverter->convertType(resultTy);
153     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
154                                                  adaptor.getOperands()[0]);
155     return success();
156   }
157 };
158 
159 /// Conversion pattern for a vector.matrix_multiply.
160 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
161 class VectorMatmulOpConversion
162     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
163 public:
164   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
165 
166   LogicalResult
167   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
168                   ConversionPatternRewriter &rewriter) const override {
169     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
170         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
171         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
172         matmulOp.lhs_columns(), matmulOp.rhs_columns());
173     return success();
174   }
175 };
176 
177 /// Conversion pattern for a vector.flat_transpose.
178 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
179 class VectorFlatTransposeOpConversion
180     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
181 public:
182   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
183 
184   LogicalResult
185   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
186                   ConversionPatternRewriter &rewriter) const override {
187     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
188         transOp, typeConverter->convertType(transOp.res().getType()),
189         adaptor.matrix(), transOp.rows(), transOp.columns());
190     return success();
191   }
192 };
193 
194 /// Overloaded utility that replaces a vector.load, vector.store,
195 /// vector.maskedload and vector.maskedstore with their respective LLVM
196 /// couterparts.
197 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
198                                  vector::LoadOpAdaptor adaptor,
199                                  VectorType vectorTy, Value ptr, unsigned align,
200                                  ConversionPatternRewriter &rewriter) {
201   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
202 }
203 
204 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
205                                  vector::MaskedLoadOpAdaptor adaptor,
206                                  VectorType vectorTy, Value ptr, unsigned align,
207                                  ConversionPatternRewriter &rewriter) {
208   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
209       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
210 }
211 
212 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
213                                  vector::StoreOpAdaptor adaptor,
214                                  VectorType vectorTy, Value ptr, unsigned align,
215                                  ConversionPatternRewriter &rewriter) {
216   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
217                                              ptr, align);
218 }
219 
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221                                  vector::MaskedStoreOpAdaptor adaptor,
222                                  VectorType vectorTy, Value ptr, unsigned align,
223                                  ConversionPatternRewriter &rewriter) {
224   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
225       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
226 }
227 
228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
229 /// vector.maskedstore.
230 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
232 public:
233   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
234 
235   LogicalResult
236   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237                   typename LoadOrStoreOp::Adaptor adaptor,
238                   ConversionPatternRewriter &rewriter) const override {
239     // Only 1-D vectors can be lowered to LLVM.
240     VectorType vectorTy = loadOrStoreOp.getVectorType();
241     if (vectorTy.getRank() > 1)
242       return failure();
243 
244     auto loc = loadOrStoreOp->getLoc();
245     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
246 
247     // Resolve alignment.
248     unsigned align;
249     if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
250                                     align)))
251       return failure();
252 
253     // Resolve address.
254     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
255                      .template cast<VectorType>();
256     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
257                                                adaptor.indices(), rewriter);
258     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
259 
260     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
261     return success();
262   }
263 };
264 
265 /// Conversion pattern for a vector.gather.
266 class VectorGatherOpConversion
267     : public ConvertOpToLLVMPattern<vector::GatherOp> {
268 public:
269   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
270 
271   LogicalResult
272   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
273                   ConversionPatternRewriter &rewriter) const override {
274     auto loc = gather->getLoc();
275     MemRefType memRefType = gather.getMemRefType();
276 
277     // Resolve alignment.
278     unsigned align;
279     if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
280       return failure();
281 
282     // Resolve address.
283     Value ptrs;
284     VectorType vType = gather.getVectorType();
285     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
286                                      adaptor.indices(), rewriter);
287     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
288                               adaptor.index_vec(), memRefType, vType, ptrs)))
289       return failure();
290 
291     // Replace with the gather intrinsic.
292     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
293         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
294         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
295     return success();
296   }
297 };
298 
299 /// Conversion pattern for a vector.scatter.
300 class VectorScatterOpConversion
301     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
302 public:
303   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
304 
305   LogicalResult
306   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
307                   ConversionPatternRewriter &rewriter) const override {
308     auto loc = scatter->getLoc();
309     MemRefType memRefType = scatter.getMemRefType();
310 
311     // Resolve alignment.
312     unsigned align;
313     if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
314       return failure();
315 
316     // Resolve address.
317     Value ptrs;
318     VectorType vType = scatter.getVectorType();
319     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
320                                      adaptor.indices(), rewriter);
321     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
322                               adaptor.index_vec(), memRefType, vType, ptrs)))
323       return failure();
324 
325     // Replace with the scatter intrinsic.
326     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
327         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
328         rewriter.getI32IntegerAttr(align));
329     return success();
330   }
331 };
332 
333 /// Conversion pattern for a vector.expandload.
334 class VectorExpandLoadOpConversion
335     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
336 public:
337   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
338 
339   LogicalResult
340   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
341                   ConversionPatternRewriter &rewriter) const override {
342     auto loc = expand->getLoc();
343     MemRefType memRefType = expand.getMemRefType();
344 
345     // Resolve address.
346     auto vtype = typeConverter->convertType(expand.getVectorType());
347     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
348                                      adaptor.indices(), rewriter);
349 
350     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
351         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
352     return success();
353   }
354 };
355 
356 /// Conversion pattern for a vector.compressstore.
357 class VectorCompressStoreOpConversion
358     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
359 public:
360   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
361 
362   LogicalResult
363   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
364                   ConversionPatternRewriter &rewriter) const override {
365     auto loc = compress->getLoc();
366     MemRefType memRefType = compress.getMemRefType();
367 
368     // Resolve address.
369     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
370                                      adaptor.indices(), rewriter);
371 
372     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
373         compress, adaptor.valueToStore(), ptr, adaptor.mask());
374     return success();
375   }
376 };
377 
378 /// Conversion pattern for all vector reductions.
379 class VectorReductionOpConversion
380     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
381 public:
382   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
383                                        bool reassociateFPRed)
384       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
385         reassociateFPReductions(reassociateFPRed) {}
386 
387   LogicalResult
388   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
389                   ConversionPatternRewriter &rewriter) const override {
390     auto kind = reductionOp.kind();
391     Type eltType = reductionOp.dest().getType();
392     Type llvmType = typeConverter->convertType(eltType);
393     Value operand = adaptor.getOperands()[0];
394     if (eltType.isIntOrIndex()) {
395       // Integer reductions: add/mul/min/max/and/or/xor.
396       if (kind == "add")
397         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
398                                                              llvmType, operand);
399       else if (kind == "mul")
400         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
401                                                              llvmType, operand);
402       else if (kind == "minui")
403         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
404             reductionOp, llvmType, operand);
405       else if (kind == "minsi")
406         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
407             reductionOp, llvmType, operand);
408       else if (kind == "maxui")
409         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
410             reductionOp, llvmType, operand);
411       else if (kind == "maxsi")
412         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
413             reductionOp, llvmType, operand);
414       else if (kind == "and")
415         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
416                                                              llvmType, operand);
417       else if (kind == "or")
418         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
419                                                             llvmType, operand);
420       else if (kind == "xor")
421         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
422                                                              llvmType, operand);
423       else
424         return failure();
425       return success();
426     }
427 
428     if (!eltType.isa<FloatType>())
429       return failure();
430 
431     // Floating-point reductions: add/mul/min/max
432     if (kind == "add") {
433       // Optional accumulator (or zero).
434       Value acc = adaptor.getOperands().size() > 1
435                       ? adaptor.getOperands()[1]
436                       : rewriter.create<LLVM::ConstantOp>(
437                             reductionOp->getLoc(), llvmType,
438                             rewriter.getZeroAttr(eltType));
439       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
440           reductionOp, llvmType, acc, operand,
441           rewriter.getBoolAttr(reassociateFPReductions));
442     } else if (kind == "mul") {
443       // Optional accumulator (or one).
444       Value acc = adaptor.getOperands().size() > 1
445                       ? adaptor.getOperands()[1]
446                       : rewriter.create<LLVM::ConstantOp>(
447                             reductionOp->getLoc(), llvmType,
448                             rewriter.getFloatAttr(eltType, 1.0));
449       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
450           reductionOp, llvmType, acc, operand,
451           rewriter.getBoolAttr(reassociateFPReductions));
452     } else if (kind == "minf")
453       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
454       // NaNs/-0.0/+0.0 in the same way.
455       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
456                                                             llvmType, operand);
457     else if (kind == "maxf")
458       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
459       // NaNs/-0.0/+0.0 in the same way.
460       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
461                                                             llvmType, operand);
462     else
463       return failure();
464     return success();
465   }
466 
467 private:
468   const bool reassociateFPReductions;
469 };
470 
471 class VectorShuffleOpConversion
472     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
473 public:
474   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
475 
476   LogicalResult
477   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
478                   ConversionPatternRewriter &rewriter) const override {
479     auto loc = shuffleOp->getLoc();
480     auto v1Type = shuffleOp.getV1VectorType();
481     auto v2Type = shuffleOp.getV2VectorType();
482     auto vectorType = shuffleOp.getVectorType();
483     Type llvmType = typeConverter->convertType(vectorType);
484     auto maskArrayAttr = shuffleOp.mask();
485 
486     // Bail if result type cannot be lowered.
487     if (!llvmType)
488       return failure();
489 
490     // Get rank and dimension sizes.
491     int64_t rank = vectorType.getRank();
492     assert(v1Type.getRank() == rank);
493     assert(v2Type.getRank() == rank);
494     int64_t v1Dim = v1Type.getDimSize(0);
495 
496     // For rank 1, where both operands have *exactly* the same vector type,
497     // there is direct shuffle support in LLVM. Use it!
498     if (rank == 1 && v1Type == v2Type) {
499       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
500           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
501       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
502       return success();
503     }
504 
505     // For all other cases, insert the individual values individually.
506     Type eltType;
507     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
508       eltType = arrayType.getElementType();
509     else
510       eltType = llvmType.cast<VectorType>().getElementType();
511     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
512     int64_t insPos = 0;
513     for (auto en : llvm::enumerate(maskArrayAttr)) {
514       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
515       Value value = adaptor.v1();
516       if (extPos >= v1Dim) {
517         extPos -= v1Dim;
518         value = adaptor.v2();
519       }
520       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
521                                  eltType, rank, extPos);
522       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
523                          llvmType, rank, insPos++);
524     }
525     rewriter.replaceOp(shuffleOp, insert);
526     return success();
527   }
528 };
529 
530 class VectorExtractElementOpConversion
531     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
532 public:
533   using ConvertOpToLLVMPattern<
534       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
535 
536   LogicalResult
537   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
538                   ConversionPatternRewriter &rewriter) const override {
539     auto vectorType = extractEltOp.getVectorType();
540     auto llvmType = typeConverter->convertType(vectorType.getElementType());
541 
542     // Bail if result type cannot be lowered.
543     if (!llvmType)
544       return failure();
545 
546     if (vectorType.getRank() == 0) {
547       Location loc = extractEltOp.getLoc();
548       auto idxType = rewriter.getIndexType();
549       auto zero = rewriter.create<LLVM::ConstantOp>(
550           loc, typeConverter->convertType(idxType),
551           rewriter.getIntegerAttr(idxType, 0));
552       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
553           extractEltOp, llvmType, adaptor.vector(), zero);
554       return success();
555     }
556 
557     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
558         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
559     return success();
560   }
561 };
562 
563 class VectorExtractOpConversion
564     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
565 public:
566   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
567 
568   LogicalResult
569   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
570                   ConversionPatternRewriter &rewriter) const override {
571     auto loc = extractOp->getLoc();
572     auto vectorType = extractOp.getVectorType();
573     auto resultType = extractOp.getResult().getType();
574     auto llvmResultType = typeConverter->convertType(resultType);
575     auto positionArrayAttr = extractOp.position();
576 
577     // Bail if result type cannot be lowered.
578     if (!llvmResultType)
579       return failure();
580 
581     // Extract entire vector. Should be handled by folder, but just to be safe.
582     if (positionArrayAttr.empty()) {
583       rewriter.replaceOp(extractOp, adaptor.vector());
584       return success();
585     }
586 
587     // One-shot extraction of vector from array (only requires extractvalue).
588     if (resultType.isa<VectorType>()) {
589       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
590           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
591       rewriter.replaceOp(extractOp, extracted);
592       return success();
593     }
594 
595     // Potential extraction of 1-D vector from array.
596     auto *context = extractOp->getContext();
597     Value extracted = adaptor.vector();
598     auto positionAttrs = positionArrayAttr.getValue();
599     if (positionAttrs.size() > 1) {
600       auto oneDVectorType = reducedVectorTypeBack(vectorType);
601       auto nMinusOnePositionAttrs =
602           ArrayAttr::get(context, positionAttrs.drop_back());
603       extracted = rewriter.create<LLVM::ExtractValueOp>(
604           loc, typeConverter->convertType(oneDVectorType), extracted,
605           nMinusOnePositionAttrs);
606     }
607 
608     // Remaining extraction of element from 1-D LLVM vector
609     auto position = positionAttrs.back().cast<IntegerAttr>();
610     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
611     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
612     extracted =
613         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
614     rewriter.replaceOp(extractOp, extracted);
615 
616     return success();
617   }
618 };
619 
620 /// Conversion pattern that turns a vector.fma on a 1-D vector
621 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
622 /// This does not match vectors of n >= 2 rank.
623 ///
624 /// Example:
625 /// ```
626 ///  vector.fma %a, %a, %a : vector<8xf32>
627 /// ```
628 /// is converted to:
629 /// ```
630 ///  llvm.intr.fmuladd %va, %va, %va:
631 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
632 ///    -> !llvm."<8 x f32>">
633 /// ```
634 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
635 public:
636   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
637 
638   LogicalResult
639   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
640                   ConversionPatternRewriter &rewriter) const override {
641     VectorType vType = fmaOp.getVectorType();
642     if (vType.getRank() != 1)
643       return failure();
644     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
645                                                  adaptor.rhs(), adaptor.acc());
646     return success();
647   }
648 };
649 
650 class VectorInsertElementOpConversion
651     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
652 public:
653   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
654 
655   LogicalResult
656   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
657                   ConversionPatternRewriter &rewriter) const override {
658     auto vectorType = insertEltOp.getDestVectorType();
659     auto llvmType = typeConverter->convertType(vectorType);
660 
661     // Bail if result type cannot be lowered.
662     if (!llvmType)
663       return failure();
664 
665     if (vectorType.getRank() == 0) {
666       Location loc = insertEltOp.getLoc();
667       auto idxType = rewriter.getIndexType();
668       auto zero = rewriter.create<LLVM::ConstantOp>(
669           loc, typeConverter->convertType(idxType),
670           rewriter.getIntegerAttr(idxType, 0));
671       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
672           insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
673       return success();
674     }
675 
676     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
677         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
678         adaptor.position());
679     return success();
680   }
681 };
682 
683 class VectorInsertOpConversion
684     : public ConvertOpToLLVMPattern<vector::InsertOp> {
685 public:
686   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
687 
688   LogicalResult
689   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
690                   ConversionPatternRewriter &rewriter) const override {
691     auto loc = insertOp->getLoc();
692     auto sourceType = insertOp.getSourceType();
693     auto destVectorType = insertOp.getDestVectorType();
694     auto llvmResultType = typeConverter->convertType(destVectorType);
695     auto positionArrayAttr = insertOp.position();
696 
697     // Bail if result type cannot be lowered.
698     if (!llvmResultType)
699       return failure();
700 
701     // Overwrite entire vector with value. Should be handled by folder, but
702     // just to be safe.
703     if (positionArrayAttr.empty()) {
704       rewriter.replaceOp(insertOp, adaptor.source());
705       return success();
706     }
707 
708     // One-shot insertion of a vector into an array (only requires insertvalue).
709     if (sourceType.isa<VectorType>()) {
710       Value inserted = rewriter.create<LLVM::InsertValueOp>(
711           loc, llvmResultType, adaptor.dest(), adaptor.source(),
712           positionArrayAttr);
713       rewriter.replaceOp(insertOp, inserted);
714       return success();
715     }
716 
717     // Potential extraction of 1-D vector from array.
718     auto *context = insertOp->getContext();
719     Value extracted = adaptor.dest();
720     auto positionAttrs = positionArrayAttr.getValue();
721     auto position = positionAttrs.back().cast<IntegerAttr>();
722     auto oneDVectorType = destVectorType;
723     if (positionAttrs.size() > 1) {
724       oneDVectorType = reducedVectorTypeBack(destVectorType);
725       auto nMinusOnePositionAttrs =
726           ArrayAttr::get(context, positionAttrs.drop_back());
727       extracted = rewriter.create<LLVM::ExtractValueOp>(
728           loc, typeConverter->convertType(oneDVectorType), extracted,
729           nMinusOnePositionAttrs);
730     }
731 
732     // Insertion of an element into a 1-D LLVM vector.
733     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
734     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
735     Value inserted = rewriter.create<LLVM::InsertElementOp>(
736         loc, typeConverter->convertType(oneDVectorType), extracted,
737         adaptor.source(), constant);
738 
739     // Potential insertion of resulting 1-D vector into array.
740     if (positionAttrs.size() > 1) {
741       auto nMinusOnePositionAttrs =
742           ArrayAttr::get(context, positionAttrs.drop_back());
743       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
744                                                       adaptor.dest(), inserted,
745                                                       nMinusOnePositionAttrs);
746     }
747 
748     rewriter.replaceOp(insertOp, inserted);
749     return success();
750   }
751 };
752 
753 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
754 ///
755 /// Example:
756 /// ```
757 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
758 /// ```
759 /// is rewritten into:
760 /// ```
761 ///  %r = splat %f0: vector<2x4xf32>
762 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
763 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
764 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
765 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
766 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
767 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
768 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
769 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
770 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
771 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
772 ///  // %r3 holds the final value.
773 /// ```
774 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
775 public:
776   using OpRewritePattern<FMAOp>::OpRewritePattern;
777 
778   void initialize() {
779     // This pattern recursively unpacks one dimension at a time. The recursion
780     // bounded as the rank is strictly decreasing.
781     setHasBoundedRewriteRecursion();
782   }
783 
784   LogicalResult matchAndRewrite(FMAOp op,
785                                 PatternRewriter &rewriter) const override {
786     auto vType = op.getVectorType();
787     if (vType.getRank() < 2)
788       return failure();
789 
790     auto loc = op.getLoc();
791     auto elemType = vType.getElementType();
792     Value zero = rewriter.create<arith::ConstantOp>(
793         loc, elemType, rewriter.getZeroAttr(elemType));
794     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
795     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
796       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
797       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
798       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
799       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
800       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
801     }
802     rewriter.replaceOp(op, desc);
803     return success();
804   }
805 };
806 
807 /// Returns the strides if the memory underlying `memRefType` has a contiguous
808 /// static layout.
809 static llvm::Optional<SmallVector<int64_t, 4>>
810 computeContiguousStrides(MemRefType memRefType) {
811   int64_t offset;
812   SmallVector<int64_t, 4> strides;
813   if (failed(getStridesAndOffset(memRefType, strides, offset)))
814     return None;
815   if (!strides.empty() && strides.back() != 1)
816     return None;
817   // If no layout or identity layout, this is contiguous by definition.
818   if (memRefType.getLayout().isIdentity())
819     return strides;
820 
821   // Otherwise, we must determine contiguity form shapes. This can only ever
822   // work in static cases because MemRefType is underspecified to represent
823   // contiguous dynamic shapes in other ways than with just empty/identity
824   // layout.
825   auto sizes = memRefType.getShape();
826   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
827     if (ShapedType::isDynamic(sizes[index + 1]) ||
828         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
829         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
830       return None;
831     if (strides[index] != strides[index + 1] * sizes[index + 1])
832       return None;
833   }
834   return strides;
835 }
836 
837 class VectorTypeCastOpConversion
838     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
839 public:
840   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
841 
842   LogicalResult
843   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
844                   ConversionPatternRewriter &rewriter) const override {
845     auto loc = castOp->getLoc();
846     MemRefType sourceMemRefType =
847         castOp.getOperand().getType().cast<MemRefType>();
848     MemRefType targetMemRefType = castOp.getType();
849 
850     // Only static shape casts supported atm.
851     if (!sourceMemRefType.hasStaticShape() ||
852         !targetMemRefType.hasStaticShape())
853       return failure();
854 
855     auto llvmSourceDescriptorTy =
856         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
857     if (!llvmSourceDescriptorTy)
858       return failure();
859     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
860 
861     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
862                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
863     if (!llvmTargetDescriptorTy)
864       return failure();
865 
866     // Only contiguous source buffers supported atm.
867     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
868     if (!sourceStrides)
869       return failure();
870     auto targetStrides = computeContiguousStrides(targetMemRefType);
871     if (!targetStrides)
872       return failure();
873     // Only support static strides for now, regardless of contiguity.
874     if (llvm::any_of(*targetStrides, [](int64_t stride) {
875           return ShapedType::isDynamicStrideOrOffset(stride);
876         }))
877       return failure();
878 
879     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
880 
881     // Create descriptor.
882     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
883     Type llvmTargetElementTy = desc.getElementPtrType();
884     // Set allocated ptr.
885     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
886     allocated =
887         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
888     desc.setAllocatedPtr(rewriter, loc, allocated);
889     // Set aligned ptr.
890     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
891     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
892     desc.setAlignedPtr(rewriter, loc, ptr);
893     // Fill offset 0.
894     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
895     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
896     desc.setOffset(rewriter, loc, zero);
897 
898     // Fill size and stride descriptors in memref.
899     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
900       int64_t index = indexedSize.index();
901       auto sizeAttr =
902           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
903       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
904       desc.setSize(rewriter, loc, index, size);
905       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
906                                                 (*targetStrides)[index]);
907       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
908       desc.setStride(rewriter, loc, index, stride);
909     }
910 
911     rewriter.replaceOp(castOp, {desc});
912     return success();
913   }
914 };
915 
916 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
917 public:
918   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
919 
920   // Proof-of-concept lowering implementation that relies on a small
921   // runtime support library, which only needs to provide a few
922   // printing methods (single value for all data types, opening/closing
923   // bracket, comma, newline). The lowering fully unrolls a vector
924   // in terms of these elementary printing operations. The advantage
925   // of this approach is that the library can remain unaware of all
926   // low-level implementation details of vectors while still supporting
927   // output of any shaped and dimensioned vector. Due to full unrolling,
928   // this approach is less suited for very large vectors though.
929   //
930   // TODO: rely solely on libc in future? something else?
931   //
932   LogicalResult
933   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
934                   ConversionPatternRewriter &rewriter) const override {
935     Type printType = printOp.getPrintType();
936 
937     if (typeConverter->convertType(printType) == nullptr)
938       return failure();
939 
940     // Make sure element type has runtime support.
941     PrintConversion conversion = PrintConversion::None;
942     VectorType vectorType = printType.dyn_cast<VectorType>();
943     Type eltType = vectorType ? vectorType.getElementType() : printType;
944     Operation *printer;
945     if (eltType.isF32()) {
946       printer =
947           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
948     } else if (eltType.isF64()) {
949       printer =
950           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
951     } else if (eltType.isIndex()) {
952       printer =
953           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
954     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
955       // Integers need a zero or sign extension on the operand
956       // (depending on the source type) as well as a signed or
957       // unsigned print method. Up to 64-bit is supported.
958       unsigned width = intTy.getWidth();
959       if (intTy.isUnsigned()) {
960         if (width <= 64) {
961           if (width < 64)
962             conversion = PrintConversion::ZeroExt64;
963           printer = LLVM::lookupOrCreatePrintU64Fn(
964               printOp->getParentOfType<ModuleOp>());
965         } else {
966           return failure();
967         }
968       } else {
969         assert(intTy.isSignless() || intTy.isSigned());
970         if (width <= 64) {
971           // Note that we *always* zero extend booleans (1-bit integers),
972           // so that true/false is printed as 1/0 rather than -1/0.
973           if (width == 1)
974             conversion = PrintConversion::ZeroExt64;
975           else if (width < 64)
976             conversion = PrintConversion::SignExt64;
977           printer = LLVM::lookupOrCreatePrintI64Fn(
978               printOp->getParentOfType<ModuleOp>());
979         } else {
980           return failure();
981         }
982       }
983     } else {
984       return failure();
985     }
986 
987     // Unroll vector into elementary print calls.
988     int64_t rank = vectorType ? vectorType.getRank() : 0;
989     Type type = vectorType ? vectorType : eltType;
990     emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
991               conversion);
992     emitCall(rewriter, printOp->getLoc(),
993              LLVM::lookupOrCreatePrintNewlineFn(
994                  printOp->getParentOfType<ModuleOp>()));
995     rewriter.eraseOp(printOp);
996     return success();
997   }
998 
999 private:
1000   enum class PrintConversion {
1001     // clang-format off
1002     None,
1003     ZeroExt64,
1004     SignExt64
1005     // clang-format on
1006   };
1007 
1008   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1009                  Value value, Type type, Operation *printer, int64_t rank,
1010                  PrintConversion conversion) const {
1011     VectorType vectorType = type.dyn_cast<VectorType>();
1012     Location loc = op->getLoc();
1013     if (!vectorType) {
1014       assert(rank == 0 && "The scalar case expects rank == 0");
1015       switch (conversion) {
1016       case PrintConversion::ZeroExt64:
1017         value = rewriter.create<arith::ExtUIOp>(
1018             loc, value, IntegerType::get(rewriter.getContext(), 64));
1019         break;
1020       case PrintConversion::SignExt64:
1021         value = rewriter.create<arith::ExtSIOp>(
1022             loc, value, IntegerType::get(rewriter.getContext(), 64));
1023         break;
1024       case PrintConversion::None:
1025         break;
1026       }
1027       emitCall(rewriter, loc, printer, value);
1028       return;
1029     }
1030 
1031     emitCall(rewriter, loc,
1032              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1033     Operation *printComma =
1034         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1035 
1036     if (rank <= 1) {
1037       auto reducedType = vectorType.getElementType();
1038       auto llvmType = typeConverter->convertType(reducedType);
1039       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1040       for (int64_t d = 0; d < dim; ++d) {
1041         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1042                                      llvmType, /*rank=*/0, /*pos=*/d);
1043         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1044                   conversion);
1045         if (d != dim - 1)
1046           emitCall(rewriter, loc, printComma);
1047       }
1048       emitCall(
1049           rewriter, loc,
1050           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1051       return;
1052     }
1053 
1054     int64_t dim = vectorType.getDimSize(0);
1055     for (int64_t d = 0; d < dim; ++d) {
1056       auto reducedType = reducedVectorTypeFront(vectorType);
1057       auto llvmType = typeConverter->convertType(reducedType);
1058       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1059                                    llvmType, rank, d);
1060       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1061                 conversion);
1062       if (d != dim - 1)
1063         emitCall(rewriter, loc, printComma);
1064     }
1065     emitCall(rewriter, loc,
1066              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1067   }
1068 
1069   // Helper to emit a call.
1070   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1071                        Operation *ref, ValueRange params = ValueRange()) {
1072     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1073                                   params);
1074   }
1075 };
1076 
1077 } // namespace
1078 
1079 /// Populate the given list with patterns that convert from Vector to LLVM.
1080 void mlir::populateVectorToLLVMConversionPatterns(
1081     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1082     bool reassociateFPReductions) {
1083   MLIRContext *ctx = converter.getDialect()->getContext();
1084   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1085   populateVectorInsertExtractStridedSliceTransforms(patterns);
1086   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1087   patterns
1088       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1089            VectorExtractElementOpConversion, VectorExtractOpConversion,
1090            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1091            VectorInsertOpConversion, VectorPrintOpConversion,
1092            VectorTypeCastOpConversion,
1093            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1094            VectorLoadStoreConversion<vector::MaskedLoadOp,
1095                                      vector::MaskedLoadOpAdaptor>,
1096            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1097            VectorLoadStoreConversion<vector::MaskedStoreOp,
1098                                      vector::MaskedStoreOpAdaptor>,
1099            VectorGatherOpConversion, VectorScatterOpConversion,
1100            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>(
1101           converter);
1102   // Transfer ops with rank > 1 are handled by VectorToSCF.
1103   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1104 }
1105 
1106 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1107     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1108   patterns.add<VectorMatmulOpConversion>(converter);
1109   patterns.add<VectorFlatTransposeOpConversion>(converter);
1110 }
1111