xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 3be3883e6d67bf908fd12b51219075293ebb3dff)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/BuiltinTypeInterfaces.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39   assert((tp.getRank() > 1) && "unlowerable vector type");
40   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41                          tp.getScalableDims().take_back());
42 }
43 
44 // Helper that picks the proper sequence for inserting.
45 static Value insertOne(ConversionPatternRewriter &rewriter,
46                        const LLVMTypeConverter &typeConverter, Location loc,
47                        Value val1, Value val2, Type llvmType, int64_t rank,
48                        int64_t pos) {
49   assert(rank > 0 && "0-D vector corner case should have been handled already");
50   if (rank == 1) {
51     auto idxType = rewriter.getIndexType();
52     auto constant = rewriter.create<LLVM::ConstantOp>(
53         loc, typeConverter.convertType(idxType),
54         rewriter.getIntegerAttr(idxType, pos));
55     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
56                                                   constant);
57   }
58   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
59 }
60 
61 // Helper that picks the proper sequence for extracting.
62 static Value extractOne(ConversionPatternRewriter &rewriter,
63                         const LLVMTypeConverter &typeConverter, Location loc,
64                         Value val, Type llvmType, int64_t rank, int64_t pos) {
65   if (rank <= 1) {
66     auto idxType = rewriter.getIndexType();
67     auto constant = rewriter.create<LLVM::ConstantOp>(
68         loc, typeConverter.convertType(idxType),
69         rewriter.getIntegerAttr(idxType, pos));
70     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
71                                                    constant);
72   }
73   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
74 }
75 
76 // Helper that returns data layout alignment of a memref.
77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
78                                  MemRefType memrefType, unsigned &align) {
79   Type elementTy = typeConverter.convertType(memrefType.getElementType());
80   if (!elementTy)
81     return failure();
82 
83   // TODO: this should use the MLIR data layout when it becomes available and
84   // stop depending on translation.
85   llvm::LLVMContext llvmContext;
86   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
87               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
88   return success();
89 }
90 
91 // Check if the last stride is non-unit and has a valid memory space.
92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
93                                            const LLVMTypeConverter &converter) {
94   if (!isLastMemrefDimUnitStride(memRefType))
95     return failure();
96   if (failed(converter.getMemRefAddressSpace(memRefType)))
97     return failure();
98   return success();
99 }
100 
101 // Add an index vector component to a base pointer.
102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
103                             const LLVMTypeConverter &typeConverter,
104                             MemRefType memRefType, Value llvmMemref, Value base,
105                             Value index, uint64_t vLen) {
106   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107          "unsupported memref type");
108   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
109   auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
110   return rewriter.create<LLVM::GEPOp>(
111       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
112       base, index);
113 }
114 
115 // Casts a strided element pointer to a vector pointer.  The vector pointer
116 // will be in the same address space as the incoming memref type.
117 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
118                          Value ptr, MemRefType memRefType, Type vt,
119                          const LLVMTypeConverter &converter) {
120   if (converter.useOpaquePointers())
121     return ptr;
122 
123   unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
124   auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
125   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
126 }
127 
128 /// Convert `foldResult` into a Value. Integer attribute is converted to
129 /// an LLVM constant op.
130 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
131                             OpFoldResult foldResult) {
132   if (auto attr = foldResult.dyn_cast<Attribute>()) {
133     auto intAttr = cast<IntegerAttr>(attr);
134     return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
135   }
136 
137   return foldResult.get<Value>();
138 }
139 
140 namespace {
141 
142 /// Trivial Vector to LLVM conversions
143 using VectorScaleOpConversion =
144     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
145 
146 /// Conversion pattern for a vector.bitcast.
147 class VectorBitCastOpConversion
148     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
149 public:
150   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
151 
152   LogicalResult
153   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
154                   ConversionPatternRewriter &rewriter) const override {
155     // Only 0-D and 1-D vectors can be lowered to LLVM.
156     VectorType resultTy = bitCastOp.getResultVectorType();
157     if (resultTy.getRank() > 1)
158       return failure();
159     Type newResultTy = typeConverter->convertType(resultTy);
160     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
161                                                  adaptor.getOperands()[0]);
162     return success();
163   }
164 };
165 
166 /// Conversion pattern for a vector.matrix_multiply.
167 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
168 class VectorMatmulOpConversion
169     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
170 public:
171   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
172 
173   LogicalResult
174   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
175                   ConversionPatternRewriter &rewriter) const override {
176     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
177         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
178         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
179         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
180     return success();
181   }
182 };
183 
184 /// Conversion pattern for a vector.flat_transpose.
185 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
186 class VectorFlatTransposeOpConversion
187     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
188 public:
189   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
190 
191   LogicalResult
192   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
193                   ConversionPatternRewriter &rewriter) const override {
194     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
195         transOp, typeConverter->convertType(transOp.getRes().getType()),
196         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
197     return success();
198   }
199 };
200 
201 /// Overloaded utility that replaces a vector.load, vector.store,
202 /// vector.maskedload and vector.maskedstore with their respective LLVM
203 /// couterparts.
204 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
205                                  vector::LoadOpAdaptor adaptor,
206                                  VectorType vectorTy, Value ptr, unsigned align,
207                                  ConversionPatternRewriter &rewriter) {
208   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
209 }
210 
211 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
212                                  vector::MaskedLoadOpAdaptor adaptor,
213                                  VectorType vectorTy, Value ptr, unsigned align,
214                                  ConversionPatternRewriter &rewriter) {
215   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
216       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
217 }
218 
219 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
220                                  vector::StoreOpAdaptor adaptor,
221                                  VectorType vectorTy, Value ptr, unsigned align,
222                                  ConversionPatternRewriter &rewriter) {
223   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
224                                              ptr, align);
225 }
226 
227 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
228                                  vector::MaskedStoreOpAdaptor adaptor,
229                                  VectorType vectorTy, Value ptr, unsigned align,
230                                  ConversionPatternRewriter &rewriter) {
231   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
232       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
233 }
234 
235 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
236 /// vector.maskedstore.
237 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
238 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
239 public:
240   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
241 
242   LogicalResult
243   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
244                   typename LoadOrStoreOp::Adaptor adaptor,
245                   ConversionPatternRewriter &rewriter) const override {
246     // Only 1-D vectors can be lowered to LLVM.
247     VectorType vectorTy = loadOrStoreOp.getVectorType();
248     if (vectorTy.getRank() > 1)
249       return failure();
250 
251     auto loc = loadOrStoreOp->getLoc();
252     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
253 
254     // Resolve alignment.
255     unsigned align;
256     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
257       return failure();
258 
259     // Resolve address.
260     auto vtype = cast<VectorType>(
261         this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
262     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
263                                                adaptor.getIndices(), rewriter);
264     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
265                             *this->getTypeConverter());
266 
267     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
268     return success();
269   }
270 };
271 
272 /// Conversion pattern for a vector.gather.
273 class VectorGatherOpConversion
274     : public ConvertOpToLLVMPattern<vector::GatherOp> {
275 public:
276   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
277 
278   LogicalResult
279   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
280                   ConversionPatternRewriter &rewriter) const override {
281     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
282     assert(memRefType && "The base should be bufferized");
283 
284     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
285       return failure();
286 
287     auto loc = gather->getLoc();
288 
289     // Resolve alignment.
290     unsigned align;
291     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
292       return failure();
293 
294     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
295                                      adaptor.getIndices(), rewriter);
296     Value base = adaptor.getBase();
297 
298     auto llvmNDVectorTy = adaptor.getIndexVec().getType();
299     // Handle the simple case of 1-D vector.
300     if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
301       auto vType = gather.getVectorType();
302       // Resolve address.
303       Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
304                                   memRefType, base, ptr, adaptor.getIndexVec(),
305                                   /*vLen=*/vType.getDimSize(0));
306       // Replace with the gather intrinsic.
307       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
308           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
309           adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
310       return success();
311     }
312 
313     const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
314     auto callback = [align, memRefType, base, ptr, loc, &rewriter,
315                      &typeConverter](Type llvm1DVectorTy,
316                                      ValueRange vectorOperands) {
317       // Resolve address.
318       Value ptrs = getIndexedPtrs(
319           rewriter, loc, typeConverter, memRefType, base, ptr,
320           /*index=*/vectorOperands[0],
321           LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
322       // Create the gather intrinsic.
323       return rewriter.create<LLVM::masked_gather>(
324           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
325           /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
326     };
327     SmallVector<Value> vectorOperands = {
328         adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
329     return LLVM::detail::handleMultidimensionalVectors(
330         gather, vectorOperands, *getTypeConverter(), callback, rewriter);
331   }
332 };
333 
334 /// Conversion pattern for a vector.scatter.
335 class VectorScatterOpConversion
336     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
337 public:
338   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
339 
340   LogicalResult
341   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
342                   ConversionPatternRewriter &rewriter) const override {
343     auto loc = scatter->getLoc();
344     MemRefType memRefType = scatter.getMemRefType();
345 
346     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
347       return failure();
348 
349     // Resolve alignment.
350     unsigned align;
351     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
352       return failure();
353 
354     // Resolve address.
355     VectorType vType = scatter.getVectorType();
356     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
357                                      adaptor.getIndices(), rewriter);
358     Value ptrs = getIndexedPtrs(
359         rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
360         ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
361 
362     // Replace with the scatter intrinsic.
363     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
364         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
365         rewriter.getI32IntegerAttr(align));
366     return success();
367   }
368 };
369 
370 /// Conversion pattern for a vector.expandload.
371 class VectorExpandLoadOpConversion
372     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
373 public:
374   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
375 
376   LogicalResult
377   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
378                   ConversionPatternRewriter &rewriter) const override {
379     auto loc = expand->getLoc();
380     MemRefType memRefType = expand.getMemRefType();
381 
382     // Resolve address.
383     auto vtype = typeConverter->convertType(expand.getVectorType());
384     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
385                                      adaptor.getIndices(), rewriter);
386 
387     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
388         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
389     return success();
390   }
391 };
392 
393 /// Conversion pattern for a vector.compressstore.
394 class VectorCompressStoreOpConversion
395     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
396 public:
397   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
398 
399   LogicalResult
400   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
401                   ConversionPatternRewriter &rewriter) const override {
402     auto loc = compress->getLoc();
403     MemRefType memRefType = compress.getMemRefType();
404 
405     // Resolve address.
406     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
407                                      adaptor.getIndices(), rewriter);
408 
409     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
410         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
411     return success();
412   }
413 };
414 
415 /// Reduction neutral classes for overloading.
416 class ReductionNeutralZero {};
417 class ReductionNeutralIntOne {};
418 class ReductionNeutralFPOne {};
419 class ReductionNeutralAllOnes {};
420 class ReductionNeutralSIntMin {};
421 class ReductionNeutralUIntMin {};
422 class ReductionNeutralSIntMax {};
423 class ReductionNeutralUIntMax {};
424 class ReductionNeutralFPMin {};
425 class ReductionNeutralFPMax {};
426 
427 /// Create the reduction neutral zero value.
428 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
429                                          ConversionPatternRewriter &rewriter,
430                                          Location loc, Type llvmType) {
431   return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
432                                            rewriter.getZeroAttr(llvmType));
433 }
434 
435 /// Create the reduction neutral integer one value.
436 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
437                                          ConversionPatternRewriter &rewriter,
438                                          Location loc, Type llvmType) {
439   return rewriter.create<LLVM::ConstantOp>(
440       loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
441 }
442 
443 /// Create the reduction neutral fp one value.
444 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
445                                          ConversionPatternRewriter &rewriter,
446                                          Location loc, Type llvmType) {
447   return rewriter.create<LLVM::ConstantOp>(
448       loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
449 }
450 
451 /// Create the reduction neutral all-ones value.
452 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
453                                          ConversionPatternRewriter &rewriter,
454                                          Location loc, Type llvmType) {
455   return rewriter.create<LLVM::ConstantOp>(
456       loc, llvmType,
457       rewriter.getIntegerAttr(
458           llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
459 }
460 
461 /// Create the reduction neutral signed int minimum value.
462 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
463                                          ConversionPatternRewriter &rewriter,
464                                          Location loc, Type llvmType) {
465   return rewriter.create<LLVM::ConstantOp>(
466       loc, llvmType,
467       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
468                                             llvmType.getIntOrFloatBitWidth())));
469 }
470 
471 /// Create the reduction neutral unsigned int minimum value.
472 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
473                                          ConversionPatternRewriter &rewriter,
474                                          Location loc, Type llvmType) {
475   return rewriter.create<LLVM::ConstantOp>(
476       loc, llvmType,
477       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
478                                             llvmType.getIntOrFloatBitWidth())));
479 }
480 
481 /// Create the reduction neutral signed int maximum value.
482 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
483                                          ConversionPatternRewriter &rewriter,
484                                          Location loc, Type llvmType) {
485   return rewriter.create<LLVM::ConstantOp>(
486       loc, llvmType,
487       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
488                                             llvmType.getIntOrFloatBitWidth())));
489 }
490 
491 /// Create the reduction neutral unsigned int maximum value.
492 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
493                                          ConversionPatternRewriter &rewriter,
494                                          Location loc, Type llvmType) {
495   return rewriter.create<LLVM::ConstantOp>(
496       loc, llvmType,
497       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
498                                             llvmType.getIntOrFloatBitWidth())));
499 }
500 
501 /// Create the reduction neutral fp minimum value.
502 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
503                                          ConversionPatternRewriter &rewriter,
504                                          Location loc, Type llvmType) {
505   auto floatType = cast<FloatType>(llvmType);
506   return rewriter.create<LLVM::ConstantOp>(
507       loc, llvmType,
508       rewriter.getFloatAttr(
509           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
510                                            /*Negative=*/false)));
511 }
512 
513 /// Create the reduction neutral fp maximum value.
514 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
515                                          ConversionPatternRewriter &rewriter,
516                                          Location loc, Type llvmType) {
517   auto floatType = cast<FloatType>(llvmType);
518   return rewriter.create<LLVM::ConstantOp>(
519       loc, llvmType,
520       rewriter.getFloatAttr(
521           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
522                                            /*Negative=*/true)));
523 }
524 
525 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
526 /// returns a new accumulator value using `ReductionNeutral`.
527 template <class ReductionNeutral>
528 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
529                                     Location loc, Type llvmType,
530                                     Value accumulator) {
531   if (accumulator)
532     return accumulator;
533 
534   return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
535                                      llvmType);
536 }
537 
538 /// Creates a constant value with the 1-D vector shape provided in `llvmType`.
539 /// This is used as effective vector length by some intrinsics supporting
540 /// dynamic vector lengths at runtime.
541 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
542                                      Location loc, Type llvmType) {
543   VectorType vType = cast<VectorType>(llvmType);
544   auto vShape = vType.getShape();
545   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
546 
547   return rewriter.create<LLVM::ConstantOp>(
548       loc, rewriter.getI32Type(),
549       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
550 }
551 
552 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
553 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
554 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
555 /// non-null.
556 template <class LLVMRedIntrinOp, class ScalarOp>
557 static Value createIntegerReductionArithmeticOpLowering(
558     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
559     Value vectorOperand, Value accumulator) {
560 
561   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
562 
563   if (accumulator)
564     result = rewriter.create<ScalarOp>(loc, accumulator, result);
565   return result;
566 }
567 
568 /// Helper method to lower a `vector.reduction` operation that performs
569 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
570 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
571 /// the accumulator value if non-null.
572 template <class LLVMRedIntrinOp>
573 static Value createIntegerReductionComparisonOpLowering(
574     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
575     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
576   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
577   if (accumulator) {
578     Value cmp =
579         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
580     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
581   }
582   return result;
583 }
584 
585 namespace {
586 template <typename Source>
587 struct VectorToScalarMapper;
588 template <>
589 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
590   using Type = LLVM::MaximumOp;
591 };
592 template <>
593 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
594   using Type = LLVM::MinimumOp;
595 };
596 template <>
597 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
598   using Type = LLVM::MaxNumOp;
599 };
600 template <>
601 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
602   using Type = LLVM::MinNumOp;
603 };
604 } // namespace
605 
606 template <class LLVMRedIntrinOp>
607 static Value createFPReductionComparisonOpLowering(
608     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
609     Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
610   Value result =
611       rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
612 
613   if (accumulator) {
614     result =
615         rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
616             loc, result, accumulator);
617   }
618 
619   return result;
620 }
621 
622 /// Reduction neutral classes for overloading
623 class MaskNeutralFMaximum {};
624 class MaskNeutralFMinimum {};
625 
626 /// Get the mask neutral floating point maximum value
627 static llvm::APFloat
628 getMaskNeutralValue(MaskNeutralFMaximum,
629                     const llvm::fltSemantics &floatSemantics) {
630   return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
631 }
632 /// Get the mask neutral floating point minimum value
633 static llvm::APFloat
634 getMaskNeutralValue(MaskNeutralFMinimum,
635                     const llvm::fltSemantics &floatSemantics) {
636   return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
637 }
638 
639 /// Create the mask neutral floating point MLIR vector constant
640 template <typename MaskNeutral>
641 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
642                                     Location loc, Type llvmType,
643                                     Type vectorType) {
644   const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
645   auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
646   auto denseValue =
647       DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
648   return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
649 }
650 
651 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
652 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
653 /// `fmaximum`/`fminimum`.
654 /// More information: https://github.com/llvm/llvm-project/issues/64940
655 template <class LLVMRedIntrinOp, class MaskNeutral>
656 static Value
657 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
658                                 Location loc, Type llvmType,
659                                 Value vectorOperand, Value accumulator,
660                                 Value mask, LLVM::FastmathFlagsAttr fmf) {
661   const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
662       rewriter, loc, llvmType, vectorOperand.getType());
663   const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
664       loc, mask, vectorOperand, vectorMaskNeutral);
665   return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
666       rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
667 }
668 
669 template <class LLVMRedIntrinOp, class ReductionNeutral>
670 static Value
671 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
672                              Type llvmType, Value vectorOperand,
673                              Value accumulator, LLVM::FastmathFlagsAttr fmf) {
674   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675                                                          llvmType, accumulator);
676   return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
677                                           /*startValue=*/accumulator,
678                                           vectorOperand, fmf);
679 }
680 
681 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
682 /// that requires a start value. This start value format spans across fp
683 /// reductions without mask and all the masked reduction intrinsics.
684 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
685 static Value
686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
687                                        Location loc, Type llvmType,
688                                        Value vectorOperand, Value accumulator) {
689   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
690                                                          llvmType, accumulator);
691   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
692                                             /*startValue=*/accumulator,
693                                             vectorOperand);
694 }
695 
696 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
697 static Value lowerPredicatedReductionWithStartValue(
698     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
699     Value vectorOperand, Value accumulator, Value mask) {
700   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701                                                          llvmType, accumulator);
702   Value vectorLength =
703       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
704   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
705                                             /*startValue=*/accumulator,
706                                             vectorOperand, mask, vectorLength);
707 }
708 
709 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
710           class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
711 static Value lowerPredicatedReductionWithStartValue(
712     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
713     Value vectorOperand, Value accumulator, Value mask) {
714   if (llvmType.isIntOrIndex())
715     return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
716                                                   IntReductionNeutral>(
717         rewriter, loc, llvmType, vectorOperand, accumulator, mask);
718 
719   // FP dispatch.
720   return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
721                                                 FPReductionNeutral>(
722       rewriter, loc, llvmType, vectorOperand, accumulator, mask);
723 }
724 
725 /// Conversion pattern for all vector reductions.
726 class VectorReductionOpConversion
727     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
728 public:
729   explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
730                                        bool reassociateFPRed)
731       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
732         reassociateFPReductions(reassociateFPRed) {}
733 
734   LogicalResult
735   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
736                   ConversionPatternRewriter &rewriter) const override {
737     auto kind = reductionOp.getKind();
738     Type eltType = reductionOp.getDest().getType();
739     Type llvmType = typeConverter->convertType(eltType);
740     Value operand = adaptor.getVector();
741     Value acc = adaptor.getAcc();
742     Location loc = reductionOp.getLoc();
743 
744     if (eltType.isIntOrIndex()) {
745       // Integer reductions: add/mul/min/max/and/or/xor.
746       Value result;
747       switch (kind) {
748       case vector::CombiningKind::ADD:
749         result =
750             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
751                                                        LLVM::AddOp>(
752                 rewriter, loc, llvmType, operand, acc);
753         break;
754       case vector::CombiningKind::MUL:
755         result =
756             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
757                                                        LLVM::MulOp>(
758                 rewriter, loc, llvmType, operand, acc);
759         break;
760       case vector::CombiningKind::MINUI:
761         result = createIntegerReductionComparisonOpLowering<
762             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
763                                       LLVM::ICmpPredicate::ule);
764         break;
765       case vector::CombiningKind::MINSI:
766         result = createIntegerReductionComparisonOpLowering<
767             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
768                                       LLVM::ICmpPredicate::sle);
769         break;
770       case vector::CombiningKind::MAXUI:
771         result = createIntegerReductionComparisonOpLowering<
772             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
773                                       LLVM::ICmpPredicate::uge);
774         break;
775       case vector::CombiningKind::MAXSI:
776         result = createIntegerReductionComparisonOpLowering<
777             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
778                                       LLVM::ICmpPredicate::sge);
779         break;
780       case vector::CombiningKind::AND:
781         result =
782             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
783                                                        LLVM::AndOp>(
784                 rewriter, loc, llvmType, operand, acc);
785         break;
786       case vector::CombiningKind::OR:
787         result =
788             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
789                                                        LLVM::OrOp>(
790                 rewriter, loc, llvmType, operand, acc);
791         break;
792       case vector::CombiningKind::XOR:
793         result =
794             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
795                                                        LLVM::XOrOp>(
796                 rewriter, loc, llvmType, operand, acc);
797         break;
798       default:
799         return failure();
800       }
801       rewriter.replaceOp(reductionOp, result);
802 
803       return success();
804     }
805 
806     if (!isa<FloatType>(eltType))
807       return failure();
808 
809     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
810     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
811         reductionOp.getContext(),
812         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
813     fmf = LLVM::FastmathFlagsAttr::get(
814         reductionOp.getContext(),
815         fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
816                                                   : LLVM::FastmathFlags::none));
817 
818     // Floating-point reductions: add/mul/min/max
819     Value result;
820     if (kind == vector::CombiningKind::ADD) {
821       result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
822                                             ReductionNeutralZero>(
823           rewriter, loc, llvmType, operand, acc, fmf);
824     } else if (kind == vector::CombiningKind::MUL) {
825       result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
826                                             ReductionNeutralFPOne>(
827           rewriter, loc, llvmType, operand, acc, fmf);
828     } else if (kind == vector::CombiningKind::MINIMUMF) {
829       result =
830           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
831               rewriter, loc, llvmType, operand, acc, fmf);
832     } else if (kind == vector::CombiningKind::MAXIMUMF) {
833       result =
834           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
835               rewriter, loc, llvmType, operand, acc, fmf);
836     } else if (kind == vector::CombiningKind::MINF) {
837       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
838           rewriter, loc, llvmType, operand, acc, fmf);
839     } else if (kind == vector::CombiningKind::MAXF) {
840       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
841           rewriter, loc, llvmType, operand, acc, fmf);
842     } else
843       return failure();
844 
845     rewriter.replaceOp(reductionOp, result);
846     return success();
847   }
848 
849 private:
850   const bool reassociateFPReductions;
851 };
852 
853 /// Base class to convert a `vector.mask` operation while matching traits
854 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
855 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
856 /// method performs a second match against the maskable operation `MaskedOp`.
857 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
858 /// implemented by the concrete conversion classes. This method can match
859 /// against specific traits of the `vector.mask` and the maskable operation. It
860 /// must replace the `vector.mask` operation.
861 template <class MaskedOp>
862 class VectorMaskOpConversionBase
863     : public ConvertOpToLLVMPattern<vector::MaskOp> {
864 public:
865   using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
866 
867   LogicalResult
868   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
869                   ConversionPatternRewriter &rewriter) const final {
870     // Match against the maskable operation kind.
871     auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
872     if (!maskedOp)
873       return failure();
874     return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
875   }
876 
877 protected:
878   virtual LogicalResult
879   matchAndRewriteMaskableOp(vector::MaskOp maskOp,
880                             vector::MaskableOpInterface maskableOp,
881                             ConversionPatternRewriter &rewriter) const = 0;
882 };
883 
884 class MaskedReductionOpConversion
885     : public VectorMaskOpConversionBase<vector::ReductionOp> {
886 
887 public:
888   using VectorMaskOpConversionBase<
889       vector::ReductionOp>::VectorMaskOpConversionBase;
890 
891   LogicalResult matchAndRewriteMaskableOp(
892       vector::MaskOp maskOp, MaskableOpInterface maskableOp,
893       ConversionPatternRewriter &rewriter) const override {
894     auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
895     auto kind = reductionOp.getKind();
896     Type eltType = reductionOp.getDest().getType();
897     Type llvmType = typeConverter->convertType(eltType);
898     Value operand = reductionOp.getVector();
899     Value acc = reductionOp.getAcc();
900     Location loc = reductionOp.getLoc();
901 
902     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
903     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
904         reductionOp.getContext(),
905         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
906 
907     Value result;
908     switch (kind) {
909     case vector::CombiningKind::ADD:
910       result = lowerPredicatedReductionWithStartValue<
911           LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
912           ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
913                                 maskOp.getMask());
914       break;
915     case vector::CombiningKind::MUL:
916       result = lowerPredicatedReductionWithStartValue<
917           LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
918           ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
919                                  maskOp.getMask());
920       break;
921     case vector::CombiningKind::MINUI:
922       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
923                                                       ReductionNeutralUIntMax>(
924           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
925       break;
926     case vector::CombiningKind::MINSI:
927       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
928                                                       ReductionNeutralSIntMax>(
929           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
930       break;
931     case vector::CombiningKind::MAXUI:
932       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
933                                                       ReductionNeutralUIntMin>(
934           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
935       break;
936     case vector::CombiningKind::MAXSI:
937       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
938                                                       ReductionNeutralSIntMin>(
939           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
940       break;
941     case vector::CombiningKind::AND:
942       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
943                                                       ReductionNeutralAllOnes>(
944           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
945       break;
946     case vector::CombiningKind::OR:
947       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
948                                                       ReductionNeutralZero>(
949           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
950       break;
951     case vector::CombiningKind::XOR:
952       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
953                                                       ReductionNeutralZero>(
954           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
955       break;
956     case vector::CombiningKind::MINF:
957       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
958                                                       ReductionNeutralFPMax>(
959           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
960       break;
961     case vector::CombiningKind::MAXF:
962       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
963                                                       ReductionNeutralFPMin>(
964           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
965       break;
966     case CombiningKind::MAXIMUMF:
967       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
968                                                MaskNeutralFMaximum>(
969           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
970       break;
971     case CombiningKind::MINIMUMF:
972       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
973                                                MaskNeutralFMinimum>(
974           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
975       break;
976     }
977 
978     // Replace `vector.mask` operation altogether.
979     rewriter.replaceOp(maskOp, result);
980     return success();
981   }
982 };
983 
984 class VectorShuffleOpConversion
985     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
986 public:
987   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
988 
989   LogicalResult
990   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
991                   ConversionPatternRewriter &rewriter) const override {
992     auto loc = shuffleOp->getLoc();
993     auto v1Type = shuffleOp.getV1VectorType();
994     auto v2Type = shuffleOp.getV2VectorType();
995     auto vectorType = shuffleOp.getResultVectorType();
996     Type llvmType = typeConverter->convertType(vectorType);
997     auto maskArrayAttr = shuffleOp.getMask();
998 
999     // Bail if result type cannot be lowered.
1000     if (!llvmType)
1001       return failure();
1002 
1003     // Get rank and dimension sizes.
1004     int64_t rank = vectorType.getRank();
1005 #ifndef NDEBUG
1006     bool wellFormed0DCase =
1007         v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1008     bool wellFormedNDCase =
1009         v1Type.getRank() == rank && v2Type.getRank() == rank;
1010     assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1011 #endif
1012 
1013     // For rank 0 and 1, where both operands have *exactly* the same vector
1014     // type, there is direct shuffle support in LLVM. Use it!
1015     if (rank <= 1 && v1Type == v2Type) {
1016       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1017           loc, adaptor.getV1(), adaptor.getV2(),
1018           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1019       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1020       return success();
1021     }
1022 
1023     // For all other cases, insert the individual values individually.
1024     int64_t v1Dim = v1Type.getDimSize(0);
1025     Type eltType;
1026     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1027       eltType = arrayType.getElementType();
1028     else
1029       eltType = cast<VectorType>(llvmType).getElementType();
1030     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1031     int64_t insPos = 0;
1032     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
1033       int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1034       Value value = adaptor.getV1();
1035       if (extPos >= v1Dim) {
1036         extPos -= v1Dim;
1037         value = adaptor.getV2();
1038       }
1039       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1040                                  eltType, rank, extPos);
1041       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1042                          llvmType, rank, insPos++);
1043     }
1044     rewriter.replaceOp(shuffleOp, insert);
1045     return success();
1046   }
1047 };
1048 
1049 class VectorExtractElementOpConversion
1050     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1051 public:
1052   using ConvertOpToLLVMPattern<
1053       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1054 
1055   LogicalResult
1056   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1057                   ConversionPatternRewriter &rewriter) const override {
1058     auto vectorType = extractEltOp.getSourceVectorType();
1059     auto llvmType = typeConverter->convertType(vectorType.getElementType());
1060 
1061     // Bail if result type cannot be lowered.
1062     if (!llvmType)
1063       return failure();
1064 
1065     if (vectorType.getRank() == 0) {
1066       Location loc = extractEltOp.getLoc();
1067       auto idxType = rewriter.getIndexType();
1068       auto zero = rewriter.create<LLVM::ConstantOp>(
1069           loc, typeConverter->convertType(idxType),
1070           rewriter.getIntegerAttr(idxType, 0));
1071       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1072           extractEltOp, llvmType, adaptor.getVector(), zero);
1073       return success();
1074     }
1075 
1076     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1077         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1078     return success();
1079   }
1080 };
1081 
1082 class VectorExtractOpConversion
1083     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1084 public:
1085   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1086 
1087   LogicalResult
1088   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1089                   ConversionPatternRewriter &rewriter) const override {
1090     auto loc = extractOp->getLoc();
1091     auto resultType = extractOp.getResult().getType();
1092     auto llvmResultType = typeConverter->convertType(resultType);
1093     // Bail if result type cannot be lowered.
1094     if (!llvmResultType)
1095       return failure();
1096 
1097     SmallVector<OpFoldResult> positionVec;
1098     for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1099       if (pos.is<Value>())
1100         // Make sure we use the value that has been already converted to LLVM.
1101         positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1102       else
1103         positionVec.push_back(pos);
1104     }
1105 
1106     // Extract entire vector. Should be handled by folder, but just to be safe.
1107     ArrayRef<OpFoldResult> position(positionVec);
1108     if (position.empty()) {
1109       rewriter.replaceOp(extractOp, adaptor.getVector());
1110       return success();
1111     }
1112 
1113     // One-shot extraction of vector from array (only requires extractvalue).
1114     if (isa<VectorType>(resultType)) {
1115       if (extractOp.hasDynamicPosition())
1116         return failure();
1117 
1118       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1119           loc, adaptor.getVector(), getAsIntegers(position));
1120       rewriter.replaceOp(extractOp, extracted);
1121       return success();
1122     }
1123 
1124     // Potential extraction of 1-D vector from array.
1125     Value extracted = adaptor.getVector();
1126     if (position.size() > 1) {
1127       if (extractOp.hasDynamicPosition())
1128         return failure();
1129 
1130       SmallVector<int64_t> nMinusOnePosition =
1131           getAsIntegers(position.drop_back());
1132       extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1133                                                         nMinusOnePosition);
1134     }
1135 
1136     Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1137     // Remaining extraction of element from 1-D LLVM vector.
1138     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1139                                                         lastPosition);
1140     return success();
1141   }
1142 };
1143 
1144 /// Conversion pattern that turns a vector.fma on a 1-D vector
1145 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1146 /// This does not match vectors of n >= 2 rank.
1147 ///
1148 /// Example:
1149 /// ```
1150 ///  vector.fma %a, %a, %a : vector<8xf32>
1151 /// ```
1152 /// is converted to:
1153 /// ```
1154 ///  llvm.intr.fmuladd %va, %va, %va:
1155 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1156 ///    -> !llvm."<8 x f32>">
1157 /// ```
1158 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1159 public:
1160   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1161 
1162   LogicalResult
1163   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1164                   ConversionPatternRewriter &rewriter) const override {
1165     VectorType vType = fmaOp.getVectorType();
1166     if (vType.getRank() > 1)
1167       return failure();
1168 
1169     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1170         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1171     return success();
1172   }
1173 };
1174 
1175 class VectorInsertElementOpConversion
1176     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1177 public:
1178   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1179 
1180   LogicalResult
1181   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1182                   ConversionPatternRewriter &rewriter) const override {
1183     auto vectorType = insertEltOp.getDestVectorType();
1184     auto llvmType = typeConverter->convertType(vectorType);
1185 
1186     // Bail if result type cannot be lowered.
1187     if (!llvmType)
1188       return failure();
1189 
1190     if (vectorType.getRank() == 0) {
1191       Location loc = insertEltOp.getLoc();
1192       auto idxType = rewriter.getIndexType();
1193       auto zero = rewriter.create<LLVM::ConstantOp>(
1194           loc, typeConverter->convertType(idxType),
1195           rewriter.getIntegerAttr(idxType, 0));
1196       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1197           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1198       return success();
1199     }
1200 
1201     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1202         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1203         adaptor.getPosition());
1204     return success();
1205   }
1206 };
1207 
1208 class VectorInsertOpConversion
1209     : public ConvertOpToLLVMPattern<vector::InsertOp> {
1210 public:
1211   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1212 
1213   LogicalResult
1214   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1215                   ConversionPatternRewriter &rewriter) const override {
1216     auto loc = insertOp->getLoc();
1217     auto sourceType = insertOp.getSourceType();
1218     auto destVectorType = insertOp.getDestVectorType();
1219     auto llvmResultType = typeConverter->convertType(destVectorType);
1220     // Bail if result type cannot be lowered.
1221     if (!llvmResultType)
1222       return failure();
1223 
1224     SmallVector<OpFoldResult> positionVec;
1225     for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
1226       if (pos.is<Value>())
1227         // Make sure we use the value that has been already converted to LLVM.
1228         positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1229       else
1230         positionVec.push_back(pos);
1231     }
1232 
1233     // Overwrite entire vector with value. Should be handled by folder, but
1234     // just to be safe.
1235     ArrayRef<OpFoldResult> position(positionVec);
1236     if (position.empty()) {
1237       rewriter.replaceOp(insertOp, adaptor.getSource());
1238       return success();
1239     }
1240 
1241     // One-shot insertion of a vector into an array (only requires insertvalue).
1242     if (isa<VectorType>(sourceType)) {
1243       if (insertOp.hasDynamicPosition())
1244         return failure();
1245 
1246       Value inserted = rewriter.create<LLVM::InsertValueOp>(
1247           loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1248       rewriter.replaceOp(insertOp, inserted);
1249       return success();
1250     }
1251 
1252     // Potential extraction of 1-D vector from array.
1253     Value extracted = adaptor.getDest();
1254     auto oneDVectorType = destVectorType;
1255     if (position.size() > 1) {
1256       if (insertOp.hasDynamicPosition())
1257         return failure();
1258 
1259       oneDVectorType = reducedVectorTypeBack(destVectorType);
1260       extracted = rewriter.create<LLVM::ExtractValueOp>(
1261           loc, extracted, getAsIntegers(position.drop_back()));
1262     }
1263 
1264     // Insertion of an element into a 1-D LLVM vector.
1265     Value inserted = rewriter.create<LLVM::InsertElementOp>(
1266         loc, typeConverter->convertType(oneDVectorType), extracted,
1267         adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1268 
1269     // Potential insertion of resulting 1-D vector into array.
1270     if (position.size() > 1) {
1271       if (insertOp.hasDynamicPosition())
1272         return failure();
1273 
1274       inserted = rewriter.create<LLVM::InsertValueOp>(
1275           loc, adaptor.getDest(), inserted,
1276           getAsIntegers(position.drop_back()));
1277     }
1278 
1279     rewriter.replaceOp(insertOp, inserted);
1280     return success();
1281   }
1282 };
1283 
1284 /// Lower vector.scalable.insert ops to LLVM vector.insert
1285 struct VectorScalableInsertOpLowering
1286     : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1287   using ConvertOpToLLVMPattern<
1288       vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1289 
1290   LogicalResult
1291   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1292                   ConversionPatternRewriter &rewriter) const override {
1293     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1294         insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
1295     return success();
1296   }
1297 };
1298 
1299 /// Lower vector.scalable.extract ops to LLVM vector.extract
1300 struct VectorScalableExtractOpLowering
1301     : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1302   using ConvertOpToLLVMPattern<
1303       vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1304 
1305   LogicalResult
1306   matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1307                   ConversionPatternRewriter &rewriter) const override {
1308     rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1309         extOp, typeConverter->convertType(extOp.getResultVectorType()),
1310         adaptor.getSource(), adaptor.getPos());
1311     return success();
1312   }
1313 };
1314 
1315 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1316 ///
1317 /// Example:
1318 /// ```
1319 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
1320 /// ```
1321 /// is rewritten into:
1322 /// ```
1323 ///  %r = splat %f0: vector<2x4xf32>
1324 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
1325 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1326 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1327 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1328 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1329 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1330 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1331 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1332 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1333 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1334 ///  // %r3 holds the final value.
1335 /// ```
1336 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1337 public:
1338   using OpRewritePattern<FMAOp>::OpRewritePattern;
1339 
1340   void initialize() {
1341     // This pattern recursively unpacks one dimension at a time. The recursion
1342     // bounded as the rank is strictly decreasing.
1343     setHasBoundedRewriteRecursion();
1344   }
1345 
1346   LogicalResult matchAndRewrite(FMAOp op,
1347                                 PatternRewriter &rewriter) const override {
1348     auto vType = op.getVectorType();
1349     if (vType.getRank() < 2)
1350       return failure();
1351 
1352     auto loc = op.getLoc();
1353     auto elemType = vType.getElementType();
1354     Value zero = rewriter.create<arith::ConstantOp>(
1355         loc, elemType, rewriter.getZeroAttr(elemType));
1356     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1357     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1358       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1359       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1360       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1361       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1362       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1363     }
1364     rewriter.replaceOp(op, desc);
1365     return success();
1366   }
1367 };
1368 
1369 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1370 /// static layout.
1371 static std::optional<SmallVector<int64_t, 4>>
1372 computeContiguousStrides(MemRefType memRefType) {
1373   int64_t offset;
1374   SmallVector<int64_t, 4> strides;
1375   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1376     return std::nullopt;
1377   if (!strides.empty() && strides.back() != 1)
1378     return std::nullopt;
1379   // If no layout or identity layout, this is contiguous by definition.
1380   if (memRefType.getLayout().isIdentity())
1381     return strides;
1382 
1383   // Otherwise, we must determine contiguity form shapes. This can only ever
1384   // work in static cases because MemRefType is underspecified to represent
1385   // contiguous dynamic shapes in other ways than with just empty/identity
1386   // layout.
1387   auto sizes = memRefType.getShape();
1388   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1389     if (ShapedType::isDynamic(sizes[index + 1]) ||
1390         ShapedType::isDynamic(strides[index]) ||
1391         ShapedType::isDynamic(strides[index + 1]))
1392       return std::nullopt;
1393     if (strides[index] != strides[index + 1] * sizes[index + 1])
1394       return std::nullopt;
1395   }
1396   return strides;
1397 }
1398 
1399 class VectorTypeCastOpConversion
1400     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1401 public:
1402   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1403 
1404   LogicalResult
1405   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1406                   ConversionPatternRewriter &rewriter) const override {
1407     auto loc = castOp->getLoc();
1408     MemRefType sourceMemRefType =
1409         cast<MemRefType>(castOp.getOperand().getType());
1410     MemRefType targetMemRefType = castOp.getType();
1411 
1412     // Only static shape casts supported atm.
1413     if (!sourceMemRefType.hasStaticShape() ||
1414         !targetMemRefType.hasStaticShape())
1415       return failure();
1416 
1417     auto llvmSourceDescriptorTy =
1418         dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1419     if (!llvmSourceDescriptorTy)
1420       return failure();
1421     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1422 
1423     auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1424         typeConverter->convertType(targetMemRefType));
1425     if (!llvmTargetDescriptorTy)
1426       return failure();
1427 
1428     // Only contiguous source buffers supported atm.
1429     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1430     if (!sourceStrides)
1431       return failure();
1432     auto targetStrides = computeContiguousStrides(targetMemRefType);
1433     if (!targetStrides)
1434       return failure();
1435     // Only support static strides for now, regardless of contiguity.
1436     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1437       return failure();
1438 
1439     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1440 
1441     // Create descriptor.
1442     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1443     Type llvmTargetElementTy = desc.getElementPtrType();
1444     // Set allocated ptr.
1445     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1446     if (!getTypeConverter()->useOpaquePointers())
1447       allocated =
1448           rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1449     desc.setAllocatedPtr(rewriter, loc, allocated);
1450 
1451     // Set aligned ptr.
1452     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1453     if (!getTypeConverter()->useOpaquePointers())
1454       ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1455 
1456     desc.setAlignedPtr(rewriter, loc, ptr);
1457     // Fill offset 0.
1458     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1459     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1460     desc.setOffset(rewriter, loc, zero);
1461 
1462     // Fill size and stride descriptors in memref.
1463     for (const auto &indexedSize :
1464          llvm::enumerate(targetMemRefType.getShape())) {
1465       int64_t index = indexedSize.index();
1466       auto sizeAttr =
1467           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1468       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1469       desc.setSize(rewriter, loc, index, size);
1470       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1471                                                 (*targetStrides)[index]);
1472       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1473       desc.setStride(rewriter, loc, index, stride);
1474     }
1475 
1476     rewriter.replaceOp(castOp, {desc});
1477     return success();
1478   }
1479 };
1480 
1481 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1482 /// Non-scalable versions of this operation are handled in Vector Transforms.
1483 class VectorCreateMaskOpRewritePattern
1484     : public OpRewritePattern<vector::CreateMaskOp> {
1485 public:
1486   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1487                                             bool enableIndexOpt)
1488       : OpRewritePattern<vector::CreateMaskOp>(context),
1489         force32BitVectorIndices(enableIndexOpt) {}
1490 
1491   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1492                                 PatternRewriter &rewriter) const override {
1493     auto dstType = op.getType();
1494     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1495       return failure();
1496     IntegerType idxType =
1497         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1498     auto loc = op->getLoc();
1499     Value indices = rewriter.create<LLVM::StepVectorOp>(
1500         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1501                                  /*isScalable=*/true));
1502     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1503                                                  op.getOperand(0));
1504     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1505     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1506                                                 indices, bounds);
1507     rewriter.replaceOp(op, comp);
1508     return success();
1509   }
1510 
1511 private:
1512   const bool force32BitVectorIndices;
1513 };
1514 
1515 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1516 public:
1517   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1518 
1519   // Lowering implementation that relies on a small runtime support library,
1520   // which only needs to provide a few printing methods (single value for all
1521   // data types, opening/closing bracket, comma, newline). The lowering splits
1522   // the vector into elementary printing operations. The advantage of this
1523   // approach is that the library can remain unaware of all low-level
1524   // implementation details of vectors while still supporting output of any
1525   // shaped and dimensioned vector.
1526   //
1527   // Note: This lowering only handles scalars, n-D vectors are broken into
1528   // printing scalars in loops in VectorToSCF.
1529   //
1530   // TODO: rely solely on libc in future? something else?
1531   //
1532   LogicalResult
1533   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1534                   ConversionPatternRewriter &rewriter) const override {
1535     auto parent = printOp->getParentOfType<ModuleOp>();
1536     if (!parent)
1537       return failure();
1538 
1539     auto loc = printOp->getLoc();
1540 
1541     if (auto value = adaptor.getSource()) {
1542       Type printType = printOp.getPrintType();
1543       if (isa<VectorType>(printType)) {
1544         // Vectors should be broken into elementary print ops in VectorToSCF.
1545         return failure();
1546       }
1547       if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1548         return failure();
1549     }
1550 
1551     auto punct = printOp.getPunctuation();
1552     if (auto stringLiteral = printOp.getStringLiteral()) {
1553       LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1554                                *stringLiteral, *getTypeConverter());
1555     } else if (punct != PrintPunctuation::NoPunctuation) {
1556       emitCall(rewriter, printOp->getLoc(), [&] {
1557         switch (punct) {
1558         case PrintPunctuation::Close:
1559           return LLVM::lookupOrCreatePrintCloseFn(parent);
1560         case PrintPunctuation::Open:
1561           return LLVM::lookupOrCreatePrintOpenFn(parent);
1562         case PrintPunctuation::Comma:
1563           return LLVM::lookupOrCreatePrintCommaFn(parent);
1564         case PrintPunctuation::NewLine:
1565           return LLVM::lookupOrCreatePrintNewlineFn(parent);
1566         default:
1567           llvm_unreachable("unexpected punctuation");
1568         }
1569       }());
1570     }
1571 
1572     rewriter.eraseOp(printOp);
1573     return success();
1574   }
1575 
1576 private:
1577   enum class PrintConversion {
1578     // clang-format off
1579     None,
1580     ZeroExt64,
1581     SignExt64,
1582     Bitcast16
1583     // clang-format on
1584   };
1585 
1586   LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1587                                 ModuleOp parent, Location loc, Type printType,
1588                                 Value value) const {
1589     if (typeConverter->convertType(printType) == nullptr)
1590       return failure();
1591 
1592     // Make sure element type has runtime support.
1593     PrintConversion conversion = PrintConversion::None;
1594     Operation *printer;
1595     if (printType.isF32()) {
1596       printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1597     } else if (printType.isF64()) {
1598       printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1599     } else if (printType.isF16()) {
1600       conversion = PrintConversion::Bitcast16; // bits!
1601       printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1602     } else if (printType.isBF16()) {
1603       conversion = PrintConversion::Bitcast16; // bits!
1604       printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1605     } else if (printType.isIndex()) {
1606       printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1607     } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1608       // Integers need a zero or sign extension on the operand
1609       // (depending on the source type) as well as a signed or
1610       // unsigned print method. Up to 64-bit is supported.
1611       unsigned width = intTy.getWidth();
1612       if (intTy.isUnsigned()) {
1613         if (width <= 64) {
1614           if (width < 64)
1615             conversion = PrintConversion::ZeroExt64;
1616           printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1617         } else {
1618           return failure();
1619         }
1620       } else {
1621         assert(intTy.isSignless() || intTy.isSigned());
1622         if (width <= 64) {
1623           // Note that we *always* zero extend booleans (1-bit integers),
1624           // so that true/false is printed as 1/0 rather than -1/0.
1625           if (width == 1)
1626             conversion = PrintConversion::ZeroExt64;
1627           else if (width < 64)
1628             conversion = PrintConversion::SignExt64;
1629           printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1630         } else {
1631           return failure();
1632         }
1633       }
1634     } else {
1635       return failure();
1636     }
1637 
1638     switch (conversion) {
1639     case PrintConversion::ZeroExt64:
1640       value = rewriter.create<arith::ExtUIOp>(
1641           loc, IntegerType::get(rewriter.getContext(), 64), value);
1642       break;
1643     case PrintConversion::SignExt64:
1644       value = rewriter.create<arith::ExtSIOp>(
1645           loc, IntegerType::get(rewriter.getContext(), 64), value);
1646       break;
1647     case PrintConversion::Bitcast16:
1648       value = rewriter.create<LLVM::BitcastOp>(
1649           loc, IntegerType::get(rewriter.getContext(), 16), value);
1650       break;
1651     case PrintConversion::None:
1652       break;
1653     }
1654     emitCall(rewriter, loc, printer, value);
1655     return success();
1656   }
1657 
1658   // Helper to emit a call.
1659   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1660                        Operation *ref, ValueRange params = ValueRange()) {
1661     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1662                                   params);
1663   }
1664 };
1665 
1666 /// The Splat operation is lowered to an insertelement + a shufflevector
1667 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1668 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1669   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1670 
1671   LogicalResult
1672   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1673                   ConversionPatternRewriter &rewriter) const override {
1674     VectorType resultType = cast<VectorType>(splatOp.getType());
1675     if (resultType.getRank() > 1)
1676       return failure();
1677 
1678     // First insert it into an undef vector so we can shuffle it.
1679     auto vectorType = typeConverter->convertType(splatOp.getType());
1680     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1681     auto zero = rewriter.create<LLVM::ConstantOp>(
1682         splatOp.getLoc(),
1683         typeConverter->convertType(rewriter.getIntegerType(32)),
1684         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1685 
1686     // For 0-d vector, we simply do `insertelement`.
1687     if (resultType.getRank() == 0) {
1688       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1689           splatOp, vectorType, undef, adaptor.getInput(), zero);
1690       return success();
1691     }
1692 
1693     // For 1-d vector, we additionally do a `vectorshuffle`.
1694     auto v = rewriter.create<LLVM::InsertElementOp>(
1695         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1696 
1697     int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1698     SmallVector<int32_t> zeroValues(width, 0);
1699 
1700     // Shuffle the value across the desired number of elements.
1701     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1702                                                        zeroValues);
1703     return success();
1704   }
1705 };
1706 
1707 /// The Splat operation is lowered to an insertelement + a shufflevector
1708 /// operation. Splat to only 2+-d vector result types are lowered by the
1709 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1710 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1711   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1712 
1713   LogicalResult
1714   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1715                   ConversionPatternRewriter &rewriter) const override {
1716     VectorType resultType = splatOp.getType();
1717     if (resultType.getRank() <= 1)
1718       return failure();
1719 
1720     // First insert it into an undef vector so we can shuffle it.
1721     auto loc = splatOp.getLoc();
1722     auto vectorTypeInfo =
1723         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1724     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1725     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1726     if (!llvmNDVectorTy || !llvm1DVectorTy)
1727       return failure();
1728 
1729     // Construct returned value.
1730     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1731 
1732     // Construct a 1-D vector with the splatted value that we insert in all the
1733     // places within the returned descriptor.
1734     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1735     auto zero = rewriter.create<LLVM::ConstantOp>(
1736         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1737         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1738     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1739                                                      adaptor.getInput(), zero);
1740 
1741     // Shuffle the value across the desired number of elements.
1742     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1743     SmallVector<int32_t> zeroValues(width, 0);
1744     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1745 
1746     // Iterate of linear index, convert to coords space and insert splatted 1-D
1747     // vector in each position.
1748     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1749       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1750     });
1751     rewriter.replaceOp(splatOp, desc);
1752     return success();
1753   }
1754 };
1755 
1756 } // namespace
1757 
1758 /// Populate the given list with patterns that convert from Vector to LLVM.
1759 void mlir::populateVectorToLLVMConversionPatterns(
1760     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1761     bool reassociateFPReductions, bool force32BitVectorIndices) {
1762   MLIRContext *ctx = converter.getDialect()->getContext();
1763   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1764   populateVectorInsertExtractStridedSliceTransforms(patterns);
1765   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1766   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1767   patterns
1768       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1769            VectorExtractElementOpConversion, VectorExtractOpConversion,
1770            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1771            VectorInsertOpConversion, VectorPrintOpConversion,
1772            VectorTypeCastOpConversion, VectorScaleOpConversion,
1773            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1774            VectorLoadStoreConversion<vector::MaskedLoadOp,
1775                                      vector::MaskedLoadOpAdaptor>,
1776            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1777            VectorLoadStoreConversion<vector::MaskedStoreOp,
1778                                      vector::MaskedStoreOpAdaptor>,
1779            VectorGatherOpConversion, VectorScatterOpConversion,
1780            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1781            VectorSplatOpLowering, VectorSplatNdOpLowering,
1782            VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1783            MaskedReductionOpConversion>(converter);
1784   // Transfer ops with rank > 1 are handled by VectorToSCF.
1785   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1786 }
1787 
1788 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1789     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1790   patterns.add<VectorMatmulOpConversion>(converter);
1791   patterns.add<VectorFlatTransposeOpConversion>(converter);
1792 }
1793