xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 2bc4c3e920ee078ef2879b00c40440e0867f0b9e)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Casting.h"
26 #include <optional>
27 
28 using namespace mlir;
29 using namespace mlir::vector;
30 
31 // Helper to reduce vector type by one rank at front.
32 static VectorType reducedVectorTypeFront(VectorType tp) {
33   assert((tp.getRank() > 1) && "unlowerable vector type");
34   unsigned numScalableDims = tp.getNumScalableDims();
35   if (tp.getShape().size() == numScalableDims)
36     --numScalableDims;
37   return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
38                          numScalableDims);
39 }
40 
41 // Helper to reduce vector type by *all* but one rank at back.
42 static VectorType reducedVectorTypeBack(VectorType tp) {
43   assert((tp.getRank() > 1) && "unlowerable vector type");
44   unsigned numScalableDims = tp.getNumScalableDims();
45   if (numScalableDims > 0)
46     --numScalableDims;
47   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
48                          numScalableDims);
49 }
50 
51 // Helper that picks the proper sequence for inserting.
52 static Value insertOne(ConversionPatternRewriter &rewriter,
53                        LLVMTypeConverter &typeConverter, Location loc,
54                        Value val1, Value val2, Type llvmType, int64_t rank,
55                        int64_t pos) {
56   assert(rank > 0 && "0-D vector corner case should have been handled already");
57   if (rank == 1) {
58     auto idxType = rewriter.getIndexType();
59     auto constant = rewriter.create<LLVM::ConstantOp>(
60         loc, typeConverter.convertType(idxType),
61         rewriter.getIntegerAttr(idxType, pos));
62     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
63                                                   constant);
64   }
65   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
66 }
67 
68 // Helper that picks the proper sequence for extracting.
69 static Value extractOne(ConversionPatternRewriter &rewriter,
70                         LLVMTypeConverter &typeConverter, Location loc,
71                         Value val, Type llvmType, int64_t rank, int64_t pos) {
72   if (rank <= 1) {
73     auto idxType = rewriter.getIndexType();
74     auto constant = rewriter.create<LLVM::ConstantOp>(
75         loc, typeConverter.convertType(idxType),
76         rewriter.getIntegerAttr(idxType, pos));
77     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
78                                                    constant);
79   }
80   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
81 }
82 
83 // Helper that returns data layout alignment of a memref.
84 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
85                                  MemRefType memrefType, unsigned &align) {
86   Type elementTy = typeConverter.convertType(memrefType.getElementType());
87   if (!elementTy)
88     return failure();
89 
90   // TODO: this should use the MLIR data layout when it becomes available and
91   // stop depending on translation.
92   llvm::LLVMContext llvmContext;
93   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
94               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
95   return success();
96 }
97 
98 // Check if the last stride is non-unit or the memory space is not zero.
99 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
100                                            LLVMTypeConverter &converter) {
101   int64_t offset;
102   SmallVector<int64_t, 4> strides;
103   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
104   FailureOr<unsigned> addressSpace =
105       converter.getMemRefAddressSpace(memRefType);
106   if (failed(successStrides) || strides.back() != 1 || failed(addressSpace) ||
107       *addressSpace != 0)
108     return failure();
109   return success();
110 }
111 
112 // Add an index vector component to a base pointer.
113 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
114                             LLVMTypeConverter &typeConverter,
115                             MemRefType memRefType, Value llvmMemref, Value base,
116                             Value index, uint64_t vLen) {
117   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
118          "unsupported memref type");
119   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
120   auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
121   return rewriter.create<LLVM::GEPOp>(
122       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
123       base, index);
124 }
125 
126 // Casts a strided element pointer to a vector pointer.  The vector pointer
127 // will be in the same address space as the incoming memref type.
128 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
129                          Value ptr, MemRefType memRefType, Type vt,
130                          LLVMTypeConverter &converter) {
131   if (converter.useOpaquePointers())
132     return ptr;
133 
134   unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
135   auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
136   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
137 }
138 
139 namespace {
140 
141 /// Trivial Vector to LLVM conversions
142 using VectorScaleOpConversion =
143     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
144 
145 /// Conversion pattern for a vector.bitcast.
146 class VectorBitCastOpConversion
147     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
148 public:
149   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
150 
151   LogicalResult
152   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
153                   ConversionPatternRewriter &rewriter) const override {
154     // Only 0-D and 1-D vectors can be lowered to LLVM.
155     VectorType resultTy = bitCastOp.getResultVectorType();
156     if (resultTy.getRank() > 1)
157       return failure();
158     Type newResultTy = typeConverter->convertType(resultTy);
159     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
160                                                  adaptor.getOperands()[0]);
161     return success();
162   }
163 };
164 
165 /// Conversion pattern for a vector.matrix_multiply.
166 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
167 class VectorMatmulOpConversion
168     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
169 public:
170   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
171 
172   LogicalResult
173   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
174                   ConversionPatternRewriter &rewriter) const override {
175     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
176         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
177         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
178         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
179     return success();
180   }
181 };
182 
183 /// Conversion pattern for a vector.flat_transpose.
184 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
185 class VectorFlatTransposeOpConversion
186     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
187 public:
188   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
189 
190   LogicalResult
191   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
192                   ConversionPatternRewriter &rewriter) const override {
193     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
194         transOp, typeConverter->convertType(transOp.getRes().getType()),
195         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
196     return success();
197   }
198 };
199 
200 /// Overloaded utility that replaces a vector.load, vector.store,
201 /// vector.maskedload and vector.maskedstore with their respective LLVM
202 /// couterparts.
203 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
204                                  vector::LoadOpAdaptor adaptor,
205                                  VectorType vectorTy, Value ptr, unsigned align,
206                                  ConversionPatternRewriter &rewriter) {
207   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
208 }
209 
210 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
211                                  vector::MaskedLoadOpAdaptor adaptor,
212                                  VectorType vectorTy, Value ptr, unsigned align,
213                                  ConversionPatternRewriter &rewriter) {
214   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
215       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
216 }
217 
218 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
219                                  vector::StoreOpAdaptor adaptor,
220                                  VectorType vectorTy, Value ptr, unsigned align,
221                                  ConversionPatternRewriter &rewriter) {
222   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
223                                              ptr, align);
224 }
225 
226 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
227                                  vector::MaskedStoreOpAdaptor adaptor,
228                                  VectorType vectorTy, Value ptr, unsigned align,
229                                  ConversionPatternRewriter &rewriter) {
230   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
231       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
232 }
233 
234 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
235 /// vector.maskedstore.
236 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
237 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
238 public:
239   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
240 
241   LogicalResult
242   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
243                   typename LoadOrStoreOp::Adaptor adaptor,
244                   ConversionPatternRewriter &rewriter) const override {
245     // Only 1-D vectors can be lowered to LLVM.
246     VectorType vectorTy = loadOrStoreOp.getVectorType();
247     if (vectorTy.getRank() > 1)
248       return failure();
249 
250     auto loc = loadOrStoreOp->getLoc();
251     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
252 
253     // Resolve alignment.
254     unsigned align;
255     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
256       return failure();
257 
258     // Resolve address.
259     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
260                      .template cast<VectorType>();
261     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
262                                                adaptor.getIndices(), rewriter);
263     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
264                             *this->getTypeConverter());
265 
266     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
267     return success();
268   }
269 };
270 
271 /// Conversion pattern for a vector.gather.
272 class VectorGatherOpConversion
273     : public ConvertOpToLLVMPattern<vector::GatherOp> {
274 public:
275   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
276 
277   LogicalResult
278   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
279                   ConversionPatternRewriter &rewriter) const override {
280     MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
281     assert(memRefType && "The base should be bufferized");
282 
283     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
284       return failure();
285 
286     auto loc = gather->getLoc();
287 
288     // Resolve alignment.
289     unsigned align;
290     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
291       return failure();
292 
293     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
294                                      adaptor.getIndices(), rewriter);
295     Value base = adaptor.getBase();
296 
297     auto llvmNDVectorTy = adaptor.getIndexVec().getType();
298     // Handle the simple case of 1-D vector.
299     if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
300       auto vType = gather.getVectorType();
301       // Resolve address.
302       Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
303                                   memRefType, base, ptr, adaptor.getIndexVec(),
304                                   /*vLen=*/vType.getDimSize(0));
305       // Replace with the gather intrinsic.
306       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
307           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
308           adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
309       return success();
310     }
311 
312     LLVMTypeConverter &typeConverter = *this->getTypeConverter();
313     auto callback = [align, memRefType, base, ptr, loc, &rewriter,
314                      &typeConverter](Type llvm1DVectorTy,
315                                      ValueRange vectorOperands) {
316       // Resolve address.
317       Value ptrs = getIndexedPtrs(
318           rewriter, loc, typeConverter, memRefType, base, ptr,
319           /*index=*/vectorOperands[0],
320           LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
321       // Create the gather intrinsic.
322       return rewriter.create<LLVM::masked_gather>(
323           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
324           /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
325     };
326     SmallVector<Value> vectorOperands = {
327         adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
328     return LLVM::detail::handleMultidimensionalVectors(
329         gather, vectorOperands, *getTypeConverter(), callback, rewriter);
330   }
331 };
332 
333 /// Conversion pattern for a vector.scatter.
334 class VectorScatterOpConversion
335     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
336 public:
337   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
338 
339   LogicalResult
340   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
341                   ConversionPatternRewriter &rewriter) const override {
342     auto loc = scatter->getLoc();
343     MemRefType memRefType = scatter.getMemRefType();
344 
345     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
346       return failure();
347 
348     // Resolve alignment.
349     unsigned align;
350     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
351       return failure();
352 
353     // Resolve address.
354     VectorType vType = scatter.getVectorType();
355     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
356                                      adaptor.getIndices(), rewriter);
357     Value ptrs = getIndexedPtrs(
358         rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
359         ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
360 
361     // Replace with the scatter intrinsic.
362     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
363         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
364         rewriter.getI32IntegerAttr(align));
365     return success();
366   }
367 };
368 
369 /// Conversion pattern for a vector.expandload.
370 class VectorExpandLoadOpConversion
371     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
372 public:
373   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
374 
375   LogicalResult
376   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
377                   ConversionPatternRewriter &rewriter) const override {
378     auto loc = expand->getLoc();
379     MemRefType memRefType = expand.getMemRefType();
380 
381     // Resolve address.
382     auto vtype = typeConverter->convertType(expand.getVectorType());
383     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
384                                      adaptor.getIndices(), rewriter);
385 
386     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
387         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
388     return success();
389   }
390 };
391 
392 /// Conversion pattern for a vector.compressstore.
393 class VectorCompressStoreOpConversion
394     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
395 public:
396   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
397 
398   LogicalResult
399   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
400                   ConversionPatternRewriter &rewriter) const override {
401     auto loc = compress->getLoc();
402     MemRefType memRefType = compress.getMemRefType();
403 
404     // Resolve address.
405     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
406                                      adaptor.getIndices(), rewriter);
407 
408     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
409         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
410     return success();
411   }
412 };
413 
414 /// Reduction neutral classes for overloading.
415 class ReductionNeutralZero {};
416 class ReductionNeutralIntOne {};
417 class ReductionNeutralFPOne {};
418 class ReductionNeutralAllOnes {};
419 class ReductionNeutralSIntMin {};
420 class ReductionNeutralUIntMin {};
421 class ReductionNeutralSIntMax {};
422 class ReductionNeutralUIntMax {};
423 class ReductionNeutralFPMin {};
424 class ReductionNeutralFPMax {};
425 
426 /// Create the reduction neutral zero value.
427 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
428                                          ConversionPatternRewriter &rewriter,
429                                          Location loc, Type llvmType) {
430   return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
431                                            rewriter.getZeroAttr(llvmType));
432 }
433 
434 /// Create the reduction neutral integer one value.
435 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
436                                          ConversionPatternRewriter &rewriter,
437                                          Location loc, Type llvmType) {
438   return rewriter.create<LLVM::ConstantOp>(
439       loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
440 }
441 
442 /// Create the reduction neutral fp one value.
443 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
444                                          ConversionPatternRewriter &rewriter,
445                                          Location loc, Type llvmType) {
446   return rewriter.create<LLVM::ConstantOp>(
447       loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
448 }
449 
450 /// Create the reduction neutral all-ones value.
451 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
452                                          ConversionPatternRewriter &rewriter,
453                                          Location loc, Type llvmType) {
454   return rewriter.create<LLVM::ConstantOp>(
455       loc, llvmType,
456       rewriter.getIntegerAttr(
457           llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
458 }
459 
460 /// Create the reduction neutral signed int minimum value.
461 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
462                                          ConversionPatternRewriter &rewriter,
463                                          Location loc, Type llvmType) {
464   return rewriter.create<LLVM::ConstantOp>(
465       loc, llvmType,
466       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
467                                             llvmType.getIntOrFloatBitWidth())));
468 }
469 
470 /// Create the reduction neutral unsigned int minimum value.
471 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
472                                          ConversionPatternRewriter &rewriter,
473                                          Location loc, Type llvmType) {
474   return rewriter.create<LLVM::ConstantOp>(
475       loc, llvmType,
476       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
477                                             llvmType.getIntOrFloatBitWidth())));
478 }
479 
480 /// Create the reduction neutral signed int maximum value.
481 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
482                                          ConversionPatternRewriter &rewriter,
483                                          Location loc, Type llvmType) {
484   return rewriter.create<LLVM::ConstantOp>(
485       loc, llvmType,
486       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
487                                             llvmType.getIntOrFloatBitWidth())));
488 }
489 
490 /// Create the reduction neutral unsigned int maximum value.
491 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
492                                          ConversionPatternRewriter &rewriter,
493                                          Location loc, Type llvmType) {
494   return rewriter.create<LLVM::ConstantOp>(
495       loc, llvmType,
496       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
497                                             llvmType.getIntOrFloatBitWidth())));
498 }
499 
500 /// Create the reduction neutral fp minimum value.
501 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
502                                          ConversionPatternRewriter &rewriter,
503                                          Location loc, Type llvmType) {
504   auto floatType = llvmType.cast<FloatType>();
505   return rewriter.create<LLVM::ConstantOp>(
506       loc, llvmType,
507       rewriter.getFloatAttr(
508           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
509                                            /*Negative=*/false)));
510 }
511 
512 /// Create the reduction neutral fp maximum value.
513 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
514                                          ConversionPatternRewriter &rewriter,
515                                          Location loc, Type llvmType) {
516   auto floatType = llvmType.cast<FloatType>();
517   return rewriter.create<LLVM::ConstantOp>(
518       loc, llvmType,
519       rewriter.getFloatAttr(
520           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
521                                            /*Negative=*/true)));
522 }
523 
524 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
525 /// returns a new accumulator value using `ReductionNeutral`.
526 template <class ReductionNeutral>
527 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
528                                     Location loc, Type llvmType,
529                                     Value accumulator) {
530   if (accumulator)
531     return accumulator;
532 
533   return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
534                                      llvmType);
535 }
536 
537 /// Creates a constant value with the 1-D vector shape provided in `llvmType`.
538 /// This is used as effective vector length by some intrinsics supporting
539 /// dynamic vector lengths at runtime.
540 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
541                                      Location loc, Type llvmType) {
542   VectorType vType = cast<VectorType>(llvmType);
543   auto vShape = vType.getShape();
544   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
545 
546   return rewriter.create<LLVM::ConstantOp>(
547       loc, rewriter.getI32Type(),
548       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
549 }
550 
551 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
552 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
553 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
554 /// non-null.
555 template <class LLVMRedIntrinOp, class ScalarOp>
556 static Value createIntegerReductionArithmeticOpLowering(
557     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
558     Value vectorOperand, Value accumulator) {
559 
560   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
561 
562   if (accumulator)
563     result = rewriter.create<ScalarOp>(loc, accumulator, result);
564   return result;
565 }
566 
567 /// Helper method to lower a `vector.reduction` operation that performs
568 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
569 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
570 /// the accumulator value if non-null.
571 template <class LLVMRedIntrinOp>
572 static Value createIntegerReductionComparisonOpLowering(
573     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
574     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
575   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
576   if (accumulator) {
577     Value cmp =
578         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
579     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
580   }
581   return result;
582 }
583 
584 /// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
585 /// with vector types.
586 static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
587                            Value rhs, bool isMin) {
588   auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
589   Type i1Type = builder.getI1Type();
590   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
591     i1Type = VectorType::get(vecType.getShape(), i1Type);
592   Value cmp = builder.create<LLVM::FCmpOp>(
593       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
594       lhs, rhs);
595   Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
596   Value isNan = builder.create<LLVM::FCmpOp>(
597       loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
598   Value nan = builder.create<LLVM::ConstantOp>(
599       loc, lhs.getType(),
600       builder.getFloatAttr(floatType,
601                            APFloat::getQNaN(floatType.getFloatSemantics())));
602   return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
603 }
604 
605 template <class LLVMRedIntrinOp>
606 static Value createFPReductionComparisonOpLowering(
607     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
608     Value vectorOperand, Value accumulator, bool isMin) {
609   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
610 
611   if (accumulator)
612     result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
613 
614   return result;
615 }
616 
617 /// Overloaded methods to lower a reduction to an llvm instrinsic that requires
618 /// a start value. This start value format spans across fp reductions without
619 /// mask and all the masked reduction intrinsics.
620 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
621 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
622                                           Location loc, Type llvmType,
623                                           Value vectorOperand,
624                                           Value accumulator) {
625   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
626                                                          llvmType, accumulator);
627   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
628                                             /*startValue=*/accumulator,
629                                             vectorOperand);
630 }
631 
632 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
633 static Value
634 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
635                              Type llvmType, Value vectorOperand,
636                              Value accumulator, bool reassociateFPReds) {
637   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
638                                                          llvmType, accumulator);
639   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
640                                             /*startValue=*/accumulator,
641                                             vectorOperand, reassociateFPReds);
642 }
643 
644 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
645 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
646                                           Location loc, Type llvmType,
647                                           Value vectorOperand,
648                                           Value accumulator, Value mask) {
649   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
650                                                          llvmType, accumulator);
651   Value vectorLength =
652       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
653   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
654                                             /*startValue=*/accumulator,
655                                             vectorOperand, mask, vectorLength);
656 }
657 
658 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
659 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
660                                           Location loc, Type llvmType,
661                                           Value vectorOperand,
662                                           Value accumulator, Value mask,
663                                           bool reassociateFPReds) {
664   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
665                                                          llvmType, accumulator);
666   Value vectorLength =
667       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
668   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
669                                             /*startValue=*/accumulator,
670                                             vectorOperand, mask, vectorLength,
671                                             reassociateFPReds);
672 }
673 
674 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
675           class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
676 static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter,
677                                           Location loc, Type llvmType,
678                                           Value vectorOperand,
679                                           Value accumulator, Value mask) {
680   if (llvmType.isIntOrIndex())
681     return lowerReductionWithStartValue<LLVMIntVPRedIntrinOp,
682                                         IntReductionNeutral>(
683         rewriter, loc, llvmType, vectorOperand, accumulator, mask);
684 
685   // FP dispatch.
686   return lowerReductionWithStartValue<LLVMFPVPRedIntrinOp, FPReductionNeutral>(
687       rewriter, loc, llvmType, vectorOperand, accumulator, mask);
688 }
689 
690 /// Conversion pattern for all vector reductions.
691 class VectorReductionOpConversion
692     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
693 public:
694   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
695                                        bool reassociateFPRed)
696       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
697         reassociateFPReductions(reassociateFPRed) {}
698 
699   LogicalResult
700   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
701                   ConversionPatternRewriter &rewriter) const override {
702     auto kind = reductionOp.getKind();
703     Type eltType = reductionOp.getDest().getType();
704     Type llvmType = typeConverter->convertType(eltType);
705     Value operand = adaptor.getVector();
706     Value acc = adaptor.getAcc();
707     Location loc = reductionOp.getLoc();
708 
709     if (eltType.isIntOrIndex()) {
710       // Integer reductions: add/mul/min/max/and/or/xor.
711       Value result;
712       switch (kind) {
713       case vector::CombiningKind::ADD:
714         result =
715             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
716                                                        LLVM::AddOp>(
717                 rewriter, loc, llvmType, operand, acc);
718         break;
719       case vector::CombiningKind::MUL:
720         result =
721             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
722                                                        LLVM::MulOp>(
723                 rewriter, loc, llvmType, operand, acc);
724         break;
725       case vector::CombiningKind::MINUI:
726         result = createIntegerReductionComparisonOpLowering<
727             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
728                                       LLVM::ICmpPredicate::ule);
729         break;
730       case vector::CombiningKind::MINSI:
731         result = createIntegerReductionComparisonOpLowering<
732             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
733                                       LLVM::ICmpPredicate::sle);
734         break;
735       case vector::CombiningKind::MAXUI:
736         result = createIntegerReductionComparisonOpLowering<
737             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
738                                       LLVM::ICmpPredicate::uge);
739         break;
740       case vector::CombiningKind::MAXSI:
741         result = createIntegerReductionComparisonOpLowering<
742             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
743                                       LLVM::ICmpPredicate::sge);
744         break;
745       case vector::CombiningKind::AND:
746         result =
747             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
748                                                        LLVM::AndOp>(
749                 rewriter, loc, llvmType, operand, acc);
750         break;
751       case vector::CombiningKind::OR:
752         result =
753             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
754                                                        LLVM::OrOp>(
755                 rewriter, loc, llvmType, operand, acc);
756         break;
757       case vector::CombiningKind::XOR:
758         result =
759             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
760                                                        LLVM::XOrOp>(
761                 rewriter, loc, llvmType, operand, acc);
762         break;
763       default:
764         return failure();
765       }
766       rewriter.replaceOp(reductionOp, result);
767 
768       return success();
769     }
770 
771     if (!eltType.isa<FloatType>())
772       return failure();
773 
774     // Floating-point reductions: add/mul/min/max
775     Value result;
776     if (kind == vector::CombiningKind::ADD) {
777       result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
778                                             ReductionNeutralZero>(
779           rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
780     } else if (kind == vector::CombiningKind::MUL) {
781       result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
782                                             ReductionNeutralFPOne>(
783           rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
784     } else if (kind == vector::CombiningKind::MINF) {
785       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
786       // NaNs/-0.0/+0.0 in the same way.
787       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
788           rewriter, loc, llvmType, operand, acc,
789           /*isMin=*/true);
790     } else if (kind == vector::CombiningKind::MAXF) {
791       // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
792       // NaNs/-0.0/+0.0 in the same way.
793       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
794           rewriter, loc, llvmType, operand, acc,
795           /*isMin=*/false);
796     } else
797       return failure();
798 
799     rewriter.replaceOp(reductionOp, result);
800     return success();
801   }
802 
803 private:
804   const bool reassociateFPReductions;
805 };
806 
807 /// Base class to convert a `vector.mask` operation while matching traits
808 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
809 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
810 /// method performs a second match against the maskable operation `MaskedOp`.
811 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
812 /// implemented by the concrete conversion classes. This method can match
813 /// against specific traits of the `vector.mask` and the maskable operation. It
814 /// must replace the `vector.mask` operation.
815 template <class MaskedOp>
816 class VectorMaskOpConversionBase
817     : public ConvertOpToLLVMPattern<vector::MaskOp> {
818 public:
819   using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
820 
821   LogicalResult
822   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
823                   ConversionPatternRewriter &rewriter) const final {
824     // Match against the maskable operation kind.
825     auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
826     if (!maskedOp)
827       return failure();
828     return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
829   }
830 
831 protected:
832   virtual LogicalResult
833   matchAndRewriteMaskableOp(vector::MaskOp maskOp,
834                             vector::MaskableOpInterface maskableOp,
835                             ConversionPatternRewriter &rewriter) const = 0;
836 };
837 
838 class MaskedReductionOpConversion
839     : public VectorMaskOpConversionBase<vector::ReductionOp> {
840 
841 public:
842   using VectorMaskOpConversionBase<
843       vector::ReductionOp>::VectorMaskOpConversionBase;
844 
845   LogicalResult matchAndRewriteMaskableOp(
846       vector::MaskOp maskOp, MaskableOpInterface maskableOp,
847       ConversionPatternRewriter &rewriter) const override {
848     auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
849     auto kind = reductionOp.getKind();
850     Type eltType = reductionOp.getDest().getType();
851     Type llvmType = typeConverter->convertType(eltType);
852     Value operand = reductionOp.getVector();
853     Value acc = reductionOp.getAcc();
854     Location loc = reductionOp.getLoc();
855 
856     Value result;
857     switch (kind) {
858     case vector::CombiningKind::ADD:
859       result = lowerReductionWithStartValue<
860           LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
861           ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
862                                 maskOp.getMask());
863       break;
864     case vector::CombiningKind::MUL:
865       result = lowerReductionWithStartValue<
866           LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
867           ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
868                                  maskOp.getMask());
869       break;
870     case vector::CombiningKind::MINUI:
871       result = lowerReductionWithStartValue<LLVM::VPReduceUMinOp,
872                                             ReductionNeutralUIntMax>(
873           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
874       break;
875     case vector::CombiningKind::MINSI:
876       result = lowerReductionWithStartValue<LLVM::VPReduceSMinOp,
877                                             ReductionNeutralSIntMax>(
878           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
879       break;
880     case vector::CombiningKind::MAXUI:
881       result = lowerReductionWithStartValue<LLVM::VPReduceUMaxOp,
882                                             ReductionNeutralUIntMin>(
883           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
884       break;
885     case vector::CombiningKind::MAXSI:
886       result = lowerReductionWithStartValue<LLVM::VPReduceSMaxOp,
887                                             ReductionNeutralSIntMin>(
888           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
889       break;
890     case vector::CombiningKind::AND:
891       result = lowerReductionWithStartValue<LLVM::VPReduceAndOp,
892                                             ReductionNeutralAllOnes>(
893           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
894       break;
895     case vector::CombiningKind::OR:
896       result = lowerReductionWithStartValue<LLVM::VPReduceOrOp,
897                                             ReductionNeutralZero>(
898           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
899       break;
900     case vector::CombiningKind::XOR:
901       result = lowerReductionWithStartValue<LLVM::VPReduceXorOp,
902                                             ReductionNeutralZero>(
903           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
904       break;
905     case vector::CombiningKind::MINF:
906       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
907       // NaNs/-0.0/+0.0 in the same way.
908       result = lowerReductionWithStartValue<LLVM::VPReduceFMinOp,
909                                             ReductionNeutralFPMax>(
910           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
911       break;
912     case vector::CombiningKind::MAXF:
913       // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
914       // NaNs/-0.0/+0.0 in the same way.
915       result = lowerReductionWithStartValue<LLVM::VPReduceFMaxOp,
916                                             ReductionNeutralFPMin>(
917           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
918       break;
919     }
920 
921     // Replace `vector.mask` operation altogether.
922     rewriter.replaceOp(maskOp, result);
923     return success();
924   }
925 };
926 
927 class VectorShuffleOpConversion
928     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
929 public:
930   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
931 
932   LogicalResult
933   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
934                   ConversionPatternRewriter &rewriter) const override {
935     auto loc = shuffleOp->getLoc();
936     auto v1Type = shuffleOp.getV1VectorType();
937     auto v2Type = shuffleOp.getV2VectorType();
938     auto vectorType = shuffleOp.getResultVectorType();
939     Type llvmType = typeConverter->convertType(vectorType);
940     auto maskArrayAttr = shuffleOp.getMask();
941 
942     // Bail if result type cannot be lowered.
943     if (!llvmType)
944       return failure();
945 
946     // Get rank and dimension sizes.
947     int64_t rank = vectorType.getRank();
948 #ifndef NDEBUG
949     bool wellFormed0DCase =
950         v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
951     bool wellFormedNDCase =
952         v1Type.getRank() == rank && v2Type.getRank() == rank;
953     assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
954 #endif
955 
956     // For rank 0 and 1, where both operands have *exactly* the same vector
957     // type, there is direct shuffle support in LLVM. Use it!
958     if (rank <= 1 && v1Type == v2Type) {
959       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
960           loc, adaptor.getV1(), adaptor.getV2(),
961           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
962       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
963       return success();
964     }
965 
966     // For all other cases, insert the individual values individually.
967     int64_t v1Dim = v1Type.getDimSize(0);
968     Type eltType;
969     if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
970       eltType = arrayType.getElementType();
971     else
972       eltType = llvmType.cast<VectorType>().getElementType();
973     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
974     int64_t insPos = 0;
975     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
976       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
977       Value value = adaptor.getV1();
978       if (extPos >= v1Dim) {
979         extPos -= v1Dim;
980         value = adaptor.getV2();
981       }
982       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
983                                  eltType, rank, extPos);
984       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
985                          llvmType, rank, insPos++);
986     }
987     rewriter.replaceOp(shuffleOp, insert);
988     return success();
989   }
990 };
991 
992 class VectorExtractElementOpConversion
993     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
994 public:
995   using ConvertOpToLLVMPattern<
996       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
997 
998   LogicalResult
999   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1000                   ConversionPatternRewriter &rewriter) const override {
1001     auto vectorType = extractEltOp.getSourceVectorType();
1002     auto llvmType = typeConverter->convertType(vectorType.getElementType());
1003 
1004     // Bail if result type cannot be lowered.
1005     if (!llvmType)
1006       return failure();
1007 
1008     if (vectorType.getRank() == 0) {
1009       Location loc = extractEltOp.getLoc();
1010       auto idxType = rewriter.getIndexType();
1011       auto zero = rewriter.create<LLVM::ConstantOp>(
1012           loc, typeConverter->convertType(idxType),
1013           rewriter.getIntegerAttr(idxType, 0));
1014       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1015           extractEltOp, llvmType, adaptor.getVector(), zero);
1016       return success();
1017     }
1018 
1019     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1020         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1021     return success();
1022   }
1023 };
1024 
1025 class VectorExtractOpConversion
1026     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1027 public:
1028   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1029 
1030   LogicalResult
1031   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1032                   ConversionPatternRewriter &rewriter) const override {
1033     auto loc = extractOp->getLoc();
1034     auto resultType = extractOp.getResult().getType();
1035     auto llvmResultType = typeConverter->convertType(resultType);
1036     auto positionArrayAttr = extractOp.getPosition();
1037 
1038     // Bail if result type cannot be lowered.
1039     if (!llvmResultType)
1040       return failure();
1041 
1042     // Extract entire vector. Should be handled by folder, but just to be safe.
1043     if (positionArrayAttr.empty()) {
1044       rewriter.replaceOp(extractOp, adaptor.getVector());
1045       return success();
1046     }
1047 
1048     // One-shot extraction of vector from array (only requires extractvalue).
1049     if (resultType.isa<VectorType>()) {
1050       SmallVector<int64_t> indices;
1051       for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
1052         indices.push_back(idx.getInt());
1053       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1054           loc, adaptor.getVector(), indices);
1055       rewriter.replaceOp(extractOp, extracted);
1056       return success();
1057     }
1058 
1059     // Potential extraction of 1-D vector from array.
1060     Value extracted = adaptor.getVector();
1061     auto positionAttrs = positionArrayAttr.getValue();
1062     if (positionAttrs.size() > 1) {
1063       SmallVector<int64_t> nMinusOnePosition;
1064       for (auto idx : positionAttrs.drop_back())
1065         nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt());
1066       extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1067                                                         nMinusOnePosition);
1068     }
1069 
1070     // Remaining extraction of element from 1-D LLVM vector
1071     auto position = positionAttrs.back().cast<IntegerAttr>();
1072     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1073     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
1074     extracted =
1075         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
1076     rewriter.replaceOp(extractOp, extracted);
1077 
1078     return success();
1079   }
1080 };
1081 
1082 /// Conversion pattern that turns a vector.fma on a 1-D vector
1083 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1084 /// This does not match vectors of n >= 2 rank.
1085 ///
1086 /// Example:
1087 /// ```
1088 ///  vector.fma %a, %a, %a : vector<8xf32>
1089 /// ```
1090 /// is converted to:
1091 /// ```
1092 ///  llvm.intr.fmuladd %va, %va, %va:
1093 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1094 ///    -> !llvm."<8 x f32>">
1095 /// ```
1096 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1097 public:
1098   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1099 
1100   LogicalResult
1101   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1102                   ConversionPatternRewriter &rewriter) const override {
1103     VectorType vType = fmaOp.getVectorType();
1104     if (vType.getRank() > 1)
1105       return failure();
1106 
1107     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1108         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1109     return success();
1110   }
1111 };
1112 
1113 class VectorInsertElementOpConversion
1114     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1115 public:
1116   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1117 
1118   LogicalResult
1119   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1120                   ConversionPatternRewriter &rewriter) const override {
1121     auto vectorType = insertEltOp.getDestVectorType();
1122     auto llvmType = typeConverter->convertType(vectorType);
1123 
1124     // Bail if result type cannot be lowered.
1125     if (!llvmType)
1126       return failure();
1127 
1128     if (vectorType.getRank() == 0) {
1129       Location loc = insertEltOp.getLoc();
1130       auto idxType = rewriter.getIndexType();
1131       auto zero = rewriter.create<LLVM::ConstantOp>(
1132           loc, typeConverter->convertType(idxType),
1133           rewriter.getIntegerAttr(idxType, 0));
1134       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1135           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1136       return success();
1137     }
1138 
1139     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1140         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1141         adaptor.getPosition());
1142     return success();
1143   }
1144 };
1145 
1146 class VectorInsertOpConversion
1147     : public ConvertOpToLLVMPattern<vector::InsertOp> {
1148 public:
1149   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1150 
1151   LogicalResult
1152   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1153                   ConversionPatternRewriter &rewriter) const override {
1154     auto loc = insertOp->getLoc();
1155     auto sourceType = insertOp.getSourceType();
1156     auto destVectorType = insertOp.getDestVectorType();
1157     auto llvmResultType = typeConverter->convertType(destVectorType);
1158     auto positionArrayAttr = insertOp.getPosition();
1159 
1160     // Bail if result type cannot be lowered.
1161     if (!llvmResultType)
1162       return failure();
1163 
1164     // Overwrite entire vector with value. Should be handled by folder, but
1165     // just to be safe.
1166     if (positionArrayAttr.empty()) {
1167       rewriter.replaceOp(insertOp, adaptor.getSource());
1168       return success();
1169     }
1170 
1171     // One-shot insertion of a vector into an array (only requires insertvalue).
1172     if (sourceType.isa<VectorType>()) {
1173       Value inserted = rewriter.create<LLVM::InsertValueOp>(
1174           loc, adaptor.getDest(), adaptor.getSource(),
1175           LLVM::convertArrayToIndices(positionArrayAttr));
1176       rewriter.replaceOp(insertOp, inserted);
1177       return success();
1178     }
1179 
1180     // Potential extraction of 1-D vector from array.
1181     Value extracted = adaptor.getDest();
1182     auto positionAttrs = positionArrayAttr.getValue();
1183     auto position = positionAttrs.back().cast<IntegerAttr>();
1184     auto oneDVectorType = destVectorType;
1185     if (positionAttrs.size() > 1) {
1186       oneDVectorType = reducedVectorTypeBack(destVectorType);
1187       extracted = rewriter.create<LLVM::ExtractValueOp>(
1188           loc, extracted,
1189           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
1190     }
1191 
1192     // Insertion of an element into a 1-D LLVM vector.
1193     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
1194     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
1195     Value inserted = rewriter.create<LLVM::InsertElementOp>(
1196         loc, typeConverter->convertType(oneDVectorType), extracted,
1197         adaptor.getSource(), constant);
1198 
1199     // Potential insertion of resulting 1-D vector into array.
1200     if (positionAttrs.size() > 1) {
1201       inserted = rewriter.create<LLVM::InsertValueOp>(
1202           loc, adaptor.getDest(), inserted,
1203           LLVM::convertArrayToIndices(positionAttrs.drop_back()));
1204     }
1205 
1206     rewriter.replaceOp(insertOp, inserted);
1207     return success();
1208   }
1209 };
1210 
1211 /// Lower vector.scalable.insert ops to LLVM vector.insert
1212 struct VectorScalableInsertOpLowering
1213     : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1214   using ConvertOpToLLVMPattern<
1215       vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1216 
1217   LogicalResult
1218   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1219                   ConversionPatternRewriter &rewriter) const override {
1220     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1221         insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
1222     return success();
1223   }
1224 };
1225 
1226 /// Lower vector.scalable.extract ops to LLVM vector.extract
1227 struct VectorScalableExtractOpLowering
1228     : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1229   using ConvertOpToLLVMPattern<
1230       vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1231 
1232   LogicalResult
1233   matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1234                   ConversionPatternRewriter &rewriter) const override {
1235     rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1236         extOp, typeConverter->convertType(extOp.getResultVectorType()),
1237         adaptor.getSource(), adaptor.getPos());
1238     return success();
1239   }
1240 };
1241 
1242 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1243 ///
1244 /// Example:
1245 /// ```
1246 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
1247 /// ```
1248 /// is rewritten into:
1249 /// ```
1250 ///  %r = splat %f0: vector<2x4xf32>
1251 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
1252 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1253 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1254 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1255 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1256 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1257 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1258 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1259 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1260 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1261 ///  // %r3 holds the final value.
1262 /// ```
1263 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1264 public:
1265   using OpRewritePattern<FMAOp>::OpRewritePattern;
1266 
1267   void initialize() {
1268     // This pattern recursively unpacks one dimension at a time. The recursion
1269     // bounded as the rank is strictly decreasing.
1270     setHasBoundedRewriteRecursion();
1271   }
1272 
1273   LogicalResult matchAndRewrite(FMAOp op,
1274                                 PatternRewriter &rewriter) const override {
1275     auto vType = op.getVectorType();
1276     if (vType.getRank() < 2)
1277       return failure();
1278 
1279     auto loc = op.getLoc();
1280     auto elemType = vType.getElementType();
1281     Value zero = rewriter.create<arith::ConstantOp>(
1282         loc, elemType, rewriter.getZeroAttr(elemType));
1283     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1284     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1285       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1286       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1287       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1288       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1289       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1290     }
1291     rewriter.replaceOp(op, desc);
1292     return success();
1293   }
1294 };
1295 
1296 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1297 /// static layout.
1298 static std::optional<SmallVector<int64_t, 4>>
1299 computeContiguousStrides(MemRefType memRefType) {
1300   int64_t offset;
1301   SmallVector<int64_t, 4> strides;
1302   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1303     return std::nullopt;
1304   if (!strides.empty() && strides.back() != 1)
1305     return std::nullopt;
1306   // If no layout or identity layout, this is contiguous by definition.
1307   if (memRefType.getLayout().isIdentity())
1308     return strides;
1309 
1310   // Otherwise, we must determine contiguity form shapes. This can only ever
1311   // work in static cases because MemRefType is underspecified to represent
1312   // contiguous dynamic shapes in other ways than with just empty/identity
1313   // layout.
1314   auto sizes = memRefType.getShape();
1315   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1316     if (ShapedType::isDynamic(sizes[index + 1]) ||
1317         ShapedType::isDynamic(strides[index]) ||
1318         ShapedType::isDynamic(strides[index + 1]))
1319       return std::nullopt;
1320     if (strides[index] != strides[index + 1] * sizes[index + 1])
1321       return std::nullopt;
1322   }
1323   return strides;
1324 }
1325 
1326 class VectorTypeCastOpConversion
1327     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1328 public:
1329   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1330 
1331   LogicalResult
1332   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1333                   ConversionPatternRewriter &rewriter) const override {
1334     auto loc = castOp->getLoc();
1335     MemRefType sourceMemRefType =
1336         castOp.getOperand().getType().cast<MemRefType>();
1337     MemRefType targetMemRefType = castOp.getType();
1338 
1339     // Only static shape casts supported atm.
1340     if (!sourceMemRefType.hasStaticShape() ||
1341         !targetMemRefType.hasStaticShape())
1342       return failure();
1343 
1344     auto llvmSourceDescriptorTy =
1345         adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
1346     if (!llvmSourceDescriptorTy)
1347       return failure();
1348     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1349 
1350     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1351                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1352     if (!llvmTargetDescriptorTy)
1353       return failure();
1354 
1355     // Only contiguous source buffers supported atm.
1356     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1357     if (!sourceStrides)
1358       return failure();
1359     auto targetStrides = computeContiguousStrides(targetMemRefType);
1360     if (!targetStrides)
1361       return failure();
1362     // Only support static strides for now, regardless of contiguity.
1363     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1364       return failure();
1365 
1366     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1367 
1368     // Create descriptor.
1369     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1370     Type llvmTargetElementTy = desc.getElementPtrType();
1371     // Set allocated ptr.
1372     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1373     if (!getTypeConverter()->useOpaquePointers())
1374       allocated =
1375           rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1376     desc.setAllocatedPtr(rewriter, loc, allocated);
1377 
1378     // Set aligned ptr.
1379     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1380     if (!getTypeConverter()->useOpaquePointers())
1381       ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1382 
1383     desc.setAlignedPtr(rewriter, loc, ptr);
1384     // Fill offset 0.
1385     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1386     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1387     desc.setOffset(rewriter, loc, zero);
1388 
1389     // Fill size and stride descriptors in memref.
1390     for (const auto &indexedSize :
1391          llvm::enumerate(targetMemRefType.getShape())) {
1392       int64_t index = indexedSize.index();
1393       auto sizeAttr =
1394           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1395       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1396       desc.setSize(rewriter, loc, index, size);
1397       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1398                                                 (*targetStrides)[index]);
1399       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1400       desc.setStride(rewriter, loc, index, stride);
1401     }
1402 
1403     rewriter.replaceOp(castOp, {desc});
1404     return success();
1405   }
1406 };
1407 
1408 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1409 /// Non-scalable versions of this operation are handled in Vector Transforms.
1410 class VectorCreateMaskOpRewritePattern
1411     : public OpRewritePattern<vector::CreateMaskOp> {
1412 public:
1413   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1414                                             bool enableIndexOpt)
1415       : OpRewritePattern<vector::CreateMaskOp>(context),
1416         force32BitVectorIndices(enableIndexOpt) {}
1417 
1418   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1419                                 PatternRewriter &rewriter) const override {
1420     auto dstType = op.getType();
1421     if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
1422       return failure();
1423     IntegerType idxType =
1424         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1425     auto loc = op->getLoc();
1426     Value indices = rewriter.create<LLVM::StepVectorOp>(
1427         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1428                                  /*isScalable=*/true));
1429     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1430                                                  op.getOperand(0));
1431     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1432     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1433                                                 indices, bounds);
1434     rewriter.replaceOp(op, comp);
1435     return success();
1436   }
1437 
1438 private:
1439   const bool force32BitVectorIndices;
1440 };
1441 
1442 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1443 public:
1444   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1445 
1446   // Proof-of-concept lowering implementation that relies on a small
1447   // runtime support library, which only needs to provide a few
1448   // printing methods (single value for all data types, opening/closing
1449   // bracket, comma, newline). The lowering fully unrolls a vector
1450   // in terms of these elementary printing operations. The advantage
1451   // of this approach is that the library can remain unaware of all
1452   // low-level implementation details of vectors while still supporting
1453   // output of any shaped and dimensioned vector. Due to full unrolling,
1454   // this approach is less suited for very large vectors though.
1455   //
1456   // TODO: rely solely on libc in future? something else?
1457   //
1458   LogicalResult
1459   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1460                   ConversionPatternRewriter &rewriter) const override {
1461     Type printType = printOp.getPrintType();
1462 
1463     if (typeConverter->convertType(printType) == nullptr)
1464       return failure();
1465 
1466     // Make sure element type has runtime support.
1467     PrintConversion conversion = PrintConversion::None;
1468     VectorType vectorType = printType.dyn_cast<VectorType>();
1469     Type eltType = vectorType ? vectorType.getElementType() : printType;
1470     auto parent = printOp->getParentOfType<ModuleOp>();
1471     Operation *printer;
1472     if (eltType.isF32()) {
1473       printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1474     } else if (eltType.isF64()) {
1475       printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1476     } else if (eltType.isF16()) {
1477       conversion = PrintConversion::Bitcast16; // bits!
1478       printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1479     } else if (eltType.isBF16()) {
1480       conversion = PrintConversion::Bitcast16; // bits!
1481       printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1482     } else if (eltType.isIndex()) {
1483       printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1484     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1485       // Integers need a zero or sign extension on the operand
1486       // (depending on the source type) as well as a signed or
1487       // unsigned print method. Up to 64-bit is supported.
1488       unsigned width = intTy.getWidth();
1489       if (intTy.isUnsigned()) {
1490         if (width <= 64) {
1491           if (width < 64)
1492             conversion = PrintConversion::ZeroExt64;
1493           printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1494         } else {
1495           return failure();
1496         }
1497       } else {
1498         assert(intTy.isSignless() || intTy.isSigned());
1499         if (width <= 64) {
1500           // Note that we *always* zero extend booleans (1-bit integers),
1501           // so that true/false is printed as 1/0 rather than -1/0.
1502           if (width == 1)
1503             conversion = PrintConversion::ZeroExt64;
1504           else if (width < 64)
1505             conversion = PrintConversion::SignExt64;
1506           printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1507         } else {
1508           return failure();
1509         }
1510       }
1511     } else {
1512       return failure();
1513     }
1514 
1515     // Unroll vector into elementary print calls.
1516     int64_t rank = vectorType ? vectorType.getRank() : 0;
1517     Type type = vectorType ? vectorType : eltType;
1518     emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1519               conversion);
1520     emitCall(rewriter, printOp->getLoc(),
1521              LLVM::lookupOrCreatePrintNewlineFn(parent));
1522     rewriter.eraseOp(printOp);
1523     return success();
1524   }
1525 
1526 private:
1527   enum class PrintConversion {
1528     // clang-format off
1529     None,
1530     ZeroExt64,
1531     SignExt64,
1532     Bitcast16
1533     // clang-format on
1534   };
1535 
1536   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1537                  Value value, Type type, Operation *printer, int64_t rank,
1538                  PrintConversion conversion) const {
1539     VectorType vectorType = type.dyn_cast<VectorType>();
1540     Location loc = op->getLoc();
1541     if (!vectorType) {
1542       assert(rank == 0 && "The scalar case expects rank == 0");
1543       switch (conversion) {
1544       case PrintConversion::ZeroExt64:
1545         value = rewriter.create<arith::ExtUIOp>(
1546             loc, IntegerType::get(rewriter.getContext(), 64), value);
1547         break;
1548       case PrintConversion::SignExt64:
1549         value = rewriter.create<arith::ExtSIOp>(
1550             loc, IntegerType::get(rewriter.getContext(), 64), value);
1551         break;
1552       case PrintConversion::Bitcast16:
1553         value = rewriter.create<LLVM::BitcastOp>(
1554             loc, IntegerType::get(rewriter.getContext(), 16), value);
1555         break;
1556       case PrintConversion::None:
1557         break;
1558       }
1559       emitCall(rewriter, loc, printer, value);
1560       return;
1561     }
1562 
1563     auto parent = op->getParentOfType<ModuleOp>();
1564     emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
1565     Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
1566 
1567     if (rank <= 1) {
1568       auto reducedType = vectorType.getElementType();
1569       auto llvmType = typeConverter->convertType(reducedType);
1570       int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1571       for (int64_t d = 0; d < dim; ++d) {
1572         Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1573                                      llvmType, /*rank=*/0, /*pos=*/d);
1574         emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1575                   conversion);
1576         if (d != dim - 1)
1577           emitCall(rewriter, loc, printComma);
1578       }
1579       emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
1580       return;
1581     }
1582 
1583     int64_t dim = vectorType.getDimSize(0);
1584     for (int64_t d = 0; d < dim; ++d) {
1585       auto reducedType = reducedVectorTypeFront(vectorType);
1586       auto llvmType = typeConverter->convertType(reducedType);
1587       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1588                                    llvmType, rank, d);
1589       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1590                 conversion);
1591       if (d != dim - 1)
1592         emitCall(rewriter, loc, printComma);
1593     }
1594     emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
1595   }
1596 
1597   // Helper to emit a call.
1598   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1599                        Operation *ref, ValueRange params = ValueRange()) {
1600     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1601                                   params);
1602   }
1603 };
1604 
1605 /// The Splat operation is lowered to an insertelement + a shufflevector
1606 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1607 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1608   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1609 
1610   LogicalResult
1611   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1612                   ConversionPatternRewriter &rewriter) const override {
1613     VectorType resultType = splatOp.getType().cast<VectorType>();
1614     if (resultType.getRank() > 1)
1615       return failure();
1616 
1617     // First insert it into an undef vector so we can shuffle it.
1618     auto vectorType = typeConverter->convertType(splatOp.getType());
1619     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1620     auto zero = rewriter.create<LLVM::ConstantOp>(
1621         splatOp.getLoc(),
1622         typeConverter->convertType(rewriter.getIntegerType(32)),
1623         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1624 
1625     // For 0-d vector, we simply do `insertelement`.
1626     if (resultType.getRank() == 0) {
1627       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1628           splatOp, vectorType, undef, adaptor.getInput(), zero);
1629       return success();
1630     }
1631 
1632     // For 1-d vector, we additionally do a `vectorshuffle`.
1633     auto v = rewriter.create<LLVM::InsertElementOp>(
1634         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1635 
1636     int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1637     SmallVector<int32_t> zeroValues(width, 0);
1638 
1639     // Shuffle the value across the desired number of elements.
1640     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1641                                                        zeroValues);
1642     return success();
1643   }
1644 };
1645 
1646 /// The Splat operation is lowered to an insertelement + a shufflevector
1647 /// operation. Splat to only 2+-d vector result types are lowered by the
1648 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1649 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1650   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1651 
1652   LogicalResult
1653   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1654                   ConversionPatternRewriter &rewriter) const override {
1655     VectorType resultType = splatOp.getType();
1656     if (resultType.getRank() <= 1)
1657       return failure();
1658 
1659     // First insert it into an undef vector so we can shuffle it.
1660     auto loc = splatOp.getLoc();
1661     auto vectorTypeInfo =
1662         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1663     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1664     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1665     if (!llvmNDVectorTy || !llvm1DVectorTy)
1666       return failure();
1667 
1668     // Construct returned value.
1669     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1670 
1671     // Construct a 1-D vector with the splatted value that we insert in all the
1672     // places within the returned descriptor.
1673     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1674     auto zero = rewriter.create<LLVM::ConstantOp>(
1675         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1676         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1677     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1678                                                      adaptor.getInput(), zero);
1679 
1680     // Shuffle the value across the desired number of elements.
1681     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1682     SmallVector<int32_t> zeroValues(width, 0);
1683     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1684 
1685     // Iterate of linear index, convert to coords space and insert splatted 1-D
1686     // vector in each position.
1687     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1688       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1689     });
1690     rewriter.replaceOp(splatOp, desc);
1691     return success();
1692   }
1693 };
1694 
1695 } // namespace
1696 
1697 /// Populate the given list with patterns that convert from Vector to LLVM.
1698 void mlir::populateVectorToLLVMConversionPatterns(
1699     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1700     bool reassociateFPReductions, bool force32BitVectorIndices) {
1701   MLIRContext *ctx = converter.getDialect()->getContext();
1702   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1703   populateVectorInsertExtractStridedSliceTransforms(patterns);
1704   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1705   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1706   patterns
1707       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1708            VectorExtractElementOpConversion, VectorExtractOpConversion,
1709            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1710            VectorInsertOpConversion, VectorPrintOpConversion,
1711            VectorTypeCastOpConversion, VectorScaleOpConversion,
1712            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1713            VectorLoadStoreConversion<vector::MaskedLoadOp,
1714                                      vector::MaskedLoadOpAdaptor>,
1715            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1716            VectorLoadStoreConversion<vector::MaskedStoreOp,
1717                                      vector::MaskedStoreOpAdaptor>,
1718            VectorGatherOpConversion, VectorScatterOpConversion,
1719            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1720            VectorSplatOpLowering, VectorSplatNdOpLowering,
1721            VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1722            MaskedReductionOpConversion>(converter);
1723   // Transfer ops with rank > 1 are handled by VectorToSCF.
1724   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1725 }
1726 
1727 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1728     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1729   patterns.add<VectorMatmulOpConversion>(converter);
1730   patterns.add<VectorFlatTransposeOpConversion>(converter);
1731 }
1732