xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 1f971e23f089c640d5a7df1e78572fe4d8bb1d0b)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/Support/MathExtras.h"
19 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 
25 // Helper to reduce vector type by one rank at front.
26 static VectorType reducedVectorTypeFront(VectorType tp) {
27   assert((tp.getRank() > 1) && "unlowerable vector type");
28   unsigned numScalableDims = tp.getNumScalableDims();
29   if (tp.getShape().size() == numScalableDims)
30     --numScalableDims;
31   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
32                          numScalableDims);
33 }
34 
35 // Helper to reduce vector type by *all* but one rank at back.
36 static VectorType reducedVectorTypeBack(VectorType tp) {
37   assert((tp.getRank() > 1) && "unlowerable vector type");
38   unsigned numScalableDims = tp.getNumScalableDims();
39   if (numScalableDims > 0)
40     --numScalableDims;
41   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
42                          numScalableDims);
43 }
44 
45 // Helper that picks the proper sequence for inserting.
46 static Value insertOne(ConversionPatternRewriter &rewriter,
47                        LLVMTypeConverter &typeConverter, Location loc,
48                        Value val1, Value val2, Type llvmType, int64_t rank,
49                        int64_t pos) {
50   assert(rank > 0 && "0-D vector corner case should have been handled already");
51   if (rank == 1) {
52     auto idxType = rewriter.getIndexType();
53     auto constant = rewriter.create<LLVM::ConstantOp>(
54         loc, typeConverter.convertType(idxType),
55         rewriter.getIntegerAttr(idxType, pos));
56     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
57                                                   constant);
58   }
59   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
60                                               rewriter.getI64ArrayAttr(pos));
61 }
62 
63 // Helper that picks the proper sequence for extracting.
64 static Value extractOne(ConversionPatternRewriter &rewriter,
65                         LLVMTypeConverter &typeConverter, Location loc,
66                         Value val, Type llvmType, int64_t rank, int64_t pos) {
67   if (rank <= 1) {
68     auto idxType = rewriter.getIndexType();
69     auto constant = rewriter.create<LLVM::ConstantOp>(
70         loc, typeConverter.convertType(idxType),
71         rewriter.getIntegerAttr(idxType, pos));
72     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
73                                                    constant);
74   }
75   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
76                                                rewriter.getI64ArrayAttr(pos));
77 }
78 
79 // Helper that returns data layout alignment of a memref.
80 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
81                                  MemRefType memrefType, unsigned &align) {
82   Type elementTy = typeConverter.convertType(memrefType.getElementType());
83   if (!elementTy)
84     return failure();
85 
86   // TODO: this should use the MLIR data layout when it becomes available and
87   // stop depending on translation.
88   llvm::LLVMContext llvmContext;
89   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
90               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
91   return success();
92 }
93 
94 // Add an index vector component to a base pointer. This almost always succeeds
95 // unless the last stride is non-unit or the memory space is not zero.
96 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
97                                     Location loc, Value memref, Value base,
98                                     Value index, MemRefType memRefType,
99                                     VectorType vType, Value &ptrs) {
100   int64_t offset;
101   SmallVector<int64_t, 4> strides;
102   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
103   if (failed(successStrides) || strides.back() != 1 ||
104       memRefType.getMemorySpaceAsInt() != 0)
105     return failure();
106   auto pType = MemRefDescriptor(memref).getElementPtrType();
107   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
108   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
109   return success();
110 }
111 
112 // Casts a strided element pointer to a vector pointer.  The vector pointer
113 // will be in the same address space as the incoming memref type.
114 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
115                          Value ptr, MemRefType memRefType, Type vt) {
116   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
117   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
118 }
119 
120 namespace {
121 
122 /// Trivial Vector to LLVM conversions
123 using VectorScaleOpConversion =
124     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
125 
126 /// Conversion pattern for a vector.bitcast.
127 class VectorBitCastOpConversion
128     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
129 public:
130   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
131 
132   LogicalResult
133   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
134                   ConversionPatternRewriter &rewriter) const override {
135     // Only 0-D and 1-D vectors can be lowered to LLVM.
136     VectorType resultTy = bitCastOp.getResultVectorType();
137     if (resultTy.getRank() > 1)
138       return failure();
139     Type newResultTy = typeConverter->convertType(resultTy);
140     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
141                                                  adaptor.getOperands()[0]);
142     return success();
143   }
144 };
145 
146 /// Conversion pattern for a vector.matrix_multiply.
147 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
148 class VectorMatmulOpConversion
149     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
150 public:
151   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
152 
153   LogicalResult
154   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
155                   ConversionPatternRewriter &rewriter) const override {
156     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
157         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
158         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
159         matmulOp.lhs_columns(), matmulOp.rhs_columns());
160     return success();
161   }
162 };
163 
164 /// Conversion pattern for a vector.flat_transpose.
165 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
166 class VectorFlatTransposeOpConversion
167     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
168 public:
169   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
170 
171   LogicalResult
172   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
173                   ConversionPatternRewriter &rewriter) const override {
174     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
175         transOp, typeConverter->convertType(transOp.res().getType()),
176         adaptor.matrix(), transOp.rows(), transOp.columns());
177     return success();
178   }
179 };
180 
181 /// Overloaded utility that replaces a vector.load, vector.store,
182 /// vector.maskedload and vector.maskedstore with their respective LLVM
183 /// couterparts.
184 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
185                                  vector::LoadOpAdaptor adaptor,
186                                  VectorType vectorTy, Value ptr, unsigned align,
187                                  ConversionPatternRewriter &rewriter) {
188   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
189 }
190 
191 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
192                                  vector::MaskedLoadOpAdaptor adaptor,
193                                  VectorType vectorTy, Value ptr, unsigned align,
194                                  ConversionPatternRewriter &rewriter) {
195   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
196       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
197 }
198 
199 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
200                                  vector::StoreOpAdaptor adaptor,
201                                  VectorType vectorTy, Value ptr, unsigned align,
202                                  ConversionPatternRewriter &rewriter) {
203   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
204                                              ptr, align);
205 }
206 
207 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
208                                  vector::MaskedStoreOpAdaptor adaptor,
209                                  VectorType vectorTy, Value ptr, unsigned align,
210                                  ConversionPatternRewriter &rewriter) {
211   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
212       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
213 }
214 
215 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
216 /// vector.maskedstore.
217 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
218 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
219 public:
220   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
221 
222   LogicalResult
223   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
224                   typename LoadOrStoreOp::Adaptor adaptor,
225                   ConversionPatternRewriter &rewriter) const override {
226     // Only 1-D vectors can be lowered to LLVM.
227     VectorType vectorTy = loadOrStoreOp.getVectorType();
228     if (vectorTy.getRank() > 1)
229       return failure();
230 
231     auto loc = loadOrStoreOp->getLoc();
232     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
233 
234     // Resolve alignment.
235     unsigned align;
236     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
237       return failure();
238 
239     // Resolve address.
240     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
241                      .template cast<VectorType>();
242     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
243                                                adaptor.indices(), rewriter);
244     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
245 
246     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
247     return success();
248   }
249 };
250 
251 /// Conversion pattern for a vector.gather.
252 class VectorGatherOpConversion
253     : public ConvertOpToLLVMPattern<vector::GatherOp> {
254 public:
255   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
256 
257   LogicalResult
258   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
259                   ConversionPatternRewriter &rewriter) const override {
260     auto loc = gather->getLoc();
261     MemRefType memRefType = gather.getMemRefType();
262 
263     // Resolve alignment.
264     unsigned align;
265     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
266       return failure();
267 
268     // Resolve address.
269     Value ptrs;
270     VectorType vType = gather.getVectorType();
271     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
272                                      adaptor.indices(), rewriter);
273     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
274                               adaptor.index_vec(), memRefType, vType, ptrs)))
275       return failure();
276 
277     // Replace with the gather intrinsic.
278     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
279         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
280         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
281     return success();
282   }
283 };
284 
285 /// Conversion pattern for a vector.scatter.
286 class VectorScatterOpConversion
287     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
288 public:
289   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
290 
291   LogicalResult
292   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
293                   ConversionPatternRewriter &rewriter) const override {
294     auto loc = scatter->getLoc();
295     MemRefType memRefType = scatter.getMemRefType();
296 
297     // Resolve alignment.
298     unsigned align;
299     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
300       return failure();
301 
302     // Resolve address.
303     Value ptrs;
304     VectorType vType = scatter.getVectorType();
305     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
306                                      adaptor.indices(), rewriter);
307     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
308                               adaptor.index_vec(), memRefType, vType, ptrs)))
309       return failure();
310 
311     // Replace with the scatter intrinsic.
312     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
313         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
314         rewriter.getI32IntegerAttr(align));
315     return success();
316   }
317 };
318 
319 /// Conversion pattern for a vector.expandload.
320 class VectorExpandLoadOpConversion
321     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
322 public:
323   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
324 
325   LogicalResult
326   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
327                   ConversionPatternRewriter &rewriter) const override {
328     auto loc = expand->getLoc();
329     MemRefType memRefType = expand.getMemRefType();
330 
331     // Resolve address.
332     auto vtype = typeConverter->convertType(expand.getVectorType());
333     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
334                                      adaptor.indices(), rewriter);
335 
336     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
337         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
338     return success();
339   }
340 };
341 
342 /// Conversion pattern for a vector.compressstore.
343 class VectorCompressStoreOpConversion
344     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
345 public:
346   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
347 
348   LogicalResult
349   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
350                   ConversionPatternRewriter &rewriter) const override {
351     auto loc = compress->getLoc();
352     MemRefType memRefType = compress.getMemRefType();
353 
354     // Resolve address.
355     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
356                                      adaptor.indices(), rewriter);
357 
358     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
359         compress, adaptor.valueToStore(), ptr, adaptor.mask());
360     return success();
361   }
362 };
363 
364 /// Conversion pattern for all vector reductions.
365 class VectorReductionOpConversion
366     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
367 public:
368   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
369                                        bool reassociateFPRed)
370       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
371         reassociateFPReductions(reassociateFPRed) {}
372 
373   LogicalResult
374   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
375                   ConversionPatternRewriter &rewriter) const override {
376     auto kind = reductionOp.kind();
377     Type eltType = reductionOp.dest().getType();
378     Type llvmType = typeConverter->convertType(eltType);
379     Value operand = adaptor.getOperands()[0];
380     if (eltType.isIntOrIndex()) {
381       // Integer reductions: add/mul/min/max/and/or/xor.
382       if (kind == vector::CombiningKind::ADD)
383         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
384                                                              llvmType, operand);
385       else if (kind == vector::CombiningKind::MUL)
386         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
387                                                              llvmType, operand);
388       else if (kind == vector::CombiningKind::MINUI)
389         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
390             reductionOp, llvmType, operand);
391       else if (kind == vector::CombiningKind::MINSI)
392         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
393             reductionOp, llvmType, operand);
394       else if (kind == vector::CombiningKind::MAXUI)
395         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
396             reductionOp, llvmType, operand);
397       else if (kind == vector::CombiningKind::MAXSI)
398         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
399             reductionOp, llvmType, operand);
400       else if (kind == vector::CombiningKind::AND)
401         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
402                                                              llvmType, operand);
403       else if (kind == vector::CombiningKind::OR)
404         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
405                                                             llvmType, operand);
406       else if (kind == vector::CombiningKind::XOR)
407         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
408                                                              llvmType, operand);
409       else
410         return failure();
411       return success();
412     }
413 
414     if (!eltType.isa<FloatType>())
415       return failure();
416 
417     // Floating-point reductions: add/mul/min/max
418     if (kind == vector::CombiningKind::ADD) {
419       // Optional accumulator (or zero).
420       Value acc = adaptor.getOperands().size() > 1
421                       ? adaptor.getOperands()[1]
422                       : rewriter.create<LLVM::ConstantOp>(
423                             reductionOp->getLoc(), llvmType,
424                             rewriter.getZeroAttr(eltType));
425       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
426           reductionOp, llvmType, acc, operand,
427           rewriter.getBoolAttr(reassociateFPReductions));
428     } else if (kind == vector::CombiningKind::MUL) {
429       // Optional accumulator (or one).
430       Value acc = adaptor.getOperands().size() > 1
431                       ? adaptor.getOperands()[1]
432                       : rewriter.create<LLVM::ConstantOp>(
433                             reductionOp->getLoc(), llvmType,
434                             rewriter.getFloatAttr(eltType, 1.0));
435       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
436           reductionOp, llvmType, acc, operand,
437           rewriter.getBoolAttr(reassociateFPReductions));
438     } else if (kind == vector::CombiningKind::MINF)
439       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
440       // NaNs/-0.0/+0.0 in the same way.
441       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
442                                                             llvmType, operand);
443     else if (kind == vector::CombiningKind::MAXF)
444       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
445       // NaNs/-0.0/+0.0 in the same way.
446       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
447                                                             llvmType, operand);
448     else
449       return failure();
450     return success();
451   }
452 
453 private:
454   const bool reassociateFPReductions;
455 };
456 
457 class VectorShuffleOpConversion
458     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
459 public:
460   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
461 
462   LogicalResult
463   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
464                   ConversionPatternRewriter &rewriter) const override {
465     auto loc = shuffleOp->getLoc();
466     auto v1Type = shuffleOp.getV1VectorType();
467     auto v2Type = shuffleOp.getV2VectorType();
468     auto vectorType = shuffleOp.getVectorType();
469     Type llvmType = typeConverter->convertType(vectorType);
470     auto maskArrayAttr = shuffleOp.mask();
471 
472     // Bail if result type cannot be lowered.
473     if (!llvmType)
474       return failure();
475 
476     // Get rank and dimension sizes.
477     int64_t rank = vectorType.getRank();
478     assert(v1Type.getRank() == rank);
479     assert(v2Type.getRank() == rank);
480     int64_t v1Dim = v1Type.getDimSize(0);
481 
482     // For rank 1, where both operands have *exactly* the same vector type,
483     // there is direct shuffle support in LLVM. Use it!
484     if (rank == 1 && v1Type == v2Type) {
485       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
486           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
487       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
488       return success();
489     }
490 
491     // For all other cases, insert the individual values individually.
492     Type eltType;
493     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
494       eltType = arrayType.getElementType();
495     else
496       eltType = llvmType.cast<VectorType>().getElementType();
497     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
498     int64_t insPos = 0;
499     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
500       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
501       Value value = adaptor.v1();
502       if (extPos >= v1Dim) {
503         extPos -= v1Dim;
504         value = adaptor.v2();
505       }
506       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
507                                  eltType, rank, extPos);
508       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
509                          llvmType, rank, insPos++);
510     }
511     rewriter.replaceOp(shuffleOp, insert);
512     return success();
513   }
514 };
515 
516 class VectorExtractElementOpConversion
517     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
518 public:
519   using ConvertOpToLLVMPattern<
520       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
521 
522   LogicalResult
523   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
524                   ConversionPatternRewriter &rewriter) const override {
525     auto vectorType = extractEltOp.getVectorType();
526     auto llvmType = typeConverter->convertType(vectorType.getElementType());
527 
528     // Bail if result type cannot be lowered.
529     if (!llvmType)
530       return failure();
531 
532     if (vectorType.getRank() == 0) {
533       Location loc = extractEltOp.getLoc();
534       auto idxType = rewriter.getIndexType();
535       auto zero = rewriter.create<LLVM::ConstantOp>(
536           loc, typeConverter->convertType(idxType),
537           rewriter.getIntegerAttr(idxType, 0));
538       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
539           extractEltOp, llvmType, adaptor.vector(), zero);
540       return success();
541     }
542 
543     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
544         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
545     return success();
546   }
547 };
548 
549 class VectorExtractOpConversion
550     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
551 public:
552   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
553 
554   LogicalResult
555   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
556                   ConversionPatternRewriter &rewriter) const override {
557     auto loc = extractOp->getLoc();
558     auto vectorType = extractOp.getVectorType();
559     auto resultType = extractOp.getResult().getType();
560     auto llvmResultType = typeConverter->convertType(resultType);
561     auto positionArrayAttr = extractOp.position();
562 
563     // Bail if result type cannot be lowered.
564     if (!llvmResultType)
565       return failure();
566 
567     // Extract entire vector. Should be handled by folder, but just to be safe.
568     if (positionArrayAttr.empty()) {
569       rewriter.replaceOp(extractOp, adaptor.vector());
570       return success();
571     }
572 
573     // One-shot extraction of vector from array (only requires extractvalue).
574     if (resultType.isa<VectorType>()) {
575       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
576           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
577       rewriter.replaceOp(extractOp, extracted);
578       return success();
579     }
580 
581     // Potential extraction of 1-D vector from array.
582     auto *context = extractOp->getContext();
583     Value extracted = adaptor.vector();
584     auto positionAttrs = positionArrayAttr.getValue();
585     if (positionAttrs.size() > 1) {
586       auto oneDVectorType = reducedVectorTypeBack(vectorType);
587       auto nMinusOnePositionAttrs =
588           ArrayAttr::get(context, positionAttrs.drop_back());
589       extracted = rewriter.create<LLVM::ExtractValueOp>(
590           loc, typeConverter->convertType(oneDVectorType), extracted,
591           nMinusOnePositionAttrs);
592     }
593 
594     // Remaining extraction of element from 1-D LLVM vector
595     auto position = positionAttrs.back().cast<IntegerAttr>();
596     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
597     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
598     extracted =
599         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
600     rewriter.replaceOp(extractOp, extracted);
601 
602     return success();
603   }
604 };
605 
606 /// Conversion pattern that turns a vector.fma on a 1-D vector
607 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
608 /// This does not match vectors of n >= 2 rank.
609 ///
610 /// Example:
611 /// ```
612 ///  vector.fma %a, %a, %a : vector<8xf32>
613 /// ```
614 /// is converted to:
615 /// ```
616 ///  llvm.intr.fmuladd %va, %va, %va:
617 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
618 ///    -> !llvm."<8 x f32>">
619 /// ```
620 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
621 public:
622   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
623 
624   LogicalResult
625   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
626                   ConversionPatternRewriter &rewriter) const override {
627     VectorType vType = fmaOp.getVectorType();
628     if (vType.getRank() != 1)
629       return failure();
630     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
631                                                  adaptor.rhs(), adaptor.acc());
632     return success();
633   }
634 };
635 
636 class VectorInsertElementOpConversion
637     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
638 public:
639   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
640 
641   LogicalResult
642   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
643                   ConversionPatternRewriter &rewriter) const override {
644     auto vectorType = insertEltOp.getDestVectorType();
645     auto llvmType = typeConverter->convertType(vectorType);
646 
647     // Bail if result type cannot be lowered.
648     if (!llvmType)
649       return failure();
650 
651     if (vectorType.getRank() == 0) {
652       Location loc = insertEltOp.getLoc();
653       auto idxType = rewriter.getIndexType();
654       auto zero = rewriter.create<LLVM::ConstantOp>(
655           loc, typeConverter->convertType(idxType),
656           rewriter.getIntegerAttr(idxType, 0));
657       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
658           insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero);
659       return success();
660     }
661 
662     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
663         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
664         adaptor.position());
665     return success();
666   }
667 };
668 
669 class VectorInsertOpConversion
670     : public ConvertOpToLLVMPattern<vector::InsertOp> {
671 public:
672   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
673 
674   LogicalResult
675   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
676                   ConversionPatternRewriter &rewriter) const override {
677     auto loc = insertOp->getLoc();
678     auto sourceType = insertOp.getSourceType();
679     auto destVectorType = insertOp.getDestVectorType();
680     auto llvmResultType = typeConverter->convertType(destVectorType);
681     auto positionArrayAttr = insertOp.position();
682 
683     // Bail if result type cannot be lowered.
684     if (!llvmResultType)
685       return failure();
686 
687     // Overwrite entire vector with value. Should be handled by folder, but
688     // just to be safe.
689     if (positionArrayAttr.empty()) {
690       rewriter.replaceOp(insertOp, adaptor.source());
691       return success();
692     }
693 
694     // One-shot insertion of a vector into an array (only requires insertvalue).
695     if (sourceType.isa<VectorType>()) {
696       Value inserted = rewriter.create<LLVM::InsertValueOp>(
697           loc, llvmResultType, adaptor.dest(), adaptor.source(),
698           positionArrayAttr);
699       rewriter.replaceOp(insertOp, inserted);
700       return success();
701     }
702 
703     // Potential extraction of 1-D vector from array.
704     auto *context = insertOp->getContext();
705     Value extracted = adaptor.dest();
706     auto positionAttrs = positionArrayAttr.getValue();
707     auto position = positionAttrs.back().cast<IntegerAttr>();
708     auto oneDVectorType = destVectorType;
709     if (positionAttrs.size() > 1) {
710       oneDVectorType = reducedVectorTypeBack(destVectorType);
711       auto nMinusOnePositionAttrs =
712           ArrayAttr::get(context, positionAttrs.drop_back());
713       extracted = rewriter.create<LLVM::ExtractValueOp>(
714           loc, typeConverter->convertType(oneDVectorType), extracted,
715           nMinusOnePositionAttrs);
716     }
717 
718     // Insertion of an element into a 1-D LLVM vector.
719     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
720     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
721     Value inserted = rewriter.create<LLVM::InsertElementOp>(
722         loc, typeConverter->convertType(oneDVectorType), extracted,
723         adaptor.source(), constant);
724 
725     // Potential insertion of resulting 1-D vector into array.
726     if (positionAttrs.size() > 1) {
727       auto nMinusOnePositionAttrs =
728           ArrayAttr::get(context, positionAttrs.drop_back());
729       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
730                                                       adaptor.dest(), inserted,
731                                                       nMinusOnePositionAttrs);
732     }
733 
734     rewriter.replaceOp(insertOp, inserted);
735     return success();
736   }
737 };
738 
739 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
740 ///
741 /// Example:
742 /// ```
743 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
744 /// ```
745 /// is rewritten into:
746 /// ```
747 ///  %r = splat %f0: vector<2x4xf32>
748 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
749 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
750 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
751 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
752 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
753 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
754 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
755 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
756 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
757 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
758 ///  // %r3 holds the final value.
759 /// ```
760 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
761 public:
762   using OpRewritePattern<FMAOp>::OpRewritePattern;
763 
764   void initialize() {
765     // This pattern recursively unpacks one dimension at a time. The recursion
766     // bounded as the rank is strictly decreasing.
767     setHasBoundedRewriteRecursion();
768   }
769 
770   LogicalResult matchAndRewrite(FMAOp op,
771                                 PatternRewriter &rewriter) const override {
772     auto vType = op.getVectorType();
773     if (vType.getRank() < 2)
774       return failure();
775 
776     auto loc = op.getLoc();
777     auto elemType = vType.getElementType();
778     Value zero = rewriter.create<arith::ConstantOp>(
779         loc, elemType, rewriter.getZeroAttr(elemType));
780     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
781     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
782       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
783       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
784       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
785       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
786       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
787     }
788     rewriter.replaceOp(op, desc);
789     return success();
790   }
791 };
792 
793 /// Returns the strides if the memory underlying `memRefType` has a contiguous
794 /// static layout.
795 static llvm::Optional<SmallVector<int64_t, 4>>
796 computeContiguousStrides(MemRefType memRefType) {
797   int64_t offset;
798   SmallVector<int64_t, 4> strides;
799   if (failed(getStridesAndOffset(memRefType, strides, offset)))
800     return None;
801   if (!strides.empty() && strides.back() != 1)
802     return None;
803   // If no layout or identity layout, this is contiguous by definition.
804   if (memRefType.getLayout().isIdentity())
805     return strides;
806 
807   // Otherwise, we must determine contiguity form shapes. This can only ever
808   // work in static cases because MemRefType is underspecified to represent
809   // contiguous dynamic shapes in other ways than with just empty/identity
810   // layout.
811   auto sizes = memRefType.getShape();
812   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
813     if (ShapedType::isDynamic(sizes[index + 1]) ||
814         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
815         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
816       return None;
817     if (strides[index] != strides[index + 1] * sizes[index + 1])
818       return None;
819   }
820   return strides;
821 }
822 
823 class VectorTypeCastOpConversion
824     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
825 public:
826   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
827 
828   LogicalResult
829   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
830                   ConversionPatternRewriter &rewriter) const override {
831     auto loc = castOp->getLoc();
832     MemRefType sourceMemRefType =
833         castOp.getOperand().getType().cast<MemRefType>();
834     MemRefType targetMemRefType = castOp.getType();
835 
836     // Only static shape casts supported atm.
837     if (!sourceMemRefType.hasStaticShape() ||
838         !targetMemRefType.hasStaticShape())
839       return failure();
840 
841     auto llvmSourceDescriptorTy =
842         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
843     if (!llvmSourceDescriptorTy)
844       return failure();
845     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
846 
847     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
848                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
849     if (!llvmTargetDescriptorTy)
850       return failure();
851 
852     // Only contiguous source buffers supported atm.
853     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
854     if (!sourceStrides)
855       return failure();
856     auto targetStrides = computeContiguousStrides(targetMemRefType);
857     if (!targetStrides)
858       return failure();
859     // Only support static strides for now, regardless of contiguity.
860     if (llvm::any_of(*targetStrides, [](int64_t stride) {
861           return ShapedType::isDynamicStrideOrOffset(stride);
862         }))
863       return failure();
864 
865     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
866 
867     // Create descriptor.
868     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
869     Type llvmTargetElementTy = desc.getElementPtrType();
870     // Set allocated ptr.
871     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
872     allocated =
873         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
874     desc.setAllocatedPtr(rewriter, loc, allocated);
875     // Set aligned ptr.
876     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
877     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
878     desc.setAlignedPtr(rewriter, loc, ptr);
879     // Fill offset 0.
880     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
881     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
882     desc.setOffset(rewriter, loc, zero);
883 
884     // Fill size and stride descriptors in memref.
885     for (const auto &indexedSize :
886          llvm::enumerate(targetMemRefType.getShape())) {
887       int64_t index = indexedSize.index();
888       auto sizeAttr =
889           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
890       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
891       desc.setSize(rewriter, loc, index, size);
892       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
893                                                 (*targetStrides)[index]);
894       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
895       desc.setStride(rewriter, loc, index, stride);
896     }
897 
898     rewriter.replaceOp(castOp, {desc});
899     return success();
900   }
901 };
902 
903 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
904 public:
905   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
906 
907   // Proof-of-concept lowering implementation that relies on a small
908   // runtime support library, which only needs to provide a few
909   // printing methods (single value for all data types, opening/closing
910   // bracket, comma, newline). The lowering fully unrolls a vector
911   // in terms of these elementary printing operations. The advantage
912   // of this approach is that the library can remain unaware of all
913   // low-level implementation details of vectors while still supporting
914   // output of any shaped and dimensioned vector. Due to full unrolling,
915   // this approach is less suited for very large vectors though.
916   //
917   // TODO: rely solely on libc in future? something else?
918   //
919   LogicalResult
920   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
921                   ConversionPatternRewriter &rewriter) const override {
922     Type printType = printOp.getPrintType();
923 
924     if (typeConverter->convertType(printType) == nullptr)
925       return failure();
926 
927     // Make sure element type has runtime support.
928     PrintConversion conversion = PrintConversion::None;
929     VectorType vectorType = printType.dyn_cast<VectorType>();
930     Type eltType = vectorType ? vectorType.getElementType() : printType;
931     Operation *printer;
932     if (eltType.isF32()) {
933       printer =
934           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
935     } else if (eltType.isF64()) {
936       printer =
937           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
938     } else if (eltType.isIndex()) {
939       printer =
940           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
941     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
942       // Integers need a zero or sign extension on the operand
943       // (depending on the source type) as well as a signed or
944       // unsigned print method. Up to 64-bit is supported.
945       unsigned width = intTy.getWidth();
946       if (intTy.isUnsigned()) {
947         if (width <= 64) {
948           if (width < 64)
949             conversion = PrintConversion::ZeroExt64;
950           printer = LLVM::lookupOrCreatePrintU64Fn(
951               printOp->getParentOfType<ModuleOp>());
952         } else {
953           return failure();
954         }
955       } else {
956         assert(intTy.isSignless() || intTy.isSigned());
957         if (width <= 64) {
958           // Note that we *always* zero extend booleans (1-bit integers),
959           // so that true/false is printed as 1/0 rather than -1/0.
960           if (width == 1)
961             conversion = PrintConversion::ZeroExt64;
962           else if (width < 64)
963             conversion = PrintConversion::SignExt64;
964           printer = LLVM::lookupOrCreatePrintI64Fn(
965               printOp->getParentOfType<ModuleOp>());
966         } else {
967           return failure();
968         }
969       }
970     } else {
971       return failure();
972     }
973 
974     // Unroll vector into elementary print calls.
975     int64_t rank = vectorType ? vectorType.getRank() : 0;
976     Type type = vectorType ? vectorType : eltType;
977     emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
978               conversion);
979     emitCall(rewriter, printOp->getLoc(),
980              LLVM::lookupOrCreatePrintNewlineFn(
981                  printOp->getParentOfType<ModuleOp>()));
982     rewriter.eraseOp(printOp);
983     return success();
984   }
985 
986 private:
987   enum class PrintConversion {
988     // clang-format off
989     None,
990     ZeroExt64,
991     SignExt64
992     // clang-format on
993   };
994 
995   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
996                  Value value, Type type, Operation *printer, int64_t rank,
997                  PrintConversion conversion) const {
998     VectorType vectorType = type.dyn_cast<VectorType>();
999     Location loc = op->getLoc();
1000     if (!vectorType) {
1001       assert(rank == 0 && "The scalar case expects rank == 0");
1002       switch (conversion) {
1003       case PrintConversion::ZeroExt64:
1004         value = rewriter.create<arith::ExtUIOp>(
1005             loc, IntegerType::get(rewriter.getContext(), 64), value);
1006         break;
1007       case PrintConversion::SignExt64:
1008         value = rewriter.create<arith::ExtSIOp>(
1009             loc, IntegerType::get(rewriter.getContext(), 64), value);
1010         break;
1011       case PrintConversion::None:
1012         break;
1013       }
1014       emitCall(rewriter, loc, printer, value);
1015       return;
1016     }
1017 
1018     emitCall(rewriter, loc,
1019              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1020     Operation *printComma =
1021         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1022 
1023     if (rank <= 1) {
1024       auto reducedType = vectorType.getElementType();
1025       auto llvmType = typeConverter->convertType(reducedType);
1026       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1027       for (int64_t d = 0; d < dim; ++d) {
1028         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1029                                      llvmType, /*rank=*/0, /*pos=*/d);
1030         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1031                   conversion);
1032         if (d != dim - 1)
1033           emitCall(rewriter, loc, printComma);
1034       }
1035       emitCall(
1036           rewriter, loc,
1037           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1038       return;
1039     }
1040 
1041     int64_t dim = vectorType.getDimSize(0);
1042     for (int64_t d = 0; d < dim; ++d) {
1043       auto reducedType = reducedVectorTypeFront(vectorType);
1044       auto llvmType = typeConverter->convertType(reducedType);
1045       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1046                                    llvmType, rank, d);
1047       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1048                 conversion);
1049       if (d != dim - 1)
1050         emitCall(rewriter, loc, printComma);
1051     }
1052     emitCall(rewriter, loc,
1053              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1054   }
1055 
1056   // Helper to emit a call.
1057   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1058                        Operation *ref, ValueRange params = ValueRange()) {
1059     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1060                                   params);
1061   }
1062 };
1063 
1064 /// The Splat operation is lowered to an insertelement + a shufflevector
1065 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1066 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1067   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1068 
1069   LogicalResult
1070   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1071                   ConversionPatternRewriter &rewriter) const override {
1072     VectorType resultType = splatOp.getType().cast<VectorType>();
1073     if (resultType.getRank() > 1)
1074       return failure();
1075 
1076     // First insert it into an undef vector so we can shuffle it.
1077     auto vectorType = typeConverter->convertType(splatOp.getType());
1078     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1079     auto zero = rewriter.create<LLVM::ConstantOp>(
1080         splatOp.getLoc(),
1081         typeConverter->convertType(rewriter.getIntegerType(32)),
1082         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1083 
1084     // For 0-d vector, we simply do `insertelement`.
1085     if (resultType.getRank() == 0) {
1086       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1087           splatOp, vectorType, undef, adaptor.input(), zero);
1088       return success();
1089     }
1090 
1091     // For 1-d vector, we additionally do a `vectorshuffle`.
1092     auto v = rewriter.create<LLVM::InsertElementOp>(
1093         splatOp.getLoc(), vectorType, undef, adaptor.input(), zero);
1094 
1095     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1096     SmallVector<int32_t, 4> zeroValues(width, 0);
1097 
1098     // Shuffle the value across the desired number of elements.
1099     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1100     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1101                                                        zeroAttrs);
1102     return success();
1103   }
1104 };
1105 
1106 /// The Splat operation is lowered to an insertelement + a shufflevector
1107 /// operation. Splat to only 2+-d vector result types are lowered by the
1108 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1109 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1110   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1111 
1112   LogicalResult
1113   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1114                   ConversionPatternRewriter &rewriter) const override {
1115     VectorType resultType = splatOp.getType();
1116     if (resultType.getRank() <= 1)
1117       return failure();
1118 
1119     // First insert it into an undef vector so we can shuffle it.
1120     auto loc = splatOp.getLoc();
1121     auto vectorTypeInfo =
1122         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1123     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1124     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1125     if (!llvmNDVectorTy || !llvm1DVectorTy)
1126       return failure();
1127 
1128     // Construct returned value.
1129     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1130 
1131     // Construct a 1-D vector with the splatted value that we insert in all the
1132     // places within the returned descriptor.
1133     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1134     auto zero = rewriter.create<LLVM::ConstantOp>(
1135         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1136         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1137     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1138                                                      adaptor.input(), zero);
1139 
1140     // Shuffle the value across the desired number of elements.
1141     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1142     SmallVector<int32_t, 4> zeroValues(width, 0);
1143     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1144     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
1145 
1146     // Iterate of linear index, convert to coords space and insert splatted 1-D
1147     // vector in each position.
1148     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
1149       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
1150                                                   position);
1151     });
1152     rewriter.replaceOp(splatOp, desc);
1153     return success();
1154   }
1155 };
1156 
1157 } // namespace
1158 
1159 /// Populate the given list with patterns that convert from Vector to LLVM.
1160 void mlir::populateVectorToLLVMConversionPatterns(
1161     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1162     bool reassociateFPReductions) {
1163   MLIRContext *ctx = converter.getDialect()->getContext();
1164   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1165   populateVectorInsertExtractStridedSliceTransforms(patterns);
1166   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1167   patterns
1168       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1169            VectorExtractElementOpConversion, VectorExtractOpConversion,
1170            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1171            VectorInsertOpConversion, VectorPrintOpConversion,
1172            VectorTypeCastOpConversion, VectorScaleOpConversion,
1173            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1174            VectorLoadStoreConversion<vector::MaskedLoadOp,
1175                                      vector::MaskedLoadOpAdaptor>,
1176            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1177            VectorLoadStoreConversion<vector::MaskedStoreOp,
1178                                      vector::MaskedStoreOpAdaptor>,
1179            VectorGatherOpConversion, VectorScatterOpConversion,
1180            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1181            VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
1182   // Transfer ops with rank > 1 are handled by VectorToSCF.
1183   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1184 }
1185 
1186 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1187     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1188   patterns.add<VectorMatmulOpConversion>(converter);
1189   patterns.add<VectorFlatTransposeOpConversion>(converter);
1190 }
1191