xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 98f6289a34bdaf7bc6cda8768e26e4405fc7726e)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypeInterfaces.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/Support/Casting.h"
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::vector;
35 
36 // Helper to reduce vector type by *all* but one rank at back.
37 static VectorType reducedVectorTypeBack(VectorType tp) {
38   assert((tp.getRank() > 1) && "unlowerable vector type");
39   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
40                          tp.getScalableDims().take_back());
41 }
42 
43 // Helper that picks the proper sequence for inserting.
44 static Value insertOne(ConversionPatternRewriter &rewriter,
45                        const LLVMTypeConverter &typeConverter, Location loc,
46                        Value val1, Value val2, Type llvmType, int64_t rank,
47                        int64_t pos) {
48   assert(rank > 0 && "0-D vector corner case should have been handled already");
49   if (rank == 1) {
50     auto idxType = rewriter.getIndexType();
51     auto constant = rewriter.create<LLVM::ConstantOp>(
52         loc, typeConverter.convertType(idxType),
53         rewriter.getIntegerAttr(idxType, pos));
54     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
55                                                   constant);
56   }
57   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
58 }
59 
60 // Helper that picks the proper sequence for extracting.
61 static Value extractOne(ConversionPatternRewriter &rewriter,
62                         const LLVMTypeConverter &typeConverter, Location loc,
63                         Value val, Type llvmType, int64_t rank, int64_t pos) {
64   if (rank <= 1) {
65     auto idxType = rewriter.getIndexType();
66     auto constant = rewriter.create<LLVM::ConstantOp>(
67         loc, typeConverter.convertType(idxType),
68         rewriter.getIntegerAttr(idxType, pos));
69     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
70                                                    constant);
71   }
72   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
73 }
74 
75 // Helper that returns data layout alignment of a memref.
76 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
77                                  MemRefType memrefType, unsigned &align) {
78   Type elementTy = typeConverter.convertType(memrefType.getElementType());
79   if (!elementTy)
80     return failure();
81 
82   // TODO: this should use the MLIR data layout when it becomes available and
83   // stop depending on translation.
84   llvm::LLVMContext llvmContext;
85   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
86               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
87   return success();
88 }
89 
90 // Check if the last stride is non-unit or the memory space is not zero.
91 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
92                                            const LLVMTypeConverter &converter) {
93   if (!isLastMemrefDimUnitStride(memRefType))
94     return failure();
95   FailureOr<unsigned> addressSpace =
96       converter.getMemRefAddressSpace(memRefType);
97   if (failed(addressSpace) || *addressSpace != 0)
98     return failure();
99   return success();
100 }
101 
102 // Add an index vector component to a base pointer.
103 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
104                             const LLVMTypeConverter &typeConverter,
105                             MemRefType memRefType, Value llvmMemref, Value base,
106                             Value index, uint64_t vLen) {
107   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
108          "unsupported memref type");
109   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
110   auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
111   return rewriter.create<LLVM::GEPOp>(
112       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
113       base, index);
114 }
115 
116 // Casts a strided element pointer to a vector pointer.  The vector pointer
117 // will be in the same address space as the incoming memref type.
118 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
119                          Value ptr, MemRefType memRefType, Type vt,
120                          const LLVMTypeConverter &converter) {
121   if (converter.useOpaquePointers())
122     return ptr;
123 
124   unsigned addressSpace = *converter.getMemRefAddressSpace(memRefType);
125   auto pType = LLVM::LLVMPointerType::get(vt, addressSpace);
126   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
127 }
128 
129 /// Convert `foldResult` into a Value. Integer attribute is converted to
130 /// an LLVM constant op.
131 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
132                             OpFoldResult foldResult) {
133   if (auto attr = foldResult.dyn_cast<Attribute>()) {
134     auto intAttr = cast<IntegerAttr>(attr);
135     return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
136   }
137 
138   return foldResult.get<Value>();
139 }
140 
141 namespace {
142 
143 /// Trivial Vector to LLVM conversions
144 using VectorScaleOpConversion =
145     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
146 
147 /// Conversion pattern for a vector.bitcast.
148 class VectorBitCastOpConversion
149     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
150 public:
151   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
152 
153   LogicalResult
154   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
155                   ConversionPatternRewriter &rewriter) const override {
156     // Only 0-D and 1-D vectors can be lowered to LLVM.
157     VectorType resultTy = bitCastOp.getResultVectorType();
158     if (resultTy.getRank() > 1)
159       return failure();
160     Type newResultTy = typeConverter->convertType(resultTy);
161     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
162                                                  adaptor.getOperands()[0]);
163     return success();
164   }
165 };
166 
167 /// Conversion pattern for a vector.matrix_multiply.
168 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
169 class VectorMatmulOpConversion
170     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
171 public:
172   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
173 
174   LogicalResult
175   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
176                   ConversionPatternRewriter &rewriter) const override {
177     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
178         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
179         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
180         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
181     return success();
182   }
183 };
184 
185 /// Conversion pattern for a vector.flat_transpose.
186 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
187 class VectorFlatTransposeOpConversion
188     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
189 public:
190   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
191 
192   LogicalResult
193   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
194                   ConversionPatternRewriter &rewriter) const override {
195     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
196         transOp, typeConverter->convertType(transOp.getRes().getType()),
197         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
198     return success();
199   }
200 };
201 
202 /// Overloaded utility that replaces a vector.load, vector.store,
203 /// vector.maskedload and vector.maskedstore with their respective LLVM
204 /// couterparts.
205 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
206                                  vector::LoadOpAdaptor adaptor,
207                                  VectorType vectorTy, Value ptr, unsigned align,
208                                  ConversionPatternRewriter &rewriter) {
209   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
210 }
211 
212 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
213                                  vector::MaskedLoadOpAdaptor adaptor,
214                                  VectorType vectorTy, Value ptr, unsigned align,
215                                  ConversionPatternRewriter &rewriter) {
216   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
217       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
218 }
219 
220 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
221                                  vector::StoreOpAdaptor adaptor,
222                                  VectorType vectorTy, Value ptr, unsigned align,
223                                  ConversionPatternRewriter &rewriter) {
224   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
225                                              ptr, align);
226 }
227 
228 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
229                                  vector::MaskedStoreOpAdaptor adaptor,
230                                  VectorType vectorTy, Value ptr, unsigned align,
231                                  ConversionPatternRewriter &rewriter) {
232   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
233       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
234 }
235 
236 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
237 /// vector.maskedstore.
238 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
239 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
240 public:
241   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
242 
243   LogicalResult
244   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
245                   typename LoadOrStoreOp::Adaptor adaptor,
246                   ConversionPatternRewriter &rewriter) const override {
247     // Only 1-D vectors can be lowered to LLVM.
248     VectorType vectorTy = loadOrStoreOp.getVectorType();
249     if (vectorTy.getRank() > 1)
250       return failure();
251 
252     auto loc = loadOrStoreOp->getLoc();
253     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
254 
255     // Resolve alignment.
256     unsigned align;
257     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
258       return failure();
259 
260     // Resolve address.
261     auto vtype = cast<VectorType>(
262         this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
263     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
264                                                adaptor.getIndices(), rewriter);
265     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
266                             *this->getTypeConverter());
267 
268     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
269     return success();
270   }
271 };
272 
273 /// Conversion pattern for a vector.gather.
274 class VectorGatherOpConversion
275     : public ConvertOpToLLVMPattern<vector::GatherOp> {
276 public:
277   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
278 
279   LogicalResult
280   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
281                   ConversionPatternRewriter &rewriter) const override {
282     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
283     assert(memRefType && "The base should be bufferized");
284 
285     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
286       return failure();
287 
288     auto loc = gather->getLoc();
289 
290     // Resolve alignment.
291     unsigned align;
292     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
293       return failure();
294 
295     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
296                                      adaptor.getIndices(), rewriter);
297     Value base = adaptor.getBase();
298 
299     auto llvmNDVectorTy = adaptor.getIndexVec().getType();
300     // Handle the simple case of 1-D vector.
301     if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
302       auto vType = gather.getVectorType();
303       // Resolve address.
304       Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
305                                   memRefType, base, ptr, adaptor.getIndexVec(),
306                                   /*vLen=*/vType.getDimSize(0));
307       // Replace with the gather intrinsic.
308       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
309           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
310           adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
311       return success();
312     }
313 
314     const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
315     auto callback = [align, memRefType, base, ptr, loc, &rewriter,
316                      &typeConverter](Type llvm1DVectorTy,
317                                      ValueRange vectorOperands) {
318       // Resolve address.
319       Value ptrs = getIndexedPtrs(
320           rewriter, loc, typeConverter, memRefType, base, ptr,
321           /*index=*/vectorOperands[0],
322           LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
323       // Create the gather intrinsic.
324       return rewriter.create<LLVM::masked_gather>(
325           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
326           /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
327     };
328     SmallVector<Value> vectorOperands = {
329         adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
330     return LLVM::detail::handleMultidimensionalVectors(
331         gather, vectorOperands, *getTypeConverter(), callback, rewriter);
332   }
333 };
334 
335 /// Conversion pattern for a vector.scatter.
336 class VectorScatterOpConversion
337     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
338 public:
339   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
340 
341   LogicalResult
342   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
343                   ConversionPatternRewriter &rewriter) const override {
344     auto loc = scatter->getLoc();
345     MemRefType memRefType = scatter.getMemRefType();
346 
347     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
348       return failure();
349 
350     // Resolve alignment.
351     unsigned align;
352     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
353       return failure();
354 
355     // Resolve address.
356     VectorType vType = scatter.getVectorType();
357     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
358                                      adaptor.getIndices(), rewriter);
359     Value ptrs = getIndexedPtrs(
360         rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
361         ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
362 
363     // Replace with the scatter intrinsic.
364     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
365         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
366         rewriter.getI32IntegerAttr(align));
367     return success();
368   }
369 };
370 
371 /// Conversion pattern for a vector.expandload.
372 class VectorExpandLoadOpConversion
373     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
374 public:
375   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
376 
377   LogicalResult
378   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
379                   ConversionPatternRewriter &rewriter) const override {
380     auto loc = expand->getLoc();
381     MemRefType memRefType = expand.getMemRefType();
382 
383     // Resolve address.
384     auto vtype = typeConverter->convertType(expand.getVectorType());
385     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
386                                      adaptor.getIndices(), rewriter);
387 
388     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
389         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
390     return success();
391   }
392 };
393 
394 /// Conversion pattern for a vector.compressstore.
395 class VectorCompressStoreOpConversion
396     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
397 public:
398   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
399 
400   LogicalResult
401   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
402                   ConversionPatternRewriter &rewriter) const override {
403     auto loc = compress->getLoc();
404     MemRefType memRefType = compress.getMemRefType();
405 
406     // Resolve address.
407     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
408                                      adaptor.getIndices(), rewriter);
409 
410     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
411         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
412     return success();
413   }
414 };
415 
416 /// Reduction neutral classes for overloading.
417 class ReductionNeutralZero {};
418 class ReductionNeutralIntOne {};
419 class ReductionNeutralFPOne {};
420 class ReductionNeutralAllOnes {};
421 class ReductionNeutralSIntMin {};
422 class ReductionNeutralUIntMin {};
423 class ReductionNeutralSIntMax {};
424 class ReductionNeutralUIntMax {};
425 class ReductionNeutralFPMin {};
426 class ReductionNeutralFPMax {};
427 
428 /// Create the reduction neutral zero value.
429 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
430                                          ConversionPatternRewriter &rewriter,
431                                          Location loc, Type llvmType) {
432   return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
433                                            rewriter.getZeroAttr(llvmType));
434 }
435 
436 /// Create the reduction neutral integer one value.
437 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
438                                          ConversionPatternRewriter &rewriter,
439                                          Location loc, Type llvmType) {
440   return rewriter.create<LLVM::ConstantOp>(
441       loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
442 }
443 
444 /// Create the reduction neutral fp one value.
445 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
446                                          ConversionPatternRewriter &rewriter,
447                                          Location loc, Type llvmType) {
448   return rewriter.create<LLVM::ConstantOp>(
449       loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
450 }
451 
452 /// Create the reduction neutral all-ones value.
453 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
454                                          ConversionPatternRewriter &rewriter,
455                                          Location loc, Type llvmType) {
456   return rewriter.create<LLVM::ConstantOp>(
457       loc, llvmType,
458       rewriter.getIntegerAttr(
459           llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
460 }
461 
462 /// Create the reduction neutral signed int minimum value.
463 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
464                                          ConversionPatternRewriter &rewriter,
465                                          Location loc, Type llvmType) {
466   return rewriter.create<LLVM::ConstantOp>(
467       loc, llvmType,
468       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
469                                             llvmType.getIntOrFloatBitWidth())));
470 }
471 
472 /// Create the reduction neutral unsigned int minimum value.
473 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
474                                          ConversionPatternRewriter &rewriter,
475                                          Location loc, Type llvmType) {
476   return rewriter.create<LLVM::ConstantOp>(
477       loc, llvmType,
478       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
479                                             llvmType.getIntOrFloatBitWidth())));
480 }
481 
482 /// Create the reduction neutral signed int maximum value.
483 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
484                                          ConversionPatternRewriter &rewriter,
485                                          Location loc, Type llvmType) {
486   return rewriter.create<LLVM::ConstantOp>(
487       loc, llvmType,
488       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
489                                             llvmType.getIntOrFloatBitWidth())));
490 }
491 
492 /// Create the reduction neutral unsigned int maximum value.
493 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
494                                          ConversionPatternRewriter &rewriter,
495                                          Location loc, Type llvmType) {
496   return rewriter.create<LLVM::ConstantOp>(
497       loc, llvmType,
498       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
499                                             llvmType.getIntOrFloatBitWidth())));
500 }
501 
502 /// Create the reduction neutral fp minimum value.
503 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
504                                          ConversionPatternRewriter &rewriter,
505                                          Location loc, Type llvmType) {
506   auto floatType = cast<FloatType>(llvmType);
507   return rewriter.create<LLVM::ConstantOp>(
508       loc, llvmType,
509       rewriter.getFloatAttr(
510           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
511                                            /*Negative=*/false)));
512 }
513 
514 /// Create the reduction neutral fp maximum value.
515 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
516                                          ConversionPatternRewriter &rewriter,
517                                          Location loc, Type llvmType) {
518   auto floatType = cast<FloatType>(llvmType);
519   return rewriter.create<LLVM::ConstantOp>(
520       loc, llvmType,
521       rewriter.getFloatAttr(
522           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
523                                            /*Negative=*/true)));
524 }
525 
526 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
527 /// returns a new accumulator value using `ReductionNeutral`.
528 template <class ReductionNeutral>
529 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
530                                     Location loc, Type llvmType,
531                                     Value accumulator) {
532   if (accumulator)
533     return accumulator;
534 
535   return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
536                                      llvmType);
537 }
538 
539 /// Creates a constant value with the 1-D vector shape provided in `llvmType`.
540 /// This is used as effective vector length by some intrinsics supporting
541 /// dynamic vector lengths at runtime.
542 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
543                                      Location loc, Type llvmType) {
544   VectorType vType = cast<VectorType>(llvmType);
545   auto vShape = vType.getShape();
546   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
547 
548   return rewriter.create<LLVM::ConstantOp>(
549       loc, rewriter.getI32Type(),
550       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
551 }
552 
553 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
554 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
555 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
556 /// non-null.
557 template <class LLVMRedIntrinOp, class ScalarOp>
558 static Value createIntegerReductionArithmeticOpLowering(
559     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
560     Value vectorOperand, Value accumulator) {
561 
562   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
563 
564   if (accumulator)
565     result = rewriter.create<ScalarOp>(loc, accumulator, result);
566   return result;
567 }
568 
569 /// Helper method to lower a `vector.reduction` operation that performs
570 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
571 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
572 /// the accumulator value if non-null.
573 template <class LLVMRedIntrinOp>
574 static Value createIntegerReductionComparisonOpLowering(
575     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
576     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
577   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
578   if (accumulator) {
579     Value cmp =
580         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
582   }
583   return result;
584 }
585 
586 namespace {
587 template <typename Source>
588 struct VectorToScalarMapper;
589 template <>
590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
591   using Type = LLVM::MaximumOp;
592 };
593 template <>
594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
595   using Type = LLVM::MinimumOp;
596 };
597 template <>
598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
599   using Type = LLVM::MaxNumOp;
600 };
601 template <>
602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
603   using Type = LLVM::MinNumOp;
604 };
605 } // namespace
606 
607 template <class LLVMRedIntrinOp>
608 static Value createFPReductionComparisonOpLowering(
609     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
610     Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
611   Value result =
612       rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
613 
614   if (accumulator) {
615     result =
616         rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
617             loc, result, accumulator);
618   }
619 
620   return result;
621 }
622 
623 /// Reduction neutral classes for overloading
624 class MaskNeutralFMaximum {};
625 class MaskNeutralFMinimum {};
626 
627 /// Get the mask neutral floating point maximum value
628 static llvm::APFloat
629 getMaskNeutralValue(MaskNeutralFMaximum,
630                     const llvm::fltSemantics &floatSemantics) {
631   return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
632 }
633 /// Get the mask neutral floating point minimum value
634 static llvm::APFloat
635 getMaskNeutralValue(MaskNeutralFMinimum,
636                     const llvm::fltSemantics &floatSemantics) {
637   return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
638 }
639 
640 /// Create the mask neutral floating point MLIR vector constant
641 template <typename MaskNeutral>
642 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
643                                     Location loc, Type llvmType,
644                                     Type vectorType) {
645   const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
646   auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
647   auto denseValue =
648       DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
649   return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
650 }
651 
652 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
653 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
654 /// `fmaximum`/`fminimum`.
655 /// More information: https://github.com/llvm/llvm-project/issues/64940
656 template <class LLVMRedIntrinOp, class MaskNeutral>
657 static Value
658 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
659                                 Location loc, Type llvmType,
660                                 Value vectorOperand, Value accumulator,
661                                 Value mask, LLVM::FastmathFlagsAttr fmf) {
662   const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
663       rewriter, loc, llvmType, vectorOperand.getType());
664   const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
665       loc, mask, vectorOperand, vectorMaskNeutral);
666   return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
667       rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
668 }
669 
670 template <class LLVMRedIntrinOp, class ReductionNeutral>
671 static Value
672 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
673                              Type llvmType, Value vectorOperand,
674                              Value accumulator, LLVM::FastmathFlagsAttr fmf) {
675   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
676                                                          llvmType, accumulator);
677   return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
678                                           /*startValue=*/accumulator,
679                                           vectorOperand, fmf);
680 }
681 
682 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
683 /// that requires a start value. This start value format spans across fp
684 /// reductions without mask and all the masked reduction intrinsics.
685 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
686 static Value
687 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
688                                        Location loc, Type llvmType,
689                                        Value vectorOperand, Value accumulator) {
690   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
691                                                          llvmType, accumulator);
692   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
693                                             /*startValue=*/accumulator,
694                                             vectorOperand);
695 }
696 
697 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
698 static Value lowerPredicatedReductionWithStartValue(
699     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
700     Value vectorOperand, Value accumulator, Value mask) {
701   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
702                                                          llvmType, accumulator);
703   Value vectorLength =
704       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
705   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
706                                             /*startValue=*/accumulator,
707                                             vectorOperand, mask, vectorLength);
708 }
709 
710 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
711           class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
712 static Value lowerPredicatedReductionWithStartValue(
713     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
714     Value vectorOperand, Value accumulator, Value mask) {
715   if (llvmType.isIntOrIndex())
716     return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
717                                                   IntReductionNeutral>(
718         rewriter, loc, llvmType, vectorOperand, accumulator, mask);
719 
720   // FP dispatch.
721   return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
722                                                 FPReductionNeutral>(
723       rewriter, loc, llvmType, vectorOperand, accumulator, mask);
724 }
725 
726 /// Conversion pattern for all vector reductions.
727 class VectorReductionOpConversion
728     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
729 public:
730   explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
731                                        bool reassociateFPRed)
732       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
733         reassociateFPReductions(reassociateFPRed) {}
734 
735   LogicalResult
736   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
737                   ConversionPatternRewriter &rewriter) const override {
738     auto kind = reductionOp.getKind();
739     Type eltType = reductionOp.getDest().getType();
740     Type llvmType = typeConverter->convertType(eltType);
741     Value operand = adaptor.getVector();
742     Value acc = adaptor.getAcc();
743     Location loc = reductionOp.getLoc();
744 
745     if (eltType.isIntOrIndex()) {
746       // Integer reductions: add/mul/min/max/and/or/xor.
747       Value result;
748       switch (kind) {
749       case vector::CombiningKind::ADD:
750         result =
751             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
752                                                        LLVM::AddOp>(
753                 rewriter, loc, llvmType, operand, acc);
754         break;
755       case vector::CombiningKind::MUL:
756         result =
757             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
758                                                        LLVM::MulOp>(
759                 rewriter, loc, llvmType, operand, acc);
760         break;
761       case vector::CombiningKind::MINUI:
762         result = createIntegerReductionComparisonOpLowering<
763             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
764                                       LLVM::ICmpPredicate::ule);
765         break;
766       case vector::CombiningKind::MINSI:
767         result = createIntegerReductionComparisonOpLowering<
768             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
769                                       LLVM::ICmpPredicate::sle);
770         break;
771       case vector::CombiningKind::MAXUI:
772         result = createIntegerReductionComparisonOpLowering<
773             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
774                                       LLVM::ICmpPredicate::uge);
775         break;
776       case vector::CombiningKind::MAXSI:
777         result = createIntegerReductionComparisonOpLowering<
778             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
779                                       LLVM::ICmpPredicate::sge);
780         break;
781       case vector::CombiningKind::AND:
782         result =
783             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
784                                                        LLVM::AndOp>(
785                 rewriter, loc, llvmType, operand, acc);
786         break;
787       case vector::CombiningKind::OR:
788         result =
789             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
790                                                        LLVM::OrOp>(
791                 rewriter, loc, llvmType, operand, acc);
792         break;
793       case vector::CombiningKind::XOR:
794         result =
795             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
796                                                        LLVM::XOrOp>(
797                 rewriter, loc, llvmType, operand, acc);
798         break;
799       default:
800         return failure();
801       }
802       rewriter.replaceOp(reductionOp, result);
803 
804       return success();
805     }
806 
807     if (!isa<FloatType>(eltType))
808       return failure();
809 
810     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
811     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
812         reductionOp.getContext(),
813         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
814     fmf = LLVM::FastmathFlagsAttr::get(
815         reductionOp.getContext(),
816         fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
817                                                   : LLVM::FastmathFlags::none));
818 
819     // Floating-point reductions: add/mul/min/max
820     Value result;
821     if (kind == vector::CombiningKind::ADD) {
822       result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
823                                             ReductionNeutralZero>(
824           rewriter, loc, llvmType, operand, acc, fmf);
825     } else if (kind == vector::CombiningKind::MUL) {
826       result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
827                                             ReductionNeutralFPOne>(
828           rewriter, loc, llvmType, operand, acc, fmf);
829     } else if (kind == vector::CombiningKind::MINIMUMF) {
830       result =
831           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
832               rewriter, loc, llvmType, operand, acc, fmf);
833     } else if (kind == vector::CombiningKind::MAXIMUMF) {
834       result =
835           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
836               rewriter, loc, llvmType, operand, acc, fmf);
837     } else if (kind == vector::CombiningKind::MINF) {
838       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
839           rewriter, loc, llvmType, operand, acc, fmf);
840     } else if (kind == vector::CombiningKind::MAXF) {
841       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
842           rewriter, loc, llvmType, operand, acc, fmf);
843     } else
844       return failure();
845 
846     rewriter.replaceOp(reductionOp, result);
847     return success();
848   }
849 
850 private:
851   const bool reassociateFPReductions;
852 };
853 
854 /// Base class to convert a `vector.mask` operation while matching traits
855 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
856 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
857 /// method performs a second match against the maskable operation `MaskedOp`.
858 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
859 /// implemented by the concrete conversion classes. This method can match
860 /// against specific traits of the `vector.mask` and the maskable operation. It
861 /// must replace the `vector.mask` operation.
862 template <class MaskedOp>
863 class VectorMaskOpConversionBase
864     : public ConvertOpToLLVMPattern<vector::MaskOp> {
865 public:
866   using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
867 
868   LogicalResult
869   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
870                   ConversionPatternRewriter &rewriter) const final {
871     // Match against the maskable operation kind.
872     auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
873     if (!maskedOp)
874       return failure();
875     return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
876   }
877 
878 protected:
879   virtual LogicalResult
880   matchAndRewriteMaskableOp(vector::MaskOp maskOp,
881                             vector::MaskableOpInterface maskableOp,
882                             ConversionPatternRewriter &rewriter) const = 0;
883 };
884 
885 class MaskedReductionOpConversion
886     : public VectorMaskOpConversionBase<vector::ReductionOp> {
887 
888 public:
889   using VectorMaskOpConversionBase<
890       vector::ReductionOp>::VectorMaskOpConversionBase;
891 
892   LogicalResult matchAndRewriteMaskableOp(
893       vector::MaskOp maskOp, MaskableOpInterface maskableOp,
894       ConversionPatternRewriter &rewriter) const override {
895     auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
896     auto kind = reductionOp.getKind();
897     Type eltType = reductionOp.getDest().getType();
898     Type llvmType = typeConverter->convertType(eltType);
899     Value operand = reductionOp.getVector();
900     Value acc = reductionOp.getAcc();
901     Location loc = reductionOp.getLoc();
902 
903     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
904     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
905         reductionOp.getContext(),
906         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
907 
908     Value result;
909     switch (kind) {
910     case vector::CombiningKind::ADD:
911       result = lowerPredicatedReductionWithStartValue<
912           LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
913           ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
914                                 maskOp.getMask());
915       break;
916     case vector::CombiningKind::MUL:
917       result = lowerPredicatedReductionWithStartValue<
918           LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
919           ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
920                                  maskOp.getMask());
921       break;
922     case vector::CombiningKind::MINUI:
923       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
924                                                       ReductionNeutralUIntMax>(
925           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
926       break;
927     case vector::CombiningKind::MINSI:
928       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
929                                                       ReductionNeutralSIntMax>(
930           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
931       break;
932     case vector::CombiningKind::MAXUI:
933       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
934                                                       ReductionNeutralUIntMin>(
935           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
936       break;
937     case vector::CombiningKind::MAXSI:
938       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
939                                                       ReductionNeutralSIntMin>(
940           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
941       break;
942     case vector::CombiningKind::AND:
943       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
944                                                       ReductionNeutralAllOnes>(
945           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
946       break;
947     case vector::CombiningKind::OR:
948       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
949                                                       ReductionNeutralZero>(
950           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
951       break;
952     case vector::CombiningKind::XOR:
953       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
954                                                       ReductionNeutralZero>(
955           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
956       break;
957     case vector::CombiningKind::MINF:
958       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
959                                                       ReductionNeutralFPMax>(
960           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
961       break;
962     case vector::CombiningKind::MAXF:
963       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
964                                                       ReductionNeutralFPMin>(
965           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
966       break;
967     case CombiningKind::MAXIMUMF:
968       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
969                                                MaskNeutralFMaximum>(
970           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
971       break;
972     case CombiningKind::MINIMUMF:
973       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
974                                                MaskNeutralFMinimum>(
975           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
976       break;
977     }
978 
979     // Replace `vector.mask` operation altogether.
980     rewriter.replaceOp(maskOp, result);
981     return success();
982   }
983 };
984 
985 class VectorShuffleOpConversion
986     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
987 public:
988   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
989 
990   LogicalResult
991   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
992                   ConversionPatternRewriter &rewriter) const override {
993     auto loc = shuffleOp->getLoc();
994     auto v1Type = shuffleOp.getV1VectorType();
995     auto v2Type = shuffleOp.getV2VectorType();
996     auto vectorType = shuffleOp.getResultVectorType();
997     Type llvmType = typeConverter->convertType(vectorType);
998     auto maskArrayAttr = shuffleOp.getMask();
999 
1000     // Bail if result type cannot be lowered.
1001     if (!llvmType)
1002       return failure();
1003 
1004     // Get rank and dimension sizes.
1005     int64_t rank = vectorType.getRank();
1006 #ifndef NDEBUG
1007     bool wellFormed0DCase =
1008         v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1009     bool wellFormedNDCase =
1010         v1Type.getRank() == rank && v2Type.getRank() == rank;
1011     assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1012 #endif
1013 
1014     // For rank 0 and 1, where both operands have *exactly* the same vector
1015     // type, there is direct shuffle support in LLVM. Use it!
1016     if (rank <= 1 && v1Type == v2Type) {
1017       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1018           loc, adaptor.getV1(), adaptor.getV2(),
1019           LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1020       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1021       return success();
1022     }
1023 
1024     // For all other cases, insert the individual values individually.
1025     int64_t v1Dim = v1Type.getDimSize(0);
1026     Type eltType;
1027     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1028       eltType = arrayType.getElementType();
1029     else
1030       eltType = cast<VectorType>(llvmType).getElementType();
1031     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1032     int64_t insPos = 0;
1033     for (const auto &en : llvm::enumerate(maskArrayAttr)) {
1034       int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1035       Value value = adaptor.getV1();
1036       if (extPos >= v1Dim) {
1037         extPos -= v1Dim;
1038         value = adaptor.getV2();
1039       }
1040       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1041                                  eltType, rank, extPos);
1042       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1043                          llvmType, rank, insPos++);
1044     }
1045     rewriter.replaceOp(shuffleOp, insert);
1046     return success();
1047   }
1048 };
1049 
1050 class VectorExtractElementOpConversion
1051     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1052 public:
1053   using ConvertOpToLLVMPattern<
1054       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1055 
1056   LogicalResult
1057   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1058                   ConversionPatternRewriter &rewriter) const override {
1059     auto vectorType = extractEltOp.getSourceVectorType();
1060     auto llvmType = typeConverter->convertType(vectorType.getElementType());
1061 
1062     // Bail if result type cannot be lowered.
1063     if (!llvmType)
1064       return failure();
1065 
1066     if (vectorType.getRank() == 0) {
1067       Location loc = extractEltOp.getLoc();
1068       auto idxType = rewriter.getIndexType();
1069       auto zero = rewriter.create<LLVM::ConstantOp>(
1070           loc, typeConverter->convertType(idxType),
1071           rewriter.getIntegerAttr(idxType, 0));
1072       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1073           extractEltOp, llvmType, adaptor.getVector(), zero);
1074       return success();
1075     }
1076 
1077     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1078         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1079     return success();
1080   }
1081 };
1082 
1083 class VectorExtractOpConversion
1084     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1085 public:
1086   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1087 
1088   LogicalResult
1089   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1090                   ConversionPatternRewriter &rewriter) const override {
1091     auto loc = extractOp->getLoc();
1092     auto resultType = extractOp.getResult().getType();
1093     auto llvmResultType = typeConverter->convertType(resultType);
1094     // Bail if result type cannot be lowered.
1095     if (!llvmResultType)
1096       return failure();
1097 
1098     SmallVector<OpFoldResult> positionVec;
1099     for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1100       if (pos.is<Value>())
1101         // Make sure we use the value that has been already converted to LLVM.
1102         positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1103       else
1104         positionVec.push_back(pos);
1105     }
1106 
1107     // Extract entire vector. Should be handled by folder, but just to be safe.
1108     ArrayRef<OpFoldResult> position(positionVec);
1109     if (position.empty()) {
1110       rewriter.replaceOp(extractOp, adaptor.getVector());
1111       return success();
1112     }
1113 
1114     // One-shot extraction of vector from array (only requires extractvalue).
1115     if (isa<VectorType>(resultType)) {
1116       if (extractOp.hasDynamicPosition())
1117         return failure();
1118 
1119       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1120           loc, adaptor.getVector(), getAsIntegers(position));
1121       rewriter.replaceOp(extractOp, extracted);
1122       return success();
1123     }
1124 
1125     // Potential extraction of 1-D vector from array.
1126     Value extracted = adaptor.getVector();
1127     if (position.size() > 1) {
1128       if (extractOp.hasDynamicPosition())
1129         return failure();
1130 
1131       SmallVector<int64_t> nMinusOnePosition =
1132           getAsIntegers(position.drop_back());
1133       extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1134                                                         nMinusOnePosition);
1135     }
1136 
1137     Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1138     // Remaining extraction of element from 1-D LLVM vector.
1139     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1140                                                         lastPosition);
1141     return success();
1142   }
1143 };
1144 
1145 /// Conversion pattern that turns a vector.fma on a 1-D vector
1146 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1147 /// This does not match vectors of n >= 2 rank.
1148 ///
1149 /// Example:
1150 /// ```
1151 ///  vector.fma %a, %a, %a : vector<8xf32>
1152 /// ```
1153 /// is converted to:
1154 /// ```
1155 ///  llvm.intr.fmuladd %va, %va, %va:
1156 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1157 ///    -> !llvm."<8 x f32>">
1158 /// ```
1159 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1160 public:
1161   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1162 
1163   LogicalResult
1164   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1165                   ConversionPatternRewriter &rewriter) const override {
1166     VectorType vType = fmaOp.getVectorType();
1167     if (vType.getRank() > 1)
1168       return failure();
1169 
1170     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1171         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1172     return success();
1173   }
1174 };
1175 
1176 class VectorInsertElementOpConversion
1177     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1178 public:
1179   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1180 
1181   LogicalResult
1182   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1183                   ConversionPatternRewriter &rewriter) const override {
1184     auto vectorType = insertEltOp.getDestVectorType();
1185     auto llvmType = typeConverter->convertType(vectorType);
1186 
1187     // Bail if result type cannot be lowered.
1188     if (!llvmType)
1189       return failure();
1190 
1191     if (vectorType.getRank() == 0) {
1192       Location loc = insertEltOp.getLoc();
1193       auto idxType = rewriter.getIndexType();
1194       auto zero = rewriter.create<LLVM::ConstantOp>(
1195           loc, typeConverter->convertType(idxType),
1196           rewriter.getIntegerAttr(idxType, 0));
1197       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1198           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1199       return success();
1200     }
1201 
1202     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1203         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1204         adaptor.getPosition());
1205     return success();
1206   }
1207 };
1208 
1209 class VectorInsertOpConversion
1210     : public ConvertOpToLLVMPattern<vector::InsertOp> {
1211 public:
1212   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1213 
1214   LogicalResult
1215   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1216                   ConversionPatternRewriter &rewriter) const override {
1217     auto loc = insertOp->getLoc();
1218     auto sourceType = insertOp.getSourceType();
1219     auto destVectorType = insertOp.getDestVectorType();
1220     auto llvmResultType = typeConverter->convertType(destVectorType);
1221     // Bail if result type cannot be lowered.
1222     if (!llvmResultType)
1223       return failure();
1224 
1225     SmallVector<OpFoldResult> positionVec;
1226     for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
1227       if (pos.is<Value>())
1228         // Make sure we use the value that has been already converted to LLVM.
1229         positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1230       else
1231         positionVec.push_back(pos);
1232     }
1233 
1234     // Overwrite entire vector with value. Should be handled by folder, but
1235     // just to be safe.
1236     ArrayRef<OpFoldResult> position(positionVec);
1237     if (position.empty()) {
1238       rewriter.replaceOp(insertOp, adaptor.getSource());
1239       return success();
1240     }
1241 
1242     // One-shot insertion of a vector into an array (only requires insertvalue).
1243     if (isa<VectorType>(sourceType)) {
1244       if (insertOp.hasDynamicPosition())
1245         return failure();
1246 
1247       Value inserted = rewriter.create<LLVM::InsertValueOp>(
1248           loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1249       rewriter.replaceOp(insertOp, inserted);
1250       return success();
1251     }
1252 
1253     // Potential extraction of 1-D vector from array.
1254     Value extracted = adaptor.getDest();
1255     auto oneDVectorType = destVectorType;
1256     if (position.size() > 1) {
1257       if (insertOp.hasDynamicPosition())
1258         return failure();
1259 
1260       oneDVectorType = reducedVectorTypeBack(destVectorType);
1261       extracted = rewriter.create<LLVM::ExtractValueOp>(
1262           loc, extracted, getAsIntegers(position.drop_back()));
1263     }
1264 
1265     // Insertion of an element into a 1-D LLVM vector.
1266     Value inserted = rewriter.create<LLVM::InsertElementOp>(
1267         loc, typeConverter->convertType(oneDVectorType), extracted,
1268         adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1269 
1270     // Potential insertion of resulting 1-D vector into array.
1271     if (position.size() > 1) {
1272       if (insertOp.hasDynamicPosition())
1273         return failure();
1274 
1275       inserted = rewriter.create<LLVM::InsertValueOp>(
1276           loc, adaptor.getDest(), inserted,
1277           getAsIntegers(position.drop_back()));
1278     }
1279 
1280     rewriter.replaceOp(insertOp, inserted);
1281     return success();
1282   }
1283 };
1284 
1285 /// Lower vector.scalable.insert ops to LLVM vector.insert
1286 struct VectorScalableInsertOpLowering
1287     : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1288   using ConvertOpToLLVMPattern<
1289       vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1290 
1291   LogicalResult
1292   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1293                   ConversionPatternRewriter &rewriter) const override {
1294     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1295         insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
1296     return success();
1297   }
1298 };
1299 
1300 /// Lower vector.scalable.extract ops to LLVM vector.extract
1301 struct VectorScalableExtractOpLowering
1302     : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1303   using ConvertOpToLLVMPattern<
1304       vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1305 
1306   LogicalResult
1307   matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1308                   ConversionPatternRewriter &rewriter) const override {
1309     rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1310         extOp, typeConverter->convertType(extOp.getResultVectorType()),
1311         adaptor.getSource(), adaptor.getPos());
1312     return success();
1313   }
1314 };
1315 
1316 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1317 ///
1318 /// Example:
1319 /// ```
1320 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
1321 /// ```
1322 /// is rewritten into:
1323 /// ```
1324 ///  %r = splat %f0: vector<2x4xf32>
1325 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
1326 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1327 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1328 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1329 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1330 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1331 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1332 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1333 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1334 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1335 ///  // %r3 holds the final value.
1336 /// ```
1337 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1338 public:
1339   using OpRewritePattern<FMAOp>::OpRewritePattern;
1340 
1341   void initialize() {
1342     // This pattern recursively unpacks one dimension at a time. The recursion
1343     // bounded as the rank is strictly decreasing.
1344     setHasBoundedRewriteRecursion();
1345   }
1346 
1347   LogicalResult matchAndRewrite(FMAOp op,
1348                                 PatternRewriter &rewriter) const override {
1349     auto vType = op.getVectorType();
1350     if (vType.getRank() < 2)
1351       return failure();
1352 
1353     auto loc = op.getLoc();
1354     auto elemType = vType.getElementType();
1355     Value zero = rewriter.create<arith::ConstantOp>(
1356         loc, elemType, rewriter.getZeroAttr(elemType));
1357     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1358     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1359       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1360       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1361       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1362       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1363       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1364     }
1365     rewriter.replaceOp(op, desc);
1366     return success();
1367   }
1368 };
1369 
1370 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1371 /// static layout.
1372 static std::optional<SmallVector<int64_t, 4>>
1373 computeContiguousStrides(MemRefType memRefType) {
1374   int64_t offset;
1375   SmallVector<int64_t, 4> strides;
1376   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1377     return std::nullopt;
1378   if (!strides.empty() && strides.back() != 1)
1379     return std::nullopt;
1380   // If no layout or identity layout, this is contiguous by definition.
1381   if (memRefType.getLayout().isIdentity())
1382     return strides;
1383 
1384   // Otherwise, we must determine contiguity form shapes. This can only ever
1385   // work in static cases because MemRefType is underspecified to represent
1386   // contiguous dynamic shapes in other ways than with just empty/identity
1387   // layout.
1388   auto sizes = memRefType.getShape();
1389   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1390     if (ShapedType::isDynamic(sizes[index + 1]) ||
1391         ShapedType::isDynamic(strides[index]) ||
1392         ShapedType::isDynamic(strides[index + 1]))
1393       return std::nullopt;
1394     if (strides[index] != strides[index + 1] * sizes[index + 1])
1395       return std::nullopt;
1396   }
1397   return strides;
1398 }
1399 
1400 class VectorTypeCastOpConversion
1401     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1402 public:
1403   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1404 
1405   LogicalResult
1406   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1407                   ConversionPatternRewriter &rewriter) const override {
1408     auto loc = castOp->getLoc();
1409     MemRefType sourceMemRefType =
1410         cast<MemRefType>(castOp.getOperand().getType());
1411     MemRefType targetMemRefType = castOp.getType();
1412 
1413     // Only static shape casts supported atm.
1414     if (!sourceMemRefType.hasStaticShape() ||
1415         !targetMemRefType.hasStaticShape())
1416       return failure();
1417 
1418     auto llvmSourceDescriptorTy =
1419         dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1420     if (!llvmSourceDescriptorTy)
1421       return failure();
1422     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1423 
1424     auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1425         typeConverter->convertType(targetMemRefType));
1426     if (!llvmTargetDescriptorTy)
1427       return failure();
1428 
1429     // Only contiguous source buffers supported atm.
1430     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1431     if (!sourceStrides)
1432       return failure();
1433     auto targetStrides = computeContiguousStrides(targetMemRefType);
1434     if (!targetStrides)
1435       return failure();
1436     // Only support static strides for now, regardless of contiguity.
1437     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1438       return failure();
1439 
1440     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1441 
1442     // Create descriptor.
1443     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1444     Type llvmTargetElementTy = desc.getElementPtrType();
1445     // Set allocated ptr.
1446     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1447     if (!getTypeConverter()->useOpaquePointers())
1448       allocated =
1449           rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1450     desc.setAllocatedPtr(rewriter, loc, allocated);
1451 
1452     // Set aligned ptr.
1453     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1454     if (!getTypeConverter()->useOpaquePointers())
1455       ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1456 
1457     desc.setAlignedPtr(rewriter, loc, ptr);
1458     // Fill offset 0.
1459     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1460     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1461     desc.setOffset(rewriter, loc, zero);
1462 
1463     // Fill size and stride descriptors in memref.
1464     for (const auto &indexedSize :
1465          llvm::enumerate(targetMemRefType.getShape())) {
1466       int64_t index = indexedSize.index();
1467       auto sizeAttr =
1468           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1469       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1470       desc.setSize(rewriter, loc, index, size);
1471       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1472                                                 (*targetStrides)[index]);
1473       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1474       desc.setStride(rewriter, loc, index, stride);
1475     }
1476 
1477     rewriter.replaceOp(castOp, {desc});
1478     return success();
1479   }
1480 };
1481 
1482 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1483 /// Non-scalable versions of this operation are handled in Vector Transforms.
1484 class VectorCreateMaskOpRewritePattern
1485     : public OpRewritePattern<vector::CreateMaskOp> {
1486 public:
1487   explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1488                                             bool enableIndexOpt)
1489       : OpRewritePattern<vector::CreateMaskOp>(context),
1490         force32BitVectorIndices(enableIndexOpt) {}
1491 
1492   LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1493                                 PatternRewriter &rewriter) const override {
1494     auto dstType = op.getType();
1495     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1496       return failure();
1497     IntegerType idxType =
1498         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1499     auto loc = op->getLoc();
1500     Value indices = rewriter.create<LLVM::StepVectorOp>(
1501         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1502                                  /*isScalable=*/true));
1503     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1504                                                  op.getOperand(0));
1505     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1506     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1507                                                 indices, bounds);
1508     rewriter.replaceOp(op, comp);
1509     return success();
1510   }
1511 
1512 private:
1513   const bool force32BitVectorIndices;
1514 };
1515 
1516 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1517 public:
1518   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1519 
1520   // Lowering implementation that relies on a small runtime support library,
1521   // which only needs to provide a few printing methods (single value for all
1522   // data types, opening/closing bracket, comma, newline). The lowering splits
1523   // the vector into elementary printing operations. The advantage of this
1524   // approach is that the library can remain unaware of all low-level
1525   // implementation details of vectors while still supporting output of any
1526   // shaped and dimensioned vector.
1527   //
1528   // Note: This lowering only handles scalars, n-D vectors are broken into
1529   // printing scalars in loops in VectorToSCF.
1530   //
1531   // TODO: rely solely on libc in future? something else?
1532   //
1533   LogicalResult
1534   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1535                   ConversionPatternRewriter &rewriter) const override {
1536     auto parent = printOp->getParentOfType<ModuleOp>();
1537     if (!parent)
1538       return failure();
1539 
1540     auto loc = printOp->getLoc();
1541 
1542     if (auto value = adaptor.getSource()) {
1543       Type printType = printOp.getPrintType();
1544       if (isa<VectorType>(printType)) {
1545         // Vectors should be broken into elementary print ops in VectorToSCF.
1546         return failure();
1547       }
1548       if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1549         return failure();
1550     }
1551 
1552     auto punct = printOp.getPunctuation();
1553     if (punct != PrintPunctuation::NoPunctuation) {
1554       emitCall(rewriter, printOp->getLoc(), [&] {
1555         switch (punct) {
1556         case PrintPunctuation::Close:
1557           return LLVM::lookupOrCreatePrintCloseFn(parent);
1558         case PrintPunctuation::Open:
1559           return LLVM::lookupOrCreatePrintOpenFn(parent);
1560         case PrintPunctuation::Comma:
1561           return LLVM::lookupOrCreatePrintCommaFn(parent);
1562         case PrintPunctuation::NewLine:
1563           return LLVM::lookupOrCreatePrintNewlineFn(parent);
1564         default:
1565           llvm_unreachable("unexpected punctuation");
1566         }
1567       }());
1568     }
1569 
1570     rewriter.eraseOp(printOp);
1571     return success();
1572   }
1573 
1574 private:
1575   enum class PrintConversion {
1576     // clang-format off
1577     None,
1578     ZeroExt64,
1579     SignExt64,
1580     Bitcast16
1581     // clang-format on
1582   };
1583 
1584   LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1585                                 ModuleOp parent, Location loc, Type printType,
1586                                 Value value) const {
1587     if (typeConverter->convertType(printType) == nullptr)
1588       return failure();
1589 
1590     // Make sure element type has runtime support.
1591     PrintConversion conversion = PrintConversion::None;
1592     Operation *printer;
1593     if (printType.isF32()) {
1594       printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1595     } else if (printType.isF64()) {
1596       printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1597     } else if (printType.isF16()) {
1598       conversion = PrintConversion::Bitcast16; // bits!
1599       printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1600     } else if (printType.isBF16()) {
1601       conversion = PrintConversion::Bitcast16; // bits!
1602       printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1603     } else if (printType.isIndex()) {
1604       printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1605     } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1606       // Integers need a zero or sign extension on the operand
1607       // (depending on the source type) as well as a signed or
1608       // unsigned print method. Up to 64-bit is supported.
1609       unsigned width = intTy.getWidth();
1610       if (intTy.isUnsigned()) {
1611         if (width <= 64) {
1612           if (width < 64)
1613             conversion = PrintConversion::ZeroExt64;
1614           printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1615         } else {
1616           return failure();
1617         }
1618       } else {
1619         assert(intTy.isSignless() || intTy.isSigned());
1620         if (width <= 64) {
1621           // Note that we *always* zero extend booleans (1-bit integers),
1622           // so that true/false is printed as 1/0 rather than -1/0.
1623           if (width == 1)
1624             conversion = PrintConversion::ZeroExt64;
1625           else if (width < 64)
1626             conversion = PrintConversion::SignExt64;
1627           printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1628         } else {
1629           return failure();
1630         }
1631       }
1632     } else {
1633       return failure();
1634     }
1635 
1636     switch (conversion) {
1637     case PrintConversion::ZeroExt64:
1638       value = rewriter.create<arith::ExtUIOp>(
1639           loc, IntegerType::get(rewriter.getContext(), 64), value);
1640       break;
1641     case PrintConversion::SignExt64:
1642       value = rewriter.create<arith::ExtSIOp>(
1643           loc, IntegerType::get(rewriter.getContext(), 64), value);
1644       break;
1645     case PrintConversion::Bitcast16:
1646       value = rewriter.create<LLVM::BitcastOp>(
1647           loc, IntegerType::get(rewriter.getContext(), 16), value);
1648       break;
1649     case PrintConversion::None:
1650       break;
1651     }
1652     emitCall(rewriter, loc, printer, value);
1653     return success();
1654   }
1655 
1656   // Helper to emit a call.
1657   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1658                        Operation *ref, ValueRange params = ValueRange()) {
1659     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1660                                   params);
1661   }
1662 };
1663 
1664 /// The Splat operation is lowered to an insertelement + a shufflevector
1665 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1666 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1667   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1668 
1669   LogicalResult
1670   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1671                   ConversionPatternRewriter &rewriter) const override {
1672     VectorType resultType = cast<VectorType>(splatOp.getType());
1673     if (resultType.getRank() > 1)
1674       return failure();
1675 
1676     // First insert it into an undef vector so we can shuffle it.
1677     auto vectorType = typeConverter->convertType(splatOp.getType());
1678     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1679     auto zero = rewriter.create<LLVM::ConstantOp>(
1680         splatOp.getLoc(),
1681         typeConverter->convertType(rewriter.getIntegerType(32)),
1682         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1683 
1684     // For 0-d vector, we simply do `insertelement`.
1685     if (resultType.getRank() == 0) {
1686       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1687           splatOp, vectorType, undef, adaptor.getInput(), zero);
1688       return success();
1689     }
1690 
1691     // For 1-d vector, we additionally do a `vectorshuffle`.
1692     auto v = rewriter.create<LLVM::InsertElementOp>(
1693         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1694 
1695     int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1696     SmallVector<int32_t> zeroValues(width, 0);
1697 
1698     // Shuffle the value across the desired number of elements.
1699     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1700                                                        zeroValues);
1701     return success();
1702   }
1703 };
1704 
1705 /// The Splat operation is lowered to an insertelement + a shufflevector
1706 /// operation. Splat to only 2+-d vector result types are lowered by the
1707 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1708 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1709   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1710 
1711   LogicalResult
1712   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1713                   ConversionPatternRewriter &rewriter) const override {
1714     VectorType resultType = splatOp.getType();
1715     if (resultType.getRank() <= 1)
1716       return failure();
1717 
1718     // First insert it into an undef vector so we can shuffle it.
1719     auto loc = splatOp.getLoc();
1720     auto vectorTypeInfo =
1721         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1722     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1723     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1724     if (!llvmNDVectorTy || !llvm1DVectorTy)
1725       return failure();
1726 
1727     // Construct returned value.
1728     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1729 
1730     // Construct a 1-D vector with the splatted value that we insert in all the
1731     // places within the returned descriptor.
1732     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1733     auto zero = rewriter.create<LLVM::ConstantOp>(
1734         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1735         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1736     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1737                                                      adaptor.getInput(), zero);
1738 
1739     // Shuffle the value across the desired number of elements.
1740     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1741     SmallVector<int32_t> zeroValues(width, 0);
1742     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1743 
1744     // Iterate of linear index, convert to coords space and insert splatted 1-D
1745     // vector in each position.
1746     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1747       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1748     });
1749     rewriter.replaceOp(splatOp, desc);
1750     return success();
1751   }
1752 };
1753 
1754 } // namespace
1755 
1756 /// Populate the given list with patterns that convert from Vector to LLVM.
1757 void mlir::populateVectorToLLVMConversionPatterns(
1758     LLVMTypeConverter &converter, RewritePatternSet &patterns,
1759     bool reassociateFPReductions, bool force32BitVectorIndices) {
1760   MLIRContext *ctx = converter.getDialect()->getContext();
1761   patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1762   populateVectorInsertExtractStridedSliceTransforms(patterns);
1763   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1764   patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1765   patterns
1766       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1767            VectorExtractElementOpConversion, VectorExtractOpConversion,
1768            VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1769            VectorInsertOpConversion, VectorPrintOpConversion,
1770            VectorTypeCastOpConversion, VectorScaleOpConversion,
1771            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1772            VectorLoadStoreConversion<vector::MaskedLoadOp,
1773                                      vector::MaskedLoadOpAdaptor>,
1774            VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1775            VectorLoadStoreConversion<vector::MaskedStoreOp,
1776                                      vector::MaskedStoreOpAdaptor>,
1777            VectorGatherOpConversion, VectorScatterOpConversion,
1778            VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1779            VectorSplatOpLowering, VectorSplatNdOpLowering,
1780            VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1781            MaskedReductionOpConversion>(converter);
1782   // Transfer ops with rank > 1 are handled by VectorToSCF.
1783   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1784 }
1785 
1786 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1787     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1788   patterns.add<VectorMatmulOpConversion>(converter);
1789   patterns.add<VectorFlatTransposeOpConversion>(converter);
1790 }
1791