xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 7c38fd605ba85657a0ecbea75a8e3a68174d3dff)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Support/MathExtras.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28   assert((tp.getRank() > 1) && "unlowerable vector type");
29   unsigned numScalableDims = tp.getNumScalableDims();
30   if (tp.getShape().size() == numScalableDims)
31     --numScalableDims;
32   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
33                          numScalableDims);
34 }
35 
36 // Helper to reduce vector type by *all* but one rank at back.
37 static VectorType reducedVectorTypeBack(VectorType tp) {
38   assert((tp.getRank() > 1) && "unlowerable vector type");
39   unsigned numScalableDims = tp.getNumScalableDims();
40   if (numScalableDims > 0)
41     --numScalableDims;
42   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
43                          numScalableDims);
44 }
45 
46 // Helper that picks the proper sequence for inserting.
47 static Value insertOne(ConversionPatternRewriter &rewriter,
48                        LLVMTypeConverter &typeConverter, Location loc,
49                        Value val1, Value val2, Type llvmType, int64_t rank,
50                        int64_t pos) {
51   assert(rank > 0 && "0-D vector corner case should have been handled already");
52   if (rank == 1) {
53     auto idxType = rewriter.getIndexType();
54     auto constant = rewriter.create<LLVM::ConstantOp>(
55         loc, typeConverter.convertType(idxType),
56         rewriter.getIntegerAttr(idxType, pos));
57     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58                                                   constant);
59   }
60   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
61                                               rewriter.getI64ArrayAttr(pos));
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank <= 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
77                                                rewriter.getI64ArrayAttr(pos));
78 }
79 
80 // Helper that returns data layout alignment of a memref.
81 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
82                                  MemRefType memrefType, unsigned &align) {
83   Type elementTy = typeConverter.convertType(memrefType.getElementType());
84   if (!elementTy)
85     return failure();
86 
87   // TODO: this should use the MLIR data layout when it becomes available and
88   // stop depending on translation.
89   llvm::LLVMContext llvmContext;
90   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
91               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
92   return success();
93 }
94 
95 // Add an index vector component to a base pointer. This almost always succeeds
96 // unless the last stride is non-unit or the memory space is not zero.
97 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
98                                     Location loc, Value memref, Value base,
99                                     Value index, MemRefType memRefType,
100                                     VectorType vType, Value &ptrs) {
101   int64_t offset;
102   SmallVector<int64_t, 4> strides;
103   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
104   if (failed(successStrides) || strides.back() != 1 ||
105       memRefType.getMemorySpaceAsInt() != 0)
106     return failure();
107   auto pType = MemRefDescriptor(memref).getElementPtrType();
108   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
109   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
110   return success();
111 }
112 
113 // Casts a strided element pointer to a vector pointer.  The vector pointer
114 // will be in the same address space as the incoming memref type.
115 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
116                          Value ptr, MemRefType memRefType, Type vt) {
117   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
118   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
119 }
120 
121 namespace {
122 
123 /// Trivial Vector to LLVM conversions
124 using VectorScaleOpConversion =
125     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
126 
127 /// Conversion pattern for a vector.bitcast.
128 class VectorBitCastOpConversion
129     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
130 public:
131   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
132 
133   LogicalResult
134   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
135                   ConversionPatternRewriter &rewriter) const override {
136     // Only 0-D and 1-D vectors can be lowered to LLVM.
137     VectorType resultTy = bitCastOp.getResultVectorType();
138     if (resultTy.getRank() > 1)
139       return failure();
140     Type newResultTy = typeConverter->convertType(resultTy);
141     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
142                                                  adaptor.getOperands()[0]);
143     return success();
144   }
145 };
146 
147 /// Conversion pattern for a vector.matrix_multiply.
148 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
149 class VectorMatmulOpConversion
150     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
151 public:
152   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
153 
154   LogicalResult
155   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
156                   ConversionPatternRewriter &rewriter) const override {
157     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
158         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
159         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
160         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
161     return success();
162   }
163 };
164 
165 /// Conversion pattern for a vector.flat_transpose.
166 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
167 class VectorFlatTransposeOpConversion
168     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
169 public:
170   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
171 
172   LogicalResult
173   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
174                   ConversionPatternRewriter &rewriter) const override {
175     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
176         transOp, typeConverter->convertType(transOp.getRes().getType()),
177         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
178     return success();
179   }
180 };
181 
182 /// Overloaded utility that replaces a vector.load, vector.store,
183 /// vector.maskedload and vector.maskedstore with their respective LLVM
184 /// couterparts.
185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
186                                  vector::LoadOpAdaptor adaptor,
187                                  VectorType vectorTy, Value ptr, unsigned align,
188                                  ConversionPatternRewriter &rewriter) {
189   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
190 }
191 
192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
193                                  vector::MaskedLoadOpAdaptor adaptor,
194                                  VectorType vectorTy, Value ptr, unsigned align,
195                                  ConversionPatternRewriter &rewriter) {
196   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
197       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
198 }
199 
200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
201                                  vector::StoreOpAdaptor adaptor,
202                                  VectorType vectorTy, Value ptr, unsigned align,
203                                  ConversionPatternRewriter &rewriter) {
204   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
205                                              ptr, align);
206 }
207 
208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
209                                  vector::MaskedStoreOpAdaptor adaptor,
210                                  VectorType vectorTy, Value ptr, unsigned align,
211                                  ConversionPatternRewriter &rewriter) {
212   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
213       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
214 }
215 
216 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
217 /// vector.maskedstore.
218 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
219 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
220 public:
221   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
222 
223   LogicalResult
224   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
225                   typename LoadOrStoreOp::Adaptor adaptor,
226                   ConversionPatternRewriter &rewriter) const override {
227     // Only 1-D vectors can be lowered to LLVM.
228     VectorType vectorTy = loadOrStoreOp.getVectorType();
229     if (vectorTy.getRank() > 1)
230       return failure();
231 
232     auto loc = loadOrStoreOp->getLoc();
233     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
234 
235     // Resolve alignment.
236     unsigned align;
237     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
238       return failure();
239 
240     // Resolve address.
241     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
242                      .template cast<VectorType>();
243     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
244                                                adaptor.getIndices(), rewriter);
245     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
246 
247     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
248     return success();
249   }
250 };
251 
252 /// Conversion pattern for a vector.gather.
253 class VectorGatherOpConversion
254     : public ConvertOpToLLVMPattern<vector::GatherOp> {
255 public:
256   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
257 
258   LogicalResult
259   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
260                   ConversionPatternRewriter &rewriter) const override {
261     auto loc = gather->getLoc();
262     MemRefType memRefType = gather.getMemRefType();
263 
264     // Resolve alignment.
265     unsigned align;
266     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
267       return failure();
268 
269     // Resolve address.
270     Value ptrs;
271     VectorType vType = gather.getVectorType();
272     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
273                                      adaptor.getIndices(), rewriter);
274     if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
275                               adaptor.getIndexVec(), memRefType, vType, ptrs)))
276       return failure();
277 
278     // Replace with the gather intrinsic.
279     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
280         gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
281         adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
282     return success();
283   }
284 };
285 
286 /// Conversion pattern for a vector.scatter.
287 class VectorScatterOpConversion
288     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
289 public:
290   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
291 
292   LogicalResult
293   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
294                   ConversionPatternRewriter &rewriter) const override {
295     auto loc = scatter->getLoc();
296     MemRefType memRefType = scatter.getMemRefType();
297 
298     // Resolve alignment.
299     unsigned align;
300     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
301       return failure();
302 
303     // Resolve address.
304     Value ptrs;
305     VectorType vType = scatter.getVectorType();
306     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
307                                      adaptor.getIndices(), rewriter);
308     if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
309                               adaptor.getIndexVec(), memRefType, vType, ptrs)))
310       return failure();
311 
312     // Replace with the scatter intrinsic.
313     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
314         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
315         rewriter.getI32IntegerAttr(align));
316     return success();
317   }
318 };
319 
320 /// Conversion pattern for a vector.expandload.
321 class VectorExpandLoadOpConversion
322     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
323 public:
324   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
325 
326   LogicalResult
327   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329     auto loc = expand->getLoc();
330     MemRefType memRefType = expand.getMemRefType();
331 
332     // Resolve address.
333     auto vtype = typeConverter->convertType(expand.getVectorType());
334     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
335                                      adaptor.getIndices(), rewriter);
336 
337     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
338         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
339     return success();
340   }
341 };
342 
343 /// Conversion pattern for a vector.compressstore.
344 class VectorCompressStoreOpConversion
345     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
346 public:
347   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
348 
349   LogicalResult
350   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
351                   ConversionPatternRewriter &rewriter) const override {
352     auto loc = compress->getLoc();
353     MemRefType memRefType = compress.getMemRefType();
354 
355     // Resolve address.
356     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
357                                      adaptor.getIndices(), rewriter);
358 
359     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
360         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
361     return success();
362   }
363 };
364 
365 /// Conversion pattern for all vector reductions.
366 class VectorReductionOpConversion
367     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
368 public:
369   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
370                                        bool reassociateFPRed)
371       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
372         reassociateFPReductions(reassociateFPRed) {}
373 
374   LogicalResult
375   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
376                   ConversionPatternRewriter &rewriter) const override {
377     auto kind = reductionOp.getKind();
378     Type eltType = reductionOp.getDest().getType();
379     Type llvmType = typeConverter->convertType(eltType);
380     Value operand = adaptor.getOperands()[0];
381     if (eltType.isIntOrIndex()) {
382       // Integer reductions: add/mul/min/max/and/or/xor.
383       if (kind == vector::CombiningKind::ADD)
384         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
385                                                              llvmType, operand);
386       else if (kind == vector::CombiningKind::MUL)
387         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
388                                                              llvmType, operand);
389       else if (kind == vector::CombiningKind::MINUI)
390         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
391             reductionOp, llvmType, operand);
392       else if (kind == vector::CombiningKind::MINSI)
393         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
394             reductionOp, llvmType, operand);
395       else if (kind == vector::CombiningKind::MAXUI)
396         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
397             reductionOp, llvmType, operand);
398       else if (kind == vector::CombiningKind::MAXSI)
399         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
400             reductionOp, llvmType, operand);
401       else if (kind == vector::CombiningKind::AND)
402         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
403                                                              llvmType, operand);
404       else if (kind == vector::CombiningKind::OR)
405         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
406                                                             llvmType, operand);
407       else if (kind == vector::CombiningKind::XOR)
408         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
409                                                              llvmType, operand);
410       else
411         return failure();
412       return success();
413     }
414 
415     if (!eltType.isa<FloatType>())
416       return failure();
417 
418     // Floating-point reductions: add/mul/min/max
419     if (kind == vector::CombiningKind::ADD) {
420       // Optional accumulator (or zero).
421       Value acc = adaptor.getOperands().size() > 1
422                       ? adaptor.getOperands()[1]
423                       : rewriter.create<LLVM::ConstantOp>(
424                             reductionOp->getLoc(), llvmType,
425                             rewriter.getZeroAttr(eltType));
426       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
427           reductionOp, llvmType, acc, operand,
428           rewriter.getBoolAttr(reassociateFPReductions));
429     } else if (kind == vector::CombiningKind::MUL) {
430       // Optional accumulator (or one).
431       Value acc = adaptor.getOperands().size() > 1
432                       ? adaptor.getOperands()[1]
433                       : rewriter.create<LLVM::ConstantOp>(
434                             reductionOp->getLoc(), llvmType,
435                             rewriter.getFloatAttr(eltType, 1.0));
436       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
437           reductionOp, llvmType, acc, operand,
438           rewriter.getBoolAttr(reassociateFPReductions));
439     } else if (kind == vector::CombiningKind::MINF)
440       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
441       // NaNs/-0.0/+0.0 in the same way.
442       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
443                                                             llvmType, operand);
444     else if (kind == vector::CombiningKind::MAXF)
445       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
446       // NaNs/-0.0/+0.0 in the same way.
447       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
448                                                             llvmType, operand);
449     else
450       return failure();
451     return success();
452   }
453 
454 private:
455   const bool reassociateFPReductions;
456 };
457 
458 class VectorShuffleOpConversion
459     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
460 public:
461   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
462 
463   LogicalResult
464   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
465                   ConversionPatternRewriter &rewriter) const override {
466     auto loc = shuffleOp->getLoc();
467     auto v1Type = shuffleOp.getV1VectorType();
468     auto v2Type = shuffleOp.getV2VectorType();
469     auto vectorType = shuffleOp.getVectorType();
470     Type llvmType = typeConverter->convertType(vectorType);
471     auto maskArrayAttr = shuffleOp.getMask();
472 
473     // Bail if result type cannot be lowered.
474     if (!llvmType)
475       return failure();
476 
477     // Get rank and dimension sizes.
478     int64_t rank = vectorType.getRank();
479     assert(v1Type.getRank() == rank);
480     assert(v2Type.getRank() == rank);
481     int64_t v1Dim = v1Type.getDimSize(0);
482 
483     // For rank 1, where both operands have *exactly* the same vector type,
484     // there is direct shuffle support in LLVM. Use it!
485     if (rank == 1 && v1Type == v2Type) {
486       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
487           loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr);
488       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
489       return success();
490     }
491 
492     // For all other cases, insert the individual values individually.
493     Type eltType;
494     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
495       eltType = arrayType.getElementType();
496     else
497       eltType = llvmType.cast<VectorType>().getElementType();
498     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
499     int64_t insPos = 0;
500     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
501       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
502       Value value = adaptor.getV1();
503       if (extPos >= v1Dim) {
504         extPos -= v1Dim;
505         value = adaptor.getV2();
506       }
507       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
508                                  eltType, rank, extPos);
509       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
510                          llvmType, rank, insPos++);
511     }
512     rewriter.replaceOp(shuffleOp, insert);
513     return success();
514   }
515 };
516 
517 class VectorExtractElementOpConversion
518     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
519 public:
520   using ConvertOpToLLVMPattern<
521       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
522 
523   LogicalResult
524   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
525                   ConversionPatternRewriter &rewriter) const override {
526     auto vectorType = extractEltOp.getVectorType();
527     auto llvmType = typeConverter->convertType(vectorType.getElementType());
528 
529     // Bail if result type cannot be lowered.
530     if (!llvmType)
531       return failure();
532 
533     if (vectorType.getRank() == 0) {
534       Location loc = extractEltOp.getLoc();
535       auto idxType = rewriter.getIndexType();
536       auto zero = rewriter.create<LLVM::ConstantOp>(
537           loc, typeConverter->convertType(idxType),
538           rewriter.getIntegerAttr(idxType, 0));
539       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
540           extractEltOp, llvmType, adaptor.getVector(), zero);
541       return success();
542     }
543 
544     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
545         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
546     return success();
547   }
548 };
549 
550 class VectorExtractOpConversion
551     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
552 public:
553   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
554 
555   LogicalResult
556   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
557                   ConversionPatternRewriter &rewriter) const override {
558     auto loc = extractOp->getLoc();
559     auto vectorType = extractOp.getVectorType();
560     auto resultType = extractOp.getResult().getType();
561     auto llvmResultType = typeConverter->convertType(resultType);
562     auto positionArrayAttr = extractOp.getPosition();
563 
564     // Bail if result type cannot be lowered.
565     if (!llvmResultType)
566       return failure();
567 
568     // Extract entire vector. Should be handled by folder, but just to be safe.
569     if (positionArrayAttr.empty()) {
570       rewriter.replaceOp(extractOp, adaptor.getVector());
571       return success();
572     }
573 
574     // One-shot extraction of vector from array (only requires extractvalue).
575     if (resultType.isa<VectorType>()) {
576       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
577           loc, llvmResultType, adaptor.getVector(), positionArrayAttr);
578       rewriter.replaceOp(extractOp, extracted);
579       return success();
580     }
581 
582     // Potential extraction of 1-D vector from array.
583     auto *context = extractOp->getContext();
584     Value extracted = adaptor.getVector();
585     auto positionAttrs = positionArrayAttr.getValue();
586     if (positionAttrs.size() > 1) {
587       auto oneDVectorType = reducedVectorTypeBack(vectorType);
588       auto nMinusOnePositionAttrs =
589           ArrayAttr::get(context, positionAttrs.drop_back());
590       extracted = rewriter.create<LLVM::ExtractValueOp>(
591           loc, typeConverter->convertType(oneDVectorType), extracted,
592           nMinusOnePositionAttrs);
593     }
594 
595     // Remaining extraction of element from 1-D LLVM vector
596     auto position = positionAttrs.back().cast<IntegerAttr>();
597     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
598     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
599     extracted =
600         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
601     rewriter.replaceOp(extractOp, extracted);
602 
603     return success();
604   }
605 };
606 
607 /// Conversion pattern that turns a vector.fma on a 1-D vector
608 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
609 /// This does not match vectors of n >= 2 rank.
610 ///
611 /// Example:
612 /// ```
613 ///  vector.fma %a, %a, %a : vector<8xf32>
614 /// ```
615 /// is converted to:
616 /// ```
617 ///  llvm.intr.fmuladd %va, %va, %va:
618 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
619 ///    -> !llvm."<8 x f32>">
620 /// ```
621 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
622 public:
623   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
624 
625   LogicalResult
626   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
627                   ConversionPatternRewriter &rewriter) const override {
628     VectorType vType = fmaOp.getVectorType();
629     if (vType.getRank() != 1)
630       return failure();
631     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
632         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
633     return success();
634   }
635 };
636 
637 class VectorInsertElementOpConversion
638     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
639 public:
640   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
641 
642   LogicalResult
643   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
644                   ConversionPatternRewriter &rewriter) const override {
645     auto vectorType = insertEltOp.getDestVectorType();
646     auto llvmType = typeConverter->convertType(vectorType);
647 
648     // Bail if result type cannot be lowered.
649     if (!llvmType)
650       return failure();
651 
652     if (vectorType.getRank() == 0) {
653       Location loc = insertEltOp.getLoc();
654       auto idxType = rewriter.getIndexType();
655       auto zero = rewriter.create<LLVM::ConstantOp>(
656           loc, typeConverter->convertType(idxType),
657           rewriter.getIntegerAttr(idxType, 0));
658       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
659           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
660       return success();
661     }
662 
663     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
664         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
665         adaptor.getPosition());
666     return success();
667   }
668 };
669 
670 class VectorInsertOpConversion
671     : public ConvertOpToLLVMPattern<vector::InsertOp> {
672 public:
673   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
674 
675   LogicalResult
676   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
677                   ConversionPatternRewriter &rewriter) const override {
678     auto loc = insertOp->getLoc();
679     auto sourceType = insertOp.getSourceType();
680     auto destVectorType = insertOp.getDestVectorType();
681     auto llvmResultType = typeConverter->convertType(destVectorType);
682     auto positionArrayAttr = insertOp.getPosition();
683 
684     // Bail if result type cannot be lowered.
685     if (!llvmResultType)
686       return failure();
687 
688     // Overwrite entire vector with value. Should be handled by folder, but
689     // just to be safe.
690     if (positionArrayAttr.empty()) {
691       rewriter.replaceOp(insertOp, adaptor.getSource());
692       return success();
693     }
694 
695     // One-shot insertion of a vector into an array (only requires insertvalue).
696     if (sourceType.isa<VectorType>()) {
697       Value inserted = rewriter.create<LLVM::InsertValueOp>(
698           loc, llvmResultType, adaptor.getDest(), adaptor.getSource(),
699           positionArrayAttr);
700       rewriter.replaceOp(insertOp, inserted);
701       return success();
702     }
703 
704     // Potential extraction of 1-D vector from array.
705     auto *context = insertOp->getContext();
706     Value extracted = adaptor.getDest();
707     auto positionAttrs = positionArrayAttr.getValue();
708     auto position = positionAttrs.back().cast<IntegerAttr>();
709     auto oneDVectorType = destVectorType;
710     if (positionAttrs.size() > 1) {
711       oneDVectorType = reducedVectorTypeBack(destVectorType);
712       auto nMinusOnePositionAttrs =
713           ArrayAttr::get(context, positionAttrs.drop_back());
714       extracted = rewriter.create<LLVM::ExtractValueOp>(
715           loc, typeConverter->convertType(oneDVectorType), extracted,
716           nMinusOnePositionAttrs);
717     }
718 
719     // Insertion of an element into a 1-D LLVM vector.
720     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
721     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
722     Value inserted = rewriter.create<LLVM::InsertElementOp>(
723         loc, typeConverter->convertType(oneDVectorType), extracted,
724         adaptor.getSource(), constant);
725 
726     // Potential insertion of resulting 1-D vector into array.
727     if (positionAttrs.size() > 1) {
728       auto nMinusOnePositionAttrs =
729           ArrayAttr::get(context, positionAttrs.drop_back());
730       inserted = rewriter.create<LLVM::InsertValueOp>(
731           loc, llvmResultType, adaptor.getDest(), inserted,
732           nMinusOnePositionAttrs);
733     }
734 
735     rewriter.replaceOp(insertOp, inserted);
736     return success();
737   }
738 };
739 
740 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
741 ///
742 /// Example:
743 /// ```
744 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
745 /// ```
746 /// is rewritten into:
747 /// ```
748 ///  %r = splat %f0: vector<2x4xf32>
749 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
750 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
751 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
752 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
753 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
754 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
755 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
756 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
757 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
758 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
759 ///  // %r3 holds the final value.
760 /// ```
761 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
762 public:
763   using OpRewritePattern<FMAOp>::OpRewritePattern;
764 
765   void initialize() {
766     // This pattern recursively unpacks one dimension at a time. The recursion
767     // bounded as the rank is strictly decreasing.
768     setHasBoundedRewriteRecursion();
769   }
770 
771   LogicalResult matchAndRewrite(FMAOp op,
772                                 PatternRewriter &rewriter) const override {
773     auto vType = op.getVectorType();
774     if (vType.getRank() < 2)
775       return failure();
776 
777     auto loc = op.getLoc();
778     auto elemType = vType.getElementType();
779     Value zero = rewriter.create<arith::ConstantOp>(
780         loc, elemType, rewriter.getZeroAttr(elemType));
781     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
782     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
783       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
784       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
785       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
786       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
787       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
788     }
789     rewriter.replaceOp(op, desc);
790     return success();
791   }
792 };
793 
794 /// Returns the strides if the memory underlying `memRefType` has a contiguous
795 /// static layout.
796 static llvm::Optional<SmallVector<int64_t, 4>>
797 computeContiguousStrides(MemRefType memRefType) {
798   int64_t offset;
799   SmallVector<int64_t, 4> strides;
800   if (failed(getStridesAndOffset(memRefType, strides, offset)))
801     return None;
802   if (!strides.empty() && strides.back() != 1)
803     return None;
804   // If no layout or identity layout, this is contiguous by definition.
805   if (memRefType.getLayout().isIdentity())
806     return strides;
807 
808   // Otherwise, we must determine contiguity form shapes. This can only ever
809   // work in static cases because MemRefType is underspecified to represent
810   // contiguous dynamic shapes in other ways than with just empty/identity
811   // layout.
812   auto sizes = memRefType.getShape();
813   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
814     if (ShapedType::isDynamic(sizes[index + 1]) ||
815         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
816         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
817       return None;
818     if (strides[index] != strides[index + 1] * sizes[index + 1])
819       return None;
820   }
821   return strides;
822 }
823 
824 class VectorTypeCastOpConversion
825     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
826 public:
827   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
828 
829   LogicalResult
830   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
831                   ConversionPatternRewriter &rewriter) const override {
832     auto loc = castOp->getLoc();
833     MemRefType sourceMemRefType =
834         castOp.getOperand().getType().cast<MemRefType>();
835     MemRefType targetMemRefType = castOp.getType();
836 
837     // Only static shape casts supported atm.
838     if (!sourceMemRefType.hasStaticShape() ||
839         !targetMemRefType.hasStaticShape())
840       return failure();
841 
842     auto llvmSourceDescriptorTy =
843         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
844     if (!llvmSourceDescriptorTy)
845       return failure();
846     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
847 
848     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
849                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
850     if (!llvmTargetDescriptorTy)
851       return failure();
852 
853     // Only contiguous source buffers supported atm.
854     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
855     if (!sourceStrides)
856       return failure();
857     auto targetStrides = computeContiguousStrides(targetMemRefType);
858     if (!targetStrides)
859       return failure();
860     // Only support static strides for now, regardless of contiguity.
861     if (llvm::any_of(*targetStrides, [](int64_t stride) {
862           return ShapedType::isDynamicStrideOrOffset(stride);
863         }))
864       return failure();
865 
866     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
867 
868     // Create descriptor.
869     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
870     Type llvmTargetElementTy = desc.getElementPtrType();
871     // Set allocated ptr.
872     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
873     allocated =
874         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
875     desc.setAllocatedPtr(rewriter, loc, allocated);
876     // Set aligned ptr.
877     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
878     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
879     desc.setAlignedPtr(rewriter, loc, ptr);
880     // Fill offset 0.
881     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
882     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
883     desc.setOffset(rewriter, loc, zero);
884 
885     // Fill size and stride descriptors in memref.
886     for (const auto &indexedSize :
887          llvm::enumerate(targetMemRefType.getShape())) {
888       int64_t index = indexedSize.index();
889       auto sizeAttr =
890           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
891       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
892       desc.setSize(rewriter, loc, index, size);
893       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
894                                                 (*targetStrides)[index]);
895       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
896       desc.setStride(rewriter, loc, index, stride);
897     }
898 
899     rewriter.replaceOp(castOp, {desc});
900     return success();
901   }
902 };
903 
904 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
905 /// Non-scalable versions of this operation are handled in Vector Transforms.
906 class VectorCreateMaskOpRewritePattern
907     : public OpRewritePattern<vector::CreateMaskOp> {
908 public:
909   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
910                                             bool enableIndexOpt)
911       : OpRewritePattern<vector::CreateMaskOp>(context),
912         indexOptimizations(enableIndexOpt) {}
913 
914   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
915                                 PatternRewriter &rewriter) const override {
916     auto dstType = op.getType();
917     if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
918       return failure();
919     IntegerType idxType =
920         indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
921     auto loc = op->getLoc();
922     Value indices = rewriter.create<LLVM::StepVectorOp>(
923         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
924                                  /*isScalable=*/true));
925     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
926                                                  op.getOperand(0));
927     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
928     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
929                                                 indices, bounds);
930     rewriter.replaceOp(op, comp);
931     return success();
932   }
933 
934 private:
935   const bool indexOptimizations;
936 };
937 
938 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
939 public:
940   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
941 
942   // Proof-of-concept lowering implementation that relies on a small
943   // runtime support library, which only needs to provide a few
944   // printing methods (single value for all data types, opening/closing
945   // bracket, comma, newline). The lowering fully unrolls a vector
946   // in terms of these elementary printing operations. The advantage
947   // of this approach is that the library can remain unaware of all
948   // low-level implementation details of vectors while still supporting
949   // output of any shaped and dimensioned vector. Due to full unrolling,
950   // this approach is less suited for very large vectors though.
951   //
952   // TODO: rely solely on libc in future? something else?
953   //
954   LogicalResult
955   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
956                   ConversionPatternRewriter &rewriter) const override {
957     Type printType = printOp.getPrintType();
958 
959     if (typeConverter->convertType(printType) == nullptr)
960       return failure();
961 
962     // Make sure element type has runtime support.
963     PrintConversion conversion = PrintConversion::None;
964     VectorType vectorType = printType.dyn_cast<VectorType>();
965     Type eltType = vectorType ? vectorType.getElementType() : printType;
966     Operation *printer;
967     if (eltType.isF32()) {
968       printer =
969           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
970     } else if (eltType.isF64()) {
971       printer =
972           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
973     } else if (eltType.isIndex()) {
974       printer =
975           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
976     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
977       // Integers need a zero or sign extension on the operand
978       // (depending on the source type) as well as a signed or
979       // unsigned print method. Up to 64-bit is supported.
980       unsigned width = intTy.getWidth();
981       if (intTy.isUnsigned()) {
982         if (width <= 64) {
983           if (width < 64)
984             conversion = PrintConversion::ZeroExt64;
985           printer = LLVM::lookupOrCreatePrintU64Fn(
986               printOp->getParentOfType<ModuleOp>());
987         } else {
988           return failure();
989         }
990       } else {
991         assert(intTy.isSignless() || intTy.isSigned());
992         if (width <= 64) {
993           // Note that we *always* zero extend booleans (1-bit integers),
994           // so that true/false is printed as 1/0 rather than -1/0.
995           if (width == 1)
996             conversion = PrintConversion::ZeroExt64;
997           else if (width < 64)
998             conversion = PrintConversion::SignExt64;
999           printer = LLVM::lookupOrCreatePrintI64Fn(
1000               printOp->getParentOfType<ModuleOp>());
1001         } else {
1002           return failure();
1003         }
1004       }
1005     } else {
1006       return failure();
1007     }
1008 
1009     // Unroll vector into elementary print calls.
1010     int64_t rank = vectorType ? vectorType.getRank() : 0;
1011     Type type = vectorType ? vectorType : eltType;
1012     emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1013               conversion);
1014     emitCall(rewriter, printOp->getLoc(),
1015              LLVM::lookupOrCreatePrintNewlineFn(
1016                  printOp->getParentOfType<ModuleOp>()));
1017     rewriter.eraseOp(printOp);
1018     return success();
1019   }
1020 
1021 private:
1022   enum class PrintConversion {
1023     // clang-format off
1024     None,
1025     ZeroExt64,
1026     SignExt64
1027     // clang-format on
1028   };
1029 
1030   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1031                  Value value, Type type, Operation *printer, int64_t rank,
1032                  PrintConversion conversion) const {
1033     VectorType vectorType = type.dyn_cast<VectorType>();
1034     Location loc = op->getLoc();
1035     if (!vectorType) {
1036       assert(rank == 0 && "The scalar case expects rank == 0");
1037       switch (conversion) {
1038       case PrintConversion::ZeroExt64:
1039         value = rewriter.create<arith::ExtUIOp>(
1040             loc, IntegerType::get(rewriter.getContext(), 64), value);
1041         break;
1042       case PrintConversion::SignExt64:
1043         value = rewriter.create<arith::ExtSIOp>(
1044             loc, IntegerType::get(rewriter.getContext(), 64), value);
1045         break;
1046       case PrintConversion::None:
1047         break;
1048       }
1049       emitCall(rewriter, loc, printer, value);
1050       return;
1051     }
1052 
1053     emitCall(rewriter, loc,
1054              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1055     Operation *printComma =
1056         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1057 
1058     if (rank <= 1) {
1059       auto reducedType = vectorType.getElementType();
1060       auto llvmType = typeConverter->convertType(reducedType);
1061       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1062       for (int64_t d = 0; d < dim; ++d) {
1063         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1064                                      llvmType, /*rank=*/0, /*pos=*/d);
1065         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1066                   conversion);
1067         if (d != dim - 1)
1068           emitCall(rewriter, loc, printComma);
1069       }
1070       emitCall(
1071           rewriter, loc,
1072           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1073       return;
1074     }
1075 
1076     int64_t dim = vectorType.getDimSize(0);
1077     for (int64_t d = 0; d < dim; ++d) {
1078       auto reducedType = reducedVectorTypeFront(vectorType);
1079       auto llvmType = typeConverter->convertType(reducedType);
1080       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1081                                    llvmType, rank, d);
1082       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1083                 conversion);
1084       if (d != dim - 1)
1085         emitCall(rewriter, loc, printComma);
1086     }
1087     emitCall(rewriter, loc,
1088              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1089   }
1090 
1091   // Helper to emit a call.
1092   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1093                        Operation *ref, ValueRange params = ValueRange()) {
1094     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1095                                   params);
1096   }
1097 };
1098 
1099 /// The Splat operation is lowered to an insertelement + a shufflevector
1100 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1101 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1102   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1103 
1104   LogicalResult
1105   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1106                   ConversionPatternRewriter &rewriter) const override {
1107     VectorType resultType = splatOp.getType().cast<VectorType>();
1108     if (resultType.getRank() > 1)
1109       return failure();
1110 
1111     // First insert it into an undef vector so we can shuffle it.
1112     auto vectorType = typeConverter->convertType(splatOp.getType());
1113     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1114     auto zero = rewriter.create<LLVM::ConstantOp>(
1115         splatOp.getLoc(),
1116         typeConverter->convertType(rewriter.getIntegerType(32)),
1117         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1118 
1119     // For 0-d vector, we simply do `insertelement`.
1120     if (resultType.getRank() == 0) {
1121       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1122           splatOp, vectorType, undef, adaptor.getInput(), zero);
1123       return success();
1124     }
1125 
1126     // For 1-d vector, we additionally do a `vectorshuffle`.
1127     auto v = rewriter.create<LLVM::InsertElementOp>(
1128         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1129 
1130     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1131     SmallVector<int32_t, 4> zeroValues(width, 0);
1132 
1133     // Shuffle the value across the desired number of elements.
1134     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1135     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1136                                                        zeroAttrs);
1137     return success();
1138   }
1139 };
1140 
1141 /// The Splat operation is lowered to an insertelement + a shufflevector
1142 /// operation. Splat to only 2+-d vector result types are lowered by the
1143 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1144 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1145   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1146 
1147   LogicalResult
1148   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1149                   ConversionPatternRewriter &rewriter) const override {
1150     VectorType resultType = splatOp.getType();
1151     if (resultType.getRank() <= 1)
1152       return failure();
1153 
1154     // First insert it into an undef vector so we can shuffle it.
1155     auto loc = splatOp.getLoc();
1156     auto vectorTypeInfo =
1157         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1158     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1159     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1160     if (!llvmNDVectorTy || !llvm1DVectorTy)
1161       return failure();
1162 
1163     // Construct returned value.
1164     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1165 
1166     // Construct a 1-D vector with the splatted value that we insert in all the
1167     // places within the returned descriptor.
1168     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1169     auto zero = rewriter.create<LLVM::ConstantOp>(
1170         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1171         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1172     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1173                                                      adaptor.getInput(), zero);
1174 
1175     // Shuffle the value across the desired number of elements.
1176     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1177     SmallVector<int32_t, 4> zeroValues(width, 0);
1178     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1179     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
1180 
1181     // Iterate of linear index, convert to coords space and insert splatted 1-D
1182     // vector in each position.
1183     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
1184       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
1185                                                   position);
1186     });
1187     rewriter.replaceOp(splatOp, desc);
1188     return success();
1189   }
1190 };
1191 
1192 } // namespace
1193 
1194 /// Populate the given list with patterns that convert from Vector to LLVM.
1195 void mlir::populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
1196                                                   RewritePatternSet &patterns,
1197                                                   bool reassociateFPReductions,
1198                                                   bool indexOptimizations) {
1199   MLIRContext *ctx = converter.getDialect()->getContext();
1200   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1201   populateVectorInsertExtractStridedSliceTransforms(patterns);
1202   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1203   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, indexOptimizations);
1204   patterns
1205       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1206            VectorExtractElementOpConversion, VectorExtractOpConversion,
1207            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1208            VectorInsertOpConversion, VectorPrintOpConversion,
1209            VectorTypeCastOpConversion, VectorScaleOpConversion,
1210            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1211            VectorLoadStoreConversion<vector::MaskedLoadOp,
1212                                      vector::MaskedLoadOpAdaptor>,
1213            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1214            VectorLoadStoreConversion<vector::MaskedStoreOp,
1215                                      vector::MaskedStoreOpAdaptor>,
1216            VectorGatherOpConversion, VectorScatterOpConversion,
1217            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1218            VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
1219   // Transfer ops with rank > 1 are handled by VectorToSCF.
1220   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1221 }
1222 
1223 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1224     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1225   patterns.add<VectorMatmulOpConversion>(converter);
1226   patterns.add<VectorFlatTransposeOpConversion>(converter);
1227 }
1228