xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 8fe076ffe09028bef761b0d0ebdd5842c595ca87)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Support/MathExtras.h"
21 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 
27 // Helper to reduce vector type by one rank at front.
28 static VectorType reducedVectorTypeFront(VectorType tp) {
29   assert((tp.getRank() > 1) && "unlowerable vector type");
30   unsigned numScalableDims = tp.getNumScalableDims();
31   if (tp.getShape().size() == numScalableDims)
32     --numScalableDims;
33   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
34                          numScalableDims);
35 }
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39   assert((tp.getRank() > 1) && "unlowerable vector type");
40   unsigned numScalableDims = tp.getNumScalableDims();
41   if (numScalableDims > 0)
42     --numScalableDims;
43   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
44                          numScalableDims);
45 }
46 
47 // Helper that picks the proper sequence for inserting.
48 static Value insertOne(ConversionPatternRewriter &rewriter,
49                        LLVMTypeConverter &typeConverter, Location loc,
50                        Value val1, Value val2, Type llvmType, int64_t rank,
51                        int64_t pos) {
52   assert(rank > 0 && "0-D vector corner case should have been handled already");
53   if (rank == 1) {
54     auto idxType = rewriter.getIndexType();
55     auto constant = rewriter.create<LLVM::ConstantOp>(
56         loc, typeConverter.convertType(idxType),
57         rewriter.getIntegerAttr(idxType, pos));
58     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
59                                                   constant);
60   }
61   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
62                                               rewriter.getI64ArrayAttr(pos));
63 }
64 
65 // Helper that picks the proper sequence for extracting.
66 static Value extractOne(ConversionPatternRewriter &rewriter,
67                         LLVMTypeConverter &typeConverter, Location loc,
68                         Value val, Type llvmType, int64_t rank, int64_t pos) {
69   if (rank <= 1) {
70     auto idxType = rewriter.getIndexType();
71     auto constant = rewriter.create<LLVM::ConstantOp>(
72         loc, typeConverter.convertType(idxType),
73         rewriter.getIntegerAttr(idxType, pos));
74     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
75                                                    constant);
76   }
77   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
78                                                rewriter.getI64ArrayAttr(pos));
79 }
80 
81 // Helper that returns data layout alignment of a memref.
82 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
83                                  MemRefType memrefType, unsigned &align) {
84   Type elementTy = typeConverter.convertType(memrefType.getElementType());
85   if (!elementTy)
86     return failure();
87 
88   // TODO: this should use the MLIR data layout when it becomes available and
89   // stop depending on translation.
90   llvm::LLVMContext llvmContext;
91   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
92               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
93   return success();
94 }
95 
96 // Add an index vector component to a base pointer. This almost always succeeds
97 // unless the last stride is non-unit or the memory space is not zero.
98 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
99                                     Location loc, Value memref, Value base,
100                                     Value index, MemRefType memRefType,
101                                     VectorType vType, Value &ptrs) {
102   int64_t offset;
103   SmallVector<int64_t, 4> strides;
104   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
105   if (failed(successStrides) || strides.back() != 1 ||
106       memRefType.getMemorySpaceAsInt() != 0)
107     return failure();
108   auto pType = MemRefDescriptor(memref).getElementPtrType();
109   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
110   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
111   return success();
112 }
113 
114 // Casts a strided element pointer to a vector pointer.  The vector pointer
115 // will be in the same address space as the incoming memref type.
116 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
117                          Value ptr, MemRefType memRefType, Type vt) {
118   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
119   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
120 }
121 
122 namespace {
123 
124 /// Trivial Vector to LLVM conversions
125 using VectorScaleOpConversion =
126     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
127 
128 /// Conversion pattern for a vector.bitcast.
129 class VectorBitCastOpConversion
130     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
131 public:
132   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
133 
134   LogicalResult
135   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
136                   ConversionPatternRewriter &rewriter) const override {
137     // Only 0-D and 1-D vectors can be lowered to LLVM.
138     VectorType resultTy = bitCastOp.getResultVectorType();
139     if (resultTy.getRank() > 1)
140       return failure();
141     Type newResultTy = typeConverter->convertType(resultTy);
142     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
143                                                  adaptor.getOperands()[0]);
144     return success();
145   }
146 };
147 
148 /// Conversion pattern for a vector.matrix_multiply.
149 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
150 class VectorMatmulOpConversion
151     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
152 public:
153   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
154 
155   LogicalResult
156   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
157                   ConversionPatternRewriter &rewriter) const override {
158     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
159         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
160         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
161         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
162     return success();
163   }
164 };
165 
166 /// Conversion pattern for a vector.flat_transpose.
167 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
168 class VectorFlatTransposeOpConversion
169     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
170 public:
171   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
172 
173   LogicalResult
174   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
175                   ConversionPatternRewriter &rewriter) const override {
176     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
177         transOp, typeConverter->convertType(transOp.getRes().getType()),
178         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
179     return success();
180   }
181 };
182 
183 /// Overloaded utility that replaces a vector.load, vector.store,
184 /// vector.maskedload and vector.maskedstore with their respective LLVM
185 /// couterparts.
186 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
187                                  vector::LoadOpAdaptor adaptor,
188                                  VectorType vectorTy, Value ptr, unsigned align,
189                                  ConversionPatternRewriter &rewriter) {
190   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
191 }
192 
193 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
194                                  vector::MaskedLoadOpAdaptor adaptor,
195                                  VectorType vectorTy, Value ptr, unsigned align,
196                                  ConversionPatternRewriter &rewriter) {
197   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
198       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
199 }
200 
201 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
202                                  vector::StoreOpAdaptor adaptor,
203                                  VectorType vectorTy, Value ptr, unsigned align,
204                                  ConversionPatternRewriter &rewriter) {
205   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
206                                              ptr, align);
207 }
208 
209 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
210                                  vector::MaskedStoreOpAdaptor adaptor,
211                                  VectorType vectorTy, Value ptr, unsigned align,
212                                  ConversionPatternRewriter &rewriter) {
213   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
214       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
215 }
216 
217 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
218 /// vector.maskedstore.
219 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
220 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
221 public:
222   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
223 
224   LogicalResult
225   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
226                   typename LoadOrStoreOp::Adaptor adaptor,
227                   ConversionPatternRewriter &rewriter) const override {
228     // Only 1-D vectors can be lowered to LLVM.
229     VectorType vectorTy = loadOrStoreOp.getVectorType();
230     if (vectorTy.getRank() > 1)
231       return failure();
232 
233     auto loc = loadOrStoreOp->getLoc();
234     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
235 
236     // Resolve alignment.
237     unsigned align;
238     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
239       return failure();
240 
241     // Resolve address.
242     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
243                      .template cast<VectorType>();
244     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
245                                                adaptor.getIndices(), rewriter);
246     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
247 
248     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
249     return success();
250   }
251 };
252 
253 /// Conversion pattern for a vector.gather.
254 class VectorGatherOpConversion
255     : public ConvertOpToLLVMPattern<vector::GatherOp> {
256 public:
257   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
258 
259   LogicalResult
260   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
261                   ConversionPatternRewriter &rewriter) const override {
262     auto loc = gather->getLoc();
263     MemRefType memRefType = gather.getMemRefType();
264 
265     // Resolve alignment.
266     unsigned align;
267     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
268       return failure();
269 
270     // Resolve address.
271     Value ptrs;
272     VectorType vType = gather.getVectorType();
273     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
274                                      adaptor.getIndices(), rewriter);
275     if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
276                               adaptor.getIndexVec(), memRefType, vType, ptrs)))
277       return failure();
278 
279     // Replace with the gather intrinsic.
280     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
281         gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
282         adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
283     return success();
284   }
285 };
286 
287 /// Conversion pattern for a vector.scatter.
288 class VectorScatterOpConversion
289     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
290 public:
291   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
292 
293   LogicalResult
294   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
295                   ConversionPatternRewriter &rewriter) const override {
296     auto loc = scatter->getLoc();
297     MemRefType memRefType = scatter.getMemRefType();
298 
299     // Resolve alignment.
300     unsigned align;
301     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
302       return failure();
303 
304     // Resolve address.
305     Value ptrs;
306     VectorType vType = scatter.getVectorType();
307     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
308                                      adaptor.getIndices(), rewriter);
309     if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
310                               adaptor.getIndexVec(), memRefType, vType, ptrs)))
311       return failure();
312 
313     // Replace with the scatter intrinsic.
314     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
315         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
316         rewriter.getI32IntegerAttr(align));
317     return success();
318   }
319 };
320 
321 /// Conversion pattern for a vector.expandload.
322 class VectorExpandLoadOpConversion
323     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
324 public:
325   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
326 
327   LogicalResult
328   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
329                   ConversionPatternRewriter &rewriter) const override {
330     auto loc = expand->getLoc();
331     MemRefType memRefType = expand.getMemRefType();
332 
333     // Resolve address.
334     auto vtype = typeConverter->convertType(expand.getVectorType());
335     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
336                                      adaptor.getIndices(), rewriter);
337 
338     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
339         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
340     return success();
341   }
342 };
343 
344 /// Conversion pattern for a vector.compressstore.
345 class VectorCompressStoreOpConversion
346     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
347 public:
348   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
349 
350   LogicalResult
351   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
352                   ConversionPatternRewriter &rewriter) const override {
353     auto loc = compress->getLoc();
354     MemRefType memRefType = compress.getMemRefType();
355 
356     // Resolve address.
357     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
358                                      adaptor.getIndices(), rewriter);
359 
360     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
361         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
362     return success();
363   }
364 };
365 
366 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
367 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
368 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
369 /// non-null.
370 template <class VectorOp, class ScalarOp>
371 static Value createIntegerReductionArithmeticOpLowering(
372     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
373     Value vectorOperand, Value accumulator) {
374   Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
375   if (accumulator)
376     result = rewriter.create<ScalarOp>(loc, accumulator, result);
377   return result;
378 }
379 
380 /// Helper method to lower a `vector.reduction` operation that performs
381 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
382 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
383 /// the accumulator value if non-null.
384 template <class VectorOp>
385 static Value createIntegerReductionComparisonOpLowering(
386     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
387     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
388   Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
389   if (accumulator) {
390     Value cmp =
391         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
392     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
393   }
394   return result;
395 }
396 
397 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
398 /// with vector types.
399 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
400                            Value rhs, bool isMin) {
401   auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
402   Type i1Type = builder.getI1Type();
403   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
404     i1Type = VectorType::get(vecType.getShape(), i1Type);
405   Value cmp = builder.create<LLVM::FCmpOp>(
406       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
407       lhs, rhs);
408   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
409   Value isNan = builder.create<LLVM::FCmpOp>(
410       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
411   Value nan = builder.create<LLVM::ConstantOp>(
412       loc, lhs.getType(),
413       builder.getFloatAttr(floatType,
414                            APFloat::getQNaN(floatType.getFloatSemantics())));
415   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
416 }
417 
418 /// Conversion pattern for all vector reductions.
419 class VectorReductionOpConversion
420     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
421 public:
422   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
423                                        bool reassociateFPRed)
424       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
425         reassociateFPReductions(reassociateFPRed) {}
426 
427   LogicalResult
428   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
429                   ConversionPatternRewriter &rewriter) const override {
430     auto kind = reductionOp.getKind();
431     Type eltType = reductionOp.getDest().getType();
432     Type llvmType = typeConverter->convertType(eltType);
433     Value operand = adaptor.getVector();
434     Value acc = adaptor.getAcc();
435     Location loc = reductionOp.getLoc();
436     if (eltType.isIntOrIndex()) {
437       // Integer reductions: add/mul/min/max/and/or/xor.
438       Value result;
439       switch (kind) {
440       case vector::CombiningKind::ADD:
441         result =
442             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
443                                                        LLVM::AddOp>(
444                 rewriter, loc, llvmType, operand, acc);
445         break;
446       case vector::CombiningKind::MUL:
447         result =
448             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
449                                                        LLVM::MulOp>(
450                 rewriter, loc, llvmType, operand, acc);
451         break;
452       case vector::CombiningKind::MINUI:
453         result = createIntegerReductionComparisonOpLowering<
454             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
455                                       LLVM::ICmpPredicate::ule);
456         break;
457       case vector::CombiningKind::MINSI:
458         result = createIntegerReductionComparisonOpLowering<
459             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
460                                       LLVM::ICmpPredicate::sle);
461         break;
462       case vector::CombiningKind::MAXUI:
463         result = createIntegerReductionComparisonOpLowering<
464             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
465                                       LLVM::ICmpPredicate::uge);
466         break;
467       case vector::CombiningKind::MAXSI:
468         result = createIntegerReductionComparisonOpLowering<
469             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
470                                       LLVM::ICmpPredicate::sge);
471         break;
472       case vector::CombiningKind::AND:
473         result =
474             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
475                                                        LLVM::AndOp>(
476                 rewriter, loc, llvmType, operand, acc);
477         break;
478       case vector::CombiningKind::OR:
479         result =
480             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
481                                                        LLVM::OrOp>(
482                 rewriter, loc, llvmType, operand, acc);
483         break;
484       case vector::CombiningKind::XOR:
485         result =
486             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
487                                                        LLVM::XOrOp>(
488                 rewriter, loc, llvmType, operand, acc);
489         break;
490       default:
491         return failure();
492       }
493       rewriter.replaceOp(reductionOp, result);
494 
495       return success();
496     }
497 
498     if (!eltType.isa<FloatType>())
499       return failure();
500 
501     // Floating-point reductions: add/mul/min/max
502     if (kind == vector::CombiningKind::ADD) {
503       // Optional accumulator (or zero).
504       Value acc = adaptor.getOperands().size() > 1
505                       ? adaptor.getOperands()[1]
506                       : rewriter.create<LLVM::ConstantOp>(
507                             reductionOp->getLoc(), llvmType,
508                             rewriter.getZeroAttr(eltType));
509       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
510           reductionOp, llvmType, acc, operand,
511           rewriter.getBoolAttr(reassociateFPReductions));
512     } else if (kind == vector::CombiningKind::MUL) {
513       // Optional accumulator (or one).
514       Value acc = adaptor.getOperands().size() > 1
515                       ? adaptor.getOperands()[1]
516                       : rewriter.create<LLVM::ConstantOp>(
517                             reductionOp->getLoc(), llvmType,
518                             rewriter.getFloatAttr(eltType, 1.0));
519       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
520           reductionOp, llvmType, acc, operand,
521           rewriter.getBoolAttr(reassociateFPReductions));
522     } else if (kind == vector::CombiningKind::MINF) {
523       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
524       // NaNs/-0.0/+0.0 in the same way.
525       Value result =
526           rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
527       if (acc)
528         result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
529       rewriter.replaceOp(reductionOp, result);
530     } else if (kind == vector::CombiningKind::MAXF) {
531       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
532       // NaNs/-0.0/+0.0 in the same way.
533       Value result =
534           rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
535       if (acc)
536         result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
537       rewriter.replaceOp(reductionOp, result);
538     } else
539       return failure();
540 
541     return success();
542   }
543 
544 private:
545   const bool reassociateFPReductions;
546 };
547 
548 class VectorShuffleOpConversion
549     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
550 public:
551   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
552 
553   LogicalResult
554   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
555                   ConversionPatternRewriter &rewriter) const override {
556     auto loc = shuffleOp->getLoc();
557     auto v1Type = shuffleOp.getV1VectorType();
558     auto v2Type = shuffleOp.getV2VectorType();
559     auto vectorType = shuffleOp.getVectorType();
560     Type llvmType = typeConverter->convertType(vectorType);
561     auto maskArrayAttr = shuffleOp.getMask();
562 
563     // Bail if result type cannot be lowered.
564     if (!llvmType)
565       return failure();
566 
567     // Get rank and dimension sizes.
568     int64_t rank = vectorType.getRank();
569     assert(v1Type.getRank() == rank);
570     assert(v2Type.getRank() == rank);
571     int64_t v1Dim = v1Type.getDimSize(0);
572 
573     // For rank 1, where both operands have *exactly* the same vector type,
574     // there is direct shuffle support in LLVM. Use it!
575     if (rank == 1 && v1Type == v2Type) {
576       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
577           loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr);
578       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
579       return success();
580     }
581 
582     // For all other cases, insert the individual values individually.
583     Type eltType;
584     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
585       eltType = arrayType.getElementType();
586     else
587       eltType = llvmType.cast<VectorType>().getElementType();
588     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
589     int64_t insPos = 0;
590     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
591       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
592       Value value = adaptor.getV1();
593       if (extPos >= v1Dim) {
594         extPos -= v1Dim;
595         value = adaptor.getV2();
596       }
597       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
598                                  eltType, rank, extPos);
599       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
600                          llvmType, rank, insPos++);
601     }
602     rewriter.replaceOp(shuffleOp, insert);
603     return success();
604   }
605 };
606 
607 class VectorExtractElementOpConversion
608     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
609 public:
610   using ConvertOpToLLVMPattern<
611       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
612 
613   LogicalResult
614   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
615                   ConversionPatternRewriter &rewriter) const override {
616     auto vectorType = extractEltOp.getVectorType();
617     auto llvmType = typeConverter->convertType(vectorType.getElementType());
618 
619     // Bail if result type cannot be lowered.
620     if (!llvmType)
621       return failure();
622 
623     if (vectorType.getRank() == 0) {
624       Location loc = extractEltOp.getLoc();
625       auto idxType = rewriter.getIndexType();
626       auto zero = rewriter.create<LLVM::ConstantOp>(
627           loc, typeConverter->convertType(idxType),
628           rewriter.getIntegerAttr(idxType, 0));
629       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
630           extractEltOp, llvmType, adaptor.getVector(), zero);
631       return success();
632     }
633 
634     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
635         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
636     return success();
637   }
638 };
639 
640 class VectorExtractOpConversion
641     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
642 public:
643   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
644 
645   LogicalResult
646   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
647                   ConversionPatternRewriter &rewriter) const override {
648     auto loc = extractOp->getLoc();
649     auto vectorType = extractOp.getVectorType();
650     auto resultType = extractOp.getResult().getType();
651     auto llvmResultType = typeConverter->convertType(resultType);
652     auto positionArrayAttr = extractOp.getPosition();
653 
654     // Bail if result type cannot be lowered.
655     if (!llvmResultType)
656       return failure();
657 
658     // Extract entire vector. Should be handled by folder, but just to be safe.
659     if (positionArrayAttr.empty()) {
660       rewriter.replaceOp(extractOp, adaptor.getVector());
661       return success();
662     }
663 
664     // One-shot extraction of vector from array (only requires extractvalue).
665     if (resultType.isa<VectorType>()) {
666       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
667           loc, llvmResultType, adaptor.getVector(), positionArrayAttr);
668       rewriter.replaceOp(extractOp, extracted);
669       return success();
670     }
671 
672     // Potential extraction of 1-D vector from array.
673     auto *context = extractOp->getContext();
674     Value extracted = adaptor.getVector();
675     auto positionAttrs = positionArrayAttr.getValue();
676     if (positionAttrs.size() > 1) {
677       auto oneDVectorType = reducedVectorTypeBack(vectorType);
678       auto nMinusOnePositionAttrs =
679           ArrayAttr::get(context, positionAttrs.drop_back());
680       extracted = rewriter.create<LLVM::ExtractValueOp>(
681           loc, typeConverter->convertType(oneDVectorType), extracted,
682           nMinusOnePositionAttrs);
683     }
684 
685     // Remaining extraction of element from 1-D LLVM vector
686     auto position = positionAttrs.back().cast<IntegerAttr>();
687     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
688     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
689     extracted =
690         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
691     rewriter.replaceOp(extractOp, extracted);
692 
693     return success();
694   }
695 };
696 
697 /// Conversion pattern that turns a vector.fma on a 1-D vector
698 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
699 /// This does not match vectors of n >= 2 rank.
700 ///
701 /// Example:
702 /// ```
703 ///  vector.fma %a, %a, %a : vector<8xf32>
704 /// ```
705 /// is converted to:
706 /// ```
707 ///  llvm.intr.fmuladd %va, %va, %va:
708 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
709 ///    -> !llvm."<8 x f32>">
710 /// ```
711 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
712 public:
713   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
714 
715   LogicalResult
716   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
717                   ConversionPatternRewriter &rewriter) const override {
718     VectorType vType = fmaOp.getVectorType();
719     if (vType.getRank() != 1)
720       return failure();
721     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
722         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
723     return success();
724   }
725 };
726 
727 class VectorInsertElementOpConversion
728     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
729 public:
730   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
731 
732   LogicalResult
733   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
734                   ConversionPatternRewriter &rewriter) const override {
735     auto vectorType = insertEltOp.getDestVectorType();
736     auto llvmType = typeConverter->convertType(vectorType);
737 
738     // Bail if result type cannot be lowered.
739     if (!llvmType)
740       return failure();
741 
742     if (vectorType.getRank() == 0) {
743       Location loc = insertEltOp.getLoc();
744       auto idxType = rewriter.getIndexType();
745       auto zero = rewriter.create<LLVM::ConstantOp>(
746           loc, typeConverter->convertType(idxType),
747           rewriter.getIntegerAttr(idxType, 0));
748       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
749           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
750       return success();
751     }
752 
753     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
754         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
755         adaptor.getPosition());
756     return success();
757   }
758 };
759 
760 class VectorInsertOpConversion
761     : public ConvertOpToLLVMPattern<vector::InsertOp> {
762 public:
763   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
764 
765   LogicalResult
766   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
767                   ConversionPatternRewriter &rewriter) const override {
768     auto loc = insertOp->getLoc();
769     auto sourceType = insertOp.getSourceType();
770     auto destVectorType = insertOp.getDestVectorType();
771     auto llvmResultType = typeConverter->convertType(destVectorType);
772     auto positionArrayAttr = insertOp.getPosition();
773 
774     // Bail if result type cannot be lowered.
775     if (!llvmResultType)
776       return failure();
777 
778     // Overwrite entire vector with value. Should be handled by folder, but
779     // just to be safe.
780     if (positionArrayAttr.empty()) {
781       rewriter.replaceOp(insertOp, adaptor.getSource());
782       return success();
783     }
784 
785     // One-shot insertion of a vector into an array (only requires insertvalue).
786     if (sourceType.isa<VectorType>()) {
787       Value inserted = rewriter.create<LLVM::InsertValueOp>(
788           loc, llvmResultType, adaptor.getDest(), adaptor.getSource(),
789           positionArrayAttr);
790       rewriter.replaceOp(insertOp, inserted);
791       return success();
792     }
793 
794     // Potential extraction of 1-D vector from array.
795     auto *context = insertOp->getContext();
796     Value extracted = adaptor.getDest();
797     auto positionAttrs = positionArrayAttr.getValue();
798     auto position = positionAttrs.back().cast<IntegerAttr>();
799     auto oneDVectorType = destVectorType;
800     if (positionAttrs.size() > 1) {
801       oneDVectorType = reducedVectorTypeBack(destVectorType);
802       auto nMinusOnePositionAttrs =
803           ArrayAttr::get(context, positionAttrs.drop_back());
804       extracted = rewriter.create<LLVM::ExtractValueOp>(
805           loc, typeConverter->convertType(oneDVectorType), extracted,
806           nMinusOnePositionAttrs);
807     }
808 
809     // Insertion of an element into a 1-D LLVM vector.
810     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
811     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
812     Value inserted = rewriter.create<LLVM::InsertElementOp>(
813         loc, typeConverter->convertType(oneDVectorType), extracted,
814         adaptor.getSource(), constant);
815 
816     // Potential insertion of resulting 1-D vector into array.
817     if (positionAttrs.size() > 1) {
818       auto nMinusOnePositionAttrs =
819           ArrayAttr::get(context, positionAttrs.drop_back());
820       inserted = rewriter.create<LLVM::InsertValueOp>(
821           loc, llvmResultType, adaptor.getDest(), inserted,
822           nMinusOnePositionAttrs);
823     }
824 
825     rewriter.replaceOp(insertOp, inserted);
826     return success();
827   }
828 };
829 
830 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
831 ///
832 /// Example:
833 /// ```
834 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
835 /// ```
836 /// is rewritten into:
837 /// ```
838 ///  %r = splat %f0: vector<2x4xf32>
839 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
840 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
841 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
842 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
843 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
844 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
845 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
846 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
847 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
848 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
849 ///  // %r3 holds the final value.
850 /// ```
851 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
852 public:
853   using OpRewritePattern<FMAOp>::OpRewritePattern;
854 
855   void initialize() {
856     // This pattern recursively unpacks one dimension at a time. The recursion
857     // bounded as the rank is strictly decreasing.
858     setHasBoundedRewriteRecursion();
859   }
860 
861   LogicalResult matchAndRewrite(FMAOp op,
862                                 PatternRewriter &rewriter) const override {
863     auto vType = op.getVectorType();
864     if (vType.getRank() < 2)
865       return failure();
866 
867     auto loc = op.getLoc();
868     auto elemType = vType.getElementType();
869     Value zero = rewriter.create<arith::ConstantOp>(
870         loc, elemType, rewriter.getZeroAttr(elemType));
871     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
872     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
873       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
874       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
875       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
876       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
877       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
878     }
879     rewriter.replaceOp(op, desc);
880     return success();
881   }
882 };
883 
884 /// Returns the strides if the memory underlying `memRefType` has a contiguous
885 /// static layout.
886 static llvm::Optional<SmallVector<int64_t, 4>>
887 computeContiguousStrides(MemRefType memRefType) {
888   int64_t offset;
889   SmallVector<int64_t, 4> strides;
890   if (failed(getStridesAndOffset(memRefType, strides, offset)))
891     return None;
892   if (!strides.empty() && strides.back() != 1)
893     return None;
894   // If no layout or identity layout, this is contiguous by definition.
895   if (memRefType.getLayout().isIdentity())
896     return strides;
897 
898   // Otherwise, we must determine contiguity form shapes. This can only ever
899   // work in static cases because MemRefType is underspecified to represent
900   // contiguous dynamic shapes in other ways than with just empty/identity
901   // layout.
902   auto sizes = memRefType.getShape();
903   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
904     if (ShapedType::isDynamic(sizes[index + 1]) ||
905         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
906         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
907       return None;
908     if (strides[index] != strides[index + 1] * sizes[index + 1])
909       return None;
910   }
911   return strides;
912 }
913 
914 class VectorTypeCastOpConversion
915     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
916 public:
917   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
918 
919   LogicalResult
920   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
921                   ConversionPatternRewriter &rewriter) const override {
922     auto loc = castOp->getLoc();
923     MemRefType sourceMemRefType =
924         castOp.getOperand().getType().cast<MemRefType>();
925     MemRefType targetMemRefType = castOp.getType();
926 
927     // Only static shape casts supported atm.
928     if (!sourceMemRefType.hasStaticShape() ||
929         !targetMemRefType.hasStaticShape())
930       return failure();
931 
932     auto llvmSourceDescriptorTy =
933         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
934     if (!llvmSourceDescriptorTy)
935       return failure();
936     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
937 
938     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
939                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
940     if (!llvmTargetDescriptorTy)
941       return failure();
942 
943     // Only contiguous source buffers supported atm.
944     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
945     if (!sourceStrides)
946       return failure();
947     auto targetStrides = computeContiguousStrides(targetMemRefType);
948     if (!targetStrides)
949       return failure();
950     // Only support static strides for now, regardless of contiguity.
951     if (llvm::any_of(*targetStrides, [](int64_t stride) {
952           return ShapedType::isDynamicStrideOrOffset(stride);
953         }))
954       return failure();
955 
956     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
957 
958     // Create descriptor.
959     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
960     Type llvmTargetElementTy = desc.getElementPtrType();
961     // Set allocated ptr.
962     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
963     allocated =
964         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
965     desc.setAllocatedPtr(rewriter, loc, allocated);
966     // Set aligned ptr.
967     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
968     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
969     desc.setAlignedPtr(rewriter, loc, ptr);
970     // Fill offset 0.
971     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
972     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
973     desc.setOffset(rewriter, loc, zero);
974 
975     // Fill size and stride descriptors in memref.
976     for (const auto &indexedSize :
977          llvm::enumerate(targetMemRefType.getShape())) {
978       int64_t index = indexedSize.index();
979       auto sizeAttr =
980           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
981       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
982       desc.setSize(rewriter, loc, index, size);
983       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
984                                                 (*targetStrides)[index]);
985       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
986       desc.setStride(rewriter, loc, index, stride);
987     }
988 
989     rewriter.replaceOp(castOp, {desc});
990     return success();
991   }
992 };
993 
994 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
995 /// Non-scalable versions of this operation are handled in Vector Transforms.
996 class VectorCreateMaskOpRewritePattern
997     : public OpRewritePattern<vector::CreateMaskOp> {
998 public:
999   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1000                                             bool enableIndexOpt)
1001       : OpRewritePattern<vector::CreateMaskOp>(context),
1002         force32BitVectorIndices(enableIndexOpt) {}
1003 
1004   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1005                                 PatternRewriter &rewriter) const override {
1006     auto dstType = op.getType();
1007     if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
1008       return failure();
1009     IntegerType idxType =
1010         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1011     auto loc = op->getLoc();
1012     Value indices = rewriter.create<LLVM::StepVectorOp>(
1013         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1014                                  /*isScalable=*/true));
1015     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1016                                                  op.getOperand(0));
1017     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1018     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1019                                                 indices, bounds);
1020     rewriter.replaceOp(op, comp);
1021     return success();
1022   }
1023 
1024 private:
1025   const bool force32BitVectorIndices;
1026 };
1027 
1028 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1029 public:
1030   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1031 
1032   // Proof-of-concept lowering implementation that relies on a small
1033   // runtime support library, which only needs to provide a few
1034   // printing methods (single value for all data types, opening/closing
1035   // bracket, comma, newline). The lowering fully unrolls a vector
1036   // in terms of these elementary printing operations. The advantage
1037   // of this approach is that the library can remain unaware of all
1038   // low-level implementation details of vectors while still supporting
1039   // output of any shaped and dimensioned vector. Due to full unrolling,
1040   // this approach is less suited for very large vectors though.
1041   //
1042   // TODO: rely solely on libc in future? something else?
1043   //
1044   LogicalResult
1045   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1046                   ConversionPatternRewriter &rewriter) const override {
1047     Type printType = printOp.getPrintType();
1048 
1049     if (typeConverter->convertType(printType) == nullptr)
1050       return failure();
1051 
1052     // Make sure element type has runtime support.
1053     PrintConversion conversion = PrintConversion::None;
1054     VectorType vectorType = printType.dyn_cast<VectorType>();
1055     Type eltType = vectorType ? vectorType.getElementType() : printType;
1056     Operation *printer;
1057     if (eltType.isF32()) {
1058       printer =
1059           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1060     } else if (eltType.isF64()) {
1061       printer =
1062           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1063     } else if (eltType.isIndex()) {
1064       printer =
1065           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1066     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1067       // Integers need a zero or sign extension on the operand
1068       // (depending on the source type) as well as a signed or
1069       // unsigned print method. Up to 64-bit is supported.
1070       unsigned width = intTy.getWidth();
1071       if (intTy.isUnsigned()) {
1072         if (width <= 64) {
1073           if (width < 64)
1074             conversion = PrintConversion::ZeroExt64;
1075           printer = LLVM::lookupOrCreatePrintU64Fn(
1076               printOp->getParentOfType<ModuleOp>());
1077         } else {
1078           return failure();
1079         }
1080       } else {
1081         assert(intTy.isSignless() || intTy.isSigned());
1082         if (width <= 64) {
1083           // Note that we *always* zero extend booleans (1-bit integers),
1084           // so that true/false is printed as 1/0 rather than -1/0.
1085           if (width == 1)
1086             conversion = PrintConversion::ZeroExt64;
1087           else if (width < 64)
1088             conversion = PrintConversion::SignExt64;
1089           printer = LLVM::lookupOrCreatePrintI64Fn(
1090               printOp->getParentOfType<ModuleOp>());
1091         } else {
1092           return failure();
1093         }
1094       }
1095     } else {
1096       return failure();
1097     }
1098 
1099     // Unroll vector into elementary print calls.
1100     int64_t rank = vectorType ? vectorType.getRank() : 0;
1101     Type type = vectorType ? vectorType : eltType;
1102     emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1103               conversion);
1104     emitCall(rewriter, printOp->getLoc(),
1105              LLVM::lookupOrCreatePrintNewlineFn(
1106                  printOp->getParentOfType<ModuleOp>()));
1107     rewriter.eraseOp(printOp);
1108     return success();
1109   }
1110 
1111 private:
1112   enum class PrintConversion {
1113     // clang-format off
1114     None,
1115     ZeroExt64,
1116     SignExt64
1117     // clang-format on
1118   };
1119 
1120   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1121                  Value value, Type type, Operation *printer, int64_t rank,
1122                  PrintConversion conversion) const {
1123     VectorType vectorType = type.dyn_cast<VectorType>();
1124     Location loc = op->getLoc();
1125     if (!vectorType) {
1126       assert(rank == 0 && "The scalar case expects rank == 0");
1127       switch (conversion) {
1128       case PrintConversion::ZeroExt64:
1129         value = rewriter.create<arith::ExtUIOp>(
1130             loc, IntegerType::get(rewriter.getContext(), 64), value);
1131         break;
1132       case PrintConversion::SignExt64:
1133         value = rewriter.create<arith::ExtSIOp>(
1134             loc, IntegerType::get(rewriter.getContext(), 64), value);
1135         break;
1136       case PrintConversion::None:
1137         break;
1138       }
1139       emitCall(rewriter, loc, printer, value);
1140       return;
1141     }
1142 
1143     emitCall(rewriter, loc,
1144              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1145     Operation *printComma =
1146         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1147 
1148     if (rank <= 1) {
1149       auto reducedType = vectorType.getElementType();
1150       auto llvmType = typeConverter->convertType(reducedType);
1151       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1152       for (int64_t d = 0; d < dim; ++d) {
1153         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1154                                      llvmType, /*rank=*/0, /*pos=*/d);
1155         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1156                   conversion);
1157         if (d != dim - 1)
1158           emitCall(rewriter, loc, printComma);
1159       }
1160       emitCall(
1161           rewriter, loc,
1162           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1163       return;
1164     }
1165 
1166     int64_t dim = vectorType.getDimSize(0);
1167     for (int64_t d = 0; d < dim; ++d) {
1168       auto reducedType = reducedVectorTypeFront(vectorType);
1169       auto llvmType = typeConverter->convertType(reducedType);
1170       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1171                                    llvmType, rank, d);
1172       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1173                 conversion);
1174       if (d != dim - 1)
1175         emitCall(rewriter, loc, printComma);
1176     }
1177     emitCall(rewriter, loc,
1178              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1179   }
1180 
1181   // Helper to emit a call.
1182   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1183                        Operation *ref, ValueRange params = ValueRange()) {
1184     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1185                                   params);
1186   }
1187 };
1188 
1189 /// The Splat operation is lowered to an insertelement + a shufflevector
1190 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1191 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1192   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1193 
1194   LogicalResult
1195   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1196                   ConversionPatternRewriter &rewriter) const override {
1197     VectorType resultType = splatOp.getType().cast<VectorType>();
1198     if (resultType.getRank() > 1)
1199       return failure();
1200 
1201     // First insert it into an undef vector so we can shuffle it.
1202     auto vectorType = typeConverter->convertType(splatOp.getType());
1203     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1204     auto zero = rewriter.create<LLVM::ConstantOp>(
1205         splatOp.getLoc(),
1206         typeConverter->convertType(rewriter.getIntegerType(32)),
1207         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1208 
1209     // For 0-d vector, we simply do `insertelement`.
1210     if (resultType.getRank() == 0) {
1211       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1212           splatOp, vectorType, undef, adaptor.getInput(), zero);
1213       return success();
1214     }
1215 
1216     // For 1-d vector, we additionally do a `vectorshuffle`.
1217     auto v = rewriter.create<LLVM::InsertElementOp>(
1218         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1219 
1220     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1221     SmallVector<int32_t, 4> zeroValues(width, 0);
1222 
1223     // Shuffle the value across the desired number of elements.
1224     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1225     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1226                                                        zeroAttrs);
1227     return success();
1228   }
1229 };
1230 
1231 /// The Splat operation is lowered to an insertelement + a shufflevector
1232 /// operation. Splat to only 2+-d vector result types are lowered by the
1233 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1234 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1235   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1236 
1237   LogicalResult
1238   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1239                   ConversionPatternRewriter &rewriter) const override {
1240     VectorType resultType = splatOp.getType();
1241     if (resultType.getRank() <= 1)
1242       return failure();
1243 
1244     // First insert it into an undef vector so we can shuffle it.
1245     auto loc = splatOp.getLoc();
1246     auto vectorTypeInfo =
1247         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1248     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1249     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1250     if (!llvmNDVectorTy || !llvm1DVectorTy)
1251       return failure();
1252 
1253     // Construct returned value.
1254     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1255 
1256     // Construct a 1-D vector with the splatted value that we insert in all the
1257     // places within the returned descriptor.
1258     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1259     auto zero = rewriter.create<LLVM::ConstantOp>(
1260         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1261         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1262     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1263                                                      adaptor.getInput(), zero);
1264 
1265     // Shuffle the value across the desired number of elements.
1266     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1267     SmallVector<int32_t, 4> zeroValues(width, 0);
1268     ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1269     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
1270 
1271     // Iterate of linear index, convert to coords space and insert splatted 1-D
1272     // vector in each position.
1273     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
1274       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
1275                                                   position);
1276     });
1277     rewriter.replaceOp(splatOp, desc);
1278     return success();
1279   }
1280 };
1281 
1282 } // namespace
1283 
1284 /// Populate the given list with patterns that convert from Vector to LLVM.
1285 void mlir::populateVectorToLLVMConversionPatterns(
1286     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1287     bool reassociateFPReductions, bool force32BitVectorIndices) {
1288   MLIRContext *ctx = converter.getDialect()->getContext();
1289   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1290   populateVectorInsertExtractStridedSliceTransforms(patterns);
1291   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1292   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1293   patterns
1294       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1295            VectorExtractElementOpConversion, VectorExtractOpConversion,
1296            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1297            VectorInsertOpConversion, VectorPrintOpConversion,
1298            VectorTypeCastOpConversion, VectorScaleOpConversion,
1299            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1300            VectorLoadStoreConversion<vector::MaskedLoadOp,
1301                                      vector::MaskedLoadOpAdaptor>,
1302            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1303            VectorLoadStoreConversion<vector::MaskedStoreOp,
1304                                      vector::MaskedStoreOpAdaptor>,
1305            VectorGatherOpConversion, VectorScatterOpConversion,
1306            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1307            VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
1308   // Transfer ops with rank > 1 are handled by VectorToSCF.
1309   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1310 }
1311 
1312 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1313     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1314   patterns.add<VectorMatmulOpConversion>(converter);
1315   patterns.add<VectorFlatTransposeOpConversion>(converter);
1316 }
1317