xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision de54bcc54c6147d90f11f70b6b53f84e62b1e74a)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 
27 // Helper to reduce vector type by one rank at front.
28 static VectorType reducedVectorTypeFront(VectorType tp) {
29   assert((tp.getRank() > 1) && "unlowerable vector type");
30   unsigned numScalableDims = tp.getNumScalableDims();
31   if (tp.getShape().size() == numScalableDims)
32     --numScalableDims;
33   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
34                          numScalableDims);
35 }
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39   assert((tp.getRank() > 1) && "unlowerable vector type");
40   unsigned numScalableDims = tp.getNumScalableDims();
41   if (numScalableDims > 0)
42     --numScalableDims;
43   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
44                          numScalableDims);
45 }
46 
47 // Helper that picks the proper sequence for inserting.
48 static Value insertOne(ConversionPatternRewriter &rewriter,
49                        LLVMTypeConverter &typeConverter, Location loc,
50                        Value val1, Value val2, Type llvmType, int64_t rank,
51                        int64_t pos) {
52   assert(rank > 0 && "0-D vector corner case should have been handled already");
53   if (rank == 1) {
54     auto idxType = rewriter.getIndexType();
55     auto constant = rewriter.create<LLVM::ConstantOp>(
56         loc, typeConverter.convertType(idxType),
57         rewriter.getIntegerAttr(idxType, pos));
58     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
59                                                   constant);
60   }
61   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank <= 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
77 }
78 
79 // Helper that returns data layout alignment of a memref.
80 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
81                                  MemRefType memrefType, unsigned &align) {
82   Type elementTy = typeConverter.convertType(memrefType.getElementType());
83   if (!elementTy)
84     return failure();
85 
86   // TODO: this should use the MLIR data layout when it becomes available and
87   // stop depending on translation.
88   llvm::LLVMContext llvmContext;
89   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
90               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
91   return success();
92 }
93 
94 // Add an index vector component to a base pointer. This almost always succeeds
95 // unless the last stride is non-unit or the memory space is not zero.
96 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
97                                     Location loc, Value memref, Value base,
98                                     Value index, MemRefType memRefType,
99                                     VectorType vType, Value &ptrs) {
100   int64_t offset;
101   SmallVector<int64_t, 4> strides;
102   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
103   if (failed(successStrides) || strides.back() != 1 ||
104       memRefType.getMemorySpaceAsInt() != 0)
105     return failure();
106   auto pType = MemRefDescriptor(memref).getElementPtrType();
107   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
108   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
109   return success();
110 }
111 
112 // Casts a strided element pointer to a vector pointer.  The vector pointer
113 // will be in the same address space as the incoming memref type.
114 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
115                          Value ptr, MemRefType memRefType, Type vt) {
116   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
117   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
118 }
119 
120 namespace {
121 
122 /// Trivial Vector to LLVM conversions
123 using VectorScaleOpConversion =
124     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
125 
126 /// Conversion pattern for a vector.bitcast.
127 class VectorBitCastOpConversion
128     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
129 public:
130   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
131 
132   LogicalResult
133   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
134                   ConversionPatternRewriter &rewriter) const override {
135     // Only 0-D and 1-D vectors can be lowered to LLVM.
136     VectorType resultTy = bitCastOp.getResultVectorType();
137     if (resultTy.getRank() > 1)
138       return failure();
139     Type newResultTy = typeConverter->convertType(resultTy);
140     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
141                                                  adaptor.getOperands()[0]);
142     return success();
143   }
144 };
145 
146 /// Conversion pattern for a vector.matrix_multiply.
147 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
148 class VectorMatmulOpConversion
149     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
150 public:
151   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
152 
153   LogicalResult
154   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
155                   ConversionPatternRewriter &rewriter) const override {
156     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
157         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
158         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
159         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
160     return success();
161   }
162 };
163 
164 /// Conversion pattern for a vector.flat_transpose.
165 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
166 class VectorFlatTransposeOpConversion
167     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
168 public:
169   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
170 
171   LogicalResult
172   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
173                   ConversionPatternRewriter &rewriter) const override {
174     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
175         transOp, typeConverter->convertType(transOp.getRes().getType()),
176         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
177     return success();
178   }
179 };
180 
181 /// Overloaded utility that replaces a vector.load, vector.store,
182 /// vector.maskedload and vector.maskedstore with their respective LLVM
183 /// couterparts.
184 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
185                                  vector::LoadOpAdaptor adaptor,
186                                  VectorType vectorTy, Value ptr, unsigned align,
187                                  ConversionPatternRewriter &rewriter) {
188   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
189 }
190 
191 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
192                                  vector::MaskedLoadOpAdaptor adaptor,
193                                  VectorType vectorTy, Value ptr, unsigned align,
194                                  ConversionPatternRewriter &rewriter) {
195   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
196       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
197 }
198 
199 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
200                                  vector::StoreOpAdaptor adaptor,
201                                  VectorType vectorTy, Value ptr, unsigned align,
202                                  ConversionPatternRewriter &rewriter) {
203   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
204                                              ptr, align);
205 }
206 
207 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
208                                  vector::MaskedStoreOpAdaptor adaptor,
209                                  VectorType vectorTy, Value ptr, unsigned align,
210                                  ConversionPatternRewriter &rewriter) {
211   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
212       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
213 }
214 
215 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
216 /// vector.maskedstore.
217 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
218 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
219 public:
220   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
221 
222   LogicalResult
223   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
224                   typename LoadOrStoreOp::Adaptor adaptor,
225                   ConversionPatternRewriter &rewriter) const override {
226     // Only 1-D vectors can be lowered to LLVM.
227     VectorType vectorTy = loadOrStoreOp.getVectorType();
228     if (vectorTy.getRank() > 1)
229       return failure();
230 
231     auto loc = loadOrStoreOp->getLoc();
232     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
233 
234     // Resolve alignment.
235     unsigned align;
236     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
237       return failure();
238 
239     // Resolve address.
240     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
241                      .template cast<VectorType>();
242     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
243                                                adaptor.getIndices(), rewriter);
244     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
245 
246     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
247     return success();
248   }
249 };
250 
251 /// Conversion pattern for a vector.gather.
252 class VectorGatherOpConversion
253     : public ConvertOpToLLVMPattern<vector::GatherOp> {
254 public:
255   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
256 
257   LogicalResult
258   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
259                   ConversionPatternRewriter &rewriter) const override {
260     auto loc = gather->getLoc();
261     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
262     assert(memRefType && "The base should be bufferized");
263 
264     // Resolve alignment.
265     unsigned align;
266     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
267       return failure();
268 
269     // Resolve address.
270     Value ptrs;
271     VectorType vType = gather.getVectorType();
272     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
273                                      adaptor.getIndices(), rewriter);
274     if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
275                               adaptor.getIndexVec(), memRefType, vType, ptrs)))
276       return failure();
277 
278     // Replace with the gather intrinsic.
279     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
280         gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
281         adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
282     return success();
283   }
284 };
285 
286 /// Conversion pattern for a vector.scatter.
287 class VectorScatterOpConversion
288     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
289 public:
290   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
291 
292   LogicalResult
293   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
294                   ConversionPatternRewriter &rewriter) const override {
295     auto loc = scatter->getLoc();
296     MemRefType memRefType = scatter.getMemRefType();
297 
298     // Resolve alignment.
299     unsigned align;
300     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
301       return failure();
302 
303     // Resolve address.
304     Value ptrs;
305     VectorType vType = scatter.getVectorType();
306     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
307                                      adaptor.getIndices(), rewriter);
308     if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
309                               adaptor.getIndexVec(), memRefType, vType, ptrs)))
310       return failure();
311 
312     // Replace with the scatter intrinsic.
313     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
314         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
315         rewriter.getI32IntegerAttr(align));
316     return success();
317   }
318 };
319 
320 /// Conversion pattern for a vector.expandload.
321 class VectorExpandLoadOpConversion
322     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
323 public:
324   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
325 
326   LogicalResult
327   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override {
329     auto loc = expand->getLoc();
330     MemRefType memRefType = expand.getMemRefType();
331 
332     // Resolve address.
333     auto vtype = typeConverter->convertType(expand.getVectorType());
334     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
335                                      adaptor.getIndices(), rewriter);
336 
337     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
338         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
339     return success();
340   }
341 };
342 
343 /// Conversion pattern for a vector.compressstore.
344 class VectorCompressStoreOpConversion
345     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
346 public:
347   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
348 
349   LogicalResult
350   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
351                   ConversionPatternRewriter &rewriter) const override {
352     auto loc = compress->getLoc();
353     MemRefType memRefType = compress.getMemRefType();
354 
355     // Resolve address.
356     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
357                                      adaptor.getIndices(), rewriter);
358 
359     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
360         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
361     return success();
362   }
363 };
364 
365 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
366 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
367 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
368 /// non-null.
369 template <class VectorOp, class ScalarOp>
370 static Value createIntegerReductionArithmeticOpLowering(
371     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
372     Value vectorOperand, Value accumulator) {
373   Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
374   if (accumulator)
375     result = rewriter.create<ScalarOp>(loc, accumulator, result);
376   return result;
377 }
378 
379 /// Helper method to lower a `vector.reduction` operation that performs
380 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
381 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
382 /// the accumulator value if non-null.
383 template <class VectorOp>
384 static Value createIntegerReductionComparisonOpLowering(
385     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
386     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
387   Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
388   if (accumulator) {
389     Value cmp =
390         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
391     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
392   }
393   return result;
394 }
395 
396 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
397 /// with vector types.
398 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
399                            Value rhs, bool isMin) {
400   auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
401   Type i1Type = builder.getI1Type();
402   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
403     i1Type = VectorType::get(vecType.getShape(), i1Type);
404   Value cmp = builder.create<LLVM::FCmpOp>(
405       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
406       lhs, rhs);
407   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
408   Value isNan = builder.create<LLVM::FCmpOp>(
409       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
410   Value nan = builder.create<LLVM::ConstantOp>(
411       loc, lhs.getType(),
412       builder.getFloatAttr(floatType,
413                            APFloat::getQNaN(floatType.getFloatSemantics())));
414   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
415 }
416 
417 /// Conversion pattern for all vector reductions.
418 class VectorReductionOpConversion
419     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
420 public:
421   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
422                                        bool reassociateFPRed)
423       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
424         reassociateFPReductions(reassociateFPRed) {}
425 
426   LogicalResult
427   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
428                   ConversionPatternRewriter &rewriter) const override {
429     auto kind = reductionOp.getKind();
430     Type eltType = reductionOp.getDest().getType();
431     Type llvmType = typeConverter->convertType(eltType);
432     Value operand = adaptor.getVector();
433     Value acc = adaptor.getAcc();
434     Location loc = reductionOp.getLoc();
435     if (eltType.isIntOrIndex()) {
436       // Integer reductions: add/mul/min/max/and/or/xor.
437       Value result;
438       switch (kind) {
439       case vector::CombiningKind::ADD:
440         result =
441             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
442                                                        LLVM::AddOp>(
443                 rewriter, loc, llvmType, operand, acc);
444         break;
445       case vector::CombiningKind::MUL:
446         result =
447             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
448                                                        LLVM::MulOp>(
449                 rewriter, loc, llvmType, operand, acc);
450         break;
451       case vector::CombiningKind::MINUI:
452         result = createIntegerReductionComparisonOpLowering<
453             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
454                                       LLVM::ICmpPredicate::ule);
455         break;
456       case vector::CombiningKind::MINSI:
457         result = createIntegerReductionComparisonOpLowering<
458             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
459                                       LLVM::ICmpPredicate::sle);
460         break;
461       case vector::CombiningKind::MAXUI:
462         result = createIntegerReductionComparisonOpLowering<
463             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
464                                       LLVM::ICmpPredicate::uge);
465         break;
466       case vector::CombiningKind::MAXSI:
467         result = createIntegerReductionComparisonOpLowering<
468             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
469                                       LLVM::ICmpPredicate::sge);
470         break;
471       case vector::CombiningKind::AND:
472         result =
473             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
474                                                        LLVM::AndOp>(
475                 rewriter, loc, llvmType, operand, acc);
476         break;
477       case vector::CombiningKind::OR:
478         result =
479             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
480                                                        LLVM::OrOp>(
481                 rewriter, loc, llvmType, operand, acc);
482         break;
483       case vector::CombiningKind::XOR:
484         result =
485             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
486                                                        LLVM::XOrOp>(
487                 rewriter, loc, llvmType, operand, acc);
488         break;
489       default:
490         return failure();
491       }
492       rewriter.replaceOp(reductionOp, result);
493 
494       return success();
495     }
496 
497     if (!eltType.isa<FloatType>())
498       return failure();
499 
500     // Floating-point reductions: add/mul/min/max
501     if (kind == vector::CombiningKind::ADD) {
502       // Optional accumulator (or zero).
503       Value acc = adaptor.getOperands().size() > 1
504                       ? adaptor.getOperands()[1]
505                       : rewriter.create<LLVM::ConstantOp>(
506                             reductionOp->getLoc(), llvmType,
507                             rewriter.getZeroAttr(eltType));
508       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
509           reductionOp, llvmType, acc, operand,
510           rewriter.getBoolAttr(reassociateFPReductions));
511     } else if (kind == vector::CombiningKind::MUL) {
512       // Optional accumulator (or one).
513       Value acc = adaptor.getOperands().size() > 1
514                       ? adaptor.getOperands()[1]
515                       : rewriter.create<LLVM::ConstantOp>(
516                             reductionOp->getLoc(), llvmType,
517                             rewriter.getFloatAttr(eltType, 1.0));
518       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
519           reductionOp, llvmType, acc, operand,
520           rewriter.getBoolAttr(reassociateFPReductions));
521     } else if (kind == vector::CombiningKind::MINF) {
522       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
523       // NaNs/-0.0/+0.0 in the same way.
524       Value result =
525           rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
526       if (acc)
527         result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
528       rewriter.replaceOp(reductionOp, result);
529     } else if (kind == vector::CombiningKind::MAXF) {
530       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
531       // NaNs/-0.0/+0.0 in the same way.
532       Value result =
533           rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
534       if (acc)
535         result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
536       rewriter.replaceOp(reductionOp, result);
537     } else
538       return failure();
539 
540     return success();
541   }
542 
543 private:
544   const bool reassociateFPReductions;
545 };
546 
547 class VectorShuffleOpConversion
548     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
549 public:
550   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
551 
552   LogicalResult
553   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
554                   ConversionPatternRewriter &rewriter) const override {
555     auto loc = shuffleOp->getLoc();
556     auto v1Type = shuffleOp.getV1VectorType();
557     auto v2Type = shuffleOp.getV2VectorType();
558     auto vectorType = shuffleOp.getVectorType();
559     Type llvmType = typeConverter->convertType(vectorType);
560     auto maskArrayAttr = shuffleOp.getMask();
561 
562     // Bail if result type cannot be lowered.
563     if (!llvmType)
564       return failure();
565 
566     // Get rank and dimension sizes.
567     int64_t rank = vectorType.getRank();
568     assert(v1Type.getRank() == rank);
569     assert(v2Type.getRank() == rank);
570     int64_t v1Dim = v1Type.getDimSize(0);
571 
572     // For rank 1, where both operands have *exactly* the same vector type,
573     // there is direct shuffle support in LLVM. Use it!
574     if (rank == 1 && v1Type == v2Type) {
575       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
576           loc, adaptor.getV1(), adaptor.getV2(),
577           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
578       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
579       return success();
580     }
581 
582     // For all other cases, insert the individual values individually.
583     Type eltType;
584     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
585       eltType = arrayType.getElementType();
586     else
587       eltType = llvmType.cast<VectorType>().getElementType();
588     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
589     int64_t insPos = 0;
590     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
591       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
592       Value value = adaptor.getV1();
593       if (extPos >= v1Dim) {
594         extPos -= v1Dim;
595         value = adaptor.getV2();
596       }
597       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
598                                  eltType, rank, extPos);
599       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
600                          llvmType, rank, insPos++);
601     }
602     rewriter.replaceOp(shuffleOp, insert);
603     return success();
604   }
605 };
606 
607 class VectorExtractElementOpConversion
608     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
609 public:
610   using ConvertOpToLLVMPattern<
611       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
612 
613   LogicalResult
614   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
615                   ConversionPatternRewriter &rewriter) const override {
616     auto vectorType = extractEltOp.getVectorType();
617     auto llvmType = typeConverter->convertType(vectorType.getElementType());
618 
619     // Bail if result type cannot be lowered.
620     if (!llvmType)
621       return failure();
622 
623     if (vectorType.getRank() == 0) {
624       Location loc = extractEltOp.getLoc();
625       auto idxType = rewriter.getIndexType();
626       auto zero = rewriter.create<LLVM::ConstantOp>(
627           loc, typeConverter->convertType(idxType),
628           rewriter.getIntegerAttr(idxType, 0));
629       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
630           extractEltOp, llvmType, adaptor.getVector(), zero);
631       return success();
632     }
633 
634     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
635         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
636     return success();
637   }
638 };
639 
640 class VectorExtractOpConversion
641     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
642 public:
643   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
644 
645   LogicalResult
646   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
647                   ConversionPatternRewriter &rewriter) const override {
648     auto loc = extractOp->getLoc();
649     auto resultType = extractOp.getResult().getType();
650     auto llvmResultType = typeConverter->convertType(resultType);
651     auto positionArrayAttr = extractOp.getPosition();
652 
653     // Bail if result type cannot be lowered.
654     if (!llvmResultType)
655       return failure();
656 
657     // Extract entire vector. Should be handled by folder, but just to be safe.
658     if (positionArrayAttr.empty()) {
659       rewriter.replaceOp(extractOp, adaptor.getVector());
660       return success();
661     }
662 
663     // One-shot extraction of vector from array (only requires extractvalue).
664     if (resultType.isa<VectorType>()) {
665       SmallVector<int64_t> indices;
666       for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
667         indices.push_back(idx.getInt());
668       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
669           loc, adaptor.getVector(), indices);
670       rewriter.replaceOp(extractOp, extracted);
671       return success();
672     }
673 
674     // Potential extraction of 1-D vector from array.
675     Value extracted = adaptor.getVector();
676     auto positionAttrs = positionArrayAttr.getValue();
677     if (positionAttrs.size() > 1) {
678       SmallVector<int64_t> nMinusOnePosition;
679       for (auto idx : positionAttrs.drop_back())
680         nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt());
681       extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
682                                                         nMinusOnePosition);
683     }
684 
685     // Remaining extraction of element from 1-D LLVM vector
686     auto position = positionAttrs.back().cast<IntegerAttr>();
687     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
688     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
689     extracted =
690         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
691     rewriter.replaceOp(extractOp, extracted);
692 
693     return success();
694   }
695 };
696 
697 /// Conversion pattern that turns a vector.fma on a 1-D vector
698 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
699 /// This does not match vectors of n >= 2 rank.
700 ///
701 /// Example:
702 /// ```
703 ///  vector.fma %a, %a, %a : vector<8xf32>
704 /// ```
705 /// is converted to:
706 /// ```
707 ///  llvm.intr.fmuladd %va, %va, %va:
708 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
709 ///    -> !llvm."<8 x f32>">
710 /// ```
711 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
712 public:
713   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
714 
715   LogicalResult
716   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
717                   ConversionPatternRewriter &rewriter) const override {
718     VectorType vType = fmaOp.getVectorType();
719     if (vType.getRank() != 1)
720       return failure();
721     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
722         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
723     return success();
724   }
725 };
726 
727 class VectorInsertElementOpConversion
728     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
729 public:
730   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
731 
732   LogicalResult
733   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
734                   ConversionPatternRewriter &rewriter) const override {
735     auto vectorType = insertEltOp.getDestVectorType();
736     auto llvmType = typeConverter->convertType(vectorType);
737 
738     // Bail if result type cannot be lowered.
739     if (!llvmType)
740       return failure();
741 
742     if (vectorType.getRank() == 0) {
743       Location loc = insertEltOp.getLoc();
744       auto idxType = rewriter.getIndexType();
745       auto zero = rewriter.create<LLVM::ConstantOp>(
746           loc, typeConverter->convertType(idxType),
747           rewriter.getIntegerAttr(idxType, 0));
748       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
749           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
750       return success();
751     }
752 
753     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
754         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
755         adaptor.getPosition());
756     return success();
757   }
758 };
759 
760 class VectorInsertOpConversion
761     : public ConvertOpToLLVMPattern<vector::InsertOp> {
762 public:
763   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
764 
765   LogicalResult
766   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
767                   ConversionPatternRewriter &rewriter) const override {
768     auto loc = insertOp->getLoc();
769     auto sourceType = insertOp.getSourceType();
770     auto destVectorType = insertOp.getDestVectorType();
771     auto llvmResultType = typeConverter->convertType(destVectorType);
772     auto positionArrayAttr = insertOp.getPosition();
773 
774     // Bail if result type cannot be lowered.
775     if (!llvmResultType)
776       return failure();
777 
778     // Overwrite entire vector with value. Should be handled by folder, but
779     // just to be safe.
780     if (positionArrayAttr.empty()) {
781       rewriter.replaceOp(insertOp, adaptor.getSource());
782       return success();
783     }
784 
785     // One-shot insertion of a vector into an array (only requires insertvalue).
786     if (sourceType.isa<VectorType>()) {
787       Value inserted = rewriter.create<LLVM::InsertValueOp>(
788           loc, adaptor.getDest(), adaptor.getSource(),
789           LLVM::convertArrayToIndices(positionArrayAttr));
790       rewriter.replaceOp(insertOp, inserted);
791       return success();
792     }
793 
794     // Potential extraction of 1-D vector from array.
795     Value extracted = adaptor.getDest();
796     auto positionAttrs = positionArrayAttr.getValue();
797     auto position = positionAttrs.back().cast<IntegerAttr>();
798     auto oneDVectorType = destVectorType;
799     if (positionAttrs.size() > 1) {
800       oneDVectorType = reducedVectorTypeBack(destVectorType);
801       extracted = rewriter.create<LLVM::ExtractValueOp>(
802           loc, extracted,
803           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
804     }
805 
806     // Insertion of an element into a 1-D LLVM vector.
807     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
808     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
809     Value inserted = rewriter.create<LLVM::InsertElementOp>(
810         loc, typeConverter->convertType(oneDVectorType), extracted,
811         adaptor.getSource(), constant);
812 
813     // Potential insertion of resulting 1-D vector into array.
814     if (positionAttrs.size() > 1) {
815       inserted = rewriter.create<LLVM::InsertValueOp>(
816           loc, adaptor.getDest(), inserted,
817           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
818     }
819 
820     rewriter.replaceOp(insertOp, inserted);
821     return success();
822   }
823 };
824 
825 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
826 ///
827 /// Example:
828 /// ```
829 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
830 /// ```
831 /// is rewritten into:
832 /// ```
833 ///  %r = splat %f0: vector<2x4xf32>
834 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
835 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
836 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
837 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
838 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
839 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
840 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
841 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
842 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
843 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
844 ///  // %r3 holds the final value.
845 /// ```
846 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
847 public:
848   using OpRewritePattern<FMAOp>::OpRewritePattern;
849 
850   void initialize() {
851     // This pattern recursively unpacks one dimension at a time. The recursion
852     // bounded as the rank is strictly decreasing.
853     setHasBoundedRewriteRecursion();
854   }
855 
856   LogicalResult matchAndRewrite(FMAOp op,
857                                 PatternRewriter &rewriter) const override {
858     auto vType = op.getVectorType();
859     if (vType.getRank() < 2)
860       return failure();
861 
862     auto loc = op.getLoc();
863     auto elemType = vType.getElementType();
864     Value zero = rewriter.create<arith::ConstantOp>(
865         loc, elemType, rewriter.getZeroAttr(elemType));
866     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
867     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
868       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
869       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
870       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
871       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
872       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
873     }
874     rewriter.replaceOp(op, desc);
875     return success();
876   }
877 };
878 
879 /// Returns the strides if the memory underlying `memRefType` has a contiguous
880 /// static layout.
881 static llvm::Optional<SmallVector<int64_t, 4>>
882 computeContiguousStrides(MemRefType memRefType) {
883   int64_t offset;
884   SmallVector<int64_t, 4> strides;
885   if (failed(getStridesAndOffset(memRefType, strides, offset)))
886     return None;
887   if (!strides.empty() && strides.back() != 1)
888     return None;
889   // If no layout or identity layout, this is contiguous by definition.
890   if (memRefType.getLayout().isIdentity())
891     return strides;
892 
893   // Otherwise, we must determine contiguity form shapes. This can only ever
894   // work in static cases because MemRefType is underspecified to represent
895   // contiguous dynamic shapes in other ways than with just empty/identity
896   // layout.
897   auto sizes = memRefType.getShape();
898   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
899     if (ShapedType::isDynamic(sizes[index + 1]) ||
900         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
901         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
902       return None;
903     if (strides[index] != strides[index + 1] * sizes[index + 1])
904       return None;
905   }
906   return strides;
907 }
908 
909 class VectorTypeCastOpConversion
910     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
911 public:
912   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
913 
914   LogicalResult
915   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
916                   ConversionPatternRewriter &rewriter) const override {
917     auto loc = castOp->getLoc();
918     MemRefType sourceMemRefType =
919         castOp.getOperand().getType().cast<MemRefType>();
920     MemRefType targetMemRefType = castOp.getType();
921 
922     // Only static shape casts supported atm.
923     if (!sourceMemRefType.hasStaticShape() ||
924         !targetMemRefType.hasStaticShape())
925       return failure();
926 
927     auto llvmSourceDescriptorTy =
928         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
929     if (!llvmSourceDescriptorTy)
930       return failure();
931     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
932 
933     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
934                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
935     if (!llvmTargetDescriptorTy)
936       return failure();
937 
938     // Only contiguous source buffers supported atm.
939     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
940     if (!sourceStrides)
941       return failure();
942     auto targetStrides = computeContiguousStrides(targetMemRefType);
943     if (!targetStrides)
944       return failure();
945     // Only support static strides for now, regardless of contiguity.
946     if (llvm::any_of(*targetStrides, ShapedType::isDynamicStrideOrOffset))
947       return failure();
948 
949     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
950 
951     // Create descriptor.
952     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
953     Type llvmTargetElementTy = desc.getElementPtrType();
954     // Set allocated ptr.
955     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
956     allocated =
957         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
958     desc.setAllocatedPtr(rewriter, loc, allocated);
959     // Set aligned ptr.
960     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
961     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
962     desc.setAlignedPtr(rewriter, loc, ptr);
963     // Fill offset 0.
964     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
965     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
966     desc.setOffset(rewriter, loc, zero);
967 
968     // Fill size and stride descriptors in memref.
969     for (const auto &indexedSize :
970          llvm::enumerate(targetMemRefType.getShape())) {
971       int64_t index = indexedSize.index();
972       auto sizeAttr =
973           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
974       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
975       desc.setSize(rewriter, loc, index, size);
976       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
977                                                 (*targetStrides)[index]);
978       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
979       desc.setStride(rewriter, loc, index, stride);
980     }
981 
982     rewriter.replaceOp(castOp, {desc});
983     return success();
984   }
985 };
986 
987 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
988 /// Non-scalable versions of this operation are handled in Vector Transforms.
989 class VectorCreateMaskOpRewritePattern
990     : public OpRewritePattern<vector::CreateMaskOp> {
991 public:
992   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
993                                             bool enableIndexOpt)
994       : OpRewritePattern<vector::CreateMaskOp>(context),
995         force32BitVectorIndices(enableIndexOpt) {}
996 
997   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
998                                 PatternRewriter &rewriter) const override {
999     auto dstType = op.getType();
1000     if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
1001       return failure();
1002     IntegerType idxType =
1003         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1004     auto loc = op->getLoc();
1005     Value indices = rewriter.create<LLVM::StepVectorOp>(
1006         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1007                                  /*isScalable=*/true));
1008     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1009                                                  op.getOperand(0));
1010     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1011     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1012                                                 indices, bounds);
1013     rewriter.replaceOp(op, comp);
1014     return success();
1015   }
1016 
1017 private:
1018   const bool force32BitVectorIndices;
1019 };
1020 
1021 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1022 public:
1023   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1024 
1025   // Proof-of-concept lowering implementation that relies on a small
1026   // runtime support library, which only needs to provide a few
1027   // printing methods (single value for all data types, opening/closing
1028   // bracket, comma, newline). The lowering fully unrolls a vector
1029   // in terms of these elementary printing operations. The advantage
1030   // of this approach is that the library can remain unaware of all
1031   // low-level implementation details of vectors while still supporting
1032   // output of any shaped and dimensioned vector. Due to full unrolling,
1033   // this approach is less suited for very large vectors though.
1034   //
1035   // TODO: rely solely on libc in future? something else?
1036   //
1037   LogicalResult
1038   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1039                   ConversionPatternRewriter &rewriter) const override {
1040     Type printType = printOp.getPrintType();
1041 
1042     if (typeConverter->convertType(printType) == nullptr)
1043       return failure();
1044 
1045     // Make sure element type has runtime support.
1046     PrintConversion conversion = PrintConversion::None;
1047     VectorType vectorType = printType.dyn_cast<VectorType>();
1048     Type eltType = vectorType ? vectorType.getElementType() : printType;
1049     Operation *printer;
1050     if (eltType.isF32()) {
1051       printer =
1052           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1053     } else if (eltType.isF64()) {
1054       printer =
1055           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1056     } else if (eltType.isIndex()) {
1057       printer =
1058           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1059     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1060       // Integers need a zero or sign extension on the operand
1061       // (depending on the source type) as well as a signed or
1062       // unsigned print method. Up to 64-bit is supported.
1063       unsigned width = intTy.getWidth();
1064       if (intTy.isUnsigned()) {
1065         if (width <= 64) {
1066           if (width < 64)
1067             conversion = PrintConversion::ZeroExt64;
1068           printer = LLVM::lookupOrCreatePrintU64Fn(
1069               printOp->getParentOfType<ModuleOp>());
1070         } else {
1071           return failure();
1072         }
1073       } else {
1074         assert(intTy.isSignless() || intTy.isSigned());
1075         if (width <= 64) {
1076           // Note that we *always* zero extend booleans (1-bit integers),
1077           // so that true/false is printed as 1/0 rather than -1/0.
1078           if (width == 1)
1079             conversion = PrintConversion::ZeroExt64;
1080           else if (width < 64)
1081             conversion = PrintConversion::SignExt64;
1082           printer = LLVM::lookupOrCreatePrintI64Fn(
1083               printOp->getParentOfType<ModuleOp>());
1084         } else {
1085           return failure();
1086         }
1087       }
1088     } else {
1089       return failure();
1090     }
1091 
1092     // Unroll vector into elementary print calls.
1093     int64_t rank = vectorType ? vectorType.getRank() : 0;
1094     Type type = vectorType ? vectorType : eltType;
1095     emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1096               conversion);
1097     emitCall(rewriter, printOp->getLoc(),
1098              LLVM::lookupOrCreatePrintNewlineFn(
1099                  printOp->getParentOfType<ModuleOp>()));
1100     rewriter.eraseOp(printOp);
1101     return success();
1102   }
1103 
1104 private:
1105   enum class PrintConversion {
1106     // clang-format off
1107     None,
1108     ZeroExt64,
1109     SignExt64
1110     // clang-format on
1111   };
1112 
1113   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1114                  Value value, Type type, Operation *printer, int64_t rank,
1115                  PrintConversion conversion) const {
1116     VectorType vectorType = type.dyn_cast<VectorType>();
1117     Location loc = op->getLoc();
1118     if (!vectorType) {
1119       assert(rank == 0 && "The scalar case expects rank == 0");
1120       switch (conversion) {
1121       case PrintConversion::ZeroExt64:
1122         value = rewriter.create<arith::ExtUIOp>(
1123             loc, IntegerType::get(rewriter.getContext(), 64), value);
1124         break;
1125       case PrintConversion::SignExt64:
1126         value = rewriter.create<arith::ExtSIOp>(
1127             loc, IntegerType::get(rewriter.getContext(), 64), value);
1128         break;
1129       case PrintConversion::None:
1130         break;
1131       }
1132       emitCall(rewriter, loc, printer, value);
1133       return;
1134     }
1135 
1136     emitCall(rewriter, loc,
1137              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1138     Operation *printComma =
1139         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1140 
1141     if (rank <= 1) {
1142       auto reducedType = vectorType.getElementType();
1143       auto llvmType = typeConverter->convertType(reducedType);
1144       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1145       for (int64_t d = 0; d < dim; ++d) {
1146         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1147                                      llvmType, /*rank=*/0, /*pos=*/d);
1148         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1149                   conversion);
1150         if (d != dim - 1)
1151           emitCall(rewriter, loc, printComma);
1152       }
1153       emitCall(
1154           rewriter, loc,
1155           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1156       return;
1157     }
1158 
1159     int64_t dim = vectorType.getDimSize(0);
1160     for (int64_t d = 0; d < dim; ++d) {
1161       auto reducedType = reducedVectorTypeFront(vectorType);
1162       auto llvmType = typeConverter->convertType(reducedType);
1163       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1164                                    llvmType, rank, d);
1165       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1166                 conversion);
1167       if (d != dim - 1)
1168         emitCall(rewriter, loc, printComma);
1169     }
1170     emitCall(rewriter, loc,
1171              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1172   }
1173 
1174   // Helper to emit a call.
1175   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1176                        Operation *ref, ValueRange params = ValueRange()) {
1177     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1178                                   params);
1179   }
1180 };
1181 
1182 /// The Splat operation is lowered to an insertelement + a shufflevector
1183 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1184 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1185   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1186 
1187   LogicalResult
1188   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1189                   ConversionPatternRewriter &rewriter) const override {
1190     VectorType resultType = splatOp.getType().cast<VectorType>();
1191     if (resultType.getRank() > 1)
1192       return failure();
1193 
1194     // First insert it into an undef vector so we can shuffle it.
1195     auto vectorType = typeConverter->convertType(splatOp.getType());
1196     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1197     auto zero = rewriter.create<LLVM::ConstantOp>(
1198         splatOp.getLoc(),
1199         typeConverter->convertType(rewriter.getIntegerType(32)),
1200         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1201 
1202     // For 0-d vector, we simply do `insertelement`.
1203     if (resultType.getRank() == 0) {
1204       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1205           splatOp, vectorType, undef, adaptor.getInput(), zero);
1206       return success();
1207     }
1208 
1209     // For 1-d vector, we additionally do a `vectorshuffle`.
1210     auto v = rewriter.create<LLVM::InsertElementOp>(
1211         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1212 
1213     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1214     SmallVector<int32_t> zeroValues(width, 0);
1215 
1216     // Shuffle the value across the desired number of elements.
1217     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1218                                                        zeroValues);
1219     return success();
1220   }
1221 };
1222 
1223 /// The Splat operation is lowered to an insertelement + a shufflevector
1224 /// operation. Splat to only 2+-d vector result types are lowered by the
1225 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1226 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1227   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1228 
1229   LogicalResult
1230   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1231                   ConversionPatternRewriter &rewriter) const override {
1232     VectorType resultType = splatOp.getType();
1233     if (resultType.getRank() <= 1)
1234       return failure();
1235 
1236     // First insert it into an undef vector so we can shuffle it.
1237     auto loc = splatOp.getLoc();
1238     auto vectorTypeInfo =
1239         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1240     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1241     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1242     if (!llvmNDVectorTy || !llvm1DVectorTy)
1243       return failure();
1244 
1245     // Construct returned value.
1246     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1247 
1248     // Construct a 1-D vector with the splatted value that we insert in all the
1249     // places within the returned descriptor.
1250     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1251     auto zero = rewriter.create<LLVM::ConstantOp>(
1252         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1253         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1254     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1255                                                      adaptor.getInput(), zero);
1256 
1257     // Shuffle the value across the desired number of elements.
1258     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1259     SmallVector<int32_t> zeroValues(width, 0);
1260     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1261 
1262     // Iterate of linear index, convert to coords space and insert splatted 1-D
1263     // vector in each position.
1264     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1265       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1266     });
1267     rewriter.replaceOp(splatOp, desc);
1268     return success();
1269   }
1270 };
1271 
1272 } // namespace
1273 
1274 /// Populate the given list with patterns that convert from Vector to LLVM.
1275 void mlir::populateVectorToLLVMConversionPatterns(
1276     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1277     bool reassociateFPReductions, bool force32BitVectorIndices) {
1278   MLIRContext *ctx = converter.getDialect()->getContext();
1279   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1280   populateVectorInsertExtractStridedSliceTransforms(patterns);
1281   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1282   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1283   patterns
1284       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1285            VectorExtractElementOpConversion, VectorExtractOpConversion,
1286            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1287            VectorInsertOpConversion, VectorPrintOpConversion,
1288            VectorTypeCastOpConversion, VectorScaleOpConversion,
1289            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1290            VectorLoadStoreConversion<vector::MaskedLoadOp,
1291                                      vector::MaskedLoadOpAdaptor>,
1292            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1293            VectorLoadStoreConversion<vector::MaskedStoreOp,
1294                                      vector::MaskedStoreOpAdaptor>,
1295            VectorGatherOpConversion, VectorScatterOpConversion,
1296            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1297            VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
1298   // Transfer ops with rank > 1 are handled by VectorToSCF.
1299   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1300 }
1301 
1302 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1303     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1304   patterns.add<VectorMatmulOpConversion>(converter);
1305   patterns.add<VectorFlatTransposeOpConversion>(converter);
1306 }
1307