xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 78c49743c7fb2353b1ca56d56f9ffda7f47e4bae)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypeInterfaces.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/Support/Casting.h"
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::vector;
35 
36 // Helper to reduce vector type by *all* but one rank at back.
37 static VectorType reducedVectorTypeBack(VectorType tp) {
38   assert((tp.getRank() > 1) && "unlowerable vector type");
39   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
40                          tp.getScalableDims().take_back());
41 }
42 
43 // Helper that picks the proper sequence for inserting.
44 static Value insertOne(ConversionPatternRewriter &rewriter,
45                        const LLVMTypeConverter &typeConverter, Location loc,
46                        Value val1, Value val2, Type llvmType, int64_t rank,
47                        int64_t pos) {
48   assert(rank > 0 && "0-D vector corner case should have been handled already");
49   if (rank == 1) {
50     auto idxType = rewriter.getIndexType();
51     auto constant = rewriter.create<LLVM::ConstantOp>(
52         loc, typeConverter.convertType(idxType),
53         rewriter.getIntegerAttr(idxType, pos));
54     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
55                                                   constant);
56   }
57   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
58 }
59 
60 // Helper that picks the proper sequence for extracting.
61 static Value extractOne(ConversionPatternRewriter &rewriter,
62                         const LLVMTypeConverter &typeConverter, Location loc,
63                         Value val, Type llvmType, int64_t rank, int64_t pos) {
64   if (rank <= 1) {
65     auto idxType = rewriter.getIndexType();
66     auto constant = rewriter.create<LLVM::ConstantOp>(
67         loc, typeConverter.convertType(idxType),
68         rewriter.getIntegerAttr(idxType, pos));
69     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
70                                                    constant);
71   }
72   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
73 }
74 
75 // Helper that returns data layout alignment of a memref.
76 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
77                                  MemRefType memrefType, unsigned &align) {
78   Type elementTy = typeConverter.convertType(memrefType.getElementType());
79   if (!elementTy)
80     return failure();
81 
82   // TODO: this should use the MLIR data layout when it becomes available and
83   // stop depending on translation.
84   llvm::LLVMContext llvmContext;
85   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
86               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
87   return success();
88 }
89 
90 // Check if the last stride is non-unit and has a valid memory space.
91 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
92                                            const LLVMTypeConverter &converter) {
93   if (!isLastMemrefDimUnitStride(memRefType))
94     return failure();
95   if (failed(converter.getMemRefAddressSpace(memRefType)))
96     return failure();
97   return success();
98 }
99 
100 // Add an index vector component to a base pointer.
101 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
102                             const LLVMTypeConverter &typeConverter,
103                             MemRefType memRefType, Value llvmMemref, Value base,
104                             Value index, uint64_t vLen) {
105   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
106          "unsupported memref type");
107   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
108   auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
109   return rewriter.create<LLVM::GEPOp>(
110       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
111       base, index);
112 }
113 
114 // Casts a strided element pointer to a vector pointer.  The vector pointer
115 // will be in the same address space as the incoming memref type.
116 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
117                          Value ptr, MemRefType memRefType, Type vt,
118                          const LLVMTypeConverter &converter) {
119   if (converter.useOpaquePointers())
120     return ptr;
121 
122   unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
123   auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
124   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
125 }
126 
127 /// Convert `foldResult` into a Value. Integer attribute is converted to
128 /// an LLVM constant op.
129 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
130                             OpFoldResult foldResult) {
131   if (auto attr = foldResult.dyn_cast<Attribute>()) {
132     auto intAttr = cast<IntegerAttr>(attr);
133     return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
134   }
135 
136   return foldResult.get<Value>();
137 }
138 
139 namespace {
140 
141 /// Trivial Vector to LLVM conversions
142 using VectorScaleOpConversion =
143     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
144 
145 /// Conversion pattern for a vector.bitcast.
146 class VectorBitCastOpConversion
147     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
148 public:
149   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
150 
151   LogicalResult
152   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
153                   ConversionPatternRewriter &rewriter) const override {
154     // Only 0-D and 1-D vectors can be lowered to LLVM.
155     VectorType resultTy = bitCastOp.getResultVectorType();
156     if (resultTy.getRank() > 1)
157       return failure();
158     Type newResultTy = typeConverter->convertType(resultTy);
159     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
160                                                  adaptor.getOperands()[0]);
161     return success();
162   }
163 };
164 
165 /// Conversion pattern for a vector.matrix_multiply.
166 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
167 class VectorMatmulOpConversion
168     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
169 public:
170   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
171 
172   LogicalResult
173   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
174                   ConversionPatternRewriter &rewriter) const override {
175     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
176         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
177         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
178         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
179     return success();
180   }
181 };
182 
183 /// Conversion pattern for a vector.flat_transpose.
184 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
185 class VectorFlatTransposeOpConversion
186     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
187 public:
188   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
189 
190   LogicalResult
191   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
192                   ConversionPatternRewriter &rewriter) const override {
193     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
194         transOp, typeConverter->convertType(transOp.getRes().getType()),
195         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
196     return success();
197   }
198 };
199 
200 /// Overloaded utility that replaces a vector.load, vector.store,
201 /// vector.maskedload and vector.maskedstore with their respective LLVM
202 /// couterparts.
203 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
204                                  vector::LoadOpAdaptor adaptor,
205                                  VectorType vectorTy, Value ptr, unsigned align,
206                                  ConversionPatternRewriter &rewriter) {
207   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
208 }
209 
210 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
211                                  vector::MaskedLoadOpAdaptor adaptor,
212                                  VectorType vectorTy, Value ptr, unsigned align,
213                                  ConversionPatternRewriter &rewriter) {
214   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
215       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
216 }
217 
218 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
219                                  vector::StoreOpAdaptor adaptor,
220                                  VectorType vectorTy, Value ptr, unsigned align,
221                                  ConversionPatternRewriter &rewriter) {
222   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
223                                              ptr, align);
224 }
225 
226 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
227                                  vector::MaskedStoreOpAdaptor adaptor,
228                                  VectorType vectorTy, Value ptr, unsigned align,
229                                  ConversionPatternRewriter &rewriter) {
230   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
231       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
232 }
233 
234 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
235 /// vector.maskedstore.
236 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
237 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
238 public:
239   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
240 
241   LogicalResult
242   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
243                   typename LoadOrStoreOp::Adaptor adaptor,
244                   ConversionPatternRewriter &rewriter) const override {
245     // Only 1-D vectors can be lowered to LLVM.
246     VectorType vectorTy = loadOrStoreOp.getVectorType();
247     if (vectorTy.getRank() > 1)
248       return failure();
249 
250     auto loc = loadOrStoreOp->getLoc();
251     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
252 
253     // Resolve alignment.
254     unsigned align;
255     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
256       return failure();
257 
258     // Resolve address.
259     auto vtype = cast<VectorType>(
260         this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
261     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
262                                                adaptor.getIndices(), rewriter);
263     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
264                             *this->getTypeConverter());
265 
266     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
267     return success();
268   }
269 };
270 
271 /// Conversion pattern for a vector.gather.
272 class VectorGatherOpConversion
273     : public ConvertOpToLLVMPattern<vector::GatherOp> {
274 public:
275   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
276 
277   LogicalResult
278   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
279                   ConversionPatternRewriter &rewriter) const override {
280     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
281     assert(memRefType && "The base should be bufferized");
282 
283     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
284       return failure();
285 
286     auto loc = gather->getLoc();
287 
288     // Resolve alignment.
289     unsigned align;
290     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
291       return failure();
292 
293     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
294                                      adaptor.getIndices(), rewriter);
295     Value base = adaptor.getBase();
296 
297     auto llvmNDVectorTy = adaptor.getIndexVec().getType();
298     // Handle the simple case of 1-D vector.
299     if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
300       auto vType = gather.getVectorType();
301       // Resolve address.
302       Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
303                                   memRefType, base, ptr, adaptor.getIndexVec(),
304                                   /*vLen=*/vType.getDimSize(0));
305       // Replace with the gather intrinsic.
306       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
307           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
308           adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
309       return success();
310     }
311 
312     const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
313     auto callback = [align, memRefType, base, ptr, loc, &rewriter,
314                      &typeConverter](Type llvm1DVectorTy,
315                                      ValueRange vectorOperands) {
316       // Resolve address.
317       Value ptrs = getIndexedPtrs(
318           rewriter, loc, typeConverter, memRefType, base, ptr,
319           /*index=*/vectorOperands[0],
320           LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
321       // Create the gather intrinsic.
322       return rewriter.create<LLVM::masked_gather>(
323           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
324           /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
325     };
326     SmallVector<Value> vectorOperands = {
327         adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
328     return LLVM::detail::handleMultidimensionalVectors(
329         gather, vectorOperands, *getTypeConverter(), callback, rewriter);
330   }
331 };
332 
333 /// Conversion pattern for a vector.scatter.
334 class VectorScatterOpConversion
335     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
336 public:
337   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
338 
339   LogicalResult
340   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
341                   ConversionPatternRewriter &rewriter) const override {
342     auto loc = scatter->getLoc();
343     MemRefType memRefType = scatter.getMemRefType();
344 
345     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
346       return failure();
347 
348     // Resolve alignment.
349     unsigned align;
350     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
351       return failure();
352 
353     // Resolve address.
354     VectorType vType = scatter.getVectorType();
355     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
356                                      adaptor.getIndices(), rewriter);
357     Value ptrs = getIndexedPtrs(
358         rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
359         ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
360 
361     // Replace with the scatter intrinsic.
362     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
363         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
364         rewriter.getI32IntegerAttr(align));
365     return success();
366   }
367 };
368 
369 /// Conversion pattern for a vector.expandload.
370 class VectorExpandLoadOpConversion
371     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
372 public:
373   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
374 
375   LogicalResult
376   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
377                   ConversionPatternRewriter &rewriter) const override {
378     auto loc = expand->getLoc();
379     MemRefType memRefType = expand.getMemRefType();
380 
381     // Resolve address.
382     auto vtype = typeConverter->convertType(expand.getVectorType());
383     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
384                                      adaptor.getIndices(), rewriter);
385 
386     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
387         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
388     return success();
389   }
390 };
391 
392 /// Conversion pattern for a vector.compressstore.
393 class VectorCompressStoreOpConversion
394     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
395 public:
396   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
397 
398   LogicalResult
399   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
400                   ConversionPatternRewriter &rewriter) const override {
401     auto loc = compress->getLoc();
402     MemRefType memRefType = compress.getMemRefType();
403 
404     // Resolve address.
405     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
406                                      adaptor.getIndices(), rewriter);
407 
408     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
409         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
410     return success();
411   }
412 };
413 
414 /// Reduction neutral classes for overloading.
415 class ReductionNeutralZero {};
416 class ReductionNeutralIntOne {};
417 class ReductionNeutralFPOne {};
418 class ReductionNeutralAllOnes {};
419 class ReductionNeutralSIntMin {};
420 class ReductionNeutralUIntMin {};
421 class ReductionNeutralSIntMax {};
422 class ReductionNeutralUIntMax {};
423 class ReductionNeutralFPMin {};
424 class ReductionNeutralFPMax {};
425 
426 /// Create the reduction neutral zero value.
427 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
428                                          ConversionPatternRewriter &rewriter,
429                                          Location loc, Type llvmType) {
430   return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
431                                            rewriter.getZeroAttr(llvmType));
432 }
433 
434 /// Create the reduction neutral integer one value.
435 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
436                                          ConversionPatternRewriter &rewriter,
437                                          Location loc, Type llvmType) {
438   return rewriter.create<LLVM::ConstantOp>(
439       loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
440 }
441 
442 /// Create the reduction neutral fp one value.
443 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
444                                          ConversionPatternRewriter &rewriter,
445                                          Location loc, Type llvmType) {
446   return rewriter.create<LLVM::ConstantOp>(
447       loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
448 }
449 
450 /// Create the reduction neutral all-ones value.
451 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
452                                          ConversionPatternRewriter &rewriter,
453                                          Location loc, Type llvmType) {
454   return rewriter.create<LLVM::ConstantOp>(
455       loc, llvmType,
456       rewriter.getIntegerAttr(
457           llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
458 }
459 
460 /// Create the reduction neutral signed int minimum value.
461 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
462                                          ConversionPatternRewriter &rewriter,
463                                          Location loc, Type llvmType) {
464   return rewriter.create<LLVM::ConstantOp>(
465       loc, llvmType,
466       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
467                                             llvmType.getIntOrFloatBitWidth())));
468 }
469 
470 /// Create the reduction neutral unsigned int minimum value.
471 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
472                                          ConversionPatternRewriter &rewriter,
473                                          Location loc, Type llvmType) {
474   return rewriter.create<LLVM::ConstantOp>(
475       loc, llvmType,
476       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
477                                             llvmType.getIntOrFloatBitWidth())));
478 }
479 
480 /// Create the reduction neutral signed int maximum value.
481 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
482                                          ConversionPatternRewriter &rewriter,
483                                          Location loc, Type llvmType) {
484   return rewriter.create<LLVM::ConstantOp>(
485       loc, llvmType,
486       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
487                                             llvmType.getIntOrFloatBitWidth())));
488 }
489 
490 /// Create the reduction neutral unsigned int maximum value.
491 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
492                                          ConversionPatternRewriter &rewriter,
493                                          Location loc, Type llvmType) {
494   return rewriter.create<LLVM::ConstantOp>(
495       loc, llvmType,
496       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
497                                             llvmType.getIntOrFloatBitWidth())));
498 }
499 
500 /// Create the reduction neutral fp minimum value.
501 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
502                                          ConversionPatternRewriter &rewriter,
503                                          Location loc, Type llvmType) {
504   auto floatType = cast<FloatType>(llvmType);
505   return rewriter.create<LLVM::ConstantOp>(
506       loc, llvmType,
507       rewriter.getFloatAttr(
508           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
509                                            /*Negative=*/false)));
510 }
511 
512 /// Create the reduction neutral fp maximum value.
513 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
514                                          ConversionPatternRewriter &rewriter,
515                                          Location loc, Type llvmType) {
516   auto floatType = cast<FloatType>(llvmType);
517   return rewriter.create<LLVM::ConstantOp>(
518       loc, llvmType,
519       rewriter.getFloatAttr(
520           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
521                                            /*Negative=*/true)));
522 }
523 
524 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
525 /// returns a new accumulator value using `ReductionNeutral`.
526 template <class ReductionNeutral>
527 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
528                                     Location loc, Type llvmType,
529                                     Value accumulator) {
530   if (accumulator)
531     return accumulator;
532 
533   return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
534                                      llvmType);
535 }
536 
537 /// Creates a constant value with the 1-D vector shape provided in `llvmType`.
538 /// This is used as effective vector length by some intrinsics supporting
539 /// dynamic vector lengths at runtime.
540 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
541                                      Location loc, Type llvmType) {
542   VectorType vType = cast<VectorType>(llvmType);
543   auto vShape = vType.getShape();
544   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
545 
546   return rewriter.create<LLVM::ConstantOp>(
547       loc, rewriter.getI32Type(),
548       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
549 }
550 
551 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
552 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
553 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
554 /// non-null.
555 template <class LLVMRedIntrinOp, class ScalarOp>
556 static Value createIntegerReductionArithmeticOpLowering(
557     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
558     Value vectorOperand, Value accumulator) {
559 
560   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
561 
562   if (accumulator)
563     result = rewriter.create<ScalarOp>(loc, accumulator, result);
564   return result;
565 }
566 
567 /// Helper method to lower a `vector.reduction` operation that performs
568 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
569 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
570 /// the accumulator value if non-null.
571 template <class LLVMRedIntrinOp>
572 static Value createIntegerReductionComparisonOpLowering(
573     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
574     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
575   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
576   if (accumulator) {
577     Value cmp =
578         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
579     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
580   }
581   return result;
582 }
583 
584 namespace {
585 template <typename Source>
586 struct VectorToScalarMapper;
587 template <>
588 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
589   using Type = LLVM::MaximumOp;
590 };
591 template <>
592 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
593   using Type = LLVM::MinimumOp;
594 };
595 template <>
596 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
597   using Type = LLVM::MaxNumOp;
598 };
599 template <>
600 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
601   using Type = LLVM::MinNumOp;
602 };
603 } // namespace
604 
605 template <class LLVMRedIntrinOp>
606 static Value createFPReductionComparisonOpLowering(
607     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
608     Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
609   Value result =
610       rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
611 
612   if (accumulator) {
613     result =
614         rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
615             loc, result, accumulator);
616   }
617 
618   return result;
619 }
620 
621 /// Reduction neutral classes for overloading
622 class MaskNeutralFMaximum {};
623 class MaskNeutralFMinimum {};
624 
625 /// Get the mask neutral floating point maximum value
626 static llvm::APFloat
627 getMaskNeutralValue(MaskNeutralFMaximum,
628                     const llvm::fltSemantics &floatSemantics) {
629   return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
630 }
631 /// Get the mask neutral floating point minimum value
632 static llvm::APFloat
633 getMaskNeutralValue(MaskNeutralFMinimum,
634                     const llvm::fltSemantics &floatSemantics) {
635   return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
636 }
637 
638 /// Create the mask neutral floating point MLIR vector constant
639 template <typename MaskNeutral>
640 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
641                                     Location loc, Type llvmType,
642                                     Type vectorType) {
643   const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
644   auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
645   auto denseValue =
646       DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
647   return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
648 }
649 
650 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
651 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
652 /// `fmaximum`/`fminimum`.
653 /// More information: https://github.com/llvm/llvm-project/issues/64940
654 template <class LLVMRedIntrinOp, class MaskNeutral>
655 static Value
656 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
657                                 Location loc, Type llvmType,
658                                 Value vectorOperand, Value accumulator,
659                                 Value mask, LLVM::FastmathFlagsAttr fmf) {
660   const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
661       rewriter, loc, llvmType, vectorOperand.getType());
662   const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
663       loc, mask, vectorOperand, vectorMaskNeutral);
664   return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
665       rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
666 }
667 
668 template <class LLVMRedIntrinOp, class ReductionNeutral>
669 static Value
670 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
671                              Type llvmType, Value vectorOperand,
672                              Value accumulator, LLVM::FastmathFlagsAttr fmf) {
673   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
674                                                          llvmType, accumulator);
675   return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
676                                           /*startValue=*/accumulator,
677                                           vectorOperand, fmf);
678 }
679 
680 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
681 /// that requires a start value. This start value format spans across fp
682 /// reductions without mask and all the masked reduction intrinsics.
683 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
684 static Value
685 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
686                                        Location loc, Type llvmType,
687                                        Value vectorOperand, Value accumulator) {
688   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
689                                                          llvmType, accumulator);
690   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
691                                             /*startValue=*/accumulator,
692                                             vectorOperand);
693 }
694 
695 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
696 static Value lowerPredicatedReductionWithStartValue(
697     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
698     Value vectorOperand, Value accumulator, Value mask) {
699   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
700                                                          llvmType, accumulator);
701   Value vectorLength =
702       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
703   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
704                                             /*startValue=*/accumulator,
705                                             vectorOperand, mask, vectorLength);
706 }
707 
708 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
709           class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
710 static Value lowerPredicatedReductionWithStartValue(
711     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
712     Value vectorOperand, Value accumulator, Value mask) {
713   if (llvmType.isIntOrIndex())
714     return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
715                                                   IntReductionNeutral>(
716         rewriter, loc, llvmType, vectorOperand, accumulator, mask);
717 
718   // FP dispatch.
719   return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
720                                                 FPReductionNeutral>(
721       rewriter, loc, llvmType, vectorOperand, accumulator, mask);
722 }
723 
724 /// Conversion pattern for all vector reductions.
725 class VectorReductionOpConversion
726     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
727 public:
728   explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
729                                        bool reassociateFPRed)
730       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
731         reassociateFPReductions(reassociateFPRed) {}
732 
733   LogicalResult
734   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
735                   ConversionPatternRewriter &rewriter) const override {
736     auto kind = reductionOp.getKind();
737     Type eltType = reductionOp.getDest().getType();
738     Type llvmType = typeConverter->convertType(eltType);
739     Value operand = adaptor.getVector();
740     Value acc = adaptor.getAcc();
741     Location loc = reductionOp.getLoc();
742 
743     if (eltType.isIntOrIndex()) {
744       // Integer reductions: add/mul/min/max/and/or/xor.
745       Value result;
746       switch (kind) {
747       case vector::CombiningKind::ADD:
748         result =
749             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
750                                                        LLVM::AddOp>(
751                 rewriter, loc, llvmType, operand, acc);
752         break;
753       case vector::CombiningKind::MUL:
754         result =
755             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
756                                                        LLVM::MulOp>(
757                 rewriter, loc, llvmType, operand, acc);
758         break;
759       case vector::CombiningKind::MINUI:
760         result = createIntegerReductionComparisonOpLowering<
761             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
762                                       LLVM::ICmpPredicate::ule);
763         break;
764       case vector::CombiningKind::MINSI:
765         result = createIntegerReductionComparisonOpLowering<
766             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
767                                       LLVM::ICmpPredicate::sle);
768         break;
769       case vector::CombiningKind::MAXUI:
770         result = createIntegerReductionComparisonOpLowering<
771             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
772                                       LLVM::ICmpPredicate::uge);
773         break;
774       case vector::CombiningKind::MAXSI:
775         result = createIntegerReductionComparisonOpLowering<
776             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
777                                       LLVM::ICmpPredicate::sge);
778         break;
779       case vector::CombiningKind::AND:
780         result =
781             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
782                                                        LLVM::AndOp>(
783                 rewriter, loc, llvmType, operand, acc);
784         break;
785       case vector::CombiningKind::OR:
786         result =
787             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
788                                                        LLVM::OrOp>(
789                 rewriter, loc, llvmType, operand, acc);
790         break;
791       case vector::CombiningKind::XOR:
792         result =
793             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
794                                                        LLVM::XOrOp>(
795                 rewriter, loc, llvmType, operand, acc);
796         break;
797       default:
798         return failure();
799       }
800       rewriter.replaceOp(reductionOp, result);
801 
802       return success();
803     }
804 
805     if (!isa<FloatType>(eltType))
806       return failure();
807 
808     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
809     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
810         reductionOp.getContext(),
811         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
812     fmf = LLVM::FastmathFlagsAttr::get(
813         reductionOp.getContext(),
814         fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
815                                                   : LLVM::FastmathFlags::none));
816 
817     // Floating-point reductions: add/mul/min/max
818     Value result;
819     if (kind == vector::CombiningKind::ADD) {
820       result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
821                                             ReductionNeutralZero>(
822           rewriter, loc, llvmType, operand, acc, fmf);
823     } else if (kind == vector::CombiningKind::MUL) {
824       result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
825                                             ReductionNeutralFPOne>(
826           rewriter, loc, llvmType, operand, acc, fmf);
827     } else if (kind == vector::CombiningKind::MINIMUMF) {
828       result =
829           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
830               rewriter, loc, llvmType, operand, acc, fmf);
831     } else if (kind == vector::CombiningKind::MAXIMUMF) {
832       result =
833           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
834               rewriter, loc, llvmType, operand, acc, fmf);
835     } else if (kind == vector::CombiningKind::MINF) {
836       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
837           rewriter, loc, llvmType, operand, acc, fmf);
838     } else if (kind == vector::CombiningKind::MAXF) {
839       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
840           rewriter, loc, llvmType, operand, acc, fmf);
841     } else
842       return failure();
843 
844     rewriter.replaceOp(reductionOp, result);
845     return success();
846   }
847 
848 private:
849   const bool reassociateFPReductions;
850 };
851 
852 /// Base class to convert a `vector.mask` operation while matching traits
853 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
854 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
855 /// method performs a second match against the maskable operation `MaskedOp`.
856 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
857 /// implemented by the concrete conversion classes. This method can match
858 /// against specific traits of the `vector.mask` and the maskable operation. It
859 /// must replace the `vector.mask` operation.
860 template <class MaskedOp>
861 class VectorMaskOpConversionBase
862     : public ConvertOpToLLVMPattern<vector::MaskOp> {
863 public:
864   using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
865 
866   LogicalResult
867   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
868                   ConversionPatternRewriter &rewriter) const final {
869     // Match against the maskable operation kind.
870     auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
871     if (!maskedOp)
872       return failure();
873     return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
874   }
875 
876 protected:
877   virtual LogicalResult
878   matchAndRewriteMaskableOp(vector::MaskOp maskOp,
879                             vector::MaskableOpInterface maskableOp,
880                             ConversionPatternRewriter &rewriter) const = 0;
881 };
882 
883 class MaskedReductionOpConversion
884     : public VectorMaskOpConversionBase<vector::ReductionOp> {
885 
886 public:
887   using VectorMaskOpConversionBase<
888       vector::ReductionOp>::VectorMaskOpConversionBase;
889 
890   LogicalResult matchAndRewriteMaskableOp(
891       vector::MaskOp maskOp, MaskableOpInterface maskableOp,
892       ConversionPatternRewriter &rewriter) const override {
893     auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
894     auto kind = reductionOp.getKind();
895     Type eltType = reductionOp.getDest().getType();
896     Type llvmType = typeConverter->convertType(eltType);
897     Value operand = reductionOp.getVector();
898     Value acc = reductionOp.getAcc();
899     Location loc = reductionOp.getLoc();
900 
901     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
902     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
903         reductionOp.getContext(),
904         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
905 
906     Value result;
907     switch (kind) {
908     case vector::CombiningKind::ADD:
909       result = lowerPredicatedReductionWithStartValue<
910           LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
911           ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
912                                 maskOp.getMask());
913       break;
914     case vector::CombiningKind::MUL:
915       result = lowerPredicatedReductionWithStartValue<
916           LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
917           ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
918                                  maskOp.getMask());
919       break;
920     case vector::CombiningKind::MINUI:
921       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
922                                                       ReductionNeutralUIntMax>(
923           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
924       break;
925     case vector::CombiningKind::MINSI:
926       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
927                                                       ReductionNeutralSIntMax>(
928           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
929       break;
930     case vector::CombiningKind::MAXUI:
931       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
932                                                       ReductionNeutralUIntMin>(
933           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
934       break;
935     case vector::CombiningKind::MAXSI:
936       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
937                                                       ReductionNeutralSIntMin>(
938           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
939       break;
940     case vector::CombiningKind::AND:
941       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
942                                                       ReductionNeutralAllOnes>(
943           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
944       break;
945     case vector::CombiningKind::OR:
946       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
947                                                       ReductionNeutralZero>(
948           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
949       break;
950     case vector::CombiningKind::XOR:
951       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
952                                                       ReductionNeutralZero>(
953           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
954       break;
955     case vector::CombiningKind::MINF:
956       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
957                                                       ReductionNeutralFPMax>(
958           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
959       break;
960     case vector::CombiningKind::MAXF:
961       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
962                                                       ReductionNeutralFPMin>(
963           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
964       break;
965     case CombiningKind::MAXIMUMF:
966       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
967                                                MaskNeutralFMaximum>(
968           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
969       break;
970     case CombiningKind::MINIMUMF:
971       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
972                                                MaskNeutralFMinimum>(
973           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
974       break;
975     }
976 
977     // Replace `vector.mask` operation altogether.
978     rewriter.replaceOp(maskOp, result);
979     return success();
980   }
981 };
982 
983 class VectorShuffleOpConversion
984     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
985 public:
986   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
987 
988   LogicalResult
989   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
990                   ConversionPatternRewriter &rewriter) const override {
991     auto loc = shuffleOp->getLoc();
992     auto v1Type = shuffleOp.getV1VectorType();
993     auto v2Type = shuffleOp.getV2VectorType();
994     auto vectorType = shuffleOp.getResultVectorType();
995     Type llvmType = typeConverter->convertType(vectorType);
996     auto maskArrayAttr = shuffleOp.getMask();
997 
998     // Bail if result type cannot be lowered.
999     if (!llvmType)
1000       return failure();
1001 
1002     // Get rank and dimension sizes.
1003     int64_t rank = vectorType.getRank();
1004 #ifndef NDEBUG
1005     bool wellFormed0DCase =
1006         v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1007     bool wellFormedNDCase =
1008         v1Type.getRank() == rank && v2Type.getRank() == rank;
1009     assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1010 #endif
1011 
1012     // For rank 0 and 1, where both operands have *exactly* the same vector
1013     // type, there is direct shuffle support in LLVM. Use it!
1014     if (rank <= 1 && v1Type == v2Type) {
1015       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1016           loc, adaptor.getV1(), adaptor.getV2(),
1017           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1018       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1019       return success();
1020     }
1021 
1022     // For all other cases, insert the individual values individually.
1023     int64_t v1Dim = v1Type.getDimSize(0);
1024     Type eltType;
1025     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1026       eltType = arrayType.getElementType();
1027     else
1028       eltType = cast<VectorType>(llvmType).getElementType();
1029     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1030     int64_t insPos = 0;
1031     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
1032       int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1033       Value value = adaptor.getV1();
1034       if (extPos >= v1Dim) {
1035         extPos -= v1Dim;
1036         value = adaptor.getV2();
1037       }
1038       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1039                                  eltType, rank, extPos);
1040       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1041                          llvmType, rank, insPos++);
1042     }
1043     rewriter.replaceOp(shuffleOp, insert);
1044     return success();
1045   }
1046 };
1047 
1048 class VectorExtractElementOpConversion
1049     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1050 public:
1051   using ConvertOpToLLVMPattern<
1052       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1053 
1054   LogicalResult
1055   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1056                   ConversionPatternRewriter &rewriter) const override {
1057     auto vectorType = extractEltOp.getSourceVectorType();
1058     auto llvmType = typeConverter->convertType(vectorType.getElementType());
1059 
1060     // Bail if result type cannot be lowered.
1061     if (!llvmType)
1062       return failure();
1063 
1064     if (vectorType.getRank() == 0) {
1065       Location loc = extractEltOp.getLoc();
1066       auto idxType = rewriter.getIndexType();
1067       auto zero = rewriter.create<LLVM::ConstantOp>(
1068           loc, typeConverter->convertType(idxType),
1069           rewriter.getIntegerAttr(idxType, 0));
1070       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1071           extractEltOp, llvmType, adaptor.getVector(), zero);
1072       return success();
1073     }
1074 
1075     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1076         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1077     return success();
1078   }
1079 };
1080 
1081 class VectorExtractOpConversion
1082     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1083 public:
1084   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1085 
1086   LogicalResult
1087   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1088                   ConversionPatternRewriter &rewriter) const override {
1089     auto loc = extractOp->getLoc();
1090     auto resultType = extractOp.getResult().getType();
1091     auto llvmResultType = typeConverter->convertType(resultType);
1092     // Bail if result type cannot be lowered.
1093     if (!llvmResultType)
1094       return failure();
1095 
1096     SmallVector<OpFoldResult> positionVec;
1097     for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1098       if (pos.is<Value>())
1099         // Make sure we use the value that has been already converted to LLVM.
1100         positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1101       else
1102         positionVec.push_back(pos);
1103     }
1104 
1105     // Extract entire vector. Should be handled by folder, but just to be safe.
1106     ArrayRef<OpFoldResult> position(positionVec);
1107     if (position.empty()) {
1108       rewriter.replaceOp(extractOp, adaptor.getVector());
1109       return success();
1110     }
1111 
1112     // One-shot extraction of vector from array (only requires extractvalue).
1113     if (isa<VectorType>(resultType)) {
1114       if (extractOp.hasDynamicPosition())
1115         return failure();
1116 
1117       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1118           loc, adaptor.getVector(), getAsIntegers(position));
1119       rewriter.replaceOp(extractOp, extracted);
1120       return success();
1121     }
1122 
1123     // Potential extraction of 1-D vector from array.
1124     Value extracted = adaptor.getVector();
1125     if (position.size() > 1) {
1126       if (extractOp.hasDynamicPosition())
1127         return failure();
1128 
1129       SmallVector<int64_t> nMinusOnePosition =
1130           getAsIntegers(position.drop_back());
1131       extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1132                                                         nMinusOnePosition);
1133     }
1134 
1135     Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1136     // Remaining extraction of element from 1-D LLVM vector.
1137     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1138                                                         lastPosition);
1139     return success();
1140   }
1141 };
1142 
1143 /// Conversion pattern that turns a vector.fma on a 1-D vector
1144 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1145 /// This does not match vectors of n >= 2 rank.
1146 ///
1147 /// Example:
1148 /// ```
1149 ///  vector.fma %a, %a, %a : vector<8xf32>
1150 /// ```
1151 /// is converted to:
1152 /// ```
1153 ///  llvm.intr.fmuladd %va, %va, %va:
1154 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1155 ///    -> !llvm."<8 x f32>">
1156 /// ```
1157 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1158 public:
1159   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1160 
1161   LogicalResult
1162   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1163                   ConversionPatternRewriter &rewriter) const override {
1164     VectorType vType = fmaOp.getVectorType();
1165     if (vType.getRank() > 1)
1166       return failure();
1167 
1168     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1169         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1170     return success();
1171   }
1172 };
1173 
1174 class VectorInsertElementOpConversion
1175     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1176 public:
1177   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1178 
1179   LogicalResult
1180   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1181                   ConversionPatternRewriter &rewriter) const override {
1182     auto vectorType = insertEltOp.getDestVectorType();
1183     auto llvmType = typeConverter->convertType(vectorType);
1184 
1185     // Bail if result type cannot be lowered.
1186     if (!llvmType)
1187       return failure();
1188 
1189     if (vectorType.getRank() == 0) {
1190       Location loc = insertEltOp.getLoc();
1191       auto idxType = rewriter.getIndexType();
1192       auto zero = rewriter.create<LLVM::ConstantOp>(
1193           loc, typeConverter->convertType(idxType),
1194           rewriter.getIntegerAttr(idxType, 0));
1195       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1196           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1197       return success();
1198     }
1199 
1200     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1201         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1202         adaptor.getPosition());
1203     return success();
1204   }
1205 };
1206 
1207 class VectorInsertOpConversion
1208     : public ConvertOpToLLVMPattern<vector::InsertOp> {
1209 public:
1210   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1211 
1212   LogicalResult
1213   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1214                   ConversionPatternRewriter &rewriter) const override {
1215     auto loc = insertOp->getLoc();
1216     auto sourceType = insertOp.getSourceType();
1217     auto destVectorType = insertOp.getDestVectorType();
1218     auto llvmResultType = typeConverter->convertType(destVectorType);
1219     // Bail if result type cannot be lowered.
1220     if (!llvmResultType)
1221       return failure();
1222 
1223     SmallVector<OpFoldResult> positionVec;
1224     for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
1225       if (pos.is<Value>())
1226         // Make sure we use the value that has been already converted to LLVM.
1227         positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1228       else
1229         positionVec.push_back(pos);
1230     }
1231 
1232     // Overwrite entire vector with value. Should be handled by folder, but
1233     // just to be safe.
1234     ArrayRef<OpFoldResult> position(positionVec);
1235     if (position.empty()) {
1236       rewriter.replaceOp(insertOp, adaptor.getSource());
1237       return success();
1238     }
1239 
1240     // One-shot insertion of a vector into an array (only requires insertvalue).
1241     if (isa<VectorType>(sourceType)) {
1242       if (insertOp.hasDynamicPosition())
1243         return failure();
1244 
1245       Value inserted = rewriter.create<LLVM::InsertValueOp>(
1246           loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1247       rewriter.replaceOp(insertOp, inserted);
1248       return success();
1249     }
1250 
1251     // Potential extraction of 1-D vector from array.
1252     Value extracted = adaptor.getDest();
1253     auto oneDVectorType = destVectorType;
1254     if (position.size() > 1) {
1255       if (insertOp.hasDynamicPosition())
1256         return failure();
1257 
1258       oneDVectorType = reducedVectorTypeBack(destVectorType);
1259       extracted = rewriter.create<LLVM::ExtractValueOp>(
1260           loc, extracted, getAsIntegers(position.drop_back()));
1261     }
1262 
1263     // Insertion of an element into a 1-D LLVM vector.
1264     Value inserted = rewriter.create<LLVM::InsertElementOp>(
1265         loc, typeConverter->convertType(oneDVectorType), extracted,
1266         adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1267 
1268     // Potential insertion of resulting 1-D vector into array.
1269     if (position.size() > 1) {
1270       if (insertOp.hasDynamicPosition())
1271         return failure();
1272 
1273       inserted = rewriter.create<LLVM::InsertValueOp>(
1274           loc, adaptor.getDest(), inserted,
1275           getAsIntegers(position.drop_back()));
1276     }
1277 
1278     rewriter.replaceOp(insertOp, inserted);
1279     return success();
1280   }
1281 };
1282 
1283 /// Lower vector.scalable.insert ops to LLVM vector.insert
1284 struct VectorScalableInsertOpLowering
1285     : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1286   using ConvertOpToLLVMPattern<
1287       vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1288 
1289   LogicalResult
1290   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1291                   ConversionPatternRewriter &rewriter) const override {
1292     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1293         insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
1294     return success();
1295   }
1296 };
1297 
1298 /// Lower vector.scalable.extract ops to LLVM vector.extract
1299 struct VectorScalableExtractOpLowering
1300     : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1301   using ConvertOpToLLVMPattern<
1302       vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1303 
1304   LogicalResult
1305   matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1306                   ConversionPatternRewriter &rewriter) const override {
1307     rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1308         extOp, typeConverter->convertType(extOp.getResultVectorType()),
1309         adaptor.getSource(), adaptor.getPos());
1310     return success();
1311   }
1312 };
1313 
1314 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1315 ///
1316 /// Example:
1317 /// ```
1318 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
1319 /// ```
1320 /// is rewritten into:
1321 /// ```
1322 ///  %r = splat %f0: vector<2x4xf32>
1323 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
1324 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1325 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1326 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1327 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1328 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1329 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1330 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1331 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1332 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1333 ///  // %r3 holds the final value.
1334 /// ```
1335 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1336 public:
1337   using OpRewritePattern<FMAOp>::OpRewritePattern;
1338 
1339   void initialize() {
1340     // This pattern recursively unpacks one dimension at a time. The recursion
1341     // bounded as the rank is strictly decreasing.
1342     setHasBoundedRewriteRecursion();
1343   }
1344 
1345   LogicalResult matchAndRewrite(FMAOp op,
1346                                 PatternRewriter &rewriter) const override {
1347     auto vType = op.getVectorType();
1348     if (vType.getRank() < 2)
1349       return failure();
1350 
1351     auto loc = op.getLoc();
1352     auto elemType = vType.getElementType();
1353     Value zero = rewriter.create<arith::ConstantOp>(
1354         loc, elemType, rewriter.getZeroAttr(elemType));
1355     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1356     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1357       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1358       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1359       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1360       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1361       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1362     }
1363     rewriter.replaceOp(op, desc);
1364     return success();
1365   }
1366 };
1367 
1368 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1369 /// static layout.
1370 static std::optional<SmallVector<int64_t, 4>>
1371 computeContiguousStrides(MemRefType memRefType) {
1372   int64_t offset;
1373   SmallVector<int64_t, 4> strides;
1374   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1375     return std::nullopt;
1376   if (!strides.empty() && strides.back() != 1)
1377     return std::nullopt;
1378   // If no layout or identity layout, this is contiguous by definition.
1379   if (memRefType.getLayout().isIdentity())
1380     return strides;
1381 
1382   // Otherwise, we must determine contiguity form shapes. This can only ever
1383   // work in static cases because MemRefType is underspecified to represent
1384   // contiguous dynamic shapes in other ways than with just empty/identity
1385   // layout.
1386   auto sizes = memRefType.getShape();
1387   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1388     if (ShapedType::isDynamic(sizes[index + 1]) ||
1389         ShapedType::isDynamic(strides[index]) ||
1390         ShapedType::isDynamic(strides[index + 1]))
1391       return std::nullopt;
1392     if (strides[index] != strides[index + 1] * sizes[index + 1])
1393       return std::nullopt;
1394   }
1395   return strides;
1396 }
1397 
1398 class VectorTypeCastOpConversion
1399     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1400 public:
1401   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1402 
1403   LogicalResult
1404   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1405                   ConversionPatternRewriter &rewriter) const override {
1406     auto loc = castOp->getLoc();
1407     MemRefType sourceMemRefType =
1408         cast<MemRefType>(castOp.getOperand().getType());
1409     MemRefType targetMemRefType = castOp.getType();
1410 
1411     // Only static shape casts supported atm.
1412     if (!sourceMemRefType.hasStaticShape() ||
1413         !targetMemRefType.hasStaticShape())
1414       return failure();
1415 
1416     auto llvmSourceDescriptorTy =
1417         dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1418     if (!llvmSourceDescriptorTy)
1419       return failure();
1420     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1421 
1422     auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1423         typeConverter->convertType(targetMemRefType));
1424     if (!llvmTargetDescriptorTy)
1425       return failure();
1426 
1427     // Only contiguous source buffers supported atm.
1428     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1429     if (!sourceStrides)
1430       return failure();
1431     auto targetStrides = computeContiguousStrides(targetMemRefType);
1432     if (!targetStrides)
1433       return failure();
1434     // Only support static strides for now, regardless of contiguity.
1435     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1436       return failure();
1437 
1438     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1439 
1440     // Create descriptor.
1441     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1442     Type llvmTargetElementTy = desc.getElementPtrType();
1443     // Set allocated ptr.
1444     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1445     if (!getTypeConverter()->useOpaquePointers())
1446       allocated =
1447           rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1448     desc.setAllocatedPtr(rewriter, loc, allocated);
1449 
1450     // Set aligned ptr.
1451     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1452     if (!getTypeConverter()->useOpaquePointers())
1453       ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1454 
1455     desc.setAlignedPtr(rewriter, loc, ptr);
1456     // Fill offset 0.
1457     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1458     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1459     desc.setOffset(rewriter, loc, zero);
1460 
1461     // Fill size and stride descriptors in memref.
1462     for (const auto &indexedSize :
1463          llvm::enumerate(targetMemRefType.getShape())) {
1464       int64_t index = indexedSize.index();
1465       auto sizeAttr =
1466           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1467       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1468       desc.setSize(rewriter, loc, index, size);
1469       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1470                                                 (*targetStrides)[index]);
1471       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1472       desc.setStride(rewriter, loc, index, stride);
1473     }
1474 
1475     rewriter.replaceOp(castOp, {desc});
1476     return success();
1477   }
1478 };
1479 
1480 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1481 /// Non-scalable versions of this operation are handled in Vector Transforms.
1482 class VectorCreateMaskOpRewritePattern
1483     : public OpRewritePattern<vector::CreateMaskOp> {
1484 public:
1485   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1486                                             bool enableIndexOpt)
1487       : OpRewritePattern<vector::CreateMaskOp>(context),
1488         force32BitVectorIndices(enableIndexOpt) {}
1489 
1490   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1491                                 PatternRewriter &rewriter) const override {
1492     auto dstType = op.getType();
1493     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1494       return failure();
1495     IntegerType idxType =
1496         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1497     auto loc = op->getLoc();
1498     Value indices = rewriter.create<LLVM::StepVectorOp>(
1499         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1500                                  /*isScalable=*/true));
1501     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1502                                                  op.getOperand(0));
1503     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1504     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1505                                                 indices, bounds);
1506     rewriter.replaceOp(op, comp);
1507     return success();
1508   }
1509 
1510 private:
1511   const bool force32BitVectorIndices;
1512 };
1513 
1514 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1515 public:
1516   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1517 
1518   // Lowering implementation that relies on a small runtime support library,
1519   // which only needs to provide a few printing methods (single value for all
1520   // data types, opening/closing bracket, comma, newline). The lowering splits
1521   // the vector into elementary printing operations. The advantage of this
1522   // approach is that the library can remain unaware of all low-level
1523   // implementation details of vectors while still supporting output of any
1524   // shaped and dimensioned vector.
1525   //
1526   // Note: This lowering only handles scalars, n-D vectors are broken into
1527   // printing scalars in loops in VectorToSCF.
1528   //
1529   // TODO: rely solely on libc in future? something else?
1530   //
1531   LogicalResult
1532   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1533                   ConversionPatternRewriter &rewriter) const override {
1534     auto parent = printOp->getParentOfType<ModuleOp>();
1535     if (!parent)
1536       return failure();
1537 
1538     auto loc = printOp->getLoc();
1539 
1540     if (auto value = adaptor.getSource()) {
1541       Type printType = printOp.getPrintType();
1542       if (isa<VectorType>(printType)) {
1543         // Vectors should be broken into elementary print ops in VectorToSCF.
1544         return failure();
1545       }
1546       if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1547         return failure();
1548     }
1549 
1550     auto punct = printOp.getPunctuation();
1551     if (punct != PrintPunctuation::NoPunctuation) {
1552       emitCall(rewriter, printOp->getLoc(), [&] {
1553         switch (punct) {
1554         case PrintPunctuation::Close:
1555           return LLVM::lookupOrCreatePrintCloseFn(parent);
1556         case PrintPunctuation::Open:
1557           return LLVM::lookupOrCreatePrintOpenFn(parent);
1558         case PrintPunctuation::Comma:
1559           return LLVM::lookupOrCreatePrintCommaFn(parent);
1560         case PrintPunctuation::NewLine:
1561           return LLVM::lookupOrCreatePrintNewlineFn(parent);
1562         default:
1563           llvm_unreachable("unexpected punctuation");
1564         }
1565       }());
1566     }
1567 
1568     rewriter.eraseOp(printOp);
1569     return success();
1570   }
1571 
1572 private:
1573   enum class PrintConversion {
1574     // clang-format off
1575     None,
1576     ZeroExt64,
1577     SignExt64,
1578     Bitcast16
1579     // clang-format on
1580   };
1581 
1582   LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1583                                 ModuleOp parent, Location loc, Type printType,
1584                                 Value value) const {
1585     if (typeConverter->convertType(printType) == nullptr)
1586       return failure();
1587 
1588     // Make sure element type has runtime support.
1589     PrintConversion conversion = PrintConversion::None;
1590     Operation *printer;
1591     if (printType.isF32()) {
1592       printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1593     } else if (printType.isF64()) {
1594       printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1595     } else if (printType.isF16()) {
1596       conversion = PrintConversion::Bitcast16; // bits!
1597       printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1598     } else if (printType.isBF16()) {
1599       conversion = PrintConversion::Bitcast16; // bits!
1600       printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1601     } else if (printType.isIndex()) {
1602       printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1603     } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1604       // Integers need a zero or sign extension on the operand
1605       // (depending on the source type) as well as a signed or
1606       // unsigned print method. Up to 64-bit is supported.
1607       unsigned width = intTy.getWidth();
1608       if (intTy.isUnsigned()) {
1609         if (width <= 64) {
1610           if (width < 64)
1611             conversion = PrintConversion::ZeroExt64;
1612           printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1613         } else {
1614           return failure();
1615         }
1616       } else {
1617         assert(intTy.isSignless() || intTy.isSigned());
1618         if (width <= 64) {
1619           // Note that we *always* zero extend booleans (1-bit integers),
1620           // so that true/false is printed as 1/0 rather than -1/0.
1621           if (width == 1)
1622             conversion = PrintConversion::ZeroExt64;
1623           else if (width < 64)
1624             conversion = PrintConversion::SignExt64;
1625           printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1626         } else {
1627           return failure();
1628         }
1629       }
1630     } else {
1631       return failure();
1632     }
1633 
1634     switch (conversion) {
1635     case PrintConversion::ZeroExt64:
1636       value = rewriter.create<arith::ExtUIOp>(
1637           loc, IntegerType::get(rewriter.getContext(), 64), value);
1638       break;
1639     case PrintConversion::SignExt64:
1640       value = rewriter.create<arith::ExtSIOp>(
1641           loc, IntegerType::get(rewriter.getContext(), 64), value);
1642       break;
1643     case PrintConversion::Bitcast16:
1644       value = rewriter.create<LLVM::BitcastOp>(
1645           loc, IntegerType::get(rewriter.getContext(), 16), value);
1646       break;
1647     case PrintConversion::None:
1648       break;
1649     }
1650     emitCall(rewriter, loc, printer, value);
1651     return success();
1652   }
1653 
1654   // Helper to emit a call.
1655   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1656                        Operation *ref, ValueRange params = ValueRange()) {
1657     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1658                                   params);
1659   }
1660 };
1661 
1662 /// The Splat operation is lowered to an insertelement + a shufflevector
1663 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1664 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1665   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1666 
1667   LogicalResult
1668   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1669                   ConversionPatternRewriter &rewriter) const override {
1670     VectorType resultType = cast<VectorType>(splatOp.getType());
1671     if (resultType.getRank() > 1)
1672       return failure();
1673 
1674     // First insert it into an undef vector so we can shuffle it.
1675     auto vectorType = typeConverter->convertType(splatOp.getType());
1676     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1677     auto zero = rewriter.create<LLVM::ConstantOp>(
1678         splatOp.getLoc(),
1679         typeConverter->convertType(rewriter.getIntegerType(32)),
1680         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1681 
1682     // For 0-d vector, we simply do `insertelement`.
1683     if (resultType.getRank() == 0) {
1684       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1685           splatOp, vectorType, undef, adaptor.getInput(), zero);
1686       return success();
1687     }
1688 
1689     // For 1-d vector, we additionally do a `vectorshuffle`.
1690     auto v = rewriter.create<LLVM::InsertElementOp>(
1691         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1692 
1693     int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1694     SmallVector<int32_t> zeroValues(width, 0);
1695 
1696     // Shuffle the value across the desired number of elements.
1697     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1698                                                        zeroValues);
1699     return success();
1700   }
1701 };
1702 
1703 /// The Splat operation is lowered to an insertelement + a shufflevector
1704 /// operation. Splat to only 2+-d vector result types are lowered by the
1705 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1706 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1707   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1708 
1709   LogicalResult
1710   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1711                   ConversionPatternRewriter &rewriter) const override {
1712     VectorType resultType = splatOp.getType();
1713     if (resultType.getRank() <= 1)
1714       return failure();
1715 
1716     // First insert it into an undef vector so we can shuffle it.
1717     auto loc = splatOp.getLoc();
1718     auto vectorTypeInfo =
1719         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1720     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1721     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1722     if (!llvmNDVectorTy || !llvm1DVectorTy)
1723       return failure();
1724 
1725     // Construct returned value.
1726     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1727 
1728     // Construct a 1-D vector with the splatted value that we insert in all the
1729     // places within the returned descriptor.
1730     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1731     auto zero = rewriter.create<LLVM::ConstantOp>(
1732         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1733         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1734     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1735                                                      adaptor.getInput(), zero);
1736 
1737     // Shuffle the value across the desired number of elements.
1738     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1739     SmallVector<int32_t> zeroValues(width, 0);
1740     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1741 
1742     // Iterate of linear index, convert to coords space and insert splatted 1-D
1743     // vector in each position.
1744     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1745       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1746     });
1747     rewriter.replaceOp(splatOp, desc);
1748     return success();
1749   }
1750 };
1751 
1752 } // namespace
1753 
1754 /// Populate the given list with patterns that convert from Vector to LLVM.
1755 void mlir::populateVectorToLLVMConversionPatterns(
1756     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1757     bool reassociateFPReductions, bool force32BitVectorIndices) {
1758   MLIRContext *ctx = converter.getDialect()->getContext();
1759   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1760   populateVectorInsertExtractStridedSliceTransforms(patterns);
1761   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1762   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1763   patterns
1764       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1765            VectorExtractElementOpConversion, VectorExtractOpConversion,
1766            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1767            VectorInsertOpConversion, VectorPrintOpConversion,
1768            VectorTypeCastOpConversion, VectorScaleOpConversion,
1769            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1770            VectorLoadStoreConversion<vector::MaskedLoadOp,
1771                                      vector::MaskedLoadOpAdaptor>,
1772            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1773            VectorLoadStoreConversion<vector::MaskedStoreOp,
1774                                      vector::MaskedStoreOpAdaptor>,
1775            VectorGatherOpConversion, VectorScatterOpConversion,
1776            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1777            VectorSplatOpLowering, VectorSplatNdOpLowering,
1778            VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1779            MaskedReductionOpConversion>(converter);
1780   // Transfer ops with rank > 1 are handled by VectorToSCF.
1781   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1782 }
1783 
1784 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1785     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1786   patterns.add<VectorMatmulOpConversion>(converter);
1787   patterns.add<VectorFlatTransposeOpConversion>(converter);
1788 }
1789