xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision a1aad28d297abb14e812cacf97947a0f857a2f54)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
19 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include <optional>
25 
26 using namespace mlir;
27 using namespace mlir::vector;
28 
29 // Helper to reduce vector type by one rank at front.
30 static VectorType reducedVectorTypeFront(VectorType tp) {
31   assert((tp.getRank() > 1) && "unlowerable vector type");
32   unsigned numScalableDims = tp.getNumScalableDims();
33   if (tp.getShape().size() == numScalableDims)
34     --numScalableDims;
35   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
36                          numScalableDims);
37 }
38 
39 // Helper to reduce vector type by *all* but one rank at back.
40 static VectorType reducedVectorTypeBack(VectorType tp) {
41   assert((tp.getRank() > 1) && "unlowerable vector type");
42   unsigned numScalableDims = tp.getNumScalableDims();
43   if (numScalableDims > 0)
44     --numScalableDims;
45   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
46                          numScalableDims);
47 }
48 
49 // Helper that picks the proper sequence for inserting.
50 static Value insertOne(ConversionPatternRewriter &rewriter,
51                        LLVMTypeConverter &typeConverter, Location loc,
52                        Value val1, Value val2, Type llvmType, int64_t rank,
53                        int64_t pos) {
54   assert(rank > 0 && "0-D vector corner case should have been handled already");
55   if (rank == 1) {
56     auto idxType = rewriter.getIndexType();
57     auto constant = rewriter.create<LLVM::ConstantOp>(
58         loc, typeConverter.convertType(idxType),
59         rewriter.getIntegerAttr(idxType, pos));
60     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
61                                                   constant);
62   }
63   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
64 }
65 
66 // Helper that picks the proper sequence for extracting.
67 static Value extractOne(ConversionPatternRewriter &rewriter,
68                         LLVMTypeConverter &typeConverter, Location loc,
69                         Value val, Type llvmType, int64_t rank, int64_t pos) {
70   if (rank <= 1) {
71     auto idxType = rewriter.getIndexType();
72     auto constant = rewriter.create<LLVM::ConstantOp>(
73         loc, typeConverter.convertType(idxType),
74         rewriter.getIntegerAttr(idxType, pos));
75     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
76                                                    constant);
77   }
78   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
79 }
80 
81 // Helper that returns data layout alignment of a memref.
82 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
83                                  MemRefType memrefType, unsigned &align) {
84   Type elementTy = typeConverter.convertType(memrefType.getElementType());
85   if (!elementTy)
86     return failure();
87 
88   // TODO: this should use the MLIR data layout when it becomes available and
89   // stop depending on translation.
90   llvm::LLVMContext llvmContext;
91   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
92               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
93   return success();
94 }
95 
96 // Check if the last stride is non-unit or the memory space is not zero.
97 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
98                                            LLVMTypeConverter &converter) {
99   int64_t offset;
100   SmallVector<int64_t, 4> strides;
101   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
102   FailureOr<unsigned> addressSpace =
103       converter.getMemRefAddressSpace(memRefType);
104   if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) ||
105       *addressSpace != 0)
106     return failure();
107   return success();
108 }
109 
110 // Add an index vector component to a base pointer.
111 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
112                             LLVMTypeConverter &typeConverter,
113                             MemRefType memRefType, Value llvmMemref, Value base,
114                             Value index, uint64_t vLen) {
115   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
116          "unsupported memref type");
117   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
118   auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
119   return rewriter.create<LLVM::GEPOp>(
120       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
121       base, index);
122 }
123 
124 // Casts a strided element pointer to a vector pointer.  The vector pointer
125 // will be in the same address space as the incoming memref type.
126 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
127                          Value ptr, MemRefType memRefType, Type vt,
128                          LLVMTypeConverter &converter) {
129   if (converter.useOpaquePointers())
130     return ptr;
131 
132   unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
133   auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
134   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
135 }
136 
137 namespace {
138 
139 /// Trivial Vector to LLVM conversions
140 using VectorScaleOpConversion =
141     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
142 
143 /// Conversion pattern for a vector.bitcast.
144 class VectorBitCastOpConversion
145     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
146 public:
147   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
148 
149   LogicalResult
150   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
151                   ConversionPatternRewriter &rewriter) const override {
152     // Only 0-D and 1-D vectors can be lowered to LLVM.
153     VectorType resultTy = bitCastOp.getResultVectorType();
154     if (resultTy.getRank() > 1)
155       return failure();
156     Type newResultTy = typeConverter->convertType(resultTy);
157     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
158                                                  adaptor.getOperands()[0]);
159     return success();
160   }
161 };
162 
163 /// Conversion pattern for a vector.matrix_multiply.
164 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
165 class VectorMatmulOpConversion
166     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
167 public:
168   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
169 
170   LogicalResult
171   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
172                   ConversionPatternRewriter &rewriter) const override {
173     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
174         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
175         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
176         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
177     return success();
178   }
179 };
180 
181 /// Conversion pattern for a vector.flat_transpose.
182 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
183 class VectorFlatTransposeOpConversion
184     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
185 public:
186   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
187 
188   LogicalResult
189   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
190                   ConversionPatternRewriter &rewriter) const override {
191     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
192         transOp, typeConverter->convertType(transOp.getRes().getType()),
193         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
194     return success();
195   }
196 };
197 
198 /// Overloaded utility that replaces a vector.load, vector.store,
199 /// vector.maskedload and vector.maskedstore with their respective LLVM
200 /// couterparts.
201 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
202                                  vector::LoadOpAdaptor adaptor,
203                                  VectorType vectorTy, Value ptr, unsigned align,
204                                  ConversionPatternRewriter &rewriter) {
205   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
206 }
207 
208 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
209                                  vector::MaskedLoadOpAdaptor adaptor,
210                                  VectorType vectorTy, Value ptr, unsigned align,
211                                  ConversionPatternRewriter &rewriter) {
212   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
213       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
214 }
215 
216 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
217                                  vector::StoreOpAdaptor adaptor,
218                                  VectorType vectorTy, Value ptr, unsigned align,
219                                  ConversionPatternRewriter &rewriter) {
220   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
221                                              ptr, align);
222 }
223 
224 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
225                                  vector::MaskedStoreOpAdaptor adaptor,
226                                  VectorType vectorTy, Value ptr, unsigned align,
227                                  ConversionPatternRewriter &rewriter) {
228   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
229       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
230 }
231 
232 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
233 /// vector.maskedstore.
234 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
235 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
236 public:
237   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
238 
239   LogicalResult
240   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
241                   typename LoadOrStoreOp::Adaptor adaptor,
242                   ConversionPatternRewriter &rewriter) const override {
243     // Only 1-D vectors can be lowered to LLVM.
244     VectorType vectorTy = loadOrStoreOp.getVectorType();
245     if (vectorTy.getRank() > 1)
246       return failure();
247 
248     auto loc = loadOrStoreOp->getLoc();
249     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
250 
251     // Resolve alignment.
252     unsigned align;
253     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
254       return failure();
255 
256     // Resolve address.
257     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
258                      .template cast<VectorType>();
259     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
260                                                adaptor.getIndices(), rewriter);
261     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
262                             *this->getTypeConverter());
263 
264     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
265     return success();
266   }
267 };
268 
269 /// Conversion pattern for a vector.gather.
270 class VectorGatherOpConversion
271     : public ConvertOpToLLVMPattern<vector::GatherOp> {
272 public:
273   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
274 
275   LogicalResult
276   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
277                   ConversionPatternRewriter &rewriter) const override {
278     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
279     assert(memRefType && "The base should be bufferized");
280 
281     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
282       return failure();
283 
284     auto loc = gather->getLoc();
285 
286     // Resolve alignment.
287     unsigned align;
288     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
289       return failure();
290 
291     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
292                                      adaptor.getIndices(), rewriter);
293     Value base = adaptor.getBase();
294 
295     auto llvmNDVectorTy = adaptor.getIndexVec().getType();
296     // Handle the simple case of 1-D vector.
297     if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
298       auto vType = gather.getVectorType();
299       // Resolve address.
300       Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
301                                   memRefType, base, ptr, adaptor.getIndexVec(),
302                                   /*vLen=*/vType.getDimSize(0));
303       // Replace with the gather intrinsic.
304       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
305           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
306           adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
307       return success();
308     }
309 
310     LLVMTypeConverter &typeConverter = *this->getTypeConverter();
311     auto callback = [align, memRefType, base, ptr, loc, &rewriter,
312                      &typeConverter](Type llvm1DVectorTy,
313                                      ValueRange vectorOperands) {
314       // Resolve address.
315       Value ptrs = getIndexedPtrs(
316           rewriter, loc, typeConverter, memRefType, base, ptr,
317           /*index=*/vectorOperands[0],
318           LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
319       // Create the gather intrinsic.
320       return rewriter.create<LLVM::masked_gather>(
321           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
322           /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
323     };
324     SmallVector<Value> vectorOperands = {
325         adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
326     return LLVM::detail::handleMultidimensionalVectors(
327         gather, vectorOperands, *getTypeConverter(), callback, rewriter);
328   }
329 };
330 
331 /// Conversion pattern for a vector.scatter.
332 class VectorScatterOpConversion
333     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
334 public:
335   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
336 
337   LogicalResult
338   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
339                   ConversionPatternRewriter &rewriter) const override {
340     auto loc = scatter->getLoc();
341     MemRefType memRefType = scatter.getMemRefType();
342 
343     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
344       return failure();
345 
346     // Resolve alignment.
347     unsigned align;
348     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
349       return failure();
350 
351     // Resolve address.
352     VectorType vType = scatter.getVectorType();
353     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
354                                      adaptor.getIndices(), rewriter);
355     Value ptrs = getIndexedPtrs(
356         rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
357         ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
358 
359     // Replace with the scatter intrinsic.
360     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
361         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
362         rewriter.getI32IntegerAttr(align));
363     return success();
364   }
365 };
366 
367 /// Conversion pattern for a vector.expandload.
368 class VectorExpandLoadOpConversion
369     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
370 public:
371   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
372 
373   LogicalResult
374   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
375                   ConversionPatternRewriter &rewriter) const override {
376     auto loc = expand->getLoc();
377     MemRefType memRefType = expand.getMemRefType();
378 
379     // Resolve address.
380     auto vtype = typeConverter->convertType(expand.getVectorType());
381     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
382                                      adaptor.getIndices(), rewriter);
383 
384     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
385         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
386     return success();
387   }
388 };
389 
390 /// Conversion pattern for a vector.compressstore.
391 class VectorCompressStoreOpConversion
392     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
393 public:
394   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
395 
396   LogicalResult
397   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
398                   ConversionPatternRewriter &rewriter) const override {
399     auto loc = compress->getLoc();
400     MemRefType memRefType = compress.getMemRefType();
401 
402     // Resolve address.
403     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
404                                      adaptor.getIndices(), rewriter);
405 
406     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
407         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
408     return success();
409   }
410 };
411 
412 /// Reduction neutral classes for overloading.
413 class ReductionNeutralZero {};
414 class ReductionNeutralIntOne {};
415 class ReductionNeutralFPOne {};
416 class ReductionNeutralAllOnes {};
417 class ReductionNeutralSIntMin {};
418 class ReductionNeutralUIntMin {};
419 class ReductionNeutralSIntMax {};
420 class ReductionNeutralUIntMax {};
421 class ReductionNeutralFPMin {};
422 class ReductionNeutralFPMax {};
423 
424 /// Create the reduction neutral zero value.
425 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
426                                          ConversionPatternRewriter &rewriter,
427                                          Location loc, Type llvmType) {
428   return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
429                                            rewriter.getZeroAttr(llvmType));
430 }
431 
432 /// Create the reduction neutral integer one value.
433 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
434                                          ConversionPatternRewriter &rewriter,
435                                          Location loc, Type llvmType) {
436   return rewriter.create<LLVM::ConstantOp>(
437       loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
438 }
439 
440 /// Create the reduction neutral fp one value.
441 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
442                                          ConversionPatternRewriter &rewriter,
443                                          Location loc, Type llvmType) {
444   return rewriter.create<LLVM::ConstantOp>(
445       loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
446 }
447 
448 /// Create the reduction neutral all-ones value.
449 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
450                                          ConversionPatternRewriter &rewriter,
451                                          Location loc, Type llvmType) {
452   return rewriter.create<LLVM::ConstantOp>(
453       loc, llvmType,
454       rewriter.getIntegerAttr(
455           llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
456 }
457 
458 /// Create the reduction neutral signed int minimum value.
459 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
460                                          ConversionPatternRewriter &rewriter,
461                                          Location loc, Type llvmType) {
462   return rewriter.create<LLVM::ConstantOp>(
463       loc, llvmType,
464       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
465                                             llvmType.getIntOrFloatBitWidth())));
466 }
467 
468 /// Create the reduction neutral unsigned int minimum value.
469 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
470                                          ConversionPatternRewriter &rewriter,
471                                          Location loc, Type llvmType) {
472   return rewriter.create<LLVM::ConstantOp>(
473       loc, llvmType,
474       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
475                                             llvmType.getIntOrFloatBitWidth())));
476 }
477 
478 /// Create the reduction neutral signed int maximum value.
479 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
480                                          ConversionPatternRewriter &rewriter,
481                                          Location loc, Type llvmType) {
482   return rewriter.create<LLVM::ConstantOp>(
483       loc, llvmType,
484       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
485                                             llvmType.getIntOrFloatBitWidth())));
486 }
487 
488 /// Create the reduction neutral unsigned int maximum value.
489 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
490                                          ConversionPatternRewriter &rewriter,
491                                          Location loc, Type llvmType) {
492   return rewriter.create<LLVM::ConstantOp>(
493       loc, llvmType,
494       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
495                                             llvmType.getIntOrFloatBitWidth())));
496 }
497 
498 /// Create the reduction neutral fp minimum value.
499 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
500                                          ConversionPatternRewriter &rewriter,
501                                          Location loc, Type llvmType) {
502   auto floatType = llvmType.cast<FloatType>();
503   return rewriter.create<LLVM::ConstantOp>(
504       loc, llvmType,
505       rewriter.getFloatAttr(
506           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
507                                            /*Negative=*/false)));
508 }
509 
510 /// Create the reduction neutral fp maximum value.
511 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
512                                          ConversionPatternRewriter &rewriter,
513                                          Location loc, Type llvmType) {
514   auto floatType = llvmType.cast<FloatType>();
515   return rewriter.create<LLVM::ConstantOp>(
516       loc, llvmType,
517       rewriter.getFloatAttr(
518           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
519                                            /*Negative=*/true)));
520 }
521 
522 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
523 /// returns a new accumulator value using `ReductionNeutral`.
524 template <class ReductionNeutral>
525 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
526                                     Location loc, Type llvmType,
527                                     Value accumulator) {
528   if (accumulator)
529     return accumulator;
530 
531   return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
532                                      llvmType);
533 }
534 
535 /// Creates a constant value with the 1-D vector shape provided in `llvmType`.
536 /// This is used as effective vector length by some intrinsics supporting
537 /// dynamic vector lengths at runtime.
538 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
539                                      Location loc, Type llvmType) {
540   VectorType vType = cast<VectorType>(llvmType);
541   auto vShape = vType.getShape();
542   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
543 
544   return rewriter.create<LLVM::ConstantOp>(
545       loc, rewriter.getI32Type(),
546       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
547 }
548 
549 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
550 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
551 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
552 /// non-null.
553 template <class LLVMRedIntrinOp, class ScalarOp>
554 static Value createIntegerReductionArithmeticOpLowering(
555     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
556     Value vectorOperand, Value accumulator) {
557 
558   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
559 
560   if (accumulator)
561     result = rewriter.create<ScalarOp>(loc, accumulator, result);
562   return result;
563 }
564 
565 /// Helper method to lower a `vector.reduction` operation that performs
566 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
567 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
568 /// the accumulator value if non-null.
569 template <class LLVMRedIntrinOp>
570 static Value createIntegerReductionComparisonOpLowering(
571     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
572     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
573   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
574   if (accumulator) {
575     Value cmp =
576         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
577     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
578   }
579   return result;
580 }
581 
582 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
583 /// with vector types.
584 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
585                            Value rhs, bool isMin) {
586   auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
587   Type i1Type = builder.getI1Type();
588   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
589     i1Type = VectorType::get(vecType.getShape(), i1Type);
590   Value cmp = builder.create<LLVM::FCmpOp>(
591       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
592       lhs, rhs);
593   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
594   Value isNan = builder.create<LLVM::FCmpOp>(
595       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
596   Value nan = builder.create<LLVM::ConstantOp>(
597       loc, lhs.getType(),
598       builder.getFloatAttr(floatType,
599                            APFloat::getQNaN(floatType.getFloatSemantics())));
600   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
601 }
602 
603 template <class LLVMRedIntrinOp>
604 static Value createFPReductionComparisonOpLowering(
605     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
606     Value vectorOperand, Value accumulator, bool isMin) {
607   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
608 
609   if (accumulator)
610     result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
611 
612   return result;
613 }
614 
615 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires
616 /// a start value. This start value format spans across fp reductions without
617 /// mask and all the masked reduction intrinsics.
618 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
619 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
620                                           Location loc, Type llvmType,
621                                           Value vectorOperand,
622                                           Value accumulator) {
623   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
624                                                          llvmType, accumulator);
625   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
626                                             /*startValue=*/accumulator,
627                                             vectorOperand);
628 }
629 
630 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
631 static Value
632 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
633                              Type llvmType, Value vectorOperand,
634                              Value accumulator, bool reassociateFPReds) {
635   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
636                                                          llvmType, accumulator);
637   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
638                                             /*startValue=*/accumulator,
639                                             vectorOperand, reassociateFPReds);
640 }
641 
642 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
643 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
644                                           Location loc, Type llvmType,
645                                           Value vectorOperand,
646                                           Value accumulator, Value mask) {
647   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
648                                                          llvmType, accumulator);
649   Value vectorLength =
650       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
651   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
652                                             /*startValue=*/accumulator,
653                                             vectorOperand, mask, vectorLength);
654 }
655 
656 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
657 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
658                                           Location loc, Type llvmType,
659                                           Value vectorOperand,
660                                           Value accumulator, Value mask,
661                                           bool reassociateFPReds) {
662   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
663                                                          llvmType, accumulator);
664   Value vectorLength =
665       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
666   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
667                                             /*startValue=*/accumulator,
668                                             vectorOperand, mask, vectorLength,
669                                             reassociateFPReds);
670 }
671 
672 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
673           class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
674 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
675                                           Location loc, Type llvmType,
676                                           Value vectorOperand,
677                                           Value accumulator, Value mask) {
678   if (llvmType.isIntOrIndex())
679     return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp,
680                                         IntReductionNeutral>(
681         rewriter, loc, llvmType, vectorOperand, accumulator, mask);
682 
683   // FP dispatch.
684   return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>(
685       rewriter, loc, llvmType, vectorOperand, accumulator, mask);
686 }
687 
688 /// Conversion pattern for all vector reductions.
689 class VectorReductionOpConversion
690     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
691 public:
692   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
693                                        bool reassociateFPRed)
694       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
695         reassociateFPReductions(reassociateFPRed) {}
696 
697   LogicalResult
698   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
699                   ConversionPatternRewriter &rewriter) const override {
700     auto kind = reductionOp.getKind();
701     Type eltType = reductionOp.getDest().getType();
702     Type llvmType = typeConverter->convertType(eltType);
703     Value operand = adaptor.getVector();
704     Value acc = adaptor.getAcc();
705     Location loc = reductionOp.getLoc();
706 
707     // Masked reductions are lowered separately.
708     auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
709     if (maskableOp.isMasked())
710       return failure();
711 
712     if (eltType.isIntOrIndex()) {
713       // Integer reductions: add/mul/min/max/and/or/xor.
714       Value result;
715       switch (kind) {
716       case vector::CombiningKind::ADD:
717         result =
718             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
719                                                        LLVM::AddOp>(
720                 rewriter, loc, llvmType, operand, acc);
721         break;
722       case vector::CombiningKind::MUL:
723         result =
724             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
725                                                        LLVM::MulOp>(
726                 rewriter, loc, llvmType, operand, acc);
727         break;
728       case vector::CombiningKind::MINUI:
729         result = createIntegerReductionComparisonOpLowering<
730             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
731                                       LLVM::ICmpPredicate::ule);
732         break;
733       case vector::CombiningKind::MINSI:
734         result = createIntegerReductionComparisonOpLowering<
735             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
736                                       LLVM::ICmpPredicate::sle);
737         break;
738       case vector::CombiningKind::MAXUI:
739         result = createIntegerReductionComparisonOpLowering<
740             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
741                                       LLVM::ICmpPredicate::uge);
742         break;
743       case vector::CombiningKind::MAXSI:
744         result = createIntegerReductionComparisonOpLowering<
745             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
746                                       LLVM::ICmpPredicate::sge);
747         break;
748       case vector::CombiningKind::AND:
749         result =
750             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
751                                                        LLVM::AndOp>(
752                 rewriter, loc, llvmType, operand, acc);
753         break;
754       case vector::CombiningKind::OR:
755         result =
756             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
757                                                        LLVM::OrOp>(
758                 rewriter, loc, llvmType, operand, acc);
759         break;
760       case vector::CombiningKind::XOR:
761         result =
762             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
763                                                        LLVM::XOrOp>(
764                 rewriter, loc, llvmType, operand, acc);
765         break;
766       default:
767         return failure();
768       }
769       rewriter.replaceOp(reductionOp, result);
770 
771       return success();
772     }
773 
774     if (!eltType.isa<FloatType>())
775       return failure();
776 
777     // Floating-point reductions: add/mul/min/max
778     Value result;
779     if (kind == vector::CombiningKind::ADD) {
780       result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
781                                             ReductionNeutralZero>(
782           rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
783     } else if (kind == vector::CombiningKind::MUL) {
784       result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
785                                             ReductionNeutralFPOne>(
786           rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
787     } else if (kind == vector::CombiningKind::MINF) {
788       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
789       // NaNs/-0.0/+0.0 in the same way.
790       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
791           rewriter, loc, llvmType, operand, acc,
792           /*isMin=*/true);
793     } else if (kind == vector::CombiningKind::MAXF) {
794       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
795       // NaNs/-0.0/+0.0 in the same way.
796       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
797           rewriter, loc, llvmType, operand, acc,
798           /*isMin=*/false);
799     } else
800       return failure();
801 
802     rewriter.replaceOp(reductionOp, result);
803     return success();
804   }
805 
806 private:
807   const bool reassociateFPReductions;
808 };
809 
810 /// Base class to convert a `vector.mask` operation while matching traits
811 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
812 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
813 /// method performs a second match against the maskable operation `MaskedOp`.
814 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
815 /// implemented by the concrete conversion classes. This method can match
816 /// against specific traits of the `vector.mask` and the maskable operation. It
817 /// must replace the `vector.mask` operation.
818 template <class MaskedOp>
819 class VectorMaskOpConversionBase
820     : public ConvertOpToLLVMPattern<vector::MaskOp> {
821 public:
822   using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
823 
824   LogicalResult
825   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
826                   ConversionPatternRewriter &rewriter) const override final {
827     // Match against the maskable operation kind.
828     Operation *maskableOp = maskOp.getMaskableOp();
829     if (!isa<MaskedOp>(maskableOp))
830       return failure();
831     return matchAndRewriteMaskableOp(
832         maskOp, cast<MaskedOp>(maskOp.getMaskableOp()), rewriter);
833   }
834 
835 protected:
836   virtual LogicalResult
837   matchAndRewriteMaskableOp(vector::MaskOp maskOp,
838                             vector::MaskableOpInterface maskableOp,
839                             ConversionPatternRewriter &rewriter) const = 0;
840 };
841 
842 class MaskedReductionOpConversion
843     : public VectorMaskOpConversionBase<vector::ReductionOp> {
844 
845 public:
846   using VectorMaskOpConversionBase<
847       vector::ReductionOp>::VectorMaskOpConversionBase;
848 
849   virtual LogicalResult matchAndRewriteMaskableOp(
850       vector::MaskOp maskOp, MaskableOpInterface maskableOp,
851       ConversionPatternRewriter &rewriter) const override {
852     auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
853     auto kind = reductionOp.getKind();
854     Type eltType = reductionOp.getDest().getType();
855     Type llvmType = typeConverter->convertType(eltType);
856     Value operand = reductionOp.getVector();
857     Value acc = reductionOp.getAcc();
858     Location loc = reductionOp.getLoc();
859 
860     Value result;
861     switch (kind) {
862     case vector::CombiningKind::ADD:
863       result = lowerReductionWithStartValue<
864           LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
865           ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
866                                 maskOp.getMask());
867       break;
868     case vector::CombiningKind::MUL:
869       result = lowerReductionWithStartValue<
870           LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
871           ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
872                                  maskOp.getMask());
873       break;
874     case vector::CombiningKind::MINUI:
875       result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp,
876                                             ReductionNeutralUIntMax>(
877           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
878       break;
879     case vector::CombiningKind::MINSI:
880       result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp,
881                                             ReductionNeutralSIntMax>(
882           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
883       break;
884     case vector::CombiningKind::MAXUI:
885       result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp,
886                                             ReductionNeutralUIntMin>(
887           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
888       break;
889     case vector::CombiningKind::MAXSI:
890       result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp,
891                                             ReductionNeutralSIntMin>(
892           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
893       break;
894     case vector::CombiningKind::AND:
895       result = lowerReductionWithStartValue<LLVM::VPReduceAndOp,
896                                             ReductionNeutralAllOnes>(
897           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
898       break;
899     case vector::CombiningKind::OR:
900       result = lowerReductionWithStartValue<LLVM::VPReduceOrOp,
901                                             ReductionNeutralZero>(
902           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
903       break;
904     case vector::CombiningKind::XOR:
905       result = lowerReductionWithStartValue<LLVM::VPReduceXorOp,
906                                             ReductionNeutralZero>(
907           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
908       break;
909     case vector::CombiningKind::MINF:
910       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
911       // NaNs/-0.0/+0.0 in the same way.
912       result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
913                                             ReductionNeutralFPMax>(
914           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
915       break;
916     case vector::CombiningKind::MAXF:
917       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
918       // NaNs/-0.0/+0.0 in the same way.
919       result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
920                                             ReductionNeutralFPMin>(
921           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
922       break;
923     }
924 
925     // Replace `vector.mask` operation altogether.
926     rewriter.replaceOp(maskOp, result);
927     return success();
928   }
929 };
930 
931 class VectorShuffleOpConversion
932     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
933 public:
934   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
935 
936   LogicalResult
937   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
938                   ConversionPatternRewriter &rewriter) const override {
939     auto loc = shuffleOp->getLoc();
940     auto v1Type = shuffleOp.getV1VectorType();
941     auto v2Type = shuffleOp.getV2VectorType();
942     auto vectorType = shuffleOp.getResultVectorType();
943     Type llvmType = typeConverter->convertType(vectorType);
944     auto maskArrayAttr = shuffleOp.getMask();
945 
946     // Bail if result type cannot be lowered.
947     if (!llvmType)
948       return failure();
949 
950     // Get rank and dimension sizes.
951     int64_t rank = vectorType.getRank();
952 #ifndef NDEBUG
953     bool wellFormed0DCase =
954         v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
955     bool wellFormedNDCase =
956         v1Type.getRank() == rank && v2Type.getRank() == rank;
957     assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
958 #endif
959 
960     // For rank 0 and 1, where both operands have *exactly* the same vector
961     // type, there is direct shuffle support in LLVM. Use it!
962     if (rank <= 1 && v1Type == v2Type) {
963       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
964           loc, adaptor.getV1(), adaptor.getV2(),
965           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
966       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
967       return success();
968     }
969 
970     // For all other cases, insert the individual values individually.
971     int64_t v1Dim = v1Type.getDimSize(0);
972     Type eltType;
973     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
974       eltType = arrayType.getElementType();
975     else
976       eltType = llvmType.cast<VectorType>().getElementType();
977     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
978     int64_t insPos = 0;
979     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
980       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
981       Value value = adaptor.getV1();
982       if (extPos >= v1Dim) {
983         extPos -= v1Dim;
984         value = adaptor.getV2();
985       }
986       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
987                                  eltType, rank, extPos);
988       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
989                          llvmType, rank, insPos++);
990     }
991     rewriter.replaceOp(shuffleOp, insert);
992     return success();
993   }
994 };
995 
996 class VectorExtractElementOpConversion
997     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
998 public:
999   using ConvertOpToLLVMPattern<
1000       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1001 
1002   LogicalResult
1003   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1004                   ConversionPatternRewriter &rewriter) const override {
1005     auto vectorType = extractEltOp.getSourceVectorType();
1006     auto llvmType = typeConverter->convertType(vectorType.getElementType());
1007 
1008     // Bail if result type cannot be lowered.
1009     if (!llvmType)
1010       return failure();
1011 
1012     if (vectorType.getRank() == 0) {
1013       Location loc = extractEltOp.getLoc();
1014       auto idxType = rewriter.getIndexType();
1015       auto zero = rewriter.create<LLVM::ConstantOp>(
1016           loc, typeConverter->convertType(idxType),
1017           rewriter.getIntegerAttr(idxType, 0));
1018       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1019           extractEltOp, llvmType, adaptor.getVector(), zero);
1020       return success();
1021     }
1022 
1023     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1024         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1025     return success();
1026   }
1027 };
1028 
1029 class VectorExtractOpConversion
1030     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1031 public:
1032   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1033 
1034   LogicalResult
1035   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1036                   ConversionPatternRewriter &rewriter) const override {
1037     auto loc = extractOp->getLoc();
1038     auto resultType = extractOp.getResult().getType();
1039     auto llvmResultType = typeConverter->convertType(resultType);
1040     auto positionArrayAttr = extractOp.getPosition();
1041 
1042     // Bail if result type cannot be lowered.
1043     if (!llvmResultType)
1044       return failure();
1045 
1046     // Extract entire vector. Should be handled by folder, but just to be safe.
1047     if (positionArrayAttr.empty()) {
1048       rewriter.replaceOp(extractOp, adaptor.getVector());
1049       return success();
1050     }
1051 
1052     // One-shot extraction of vector from array (only requires extractvalue).
1053     if (resultType.isa<VectorType>()) {
1054       SmallVector<int64_t> indices;
1055       for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
1056         indices.push_back(idx.getInt());
1057       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1058           loc, adaptor.getVector(), indices);
1059       rewriter.replaceOp(extractOp, extracted);
1060       return success();
1061     }
1062 
1063     // Potential extraction of 1-D vector from array.
1064     Value extracted = adaptor.getVector();
1065     auto positionAttrs = positionArrayAttr.getValue();
1066     if (positionAttrs.size() > 1) {
1067       SmallVector<int64_t> nMinusOnePosition;
1068       for (auto idx : positionAttrs.drop_back())
1069         nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt());
1070       extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1071                                                         nMinusOnePosition);
1072     }
1073 
1074     // Remaining extraction of element from 1-D LLVM vector
1075     auto position = positionAttrs.back().cast<IntegerAttr>();
1076     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1077     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
1078     extracted =
1079         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
1080     rewriter.replaceOp(extractOp, extracted);
1081 
1082     return success();
1083   }
1084 };
1085 
1086 /// Conversion pattern that turns a vector.fma on a 1-D vector
1087 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1088 /// This does not match vectors of n >= 2 rank.
1089 ///
1090 /// Example:
1091 /// ```
1092 ///  vector.fma %a, %a, %a : vector<8xf32>
1093 /// ```
1094 /// is converted to:
1095 /// ```
1096 ///  llvm.intr.fmuladd %va, %va, %va:
1097 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1098 ///    -> !llvm."<8 x f32>">
1099 /// ```
1100 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1101 public:
1102   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1103 
1104   LogicalResult
1105   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1106                   ConversionPatternRewriter &rewriter) const override {
1107     VectorType vType = fmaOp.getVectorType();
1108     if (vType.getRank() > 1)
1109       return failure();
1110 
1111     // Masked fmas are lowered separately.
1112     auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation());
1113     if (maskableOp.isMasked())
1114       return failure();
1115 
1116     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1117         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1118     return success();
1119   }
1120 };
1121 
1122 /// Conversion pattern that turns a masked vector.fma on a 1-D vector into their
1123 /// LLVM counterpart representation. Non side effecting VP intrinsics are not
1124 /// fully supported by some backends, including x86, and they don't support
1125 /// pass-through values either. For these reasons, we generate an unmasked
1126 /// fma followed by a select instrution to emulate the masking behavior.
1127 /// This pattern is peepholed by some backends with support for masked fma
1128 /// instructions. This pattern does not match vectors of n >= 2 rank.
1129 class MaskedFMAOp1DConversion
1130     : public VectorMaskOpConversionBase<vector::FMAOp> {
1131 public:
1132   using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase;
1133 
1134   MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr)
1135       : VectorMaskOpConversionBase<vector::FMAOp>(converter) {}
1136 
1137   virtual LogicalResult matchAndRewriteMaskableOp(
1138       vector::MaskOp maskOp, MaskableOpInterface maskableOp,
1139       ConversionPatternRewriter &rewriter) const override {
1140     auto fmaOp = cast<FMAOp>(maskableOp.getOperation());
1141     Type llvmType = typeConverter->convertType(fmaOp.getVectorType());
1142 
1143     Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>(
1144         fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(),
1145         fmaOp.getAcc());
1146     rewriter.replaceOpWithNewOp<LLVM::SelectOp>(
1147         maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc());
1148     return success();
1149   }
1150 };
1151 
1152 class VectorInsertElementOpConversion
1153     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1154 public:
1155   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1156 
1157   LogicalResult
1158   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1159                   ConversionPatternRewriter &rewriter) const override {
1160     auto vectorType = insertEltOp.getDestVectorType();
1161     auto llvmType = typeConverter->convertType(vectorType);
1162 
1163     // Bail if result type cannot be lowered.
1164     if (!llvmType)
1165       return failure();
1166 
1167     if (vectorType.getRank() == 0) {
1168       Location loc = insertEltOp.getLoc();
1169       auto idxType = rewriter.getIndexType();
1170       auto zero = rewriter.create<LLVM::ConstantOp>(
1171           loc, typeConverter->convertType(idxType),
1172           rewriter.getIntegerAttr(idxType, 0));
1173       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1174           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1175       return success();
1176     }
1177 
1178     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1179         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1180         adaptor.getPosition());
1181     return success();
1182   }
1183 };
1184 
1185 class VectorInsertOpConversion
1186     : public ConvertOpToLLVMPattern<vector::InsertOp> {
1187 public:
1188   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1189 
1190   LogicalResult
1191   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1192                   ConversionPatternRewriter &rewriter) const override {
1193     auto loc = insertOp->getLoc();
1194     auto sourceType = insertOp.getSourceType();
1195     auto destVectorType = insertOp.getDestVectorType();
1196     auto llvmResultType = typeConverter->convertType(destVectorType);
1197     auto positionArrayAttr = insertOp.getPosition();
1198 
1199     // Bail if result type cannot be lowered.
1200     if (!llvmResultType)
1201       return failure();
1202 
1203     // Overwrite entire vector with value. Should be handled by folder, but
1204     // just to be safe.
1205     if (positionArrayAttr.empty()) {
1206       rewriter.replaceOp(insertOp, adaptor.getSource());
1207       return success();
1208     }
1209 
1210     // One-shot insertion of a vector into an array (only requires insertvalue).
1211     if (sourceType.isa<VectorType>()) {
1212       Value inserted = rewriter.create<LLVM::InsertValueOp>(
1213           loc, adaptor.getDest(), adaptor.getSource(),
1214           LLVM::convertArrayToIndices(positionArrayAttr));
1215       rewriter.replaceOp(insertOp, inserted);
1216       return success();
1217     }
1218 
1219     // Potential extraction of 1-D vector from array.
1220     Value extracted = adaptor.getDest();
1221     auto positionAttrs = positionArrayAttr.getValue();
1222     auto position = positionAttrs.back().cast<IntegerAttr>();
1223     auto oneDVectorType = destVectorType;
1224     if (positionAttrs.size() > 1) {
1225       oneDVectorType = reducedVectorTypeBack(destVectorType);
1226       extracted = rewriter.create<LLVM::ExtractValueOp>(
1227           loc, extracted,
1228           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
1229     }
1230 
1231     // Insertion of an element into a 1-D LLVM vector.
1232     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1233     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
1234     Value inserted = rewriter.create<LLVM::InsertElementOp>(
1235         loc, typeConverter->convertType(oneDVectorType), extracted,
1236         adaptor.getSource(), constant);
1237 
1238     // Potential insertion of resulting 1-D vector into array.
1239     if (positionAttrs.size() > 1) {
1240       inserted = rewriter.create<LLVM::InsertValueOp>(
1241           loc, adaptor.getDest(), inserted,
1242           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
1243     }
1244 
1245     rewriter.replaceOp(insertOp, inserted);
1246     return success();
1247   }
1248 };
1249 
1250 /// Lower vector.scalable.insert ops to LLVM vector.insert
1251 struct VectorScalableInsertOpLowering
1252     : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1253   using ConvertOpToLLVMPattern<
1254       vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1255 
1256   LogicalResult
1257   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1258                   ConversionPatternRewriter &rewriter) const override {
1259     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1260         insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
1261     return success();
1262   }
1263 };
1264 
1265 /// Lower vector.scalable.extract ops to LLVM vector.extract
1266 struct VectorScalableExtractOpLowering
1267     : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1268   using ConvertOpToLLVMPattern<
1269       vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1270 
1271   LogicalResult
1272   matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1273                   ConversionPatternRewriter &rewriter) const override {
1274     rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1275         extOp, typeConverter->convertType(extOp.getResultVectorType()),
1276         adaptor.getSource(), adaptor.getPos());
1277     return success();
1278   }
1279 };
1280 
1281 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1282 ///
1283 /// Example:
1284 /// ```
1285 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
1286 /// ```
1287 /// is rewritten into:
1288 /// ```
1289 ///  %r = splat %f0: vector<2x4xf32>
1290 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
1291 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1292 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1293 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1294 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1295 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1296 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1297 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1298 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1299 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1300 ///  // %r3 holds the final value.
1301 /// ```
1302 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1303 public:
1304   using OpRewritePattern<FMAOp>::OpRewritePattern;
1305 
1306   void initialize() {
1307     // This pattern recursively unpacks one dimension at a time. The recursion
1308     // bounded as the rank is strictly decreasing.
1309     setHasBoundedRewriteRecursion();
1310   }
1311 
1312   LogicalResult matchAndRewrite(FMAOp op,
1313                                 PatternRewriter &rewriter) const override {
1314     auto vType = op.getVectorType();
1315     if (vType.getRank() < 2)
1316       return failure();
1317 
1318     // Masked fmas are lowered separately.
1319     auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
1320     if (maskableOp.isMasked())
1321       return failure();
1322 
1323     auto loc = op.getLoc();
1324     auto elemType = vType.getElementType();
1325     Value zero = rewriter.create<arith::ConstantOp>(
1326         loc, elemType, rewriter.getZeroAttr(elemType));
1327     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1328     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1329       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1330       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1331       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1332       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1333       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1334     }
1335     rewriter.replaceOp(op, desc);
1336     return success();
1337   }
1338 };
1339 
1340 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1341 /// static layout.
1342 static std::optional<SmallVector<int64_t, 4>>
1343 computeContiguousStrides(MemRefType memRefType) {
1344   int64_t offset;
1345   SmallVector<int64_t, 4> strides;
1346   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1347     return std::nullopt;
1348   if (!strides.empty() && strides.back() != 1)
1349     return std::nullopt;
1350   // If no layout or identity layout, this is contiguous by definition.
1351   if (memRefType.getLayout().isIdentity())
1352     return strides;
1353 
1354   // Otherwise, we must determine contiguity form shapes. This can only ever
1355   // work in static cases because MemRefType is underspecified to represent
1356   // contiguous dynamic shapes in other ways than with just empty/identity
1357   // layout.
1358   auto sizes = memRefType.getShape();
1359   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1360     if (ShapedType::isDynamic(sizes[index + 1]) ||
1361         ShapedType::isDynamic(strides[index]) ||
1362         ShapedType::isDynamic(strides[index + 1]))
1363       return std::nullopt;
1364     if (strides[index] != strides[index + 1] * sizes[index + 1])
1365       return std::nullopt;
1366   }
1367   return strides;
1368 }
1369 
1370 class VectorTypeCastOpConversion
1371     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1372 public:
1373   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1374 
1375   LogicalResult
1376   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1377                   ConversionPatternRewriter &rewriter) const override {
1378     auto loc = castOp->getLoc();
1379     MemRefType sourceMemRefType =
1380         castOp.getOperand().getType().cast<MemRefType>();
1381     MemRefType targetMemRefType = castOp.getType();
1382 
1383     // Only static shape casts supported atm.
1384     if (!sourceMemRefType.hasStaticShape() ||
1385         !targetMemRefType.hasStaticShape())
1386       return failure();
1387 
1388     auto llvmSourceDescriptorTy =
1389         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
1390     if (!llvmSourceDescriptorTy)
1391       return failure();
1392     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1393 
1394     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1395                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1396     if (!llvmTargetDescriptorTy)
1397       return failure();
1398 
1399     // Only contiguous source buffers supported atm.
1400     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1401     if (!sourceStrides)
1402       return failure();
1403     auto targetStrides = computeContiguousStrides(targetMemRefType);
1404     if (!targetStrides)
1405       return failure();
1406     // Only support static strides for now, regardless of contiguity.
1407     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1408       return failure();
1409 
1410     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1411 
1412     // Create descriptor.
1413     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1414     Type llvmTargetElementTy = desc.getElementPtrType();
1415     // Set allocated ptr.
1416     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1417     if (!getTypeConverter()->useOpaquePointers())
1418       allocated =
1419           rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1420     desc.setAllocatedPtr(rewriter, loc, allocated);
1421 
1422     // Set aligned ptr.
1423     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1424     if (!getTypeConverter()->useOpaquePointers())
1425       ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1426 
1427     desc.setAlignedPtr(rewriter, loc, ptr);
1428     // Fill offset 0.
1429     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1430     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1431     desc.setOffset(rewriter, loc, zero);
1432 
1433     // Fill size and stride descriptors in memref.
1434     for (const auto &indexedSize :
1435          llvm::enumerate(targetMemRefType.getShape())) {
1436       int64_t index = indexedSize.index();
1437       auto sizeAttr =
1438           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1439       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1440       desc.setSize(rewriter, loc, index, size);
1441       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1442                                                 (*targetStrides)[index]);
1443       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1444       desc.setStride(rewriter, loc, index, stride);
1445     }
1446 
1447     rewriter.replaceOp(castOp, {desc});
1448     return success();
1449   }
1450 };
1451 
1452 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1453 /// Non-scalable versions of this operation are handled in Vector Transforms.
1454 class VectorCreateMaskOpRewritePattern
1455     : public OpRewritePattern<vector::CreateMaskOp> {
1456 public:
1457   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1458                                             bool enableIndexOpt)
1459       : OpRewritePattern<vector::CreateMaskOp>(context),
1460         force32BitVectorIndices(enableIndexOpt) {}
1461 
1462   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1463                                 PatternRewriter &rewriter) const override {
1464     auto dstType = op.getType();
1465     if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
1466       return failure();
1467     IntegerType idxType =
1468         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1469     auto loc = op->getLoc();
1470     Value indices = rewriter.create<LLVM::StepVectorOp>(
1471         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1472                                  /*isScalable=*/true));
1473     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1474                                                  op.getOperand(0));
1475     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1476     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1477                                                 indices, bounds);
1478     rewriter.replaceOp(op, comp);
1479     return success();
1480   }
1481 
1482 private:
1483   const bool force32BitVectorIndices;
1484 };
1485 
1486 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1487 public:
1488   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1489 
1490   // Proof-of-concept lowering implementation that relies on a small
1491   // runtime support library, which only needs to provide a few
1492   // printing methods (single value for all data types, opening/closing
1493   // bracket, comma, newline). The lowering fully unrolls a vector
1494   // in terms of these elementary printing operations. The advantage
1495   // of this approach is that the library can remain unaware of all
1496   // low-level implementation details of vectors while still supporting
1497   // output of any shaped and dimensioned vector. Due to full unrolling,
1498   // this approach is less suited for very large vectors though.
1499   //
1500   // TODO: rely solely on libc in future? something else?
1501   //
1502   LogicalResult
1503   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1504                   ConversionPatternRewriter &rewriter) const override {
1505     Type printType = printOp.getPrintType();
1506 
1507     if (typeConverter->convertType(printType) == nullptr)
1508       return failure();
1509 
1510     // Make sure element type has runtime support.
1511     PrintConversion conversion = PrintConversion::None;
1512     VectorType vectorType = printType.dyn_cast<VectorType>();
1513     Type eltType = vectorType ? vectorType.getElementType() : printType;
1514     Operation *printer;
1515     if (eltType.isF32()) {
1516       printer =
1517           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1518     } else if (eltType.isF64()) {
1519       printer =
1520           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1521     } else if (eltType.isIndex()) {
1522       printer =
1523           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1524     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1525       // Integers need a zero or sign extension on the operand
1526       // (depending on the source type) as well as a signed or
1527       // unsigned print method. Up to 64-bit is supported.
1528       unsigned width = intTy.getWidth();
1529       if (intTy.isUnsigned()) {
1530         if (width <= 64) {
1531           if (width < 64)
1532             conversion = PrintConversion::ZeroExt64;
1533           printer = LLVM::lookupOrCreatePrintU64Fn(
1534               printOp->getParentOfType<ModuleOp>());
1535         } else {
1536           return failure();
1537         }
1538       } else {
1539         assert(intTy.isSignless() || intTy.isSigned());
1540         if (width <= 64) {
1541           // Note that we *always* zero extend booleans (1-bit integers),
1542           // so that true/false is printed as 1/0 rather than -1/0.
1543           if (width == 1)
1544             conversion = PrintConversion::ZeroExt64;
1545           else if (width < 64)
1546             conversion = PrintConversion::SignExt64;
1547           printer = LLVM::lookupOrCreatePrintI64Fn(
1548               printOp->getParentOfType<ModuleOp>());
1549         } else {
1550           return failure();
1551         }
1552       }
1553     } else {
1554       return failure();
1555     }
1556 
1557     // Unroll vector into elementary print calls.
1558     int64_t rank = vectorType ? vectorType.getRank() : 0;
1559     Type type = vectorType ? vectorType : eltType;
1560     emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1561               conversion);
1562     emitCall(rewriter, printOp->getLoc(),
1563              LLVM::lookupOrCreatePrintNewlineFn(
1564                  printOp->getParentOfType<ModuleOp>()));
1565     rewriter.eraseOp(printOp);
1566     return success();
1567   }
1568 
1569 private:
1570   enum class PrintConversion {
1571     // clang-format off
1572     None,
1573     ZeroExt64,
1574     SignExt64
1575     // clang-format on
1576   };
1577 
1578   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1579                  Value value, Type type, Operation *printer, int64_t rank,
1580                  PrintConversion conversion) const {
1581     VectorType vectorType = type.dyn_cast<VectorType>();
1582     Location loc = op->getLoc();
1583     if (!vectorType) {
1584       assert(rank == 0 && "The scalar case expects rank == 0");
1585       switch (conversion) {
1586       case PrintConversion::ZeroExt64:
1587         value = rewriter.create<arith::ExtUIOp>(
1588             loc, IntegerType::get(rewriter.getContext(), 64), value);
1589         break;
1590       case PrintConversion::SignExt64:
1591         value = rewriter.create<arith::ExtSIOp>(
1592             loc, IntegerType::get(rewriter.getContext(), 64), value);
1593         break;
1594       case PrintConversion::None:
1595         break;
1596       }
1597       emitCall(rewriter, loc, printer, value);
1598       return;
1599     }
1600 
1601     emitCall(rewriter, loc,
1602              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1603     Operation *printComma =
1604         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1605 
1606     if (rank <= 1) {
1607       auto reducedType = vectorType.getElementType();
1608       auto llvmType = typeConverter->convertType(reducedType);
1609       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1610       for (int64_t d = 0; d < dim; ++d) {
1611         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1612                                      llvmType, /*rank=*/0, /*pos=*/d);
1613         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1614                   conversion);
1615         if (d != dim - 1)
1616           emitCall(rewriter, loc, printComma);
1617       }
1618       emitCall(
1619           rewriter, loc,
1620           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1621       return;
1622     }
1623 
1624     int64_t dim = vectorType.getDimSize(0);
1625     for (int64_t d = 0; d < dim; ++d) {
1626       auto reducedType = reducedVectorTypeFront(vectorType);
1627       auto llvmType = typeConverter->convertType(reducedType);
1628       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1629                                    llvmType, rank, d);
1630       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1631                 conversion);
1632       if (d != dim - 1)
1633         emitCall(rewriter, loc, printComma);
1634     }
1635     emitCall(rewriter, loc,
1636              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1637   }
1638 
1639   // Helper to emit a call.
1640   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1641                        Operation *ref, ValueRange params = ValueRange()) {
1642     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1643                                   params);
1644   }
1645 };
1646 
1647 /// The Splat operation is lowered to an insertelement + a shufflevector
1648 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1649 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1650   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1651 
1652   LogicalResult
1653   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1654                   ConversionPatternRewriter &rewriter) const override {
1655     VectorType resultType = splatOp.getType().cast<VectorType>();
1656     if (resultType.getRank() > 1)
1657       return failure();
1658 
1659     // First insert it into an undef vector so we can shuffle it.
1660     auto vectorType = typeConverter->convertType(splatOp.getType());
1661     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1662     auto zero = rewriter.create<LLVM::ConstantOp>(
1663         splatOp.getLoc(),
1664         typeConverter->convertType(rewriter.getIntegerType(32)),
1665         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1666 
1667     // For 0-d vector, we simply do `insertelement`.
1668     if (resultType.getRank() == 0) {
1669       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1670           splatOp, vectorType, undef, adaptor.getInput(), zero);
1671       return success();
1672     }
1673 
1674     // For 1-d vector, we additionally do a `vectorshuffle`.
1675     auto v = rewriter.create<LLVM::InsertElementOp>(
1676         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1677 
1678     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1679     SmallVector<int32_t> zeroValues(width, 0);
1680 
1681     // Shuffle the value across the desired number of elements.
1682     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1683                                                        zeroValues);
1684     return success();
1685   }
1686 };
1687 
1688 /// The Splat operation is lowered to an insertelement + a shufflevector
1689 /// operation. Splat to only 2+-d vector result types are lowered by the
1690 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1691 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1692   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1693 
1694   LogicalResult
1695   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1696                   ConversionPatternRewriter &rewriter) const override {
1697     VectorType resultType = splatOp.getType();
1698     if (resultType.getRank() <= 1)
1699       return failure();
1700 
1701     // First insert it into an undef vector so we can shuffle it.
1702     auto loc = splatOp.getLoc();
1703     auto vectorTypeInfo =
1704         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1705     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1706     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1707     if (!llvmNDVectorTy || !llvm1DVectorTy)
1708       return failure();
1709 
1710     // Construct returned value.
1711     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1712 
1713     // Construct a 1-D vector with the splatted value that we insert in all the
1714     // places within the returned descriptor.
1715     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1716     auto zero = rewriter.create<LLVM::ConstantOp>(
1717         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1718         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1719     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1720                                                      adaptor.getInput(), zero);
1721 
1722     // Shuffle the value across the desired number of elements.
1723     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1724     SmallVector<int32_t> zeroValues(width, 0);
1725     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1726 
1727     // Iterate of linear index, convert to coords space and insert splatted 1-D
1728     // vector in each position.
1729     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1730       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1731     });
1732     rewriter.replaceOp(splatOp, desc);
1733     return success();
1734   }
1735 };
1736 
1737 } // namespace
1738 
1739 /// Populate the given list with patterns that convert from Vector to LLVM.
1740 void mlir::populateVectorToLLVMConversionPatterns(
1741     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1742     bool reassociateFPReductions, bool force32BitVectorIndices) {
1743   MLIRContext *ctx = converter.getDialect()->getContext();
1744   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1745   populateVectorInsertExtractStridedSliceTransforms(patterns);
1746   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1747   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1748   patterns
1749       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1750            VectorExtractElementOpConversion, VectorExtractOpConversion,
1751            VectorFMAOp1DConversion, MaskedFMAOp1DConversion,
1752            VectorInsertElementOpConversion, VectorInsertOpConversion,
1753            VectorPrintOpConversion, VectorTypeCastOpConversion,
1754            VectorScaleOpConversion,
1755            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1756            VectorLoadStoreConversion<vector::MaskedLoadOp,
1757                                      vector::MaskedLoadOpAdaptor>,
1758            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1759            VectorLoadStoreConversion<vector::MaskedStoreOp,
1760                                      vector::MaskedStoreOpAdaptor>,
1761            VectorGatherOpConversion, VectorScatterOpConversion,
1762            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1763            VectorSplatOpLowering, VectorSplatNdOpLowering,
1764            VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1765            MaskedReductionOpConversion>(converter);
1766   // Transfer ops with rank > 1 are handled by VectorToSCF.
1767   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1768 }
1769 
1770 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1771     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1772   patterns.add<VectorMatmulOpConversion>(converter);
1773   patterns.add<VectorFlatTransposeOpConversion>(converter);
1774 }
1775