xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision a9a1b45b3c2ef0a44a56cb6e677e4daa4b29fbc6)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/Arith/Utils/Utils.h"
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 #include <optional>
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 
27 // Helper to reduce vector type by one rank at front.
28 static VectorType reducedVectorTypeFront(VectorType tp) {
29   assert((tp.getRank() > 1) && "unlowerable vector type");
30   unsigned numScalableDims = tp.getNumScalableDims();
31   if (tp.getShape().size() == numScalableDims)
32     --numScalableDims;
33   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
34                          numScalableDims);
35 }
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39   assert((tp.getRank() > 1) && "unlowerable vector type");
40   unsigned numScalableDims = tp.getNumScalableDims();
41   if (numScalableDims > 0)
42     --numScalableDims;
43   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
44                          numScalableDims);
45 }
46 
47 // Helper that picks the proper sequence for inserting.
48 static Value insertOne(ConversionPatternRewriter &rewriter,
49                        LLVMTypeConverter &typeConverter, Location loc,
50                        Value val1, Value val2, Type llvmType, int64_t rank,
51                        int64_t pos) {
52   assert(rank > 0 && "0-D vector corner case should have been handled already");
53   if (rank == 1) {
54     auto idxType = rewriter.getIndexType();
55     auto constant = rewriter.create<LLVM::ConstantOp>(
56         loc, typeConverter.convertType(idxType),
57         rewriter.getIntegerAttr(idxType, pos));
58     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
59                                                   constant);
60   }
61   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank <= 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
77 }
78 
79 // Helper that returns data layout alignment of a memref.
80 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
81                                  MemRefType memrefType, unsigned &align) {
82   Type elementTy = typeConverter.convertType(memrefType.getElementType());
83   if (!elementTy)
84     return failure();
85 
86   // TODO: this should use the MLIR data layout when it becomes available and
87   // stop depending on translation.
88   llvm::LLVMContext llvmContext;
89   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
90               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
91   return success();
92 }
93 
94 // Check if the last stride is non-unit or the memory space is not zero.
95 static LogicalResult isMemRefTypeSupported(MemRefType memRefType) {
96   int64_t offset;
97   SmallVector<int64_t, 4> strides;
98   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
99   if (failed(successStrides) || strides.back() != 1 ||
100       memRefType.getMemorySpaceAsInt() != 0)
101     return failure();
102   return success();
103 }
104 
105 // Add an index vector component to a base pointer.
106 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
107                             MemRefType memRefType, Value llvmMemref, Value base,
108                             Value index, uint64_t vLen) {
109   assert(succeeded(isMemRefTypeSupported(memRefType)) &&
110          "unsupported memref type");
111   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
112   auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
113   return rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
114 }
115 
116 // Casts a strided element pointer to a vector pointer.  The vector pointer
117 // will be in the same address space as the incoming memref type.
118 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
119                          Value ptr, MemRefType memRefType, Type vt) {
120   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
121   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
122 }
123 
124 namespace {
125 
126 /// Trivial Vector to LLVM conversions
127 using VectorScaleOpConversion =
128     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
129 
130 /// Conversion pattern for a vector.bitcast.
131 class VectorBitCastOpConversion
132     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
133 public:
134   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
135 
136   LogicalResult
137   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
138                   ConversionPatternRewriter &rewriter) const override {
139     // Only 0-D and 1-D vectors can be lowered to LLVM.
140     VectorType resultTy = bitCastOp.getResultVectorType();
141     if (resultTy.getRank() > 1)
142       return failure();
143     Type newResultTy = typeConverter->convertType(resultTy);
144     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
145                                                  adaptor.getOperands()[0]);
146     return success();
147   }
148 };
149 
150 /// Conversion pattern for a vector.matrix_multiply.
151 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
152 class VectorMatmulOpConversion
153     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
154 public:
155   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
156 
157   LogicalResult
158   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override {
160     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
161         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
162         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
163         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
164     return success();
165   }
166 };
167 
168 /// Conversion pattern for a vector.flat_transpose.
169 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
170 class VectorFlatTransposeOpConversion
171     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
172 public:
173   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
174 
175   LogicalResult
176   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
177                   ConversionPatternRewriter &rewriter) const override {
178     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
179         transOp, typeConverter->convertType(transOp.getRes().getType()),
180         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
181     return success();
182   }
183 };
184 
185 /// Overloaded utility that replaces a vector.load, vector.store,
186 /// vector.maskedload and vector.maskedstore with their respective LLVM
187 /// couterparts.
188 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
189                                  vector::LoadOpAdaptor adaptor,
190                                  VectorType vectorTy, Value ptr, unsigned align,
191                                  ConversionPatternRewriter &rewriter) {
192   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
193 }
194 
195 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
196                                  vector::MaskedLoadOpAdaptor adaptor,
197                                  VectorType vectorTy, Value ptr, unsigned align,
198                                  ConversionPatternRewriter &rewriter) {
199   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
200       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
201 }
202 
203 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
204                                  vector::StoreOpAdaptor adaptor,
205                                  VectorType vectorTy, Value ptr, unsigned align,
206                                  ConversionPatternRewriter &rewriter) {
207   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
208                                              ptr, align);
209 }
210 
211 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
212                                  vector::MaskedStoreOpAdaptor adaptor,
213                                  VectorType vectorTy, Value ptr, unsigned align,
214                                  ConversionPatternRewriter &rewriter) {
215   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
216       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
217 }
218 
219 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
220 /// vector.maskedstore.
221 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
222 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
223 public:
224   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
225 
226   LogicalResult
227   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
228                   typename LoadOrStoreOp::Adaptor adaptor,
229                   ConversionPatternRewriter &rewriter) const override {
230     // Only 1-D vectors can be lowered to LLVM.
231     VectorType vectorTy = loadOrStoreOp.getVectorType();
232     if (vectorTy.getRank() > 1)
233       return failure();
234 
235     auto loc = loadOrStoreOp->getLoc();
236     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
237 
238     // Resolve alignment.
239     unsigned align;
240     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
241       return failure();
242 
243     // Resolve address.
244     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
245                      .template cast<VectorType>();
246     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
247                                                adaptor.getIndices(), rewriter);
248     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
249 
250     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
251     return success();
252   }
253 };
254 
255 /// Conversion pattern for a vector.gather.
256 class VectorGatherOpConversion
257     : public ConvertOpToLLVMPattern<vector::GatherOp> {
258 public:
259   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
260 
261   LogicalResult
262   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
263                   ConversionPatternRewriter &rewriter) const override {
264     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
265     assert(memRefType && "The base should be bufferized");
266 
267     if (failed(isMemRefTypeSupported(memRefType)))
268       return failure();
269 
270     auto loc = gather->getLoc();
271 
272     // Resolve alignment.
273     unsigned align;
274     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
275       return failure();
276 
277     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
278                                      adaptor.getIndices(), rewriter);
279     Value base = adaptor.getBase();
280 
281     auto llvmNDVectorTy = adaptor.getIndexVec().getType();
282     // Handle the simple case of 1-D vector.
283     if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
284       auto vType = gather.getVectorType();
285       // Resolve address.
286       Value ptrs = getIndexedPtrs(rewriter, loc, memRefType, base, ptr,
287                                   adaptor.getIndexVec(),
288                                   /*vLen=*/vType.getDimSize(0));
289       // Replace with the gather intrinsic.
290       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
291           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
292           adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
293       return success();
294     }
295 
296     auto callback = [align, memRefType, base, ptr, loc, &rewriter](
297                         Type llvm1DVectorTy, ValueRange vectorOperands) {
298       // Resolve address.
299       Value ptrs = getIndexedPtrs(
300           rewriter, loc, memRefType, base, ptr, /*index=*/vectorOperands[0],
301           LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
302       // Create the gather intrinsic.
303       return rewriter.create<LLVM::masked_gather>(
304           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
305           /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
306     };
307     SmallVector<Value> vectorOperands = {
308         adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
309     return LLVM::detail::handleMultidimensionalVectors(
310         gather, vectorOperands, *getTypeConverter(), callback, rewriter);
311   }
312 };
313 
314 /// Conversion pattern for a vector.scatter.
315 class VectorScatterOpConversion
316     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
317 public:
318   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
319 
320   LogicalResult
321   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
322                   ConversionPatternRewriter &rewriter) const override {
323     auto loc = scatter->getLoc();
324     MemRefType memRefType = scatter.getMemRefType();
325 
326     if (failed(isMemRefTypeSupported(memRefType)))
327       return failure();
328 
329     // Resolve alignment.
330     unsigned align;
331     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
332       return failure();
333 
334     // Resolve address.
335     VectorType vType = scatter.getVectorType();
336     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
337                                      adaptor.getIndices(), rewriter);
338     Value ptrs =
339         getIndexedPtrs(rewriter, loc, memRefType, adaptor.getBase(), ptr,
340                        adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
341 
342     // Replace with the scatter intrinsic.
343     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
344         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
345         rewriter.getI32IntegerAttr(align));
346     return success();
347   }
348 };
349 
350 /// Conversion pattern for a vector.expandload.
351 class VectorExpandLoadOpConversion
352     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
353 public:
354   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
355 
356   LogicalResult
357   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
358                   ConversionPatternRewriter &rewriter) const override {
359     auto loc = expand->getLoc();
360     MemRefType memRefType = expand.getMemRefType();
361 
362     // Resolve address.
363     auto vtype = typeConverter->convertType(expand.getVectorType());
364     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
365                                      adaptor.getIndices(), rewriter);
366 
367     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
368         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
369     return success();
370   }
371 };
372 
373 /// Conversion pattern for a vector.compressstore.
374 class VectorCompressStoreOpConversion
375     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
376 public:
377   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
378 
379   LogicalResult
380   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
381                   ConversionPatternRewriter &rewriter) const override {
382     auto loc = compress->getLoc();
383     MemRefType memRefType = compress.getMemRefType();
384 
385     // Resolve address.
386     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
387                                      adaptor.getIndices(), rewriter);
388 
389     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
390         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
391     return success();
392   }
393 };
394 
395 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
396 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
397 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
398 /// non-null.
399 template <class VectorOp, class ScalarOp>
400 static Value createIntegerReductionArithmeticOpLowering(
401     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
402     Value vectorOperand, Value accumulator) {
403   Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
404   if (accumulator)
405     result = rewriter.create<ScalarOp>(loc, accumulator, result);
406   return result;
407 }
408 
409 /// Helper method to lower a `vector.reduction` operation that performs
410 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
411 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
412 /// the accumulator value if non-null.
413 template <class VectorOp>
414 static Value createIntegerReductionComparisonOpLowering(
415     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
416     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
417   Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
418   if (accumulator) {
419     Value cmp =
420         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
421     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
422   }
423   return result;
424 }
425 
426 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
427 /// with vector types.
428 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
429                            Value rhs, bool isMin) {
430   auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
431   Type i1Type = builder.getI1Type();
432   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
433     i1Type = VectorType::get(vecType.getShape(), i1Type);
434   Value cmp = builder.create<LLVM::FCmpOp>(
435       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
436       lhs, rhs);
437   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
438   Value isNan = builder.create<LLVM::FCmpOp>(
439       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
440   Value nan = builder.create<LLVM::ConstantOp>(
441       loc, lhs.getType(),
442       builder.getFloatAttr(floatType,
443                            APFloat::getQNaN(floatType.getFloatSemantics())));
444   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
445 }
446 
447 /// Conversion pattern for all vector reductions.
448 class VectorReductionOpConversion
449     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
450 public:
451   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
452                                        bool reassociateFPRed)
453       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
454         reassociateFPReductions(reassociateFPRed) {}
455 
456   LogicalResult
457   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
458                   ConversionPatternRewriter &rewriter) const override {
459     auto kind = reductionOp.getKind();
460     Type eltType = reductionOp.getDest().getType();
461     Type llvmType = typeConverter->convertType(eltType);
462     Value operand = adaptor.getVector();
463     Value acc = adaptor.getAcc();
464     Location loc = reductionOp.getLoc();
465     if (eltType.isIntOrIndex()) {
466       // Integer reductions: add/mul/min/max/and/or/xor.
467       Value result;
468       switch (kind) {
469       case vector::CombiningKind::ADD:
470         result =
471             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
472                                                        LLVM::AddOp>(
473                 rewriter, loc, llvmType, operand, acc);
474         break;
475       case vector::CombiningKind::MUL:
476         result =
477             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
478                                                        LLVM::MulOp>(
479                 rewriter, loc, llvmType, operand, acc);
480         break;
481       case vector::CombiningKind::MINUI:
482         result = createIntegerReductionComparisonOpLowering<
483             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
484                                       LLVM::ICmpPredicate::ule);
485         break;
486       case vector::CombiningKind::MINSI:
487         result = createIntegerReductionComparisonOpLowering<
488             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
489                                       LLVM::ICmpPredicate::sle);
490         break;
491       case vector::CombiningKind::MAXUI:
492         result = createIntegerReductionComparisonOpLowering<
493             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
494                                       LLVM::ICmpPredicate::uge);
495         break;
496       case vector::CombiningKind::MAXSI:
497         result = createIntegerReductionComparisonOpLowering<
498             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
499                                       LLVM::ICmpPredicate::sge);
500         break;
501       case vector::CombiningKind::AND:
502         result =
503             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
504                                                        LLVM::AndOp>(
505                 rewriter, loc, llvmType, operand, acc);
506         break;
507       case vector::CombiningKind::OR:
508         result =
509             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
510                                                        LLVM::OrOp>(
511                 rewriter, loc, llvmType, operand, acc);
512         break;
513       case vector::CombiningKind::XOR:
514         result =
515             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
516                                                        LLVM::XOrOp>(
517                 rewriter, loc, llvmType, operand, acc);
518         break;
519       default:
520         return failure();
521       }
522       rewriter.replaceOp(reductionOp, result);
523 
524       return success();
525     }
526 
527     if (!eltType.isa<FloatType>())
528       return failure();
529 
530     // Floating-point reductions: add/mul/min/max
531     if (kind == vector::CombiningKind::ADD) {
532       // Optional accumulator (or zero).
533       Value acc = adaptor.getOperands().size() > 1
534                       ? adaptor.getOperands()[1]
535                       : rewriter.create<LLVM::ConstantOp>(
536                             reductionOp->getLoc(), llvmType,
537                             rewriter.getZeroAttr(eltType));
538       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
539           reductionOp, llvmType, acc, operand,
540           rewriter.getBoolAttr(reassociateFPReductions));
541     } else if (kind == vector::CombiningKind::MUL) {
542       // Optional accumulator (or one).
543       Value acc = adaptor.getOperands().size() > 1
544                       ? adaptor.getOperands()[1]
545                       : rewriter.create<LLVM::ConstantOp>(
546                             reductionOp->getLoc(), llvmType,
547                             rewriter.getFloatAttr(eltType, 1.0));
548       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
549           reductionOp, llvmType, acc, operand,
550           rewriter.getBoolAttr(reassociateFPReductions));
551     } else if (kind == vector::CombiningKind::MINF) {
552       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
553       // NaNs/-0.0/+0.0 in the same way.
554       Value result =
555           rewriter.create<LLVM::vector_reduce_fmin>(loc, llvmType, operand);
556       if (acc)
557         result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true);
558       rewriter.replaceOp(reductionOp, result);
559     } else if (kind == vector::CombiningKind::MAXF) {
560       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
561       // NaNs/-0.0/+0.0 in the same way.
562       Value result =
563           rewriter.create<LLVM::vector_reduce_fmax>(loc, llvmType, operand);
564       if (acc)
565         result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false);
566       rewriter.replaceOp(reductionOp, result);
567     } else
568       return failure();
569 
570     return success();
571   }
572 
573 private:
574   const bool reassociateFPReductions;
575 };
576 
577 class VectorShuffleOpConversion
578     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
579 public:
580   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
581 
582   LogicalResult
583   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
584                   ConversionPatternRewriter &rewriter) const override {
585     auto loc = shuffleOp->getLoc();
586     auto v1Type = shuffleOp.getV1VectorType();
587     auto v2Type = shuffleOp.getV2VectorType();
588     auto vectorType = shuffleOp.getVectorType();
589     Type llvmType = typeConverter->convertType(vectorType);
590     auto maskArrayAttr = shuffleOp.getMask();
591 
592     // Bail if result type cannot be lowered.
593     if (!llvmType)
594       return failure();
595 
596     // Get rank and dimension sizes.
597     int64_t rank = vectorType.getRank();
598 #ifndef NDEBUG
599     bool wellFormed0DCase =
600         v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
601     bool wellFormedNDCase =
602         v1Type.getRank() == rank && v2Type.getRank() == rank;
603     assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
604 #endif
605 
606     // For rank 0 and 1, where both operands have *exactly* the same vector
607     // type, there is direct shuffle support in LLVM. Use it!
608     if (rank <= 1 && v1Type == v2Type) {
609       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
610           loc, adaptor.getV1(), adaptor.getV2(),
611           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
612       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
613       return success();
614     }
615 
616     // For all other cases, insert the individual values individually.
617     int64_t v1Dim = v1Type.getDimSize(0);
618     Type eltType;
619     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
620       eltType = arrayType.getElementType();
621     else
622       eltType = llvmType.cast<VectorType>().getElementType();
623     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
624     int64_t insPos = 0;
625     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
626       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
627       Value value = adaptor.getV1();
628       if (extPos >= v1Dim) {
629         extPos -= v1Dim;
630         value = adaptor.getV2();
631       }
632       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
633                                  eltType, rank, extPos);
634       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
635                          llvmType, rank, insPos++);
636     }
637     rewriter.replaceOp(shuffleOp, insert);
638     return success();
639   }
640 };
641 
642 class VectorExtractElementOpConversion
643     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
644 public:
645   using ConvertOpToLLVMPattern<
646       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
647 
648   LogicalResult
649   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
650                   ConversionPatternRewriter &rewriter) const override {
651     auto vectorType = extractEltOp.getVectorType();
652     auto llvmType = typeConverter->convertType(vectorType.getElementType());
653 
654     // Bail if result type cannot be lowered.
655     if (!llvmType)
656       return failure();
657 
658     if (vectorType.getRank() == 0) {
659       Location loc = extractEltOp.getLoc();
660       auto idxType = rewriter.getIndexType();
661       auto zero = rewriter.create<LLVM::ConstantOp>(
662           loc, typeConverter->convertType(idxType),
663           rewriter.getIntegerAttr(idxType, 0));
664       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
665           extractEltOp, llvmType, adaptor.getVector(), zero);
666       return success();
667     }
668 
669     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
670         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
671     return success();
672   }
673 };
674 
675 class VectorExtractOpConversion
676     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
677 public:
678   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
679 
680   LogicalResult
681   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
682                   ConversionPatternRewriter &rewriter) const override {
683     auto loc = extractOp->getLoc();
684     auto resultType = extractOp.getResult().getType();
685     auto llvmResultType = typeConverter->convertType(resultType);
686     auto positionArrayAttr = extractOp.getPosition();
687 
688     // Bail if result type cannot be lowered.
689     if (!llvmResultType)
690       return failure();
691 
692     // Extract entire vector. Should be handled by folder, but just to be safe.
693     if (positionArrayAttr.empty()) {
694       rewriter.replaceOp(extractOp, adaptor.getVector());
695       return success();
696     }
697 
698     // One-shot extraction of vector from array (only requires extractvalue).
699     if (resultType.isa<VectorType>()) {
700       SmallVector<int64_t> indices;
701       for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
702         indices.push_back(idx.getInt());
703       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
704           loc, adaptor.getVector(), indices);
705       rewriter.replaceOp(extractOp, extracted);
706       return success();
707     }
708 
709     // Potential extraction of 1-D vector from array.
710     Value extracted = adaptor.getVector();
711     auto positionAttrs = positionArrayAttr.getValue();
712     if (positionAttrs.size() > 1) {
713       SmallVector<int64_t> nMinusOnePosition;
714       for (auto idx : positionAttrs.drop_back())
715         nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt());
716       extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
717                                                         nMinusOnePosition);
718     }
719 
720     // Remaining extraction of element from 1-D LLVM vector
721     auto position = positionAttrs.back().cast<IntegerAttr>();
722     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
723     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
724     extracted =
725         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
726     rewriter.replaceOp(extractOp, extracted);
727 
728     return success();
729   }
730 };
731 
732 /// Conversion pattern that turns a vector.fma on a 1-D vector
733 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
734 /// This does not match vectors of n >= 2 rank.
735 ///
736 /// Example:
737 /// ```
738 ///  vector.fma %a, %a, %a : vector<8xf32>
739 /// ```
740 /// is converted to:
741 /// ```
742 ///  llvm.intr.fmuladd %va, %va, %va:
743 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
744 ///    -> !llvm."<8 x f32>">
745 /// ```
746 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
747 public:
748   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
749 
750   LogicalResult
751   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
752                   ConversionPatternRewriter &rewriter) const override {
753     VectorType vType = fmaOp.getVectorType();
754     if (vType.getRank() > 1)
755       return failure();
756     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
757         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
758     return success();
759   }
760 };
761 
762 class VectorInsertElementOpConversion
763     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
764 public:
765   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
766 
767   LogicalResult
768   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
769                   ConversionPatternRewriter &rewriter) const override {
770     auto vectorType = insertEltOp.getDestVectorType();
771     auto llvmType = typeConverter->convertType(vectorType);
772 
773     // Bail if result type cannot be lowered.
774     if (!llvmType)
775       return failure();
776 
777     if (vectorType.getRank() == 0) {
778       Location loc = insertEltOp.getLoc();
779       auto idxType = rewriter.getIndexType();
780       auto zero = rewriter.create<LLVM::ConstantOp>(
781           loc, typeConverter->convertType(idxType),
782           rewriter.getIntegerAttr(idxType, 0));
783       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
784           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
785       return success();
786     }
787 
788     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
789         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
790         adaptor.getPosition());
791     return success();
792   }
793 };
794 
795 class VectorInsertOpConversion
796     : public ConvertOpToLLVMPattern<vector::InsertOp> {
797 public:
798   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
799 
800   LogicalResult
801   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
802                   ConversionPatternRewriter &rewriter) const override {
803     auto loc = insertOp->getLoc();
804     auto sourceType = insertOp.getSourceType();
805     auto destVectorType = insertOp.getDestVectorType();
806     auto llvmResultType = typeConverter->convertType(destVectorType);
807     auto positionArrayAttr = insertOp.getPosition();
808 
809     // Bail if result type cannot be lowered.
810     if (!llvmResultType)
811       return failure();
812 
813     // Overwrite entire vector with value. Should be handled by folder, but
814     // just to be safe.
815     if (positionArrayAttr.empty()) {
816       rewriter.replaceOp(insertOp, adaptor.getSource());
817       return success();
818     }
819 
820     // One-shot insertion of a vector into an array (only requires insertvalue).
821     if (sourceType.isa<VectorType>()) {
822       Value inserted = rewriter.create<LLVM::InsertValueOp>(
823           loc, adaptor.getDest(), adaptor.getSource(),
824           LLVM::convertArrayToIndices(positionArrayAttr));
825       rewriter.replaceOp(insertOp, inserted);
826       return success();
827     }
828 
829     // Potential extraction of 1-D vector from array.
830     Value extracted = adaptor.getDest();
831     auto positionAttrs = positionArrayAttr.getValue();
832     auto position = positionAttrs.back().cast<IntegerAttr>();
833     auto oneDVectorType = destVectorType;
834     if (positionAttrs.size() > 1) {
835       oneDVectorType = reducedVectorTypeBack(destVectorType);
836       extracted = rewriter.create<LLVM::ExtractValueOp>(
837           loc, extracted,
838           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
839     }
840 
841     // Insertion of an element into a 1-D LLVM vector.
842     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
843     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
844     Value inserted = rewriter.create<LLVM::InsertElementOp>(
845         loc, typeConverter->convertType(oneDVectorType), extracted,
846         adaptor.getSource(), constant);
847 
848     // Potential insertion of resulting 1-D vector into array.
849     if (positionAttrs.size() > 1) {
850       inserted = rewriter.create<LLVM::InsertValueOp>(
851           loc, adaptor.getDest(), inserted,
852           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
853     }
854 
855     rewriter.replaceOp(insertOp, inserted);
856     return success();
857   }
858 };
859 
860 /// Lower vector.scalable.insert ops to LLVM vector.insert
861 struct VectorScalableInsertOpLowering
862     : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
863   using ConvertOpToLLVMPattern<
864       vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
865 
866   LogicalResult
867   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
868                   ConversionPatternRewriter &rewriter) const override {
869     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
870         insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
871     return success();
872   }
873 };
874 
875 /// Lower vector.scalable.extract ops to LLVM vector.extract
876 struct VectorScalableExtractOpLowering
877     : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
878   using ConvertOpToLLVMPattern<
879       vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
880 
881   LogicalResult
882   matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
883                   ConversionPatternRewriter &rewriter) const override {
884     rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
885         extOp, typeConverter->convertType(extOp.getResultVectorType()),
886         adaptor.getSource(), adaptor.getPos());
887     return success();
888   }
889 };
890 
891 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
892 ///
893 /// Example:
894 /// ```
895 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
896 /// ```
897 /// is rewritten into:
898 /// ```
899 ///  %r = splat %f0: vector<2x4xf32>
900 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
901 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
902 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
903 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
904 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
905 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
906 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
907 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
908 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
909 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
910 ///  // %r3 holds the final value.
911 /// ```
912 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
913 public:
914   using OpRewritePattern<FMAOp>::OpRewritePattern;
915 
916   void initialize() {
917     // This pattern recursively unpacks one dimension at a time. The recursion
918     // bounded as the rank is strictly decreasing.
919     setHasBoundedRewriteRecursion();
920   }
921 
922   LogicalResult matchAndRewrite(FMAOp op,
923                                 PatternRewriter &rewriter) const override {
924     auto vType = op.getVectorType();
925     if (vType.getRank() < 2)
926       return failure();
927 
928     auto loc = op.getLoc();
929     auto elemType = vType.getElementType();
930     Value zero = rewriter.create<arith::ConstantOp>(
931         loc, elemType, rewriter.getZeroAttr(elemType));
932     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
933     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
934       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
935       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
936       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
937       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
938       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
939     }
940     rewriter.replaceOp(op, desc);
941     return success();
942   }
943 };
944 
945 /// Returns the strides if the memory underlying `memRefType` has a contiguous
946 /// static layout.
947 static std::optional<SmallVector<int64_t, 4>>
948 computeContiguousStrides(MemRefType memRefType) {
949   int64_t offset;
950   SmallVector<int64_t, 4> strides;
951   if (failed(getStridesAndOffset(memRefType, strides, offset)))
952     return std::nullopt;
953   if (!strides.empty() && strides.back() != 1)
954     return std::nullopt;
955   // If no layout or identity layout, this is contiguous by definition.
956   if (memRefType.getLayout().isIdentity())
957     return strides;
958 
959   // Otherwise, we must determine contiguity form shapes. This can only ever
960   // work in static cases because MemRefType is underspecified to represent
961   // contiguous dynamic shapes in other ways than with just empty/identity
962   // layout.
963   auto sizes = memRefType.getShape();
964   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
965     if (ShapedType::isDynamic(sizes[index + 1]) ||
966         ShapedType::isDynamic(strides[index]) ||
967         ShapedType::isDynamic(strides[index + 1]))
968       return std::nullopt;
969     if (strides[index] != strides[index + 1] * sizes[index + 1])
970       return std::nullopt;
971   }
972   return strides;
973 }
974 
975 class VectorTypeCastOpConversion
976     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
977 public:
978   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
979 
980   LogicalResult
981   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
982                   ConversionPatternRewriter &rewriter) const override {
983     auto loc = castOp->getLoc();
984     MemRefType sourceMemRefType =
985         castOp.getOperand().getType().cast<MemRefType>();
986     MemRefType targetMemRefType = castOp.getType();
987 
988     // Only static shape casts supported atm.
989     if (!sourceMemRefType.hasStaticShape() ||
990         !targetMemRefType.hasStaticShape())
991       return failure();
992 
993     auto llvmSourceDescriptorTy =
994         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
995     if (!llvmSourceDescriptorTy)
996       return failure();
997     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
998 
999     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1000                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1001     if (!llvmTargetDescriptorTy)
1002       return failure();
1003 
1004     // Only contiguous source buffers supported atm.
1005     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1006     if (!sourceStrides)
1007       return failure();
1008     auto targetStrides = computeContiguousStrides(targetMemRefType);
1009     if (!targetStrides)
1010       return failure();
1011     // Only support static strides for now, regardless of contiguity.
1012     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1013       return failure();
1014 
1015     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1016 
1017     // Create descriptor.
1018     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1019     Type llvmTargetElementTy = desc.getElementPtrType();
1020     // Set allocated ptr.
1021     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1022     allocated =
1023         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1024     desc.setAllocatedPtr(rewriter, loc, allocated);
1025     // Set aligned ptr.
1026     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1027     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1028     desc.setAlignedPtr(rewriter, loc, ptr);
1029     // Fill offset 0.
1030     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1031     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1032     desc.setOffset(rewriter, loc, zero);
1033 
1034     // Fill size and stride descriptors in memref.
1035     for (const auto &indexedSize :
1036          llvm::enumerate(targetMemRefType.getShape())) {
1037       int64_t index = indexedSize.index();
1038       auto sizeAttr =
1039           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1040       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1041       desc.setSize(rewriter, loc, index, size);
1042       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1043                                                 (*targetStrides)[index]);
1044       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1045       desc.setStride(rewriter, loc, index, stride);
1046     }
1047 
1048     rewriter.replaceOp(castOp, {desc});
1049     return success();
1050   }
1051 };
1052 
1053 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1054 /// Non-scalable versions of this operation are handled in Vector Transforms.
1055 class VectorCreateMaskOpRewritePattern
1056     : public OpRewritePattern<vector::CreateMaskOp> {
1057 public:
1058   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1059                                             bool enableIndexOpt)
1060       : OpRewritePattern<vector::CreateMaskOp>(context),
1061         force32BitVectorIndices(enableIndexOpt) {}
1062 
1063   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1064                                 PatternRewriter &rewriter) const override {
1065     auto dstType = op.getType();
1066     if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
1067       return failure();
1068     IntegerType idxType =
1069         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1070     auto loc = op->getLoc();
1071     Value indices = rewriter.create<LLVM::StepVectorOp>(
1072         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1073                                  /*isScalable=*/true));
1074     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1075                                                  op.getOperand(0));
1076     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1077     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1078                                                 indices, bounds);
1079     rewriter.replaceOp(op, comp);
1080     return success();
1081   }
1082 
1083 private:
1084   const bool force32BitVectorIndices;
1085 };
1086 
1087 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1088 public:
1089   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1090 
1091   // Proof-of-concept lowering implementation that relies on a small
1092   // runtime support library, which only needs to provide a few
1093   // printing methods (single value for all data types, opening/closing
1094   // bracket, comma, newline). The lowering fully unrolls a vector
1095   // in terms of these elementary printing operations. The advantage
1096   // of this approach is that the library can remain unaware of all
1097   // low-level implementation details of vectors while still supporting
1098   // output of any shaped and dimensioned vector. Due to full unrolling,
1099   // this approach is less suited for very large vectors though.
1100   //
1101   // TODO: rely solely on libc in future? something else?
1102   //
1103   LogicalResult
1104   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1105                   ConversionPatternRewriter &rewriter) const override {
1106     Type printType = printOp.getPrintType();
1107 
1108     if (typeConverter->convertType(printType) == nullptr)
1109       return failure();
1110 
1111     // Make sure element type has runtime support.
1112     PrintConversion conversion = PrintConversion::None;
1113     VectorType vectorType = printType.dyn_cast<VectorType>();
1114     Type eltType = vectorType ? vectorType.getElementType() : printType;
1115     Operation *printer;
1116     if (eltType.isF32()) {
1117       printer =
1118           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1119     } else if (eltType.isF64()) {
1120       printer =
1121           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1122     } else if (eltType.isIndex()) {
1123       printer =
1124           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1125     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1126       // Integers need a zero or sign extension on the operand
1127       // (depending on the source type) as well as a signed or
1128       // unsigned print method. Up to 64-bit is supported.
1129       unsigned width = intTy.getWidth();
1130       if (intTy.isUnsigned()) {
1131         if (width <= 64) {
1132           if (width < 64)
1133             conversion = PrintConversion::ZeroExt64;
1134           printer = LLVM::lookupOrCreatePrintU64Fn(
1135               printOp->getParentOfType<ModuleOp>());
1136         } else {
1137           return failure();
1138         }
1139       } else {
1140         assert(intTy.isSignless() || intTy.isSigned());
1141         if (width <= 64) {
1142           // Note that we *always* zero extend booleans (1-bit integers),
1143           // so that true/false is printed as 1/0 rather than -1/0.
1144           if (width == 1)
1145             conversion = PrintConversion::ZeroExt64;
1146           else if (width < 64)
1147             conversion = PrintConversion::SignExt64;
1148           printer = LLVM::lookupOrCreatePrintI64Fn(
1149               printOp->getParentOfType<ModuleOp>());
1150         } else {
1151           return failure();
1152         }
1153       }
1154     } else {
1155       return failure();
1156     }
1157 
1158     // Unroll vector into elementary print calls.
1159     int64_t rank = vectorType ? vectorType.getRank() : 0;
1160     Type type = vectorType ? vectorType : eltType;
1161     emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1162               conversion);
1163     emitCall(rewriter, printOp->getLoc(),
1164              LLVM::lookupOrCreatePrintNewlineFn(
1165                  printOp->getParentOfType<ModuleOp>()));
1166     rewriter.eraseOp(printOp);
1167     return success();
1168   }
1169 
1170 private:
1171   enum class PrintConversion {
1172     // clang-format off
1173     None,
1174     ZeroExt64,
1175     SignExt64
1176     // clang-format on
1177   };
1178 
1179   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1180                  Value value, Type type, Operation *printer, int64_t rank,
1181                  PrintConversion conversion) const {
1182     VectorType vectorType = type.dyn_cast<VectorType>();
1183     Location loc = op->getLoc();
1184     if (!vectorType) {
1185       assert(rank == 0 && "The scalar case expects rank == 0");
1186       switch (conversion) {
1187       case PrintConversion::ZeroExt64:
1188         value = rewriter.create<arith::ExtUIOp>(
1189             loc, IntegerType::get(rewriter.getContext(), 64), value);
1190         break;
1191       case PrintConversion::SignExt64:
1192         value = rewriter.create<arith::ExtSIOp>(
1193             loc, IntegerType::get(rewriter.getContext(), 64), value);
1194         break;
1195       case PrintConversion::None:
1196         break;
1197       }
1198       emitCall(rewriter, loc, printer, value);
1199       return;
1200     }
1201 
1202     emitCall(rewriter, loc,
1203              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1204     Operation *printComma =
1205         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1206 
1207     if (rank <= 1) {
1208       auto reducedType = vectorType.getElementType();
1209       auto llvmType = typeConverter->convertType(reducedType);
1210       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1211       for (int64_t d = 0; d < dim; ++d) {
1212         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1213                                      llvmType, /*rank=*/0, /*pos=*/d);
1214         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1215                   conversion);
1216         if (d != dim - 1)
1217           emitCall(rewriter, loc, printComma);
1218       }
1219       emitCall(
1220           rewriter, loc,
1221           LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1222       return;
1223     }
1224 
1225     int64_t dim = vectorType.getDimSize(0);
1226     for (int64_t d = 0; d < dim; ++d) {
1227       auto reducedType = reducedVectorTypeFront(vectorType);
1228       auto llvmType = typeConverter->convertType(reducedType);
1229       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1230                                    llvmType, rank, d);
1231       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1232                 conversion);
1233       if (d != dim - 1)
1234         emitCall(rewriter, loc, printComma);
1235     }
1236     emitCall(rewriter, loc,
1237              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1238   }
1239 
1240   // Helper to emit a call.
1241   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1242                        Operation *ref, ValueRange params = ValueRange()) {
1243     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1244                                   params);
1245   }
1246 };
1247 
1248 /// The Splat operation is lowered to an insertelement + a shufflevector
1249 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1250 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1251   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1252 
1253   LogicalResult
1254   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1255                   ConversionPatternRewriter &rewriter) const override {
1256     VectorType resultType = splatOp.getType().cast<VectorType>();
1257     if (resultType.getRank() > 1)
1258       return failure();
1259 
1260     // First insert it into an undef vector so we can shuffle it.
1261     auto vectorType = typeConverter->convertType(splatOp.getType());
1262     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1263     auto zero = rewriter.create<LLVM::ConstantOp>(
1264         splatOp.getLoc(),
1265         typeConverter->convertType(rewriter.getIntegerType(32)),
1266         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1267 
1268     // For 0-d vector, we simply do `insertelement`.
1269     if (resultType.getRank() == 0) {
1270       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1271           splatOp, vectorType, undef, adaptor.getInput(), zero);
1272       return success();
1273     }
1274 
1275     // For 1-d vector, we additionally do a `vectorshuffle`.
1276     auto v = rewriter.create<LLVM::InsertElementOp>(
1277         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1278 
1279     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1280     SmallVector<int32_t> zeroValues(width, 0);
1281 
1282     // Shuffle the value across the desired number of elements.
1283     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1284                                                        zeroValues);
1285     return success();
1286   }
1287 };
1288 
1289 /// The Splat operation is lowered to an insertelement + a shufflevector
1290 /// operation. Splat to only 2+-d vector result types are lowered by the
1291 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1292 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1293   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1294 
1295   LogicalResult
1296   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1297                   ConversionPatternRewriter &rewriter) const override {
1298     VectorType resultType = splatOp.getType();
1299     if (resultType.getRank() <= 1)
1300       return failure();
1301 
1302     // First insert it into an undef vector so we can shuffle it.
1303     auto loc = splatOp.getLoc();
1304     auto vectorTypeInfo =
1305         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1306     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1307     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1308     if (!llvmNDVectorTy || !llvm1DVectorTy)
1309       return failure();
1310 
1311     // Construct returned value.
1312     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1313 
1314     // Construct a 1-D vector with the splatted value that we insert in all the
1315     // places within the returned descriptor.
1316     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1317     auto zero = rewriter.create<LLVM::ConstantOp>(
1318         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1319         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1320     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1321                                                      adaptor.getInput(), zero);
1322 
1323     // Shuffle the value across the desired number of elements.
1324     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1325     SmallVector<int32_t> zeroValues(width, 0);
1326     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1327 
1328     // Iterate of linear index, convert to coords space and insert splatted 1-D
1329     // vector in each position.
1330     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1331       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1332     });
1333     rewriter.replaceOp(splatOp, desc);
1334     return success();
1335   }
1336 };
1337 
1338 } // namespace
1339 
1340 /// Populate the given list with patterns that convert from Vector to LLVM.
1341 void mlir::populateVectorToLLVMConversionPatterns(
1342     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1343     bool reassociateFPReductions, bool force32BitVectorIndices) {
1344   MLIRContext *ctx = converter.getDialect()->getContext();
1345   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1346   populateVectorInsertExtractStridedSliceTransforms(patterns);
1347   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1348   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1349   patterns
1350       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1351            VectorExtractElementOpConversion, VectorExtractOpConversion,
1352            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1353            VectorInsertOpConversion, VectorPrintOpConversion,
1354            VectorTypeCastOpConversion, VectorScaleOpConversion,
1355            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1356            VectorLoadStoreConversion<vector::MaskedLoadOp,
1357                                      vector::MaskedLoadOpAdaptor>,
1358            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1359            VectorLoadStoreConversion<vector::MaskedStoreOp,
1360                                      vector::MaskedStoreOpAdaptor>,
1361            VectorGatherOpConversion, VectorScatterOpConversion,
1362            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1363            VectorSplatOpLowering, VectorSplatNdOpLowering,
1364            VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
1365           converter);
1366   // Transfer ops with rank > 1 are handled by VectorToSCF.
1367   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1368 }
1369 
1370 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1371     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1372   patterns.add<VectorMatmulOpConversion>(converter);
1373   patterns.add<VectorFlatTransposeOpConversion>(converter);
1374 }
1375