xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Utils/Utils.h"
17 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
23 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24 #include "mlir/IR/BuiltinAttributes.h"
25 #include "mlir/IR/BuiltinTypeInterfaces.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39   assert((tp.getRank() > 1) && "unlowerable vector type");
40   return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41                          tp.getScalableDims().take_back());
42 }
43 
44 // Helper that picks the proper sequence for inserting.
45 static Value insertOne(ConversionPatternRewriter &rewriter,
46                        const LLVMTypeConverter &typeConverter, Location loc,
47                        Value val1, Value val2, Type llvmType, int64_t rank,
48                        int64_t pos) {
49   assert(rank > 0 && "0-D vector corner case should have been handled already");
50   if (rank == 1) {
51     auto idxType = rewriter.getIndexType();
52     auto constant = rewriter.create<LLVM::ConstantOp>(
53         loc, typeConverter.convertType(idxType),
54         rewriter.getIntegerAttr(idxType, pos));
55     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
56                                                   constant);
57   }
58   return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
59 }
60 
61 // Helper that picks the proper sequence for extracting.
62 static Value extractOne(ConversionPatternRewriter &rewriter,
63                         const LLVMTypeConverter &typeConverter, Location loc,
64                         Value val, Type llvmType, int64_t rank, int64_t pos) {
65   if (rank <= 1) {
66     auto idxType = rewriter.getIndexType();
67     auto constant = rewriter.create<LLVM::ConstantOp>(
68         loc, typeConverter.convertType(idxType),
69         rewriter.getIntegerAttr(idxType, pos));
70     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
71                                                    constant);
72   }
73   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
74 }
75 
76 // Helper that returns data layout alignment of a memref.
77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
78                                  MemRefType memrefType, unsigned &align) {
79   Type elementTy = typeConverter.convertType(memrefType.getElementType());
80   if (!elementTy)
81     return failure();
82 
83   // TODO: this should use the MLIR data layout when it becomes available and
84   // stop depending on translation.
85   llvm::LLVMContext llvmContext;
86   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
87               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
88   return success();
89 }
90 
91 // Check if the last stride is non-unit and has a valid memory space.
92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
93                                            const LLVMTypeConverter &converter) {
94   if (!memRefType.isLastDimUnitStride())
95     return failure();
96   if (failed(converter.getMemRefAddressSpace(memRefType)))
97     return failure();
98   return success();
99 }
100 
101 // Add an index vector component to a base pointer.
102 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
103                             const LLVMTypeConverter &typeConverter,
104                             MemRefType memRefType, Value llvmMemref, Value base,
105                             Value index, VectorType vectorType) {
106   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107          "unsupported memref type");
108   assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
109   auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
110   auto ptrsType =
111       LLVM::getVectorType(pType, vectorType.getDimSize(0),
112                           /*isScalable=*/vectorType.getScalableDims()[0]);
113   return rewriter.create<LLVM::GEPOp>(
114       loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
115       base, index);
116 }
117 
118 /// Convert `foldResult` into a Value. Integer attribute is converted to
119 /// an LLVM constant op.
120 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
121                             OpFoldResult foldResult) {
122   if (auto attr = foldResult.dyn_cast<Attribute>()) {
123     auto intAttr = cast<IntegerAttr>(attr);
124     return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
125   }
126 
127   return cast<Value>(foldResult);
128 }
129 
130 namespace {
131 
132 /// Trivial Vector to LLVM conversions
133 using VectorScaleOpConversion =
134     OneToOneConvertToLLVMPattern<vector::VectorScaleOp, LLVM::vscale>;
135 
136 /// Conversion pattern for a vector.bitcast.
137 class VectorBitCastOpConversion
138     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
139 public:
140   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
141 
142   LogicalResult
143   matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
144                   ConversionPatternRewriter &rewriter) const override {
145     // Only 0-D and 1-D vectors can be lowered to LLVM.
146     VectorType resultTy = bitCastOp.getResultVectorType();
147     if (resultTy.getRank() > 1)
148       return failure();
149     Type newResultTy = typeConverter->convertType(resultTy);
150     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
151                                                  adaptor.getOperands()[0]);
152     return success();
153   }
154 };
155 
156 /// Conversion pattern for a vector.matrix_multiply.
157 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
158 class VectorMatmulOpConversion
159     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
160 public:
161   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
162 
163   LogicalResult
164   matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
165                   ConversionPatternRewriter &rewriter) const override {
166     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
167         matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
168         adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
169         matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
170     return success();
171   }
172 };
173 
174 /// Conversion pattern for a vector.flat_transpose.
175 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
176 class VectorFlatTransposeOpConversion
177     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
178 public:
179   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
180 
181   LogicalResult
182   matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
183                   ConversionPatternRewriter &rewriter) const override {
184     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
185         transOp, typeConverter->convertType(transOp.getRes().getType()),
186         adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
187     return success();
188   }
189 };
190 
191 /// Overloaded utility that replaces a vector.load, vector.store,
192 /// vector.maskedload and vector.maskedstore with their respective LLVM
193 /// couterparts.
194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
195                                  vector::LoadOpAdaptor adaptor,
196                                  VectorType vectorTy, Value ptr, unsigned align,
197                                  ConversionPatternRewriter &rewriter) {
198   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
199                                             /*volatile_=*/false,
200                                             loadOp.getNontemporal());
201 }
202 
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
204                                  vector::MaskedLoadOpAdaptor adaptor,
205                                  VectorType vectorTy, Value ptr, unsigned align,
206                                  ConversionPatternRewriter &rewriter) {
207   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
208       loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209 }
210 
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
212                                  vector::StoreOpAdaptor adaptor,
213                                  VectorType vectorTy, Value ptr, unsigned align,
214                                  ConversionPatternRewriter &rewriter) {
215   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
216                                              ptr, align, /*volatile_=*/false,
217                                              storeOp.getNontemporal());
218 }
219 
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221                                  vector::MaskedStoreOpAdaptor adaptor,
222                                  VectorType vectorTy, Value ptr, unsigned align,
223                                  ConversionPatternRewriter &rewriter) {
224   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
225       storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
226 }
227 
228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
229 /// vector.maskedstore.
230 template <class LoadOrStoreOp>
231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
232 public:
233   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
234 
235   LogicalResult
236   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237                   typename LoadOrStoreOp::Adaptor adaptor,
238                   ConversionPatternRewriter &rewriter) const override {
239     // Only 1-D vectors can be lowered to LLVM.
240     VectorType vectorTy = loadOrStoreOp.getVectorType();
241     if (vectorTy.getRank() > 1)
242       return failure();
243 
244     auto loc = loadOrStoreOp->getLoc();
245     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
246 
247     // Resolve alignment.
248     unsigned align;
249     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
250       return failure();
251 
252     // Resolve address.
253     auto vtype = cast<VectorType>(
254         this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
255     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
256                                                adaptor.getIndices(), rewriter);
257     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
258                          rewriter);
259     return success();
260   }
261 };
262 
263 /// Conversion pattern for a vector.gather.
264 class VectorGatherOpConversion
265     : public ConvertOpToLLVMPattern<vector::GatherOp> {
266 public:
267   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
268 
269   LogicalResult
270   matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
271                   ConversionPatternRewriter &rewriter) const override {
272     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
273     assert(memRefType && "The base should be bufferized");
274 
275     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
276       return failure();
277 
278     auto loc = gather->getLoc();
279 
280     // Resolve alignment.
281     unsigned align;
282     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
283       return failure();
284 
285     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
286                                      adaptor.getIndices(), rewriter);
287     Value base = adaptor.getBase();
288 
289     auto llvmNDVectorTy = adaptor.getIndexVec().getType();
290     // Handle the simple case of 1-D vector.
291     if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
292       auto vType = gather.getVectorType();
293       // Resolve address.
294       Value ptrs =
295           getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
296                          base, ptr, adaptor.getIndexVec(), vType);
297       // Replace with the gather intrinsic.
298       rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
299           gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
300           adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
301       return success();
302     }
303 
304     const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
305     auto callback = [align, memRefType, base, ptr, loc, &rewriter,
306                      &typeConverter](Type llvm1DVectorTy,
307                                      ValueRange vectorOperands) {
308       // Resolve address.
309       Value ptrs = getIndexedPtrs(
310           rewriter, loc, typeConverter, memRefType, base, ptr,
311           /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
312       // Create the gather intrinsic.
313       return rewriter.create<LLVM::masked_gather>(
314           loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
315           /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
316     };
317     SmallVector<Value> vectorOperands = {
318         adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
319     return LLVM::detail::handleMultidimensionalVectors(
320         gather, vectorOperands, *getTypeConverter(), callback, rewriter);
321   }
322 };
323 
324 /// Conversion pattern for a vector.scatter.
325 class VectorScatterOpConversion
326     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
327 public:
328   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
329 
330   LogicalResult
331   matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
332                   ConversionPatternRewriter &rewriter) const override {
333     auto loc = scatter->getLoc();
334     MemRefType memRefType = scatter.getMemRefType();
335 
336     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
337       return failure();
338 
339     // Resolve alignment.
340     unsigned align;
341     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
342       return failure();
343 
344     // Resolve address.
345     VectorType vType = scatter.getVectorType();
346     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
347                                      adaptor.getIndices(), rewriter);
348     Value ptrs =
349         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
350                        adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
351 
352     // Replace with the scatter intrinsic.
353     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
354         scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
355         rewriter.getI32IntegerAttr(align));
356     return success();
357   }
358 };
359 
360 /// Conversion pattern for a vector.expandload.
361 class VectorExpandLoadOpConversion
362     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
363 public:
364   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
365 
366   LogicalResult
367   matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
368                   ConversionPatternRewriter &rewriter) const override {
369     auto loc = expand->getLoc();
370     MemRefType memRefType = expand.getMemRefType();
371 
372     // Resolve address.
373     auto vtype = typeConverter->convertType(expand.getVectorType());
374     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
375                                      adaptor.getIndices(), rewriter);
376 
377     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
378         expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
379     return success();
380   }
381 };
382 
383 /// Conversion pattern for a vector.compressstore.
384 class VectorCompressStoreOpConversion
385     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
386 public:
387   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
388 
389   LogicalResult
390   matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
391                   ConversionPatternRewriter &rewriter) const override {
392     auto loc = compress->getLoc();
393     MemRefType memRefType = compress.getMemRefType();
394 
395     // Resolve address.
396     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397                                      adaptor.getIndices(), rewriter);
398 
399     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
400         compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
401     return success();
402   }
403 };
404 
405 /// Reduction neutral classes for overloading.
406 class ReductionNeutralZero {};
407 class ReductionNeutralIntOne {};
408 class ReductionNeutralFPOne {};
409 class ReductionNeutralAllOnes {};
410 class ReductionNeutralSIntMin {};
411 class ReductionNeutralUIntMin {};
412 class ReductionNeutralSIntMax {};
413 class ReductionNeutralUIntMax {};
414 class ReductionNeutralFPMin {};
415 class ReductionNeutralFPMax {};
416 
417 /// Create the reduction neutral zero value.
418 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
419                                          ConversionPatternRewriter &rewriter,
420                                          Location loc, Type llvmType) {
421   return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
422                                            rewriter.getZeroAttr(llvmType));
423 }
424 
425 /// Create the reduction neutral integer one value.
426 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
427                                          ConversionPatternRewriter &rewriter,
428                                          Location loc, Type llvmType) {
429   return rewriter.create<LLVM::ConstantOp>(
430       loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
431 }
432 
433 /// Create the reduction neutral fp one value.
434 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
435                                          ConversionPatternRewriter &rewriter,
436                                          Location loc, Type llvmType) {
437   return rewriter.create<LLVM::ConstantOp>(
438       loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
439 }
440 
441 /// Create the reduction neutral all-ones value.
442 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
443                                          ConversionPatternRewriter &rewriter,
444                                          Location loc, Type llvmType) {
445   return rewriter.create<LLVM::ConstantOp>(
446       loc, llvmType,
447       rewriter.getIntegerAttr(
448           llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
449 }
450 
451 /// Create the reduction neutral signed int minimum value.
452 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
453                                          ConversionPatternRewriter &rewriter,
454                                          Location loc, Type llvmType) {
455   return rewriter.create<LLVM::ConstantOp>(
456       loc, llvmType,
457       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
458                                             llvmType.getIntOrFloatBitWidth())));
459 }
460 
461 /// Create the reduction neutral unsigned int minimum value.
462 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
463                                          ConversionPatternRewriter &rewriter,
464                                          Location loc, Type llvmType) {
465   return rewriter.create<LLVM::ConstantOp>(
466       loc, llvmType,
467       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
468                                             llvmType.getIntOrFloatBitWidth())));
469 }
470 
471 /// Create the reduction neutral signed int maximum value.
472 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
473                                          ConversionPatternRewriter &rewriter,
474                                          Location loc, Type llvmType) {
475   return rewriter.create<LLVM::ConstantOp>(
476       loc, llvmType,
477       rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
478                                             llvmType.getIntOrFloatBitWidth())));
479 }
480 
481 /// Create the reduction neutral unsigned int maximum value.
482 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
483                                          ConversionPatternRewriter &rewriter,
484                                          Location loc, Type llvmType) {
485   return rewriter.create<LLVM::ConstantOp>(
486       loc, llvmType,
487       rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
488                                             llvmType.getIntOrFloatBitWidth())));
489 }
490 
491 /// Create the reduction neutral fp minimum value.
492 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
493                                          ConversionPatternRewriter &rewriter,
494                                          Location loc, Type llvmType) {
495   auto floatType = cast<FloatType>(llvmType);
496   return rewriter.create<LLVM::ConstantOp>(
497       loc, llvmType,
498       rewriter.getFloatAttr(
499           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
500                                            /*Negative=*/false)));
501 }
502 
503 /// Create the reduction neutral fp maximum value.
504 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
505                                          ConversionPatternRewriter &rewriter,
506                                          Location loc, Type llvmType) {
507   auto floatType = cast<FloatType>(llvmType);
508   return rewriter.create<LLVM::ConstantOp>(
509       loc, llvmType,
510       rewriter.getFloatAttr(
511           llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
512                                            /*Negative=*/true)));
513 }
514 
515 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
516 /// returns a new accumulator value using `ReductionNeutral`.
517 template <class ReductionNeutral>
518 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
519                                     Location loc, Type llvmType,
520                                     Value accumulator) {
521   if (accumulator)
522     return accumulator;
523 
524   return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
525                                      llvmType);
526 }
527 
528 /// Creates a value with the 1-D vector shape provided in `llvmType`.
529 /// This is used as effective vector length by some intrinsics supporting
530 /// dynamic vector lengths at runtime.
531 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
532                                      Location loc, Type llvmType) {
533   VectorType vType = cast<VectorType>(llvmType);
534   auto vShape = vType.getShape();
535   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
536 
537   Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
538       loc, rewriter.getI32Type(),
539       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
540 
541   if (!vType.getScalableDims()[0])
542     return baseVecLength;
543 
544   // For a scalable vector type, create and return `vScale * baseVecLength`.
545   Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
546   vScale =
547       rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
548   Value scalableVecLength =
549       rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
550   return scalableVecLength;
551 }
552 
553 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
554 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
555 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
556 /// non-null.
557 template <class LLVMRedIntrinOp, class ScalarOp>
558 static Value createIntegerReductionArithmeticOpLowering(
559     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
560     Value vectorOperand, Value accumulator) {
561 
562   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
563 
564   if (accumulator)
565     result = rewriter.create<ScalarOp>(loc, accumulator, result);
566   return result;
567 }
568 
569 /// Helper method to lower a `vector.reduction` operation that performs
570 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
571 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
572 /// the accumulator value if non-null.
573 template <class LLVMRedIntrinOp>
574 static Value createIntegerReductionComparisonOpLowering(
575     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
576     Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
577   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
578   if (accumulator) {
579     Value cmp =
580         rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581     result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
582   }
583   return result;
584 }
585 
586 namespace {
587 template <typename Source>
588 struct VectorToScalarMapper;
589 template <>
590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
591   using Type = LLVM::MaximumOp;
592 };
593 template <>
594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
595   using Type = LLVM::MinimumOp;
596 };
597 template <>
598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
599   using Type = LLVM::MaxNumOp;
600 };
601 template <>
602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
603   using Type = LLVM::MinNumOp;
604 };
605 } // namespace
606 
607 template <class LLVMRedIntrinOp>
608 static Value createFPReductionComparisonOpLowering(
609     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
610     Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
611   Value result =
612       rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
613 
614   if (accumulator) {
615     result =
616         rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
617             loc, result, accumulator);
618   }
619 
620   return result;
621 }
622 
623 /// Reduction neutral classes for overloading
624 class MaskNeutralFMaximum {};
625 class MaskNeutralFMinimum {};
626 
627 /// Get the mask neutral floating point maximum value
628 static llvm::APFloat
629 getMaskNeutralValue(MaskNeutralFMaximum,
630                     const llvm::fltSemantics &floatSemantics) {
631   return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
632 }
633 /// Get the mask neutral floating point minimum value
634 static llvm::APFloat
635 getMaskNeutralValue(MaskNeutralFMinimum,
636                     const llvm::fltSemantics &floatSemantics) {
637   return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
638 }
639 
640 /// Create the mask neutral floating point MLIR vector constant
641 template <typename MaskNeutral>
642 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
643                                     Location loc, Type llvmType,
644                                     Type vectorType) {
645   const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
646   auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
647   auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
648   return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
649 }
650 
651 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
652 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
653 /// `fmaximum`/`fminimum`.
654 /// More information: https://github.com/llvm/llvm-project/issues/64940
655 template <class LLVMRedIntrinOp, class MaskNeutral>
656 static Value
657 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
658                                 Location loc, Type llvmType,
659                                 Value vectorOperand, Value accumulator,
660                                 Value mask, LLVM::FastmathFlagsAttr fmf) {
661   const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
662       rewriter, loc, llvmType, vectorOperand.getType());
663   const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
664       loc, mask, vectorOperand, vectorMaskNeutral);
665   return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
666       rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
667 }
668 
669 template <class LLVMRedIntrinOp, class ReductionNeutral>
670 static Value
671 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
672                              Type llvmType, Value vectorOperand,
673                              Value accumulator, LLVM::FastmathFlagsAttr fmf) {
674   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675                                                          llvmType, accumulator);
676   return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
677                                           /*startValue=*/accumulator,
678                                           vectorOperand, fmf);
679 }
680 
681 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
682 /// that requires a start value. This start value format spans across fp
683 /// reductions without mask and all the masked reduction intrinsics.
684 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
685 static Value
686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
687                                        Location loc, Type llvmType,
688                                        Value vectorOperand, Value accumulator) {
689   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
690                                                          llvmType, accumulator);
691   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
692                                             /*startValue=*/accumulator,
693                                             vectorOperand);
694 }
695 
696 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
697 static Value lowerPredicatedReductionWithStartValue(
698     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
699     Value vectorOperand, Value accumulator, Value mask) {
700   accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701                                                          llvmType, accumulator);
702   Value vectorLength =
703       createVectorLengthValue(rewriter, loc, vectorOperand.getType());
704   return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
705                                             /*startValue=*/accumulator,
706                                             vectorOperand, mask, vectorLength);
707 }
708 
709 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
710           class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
711 static Value lowerPredicatedReductionWithStartValue(
712     ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
713     Value vectorOperand, Value accumulator, Value mask) {
714   if (llvmType.isIntOrIndex())
715     return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
716                                                   IntReductionNeutral>(
717         rewriter, loc, llvmType, vectorOperand, accumulator, mask);
718 
719   // FP dispatch.
720   return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
721                                                 FPReductionNeutral>(
722       rewriter, loc, llvmType, vectorOperand, accumulator, mask);
723 }
724 
725 /// Conversion pattern for all vector reductions.
726 class VectorReductionOpConversion
727     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
728 public:
729   explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
730                                        bool reassociateFPRed)
731       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
732         reassociateFPReductions(reassociateFPRed) {}
733 
734   LogicalResult
735   matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
736                   ConversionPatternRewriter &rewriter) const override {
737     auto kind = reductionOp.getKind();
738     Type eltType = reductionOp.getDest().getType();
739     Type llvmType = typeConverter->convertType(eltType);
740     Value operand = adaptor.getVector();
741     Value acc = adaptor.getAcc();
742     Location loc = reductionOp.getLoc();
743 
744     if (eltType.isIntOrIndex()) {
745       // Integer reductions: add/mul/min/max/and/or/xor.
746       Value result;
747       switch (kind) {
748       case vector::CombiningKind::ADD:
749         result =
750             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
751                                                        LLVM::AddOp>(
752                 rewriter, loc, llvmType, operand, acc);
753         break;
754       case vector::CombiningKind::MUL:
755         result =
756             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
757                                                        LLVM::MulOp>(
758                 rewriter, loc, llvmType, operand, acc);
759         break;
760       case vector::CombiningKind::MINUI:
761         result = createIntegerReductionComparisonOpLowering<
762             LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
763                                       LLVM::ICmpPredicate::ule);
764         break;
765       case vector::CombiningKind::MINSI:
766         result = createIntegerReductionComparisonOpLowering<
767             LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
768                                       LLVM::ICmpPredicate::sle);
769         break;
770       case vector::CombiningKind::MAXUI:
771         result = createIntegerReductionComparisonOpLowering<
772             LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
773                                       LLVM::ICmpPredicate::uge);
774         break;
775       case vector::CombiningKind::MAXSI:
776         result = createIntegerReductionComparisonOpLowering<
777             LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
778                                       LLVM::ICmpPredicate::sge);
779         break;
780       case vector::CombiningKind::AND:
781         result =
782             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
783                                                        LLVM::AndOp>(
784                 rewriter, loc, llvmType, operand, acc);
785         break;
786       case vector::CombiningKind::OR:
787         result =
788             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
789                                                        LLVM::OrOp>(
790                 rewriter, loc, llvmType, operand, acc);
791         break;
792       case vector::CombiningKind::XOR:
793         result =
794             createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
795                                                        LLVM::XOrOp>(
796                 rewriter, loc, llvmType, operand, acc);
797         break;
798       default:
799         return failure();
800       }
801       rewriter.replaceOp(reductionOp, result);
802 
803       return success();
804     }
805 
806     if (!isa<FloatType>(eltType))
807       return failure();
808 
809     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
810     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
811         reductionOp.getContext(),
812         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
813     fmf = LLVM::FastmathFlagsAttr::get(
814         reductionOp.getContext(),
815         fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
816                                                   : LLVM::FastmathFlags::none));
817 
818     // Floating-point reductions: add/mul/min/max
819     Value result;
820     if (kind == vector::CombiningKind::ADD) {
821       result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
822                                             ReductionNeutralZero>(
823           rewriter, loc, llvmType, operand, acc, fmf);
824     } else if (kind == vector::CombiningKind::MUL) {
825       result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
826                                             ReductionNeutralFPOne>(
827           rewriter, loc, llvmType, operand, acc, fmf);
828     } else if (kind == vector::CombiningKind::MINIMUMF) {
829       result =
830           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
831               rewriter, loc, llvmType, operand, acc, fmf);
832     } else if (kind == vector::CombiningKind::MAXIMUMF) {
833       result =
834           createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
835               rewriter, loc, llvmType, operand, acc, fmf);
836     } else if (kind == vector::CombiningKind::MINNUMF) {
837       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
838           rewriter, loc, llvmType, operand, acc, fmf);
839     } else if (kind == vector::CombiningKind::MAXNUMF) {
840       result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
841           rewriter, loc, llvmType, operand, acc, fmf);
842     } else
843       return failure();
844 
845     rewriter.replaceOp(reductionOp, result);
846     return success();
847   }
848 
849 private:
850   const bool reassociateFPReductions;
851 };
852 
853 /// Base class to convert a `vector.mask` operation while matching traits
854 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
855 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
856 /// method performs a second match against the maskable operation `MaskedOp`.
857 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
858 /// implemented by the concrete conversion classes. This method can match
859 /// against specific traits of the `vector.mask` and the maskable operation. It
860 /// must replace the `vector.mask` operation.
861 template <class MaskedOp>
862 class VectorMaskOpConversionBase
863     : public ConvertOpToLLVMPattern<vector::MaskOp> {
864 public:
865   using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
866 
867   LogicalResult
868   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
869                   ConversionPatternRewriter &rewriter) const final {
870     // Match against the maskable operation kind.
871     auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
872     if (!maskedOp)
873       return failure();
874     return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
875   }
876 
877 protected:
878   virtual LogicalResult
879   matchAndRewriteMaskableOp(vector::MaskOp maskOp,
880                             vector::MaskableOpInterface maskableOp,
881                             ConversionPatternRewriter &rewriter) const = 0;
882 };
883 
884 class MaskedReductionOpConversion
885     : public VectorMaskOpConversionBase<vector::ReductionOp> {
886 
887 public:
888   using VectorMaskOpConversionBase<
889       vector::ReductionOp>::VectorMaskOpConversionBase;
890 
891   LogicalResult matchAndRewriteMaskableOp(
892       vector::MaskOp maskOp, MaskableOpInterface maskableOp,
893       ConversionPatternRewriter &rewriter) const override {
894     auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
895     auto kind = reductionOp.getKind();
896     Type eltType = reductionOp.getDest().getType();
897     Type llvmType = typeConverter->convertType(eltType);
898     Value operand = reductionOp.getVector();
899     Value acc = reductionOp.getAcc();
900     Location loc = reductionOp.getLoc();
901 
902     arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
903     LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
904         reductionOp.getContext(),
905         convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
906 
907     Value result;
908     switch (kind) {
909     case vector::CombiningKind::ADD:
910       result = lowerPredicatedReductionWithStartValue<
911           LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
912           ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
913                                 maskOp.getMask());
914       break;
915     case vector::CombiningKind::MUL:
916       result = lowerPredicatedReductionWithStartValue<
917           LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
918           ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
919                                  maskOp.getMask());
920       break;
921     case vector::CombiningKind::MINUI:
922       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
923                                                       ReductionNeutralUIntMax>(
924           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
925       break;
926     case vector::CombiningKind::MINSI:
927       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
928                                                       ReductionNeutralSIntMax>(
929           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
930       break;
931     case vector::CombiningKind::MAXUI:
932       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
933                                                       ReductionNeutralUIntMin>(
934           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
935       break;
936     case vector::CombiningKind::MAXSI:
937       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
938                                                       ReductionNeutralSIntMin>(
939           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
940       break;
941     case vector::CombiningKind::AND:
942       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
943                                                       ReductionNeutralAllOnes>(
944           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
945       break;
946     case vector::CombiningKind::OR:
947       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
948                                                       ReductionNeutralZero>(
949           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
950       break;
951     case vector::CombiningKind::XOR:
952       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
953                                                       ReductionNeutralZero>(
954           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
955       break;
956     case vector::CombiningKind::MINNUMF:
957       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
958                                                       ReductionNeutralFPMax>(
959           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
960       break;
961     case vector::CombiningKind::MAXNUMF:
962       result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
963                                                       ReductionNeutralFPMin>(
964           rewriter, loc, llvmType, operand, acc, maskOp.getMask());
965       break;
966     case CombiningKind::MAXIMUMF:
967       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
968                                                MaskNeutralFMaximum>(
969           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
970       break;
971     case CombiningKind::MINIMUMF:
972       result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
973                                                MaskNeutralFMinimum>(
974           rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
975       break;
976     }
977 
978     // Replace `vector.mask` operation altogether.
979     rewriter.replaceOp(maskOp, result);
980     return success();
981   }
982 };
983 
984 class VectorShuffleOpConversion
985     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
986 public:
987   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
988 
989   LogicalResult
990   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
991                   ConversionPatternRewriter &rewriter) const override {
992     auto loc = shuffleOp->getLoc();
993     auto v1Type = shuffleOp.getV1VectorType();
994     auto v2Type = shuffleOp.getV2VectorType();
995     auto vectorType = shuffleOp.getResultVectorType();
996     Type llvmType = typeConverter->convertType(vectorType);
997     ArrayRef<int64_t> mask = shuffleOp.getMask();
998 
999     // Bail if result type cannot be lowered.
1000     if (!llvmType)
1001       return failure();
1002 
1003     // Get rank and dimension sizes.
1004     int64_t rank = vectorType.getRank();
1005 #ifndef NDEBUG
1006     bool wellFormed0DCase =
1007         v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1008     bool wellFormedNDCase =
1009         v1Type.getRank() == rank && v2Type.getRank() == rank;
1010     assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1011 #endif
1012 
1013     // For rank 0 and 1, where both operands have *exactly* the same vector
1014     // type, there is direct shuffle support in LLVM. Use it!
1015     if (rank <= 1 && v1Type == v2Type) {
1016       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1017           loc, adaptor.getV1(), adaptor.getV2(),
1018           llvm::to_vector_of<int32_t>(mask));
1019       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1020       return success();
1021     }
1022 
1023     // For all other cases, insert the individual values individually.
1024     int64_t v1Dim = v1Type.getDimSize(0);
1025     Type eltType;
1026     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1027       eltType = arrayType.getElementType();
1028     else
1029       eltType = cast<VectorType>(llvmType).getElementType();
1030     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1031     int64_t insPos = 0;
1032     for (int64_t extPos : mask) {
1033       Value value = adaptor.getV1();
1034       if (extPos >= v1Dim) {
1035         extPos -= v1Dim;
1036         value = adaptor.getV2();
1037       }
1038       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1039                                  eltType, rank, extPos);
1040       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1041                          llvmType, rank, insPos++);
1042     }
1043     rewriter.replaceOp(shuffleOp, insert);
1044     return success();
1045   }
1046 };
1047 
1048 class VectorExtractElementOpConversion
1049     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1050 public:
1051   using ConvertOpToLLVMPattern<
1052       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1053 
1054   LogicalResult
1055   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1056                   ConversionPatternRewriter &rewriter) const override {
1057     auto vectorType = extractEltOp.getSourceVectorType();
1058     auto llvmType = typeConverter->convertType(vectorType.getElementType());
1059 
1060     // Bail if result type cannot be lowered.
1061     if (!llvmType)
1062       return failure();
1063 
1064     if (vectorType.getRank() == 0) {
1065       Location loc = extractEltOp.getLoc();
1066       auto idxType = rewriter.getIndexType();
1067       auto zero = rewriter.create<LLVM::ConstantOp>(
1068           loc, typeConverter->convertType(idxType),
1069           rewriter.getIntegerAttr(idxType, 0));
1070       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1071           extractEltOp, llvmType, adaptor.getVector(), zero);
1072       return success();
1073     }
1074 
1075     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1076         extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1077     return success();
1078   }
1079 };
1080 
1081 class VectorExtractOpConversion
1082     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1083 public:
1084   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1085 
1086   LogicalResult
1087   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1088                   ConversionPatternRewriter &rewriter) const override {
1089     auto loc = extractOp->getLoc();
1090     auto resultType = extractOp.getResult().getType();
1091     auto llvmResultType = typeConverter->convertType(resultType);
1092     // Bail if result type cannot be lowered.
1093     if (!llvmResultType)
1094       return failure();
1095 
1096     SmallVector<OpFoldResult> positionVec = getMixedValues(
1097         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1098 
1099     // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1100     // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1101     // from a N-d vector extract to a nested aggregate vector extract in two
1102     // steps:
1103     //  - Extract a member from the nested aggregate. The result can be
1104     //    a lower rank nested aggregate or a vector (1-D). This is done using
1105     //    `llvm.extractvalue`.
1106     //  - Extract a scalar out of the vector if needed. This is done using
1107     //   `llvm.extractelement`.
1108 
1109     // Determine if we need to extract a member out of the aggregate. We
1110     // always need to extract a member if the input rank >= 2.
1111     bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1112     // Determine if we need to extract a scalar as the result. We extract
1113     // a scalar if the extract is full rank, i.e., the number of indices is
1114     // equal to source vector rank.
1115     bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1116                           extractOp.getSourceVectorType().getRank();
1117 
1118     // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1119     // need to add a position for this change.
1120     if (extractOp.getSourceVectorType().getRank() == 0) {
1121       Type idxType = typeConverter->convertType(rewriter.getIndexType());
1122       positionVec.push_back(rewriter.getZeroAttr(idxType));
1123     }
1124 
1125     Value extracted = adaptor.getVector();
1126     if (extractsAggregate) {
1127       ArrayRef<OpFoldResult> position(positionVec);
1128       if (extractsScalar) {
1129         // If we are extracting a scalar from the extracted member, we drop
1130         // the last index, which will be used to extract the scalar out of the
1131         // vector.
1132         position = position.drop_back();
1133       }
1134       // llvm.extractvalue does not support dynamic dimensions.
1135       if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1136         return failure();
1137       }
1138       extracted = rewriter.create<LLVM::ExtractValueOp>(
1139           loc, extracted, getAsIntegers(position));
1140     }
1141 
1142     if (extractsScalar) {
1143       extracted = rewriter.create<LLVM::ExtractElementOp>(
1144           loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
1145     }
1146 
1147     rewriter.replaceOp(extractOp, extracted);
1148     return success();
1149   }
1150 };
1151 
1152 /// Conversion pattern that turns a vector.fma on a 1-D vector
1153 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1154 /// This does not match vectors of n >= 2 rank.
1155 ///
1156 /// Example:
1157 /// ```
1158 ///  vector.fma %a, %a, %a : vector<8xf32>
1159 /// ```
1160 /// is converted to:
1161 /// ```
1162 ///  llvm.intr.fmuladd %va, %va, %va:
1163 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1164 ///    -> !llvm."<8 x f32>">
1165 /// ```
1166 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1167 public:
1168   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1169 
1170   LogicalResult
1171   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1172                   ConversionPatternRewriter &rewriter) const override {
1173     VectorType vType = fmaOp.getVectorType();
1174     if (vType.getRank() > 1)
1175       return failure();
1176 
1177     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1178         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1179     return success();
1180   }
1181 };
1182 
1183 class VectorInsertElementOpConversion
1184     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1185 public:
1186   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
1187 
1188   LogicalResult
1189   matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1190                   ConversionPatternRewriter &rewriter) const override {
1191     auto vectorType = insertEltOp.getDestVectorType();
1192     auto llvmType = typeConverter->convertType(vectorType);
1193 
1194     // Bail if result type cannot be lowered.
1195     if (!llvmType)
1196       return failure();
1197 
1198     if (vectorType.getRank() == 0) {
1199       Location loc = insertEltOp.getLoc();
1200       auto idxType = rewriter.getIndexType();
1201       auto zero = rewriter.create<LLVM::ConstantOp>(
1202           loc, typeConverter->convertType(idxType),
1203           rewriter.getIntegerAttr(idxType, 0));
1204       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1205           insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1206       return success();
1207     }
1208 
1209     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1210         insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1211         adaptor.getPosition());
1212     return success();
1213   }
1214 };
1215 
1216 class VectorInsertOpConversion
1217     : public ConvertOpToLLVMPattern<vector::InsertOp> {
1218 public:
1219   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1220 
1221   LogicalResult
1222   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1223                   ConversionPatternRewriter &rewriter) const override {
1224     auto loc = insertOp->getLoc();
1225     auto sourceType = insertOp.getSourceType();
1226     auto destVectorType = insertOp.getDestVectorType();
1227     auto llvmResultType = typeConverter->convertType(destVectorType);
1228     // Bail if result type cannot be lowered.
1229     if (!llvmResultType)
1230       return failure();
1231 
1232     SmallVector<OpFoldResult> positionVec = getMixedValues(
1233         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1234 
1235     // Overwrite entire vector with value. Should be handled by folder, but
1236     // just to be safe.
1237     ArrayRef<OpFoldResult> position(positionVec);
1238     if (position.empty()) {
1239       rewriter.replaceOp(insertOp, adaptor.getSource());
1240       return success();
1241     }
1242 
1243     // One-shot insertion of a vector into an array (only requires insertvalue).
1244     if (isa<VectorType>(sourceType)) {
1245       if (insertOp.hasDynamicPosition())
1246         return failure();
1247 
1248       Value inserted = rewriter.create<LLVM::InsertValueOp>(
1249           loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1250       rewriter.replaceOp(insertOp, inserted);
1251       return success();
1252     }
1253 
1254     // Potential extraction of 1-D vector from array.
1255     Value extracted = adaptor.getDest();
1256     auto oneDVectorType = destVectorType;
1257     if (position.size() > 1) {
1258       if (insertOp.hasDynamicPosition())
1259         return failure();
1260 
1261       oneDVectorType = reducedVectorTypeBack(destVectorType);
1262       extracted = rewriter.create<LLVM::ExtractValueOp>(
1263           loc, extracted, getAsIntegers(position.drop_back()));
1264     }
1265 
1266     // Insertion of an element into a 1-D LLVM vector.
1267     Value inserted = rewriter.create<LLVM::InsertElementOp>(
1268         loc, typeConverter->convertType(oneDVectorType), extracted,
1269         adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1270 
1271     // Potential insertion of resulting 1-D vector into array.
1272     if (position.size() > 1) {
1273       if (insertOp.hasDynamicPosition())
1274         return failure();
1275 
1276       inserted = rewriter.create<LLVM::InsertValueOp>(
1277           loc, adaptor.getDest(), inserted,
1278           getAsIntegers(position.drop_back()));
1279     }
1280 
1281     rewriter.replaceOp(insertOp, inserted);
1282     return success();
1283   }
1284 };
1285 
1286 /// Lower vector.scalable.insert ops to LLVM vector.insert
1287 struct VectorScalableInsertOpLowering
1288     : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1289   using ConvertOpToLLVMPattern<
1290       vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1291 
1292   LogicalResult
1293   matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1294                   ConversionPatternRewriter &rewriter) const override {
1295     rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1296         insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1297     return success();
1298   }
1299 };
1300 
1301 /// Lower vector.scalable.extract ops to LLVM vector.extract
1302 struct VectorScalableExtractOpLowering
1303     : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1304   using ConvertOpToLLVMPattern<
1305       vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1306 
1307   LogicalResult
1308   matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1309                   ConversionPatternRewriter &rewriter) const override {
1310     rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1311         extOp, typeConverter->convertType(extOp.getResultVectorType()),
1312         adaptor.getSource(), adaptor.getPos());
1313     return success();
1314   }
1315 };
1316 
1317 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1318 ///
1319 /// Example:
1320 /// ```
1321 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
1322 /// ```
1323 /// is rewritten into:
1324 /// ```
1325 ///  %r = splat %f0: vector<2x4xf32>
1326 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
1327 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1328 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1329 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1330 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1331 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1332 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1333 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1334 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1335 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1336 ///  // %r3 holds the final value.
1337 /// ```
1338 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1339 public:
1340   using OpRewritePattern<FMAOp>::OpRewritePattern;
1341 
1342   void initialize() {
1343     // This pattern recursively unpacks one dimension at a time. The recursion
1344     // bounded as the rank is strictly decreasing.
1345     setHasBoundedRewriteRecursion();
1346   }
1347 
1348   LogicalResult matchAndRewrite(FMAOp op,
1349                                 PatternRewriter &rewriter) const override {
1350     auto vType = op.getVectorType();
1351     if (vType.getRank() < 2)
1352       return failure();
1353 
1354     auto loc = op.getLoc();
1355     auto elemType = vType.getElementType();
1356     Value zero = rewriter.create<arith::ConstantOp>(
1357         loc, elemType, rewriter.getZeroAttr(elemType));
1358     Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1359     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1360       Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1361       Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1362       Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1363       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1364       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1365     }
1366     rewriter.replaceOp(op, desc);
1367     return success();
1368   }
1369 };
1370 
1371 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1372 /// static layout.
1373 static std::optional<SmallVector<int64_t, 4>>
1374 computeContiguousStrides(MemRefType memRefType) {
1375   int64_t offset;
1376   SmallVector<int64_t, 4> strides;
1377   if (failed(memRefType.getStridesAndOffset(strides, offset)))
1378     return std::nullopt;
1379   if (!strides.empty() && strides.back() != 1)
1380     return std::nullopt;
1381   // If no layout or identity layout, this is contiguous by definition.
1382   if (memRefType.getLayout().isIdentity())
1383     return strides;
1384 
1385   // Otherwise, we must determine contiguity form shapes. This can only ever
1386   // work in static cases because MemRefType is underspecified to represent
1387   // contiguous dynamic shapes in other ways than with just empty/identity
1388   // layout.
1389   auto sizes = memRefType.getShape();
1390   for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1391     if (ShapedType::isDynamic(sizes[index + 1]) ||
1392         ShapedType::isDynamic(strides[index]) ||
1393         ShapedType::isDynamic(strides[index + 1]))
1394       return std::nullopt;
1395     if (strides[index] != strides[index + 1] * sizes[index + 1])
1396       return std::nullopt;
1397   }
1398   return strides;
1399 }
1400 
1401 class VectorTypeCastOpConversion
1402     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1403 public:
1404   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1405 
1406   LogicalResult
1407   matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1408                   ConversionPatternRewriter &rewriter) const override {
1409     auto loc = castOp->getLoc();
1410     MemRefType sourceMemRefType =
1411         cast<MemRefType>(castOp.getOperand().getType());
1412     MemRefType targetMemRefType = castOp.getType();
1413 
1414     // Only static shape casts supported atm.
1415     if (!sourceMemRefType.hasStaticShape() ||
1416         !targetMemRefType.hasStaticShape())
1417       return failure();
1418 
1419     auto llvmSourceDescriptorTy =
1420         dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1421     if (!llvmSourceDescriptorTy)
1422       return failure();
1423     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1424 
1425     auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1426         typeConverter->convertType(targetMemRefType));
1427     if (!llvmTargetDescriptorTy)
1428       return failure();
1429 
1430     // Only contiguous source buffers supported atm.
1431     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1432     if (!sourceStrides)
1433       return failure();
1434     auto targetStrides = computeContiguousStrides(targetMemRefType);
1435     if (!targetStrides)
1436       return failure();
1437     // Only support static strides for now, regardless of contiguity.
1438     if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1439       return failure();
1440 
1441     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1442 
1443     // Create descriptor.
1444     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1445     // Set allocated ptr.
1446     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1447     desc.setAllocatedPtr(rewriter, loc, allocated);
1448 
1449     // Set aligned ptr.
1450     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1451     desc.setAlignedPtr(rewriter, loc, ptr);
1452     // Fill offset 0.
1453     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1454     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1455     desc.setOffset(rewriter, loc, zero);
1456 
1457     // Fill size and stride descriptors in memref.
1458     for (const auto &indexedSize :
1459          llvm::enumerate(targetMemRefType.getShape())) {
1460       int64_t index = indexedSize.index();
1461       auto sizeAttr =
1462           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1463       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1464       desc.setSize(rewriter, loc, index, size);
1465       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1466                                                 (*targetStrides)[index]);
1467       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1468       desc.setStride(rewriter, loc, index, stride);
1469     }
1470 
1471     rewriter.replaceOp(castOp, {desc});
1472     return success();
1473   }
1474 };
1475 
1476 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1477 /// Non-scalable versions of this operation are handled in Vector Transforms.
1478 class VectorCreateMaskOpConversion
1479     : public OpConversionPattern<vector::CreateMaskOp> {
1480 public:
1481   explicit VectorCreateMaskOpConversion(MLIRContext *context,
1482                                         bool enableIndexOpt)
1483       : OpConversionPattern<vector::CreateMaskOp>(context),
1484         force32BitVectorIndices(enableIndexOpt) {}
1485 
1486   LogicalResult
1487   matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1488                   ConversionPatternRewriter &rewriter) const override {
1489     auto dstType = op.getType();
1490     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1491       return failure();
1492     IntegerType idxType =
1493         force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1494     auto loc = op->getLoc();
1495     Value indices = rewriter.create<LLVM::StepVectorOp>(
1496         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1497                                  /*isScalable=*/true));
1498     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1499                                                  adaptor.getOperands()[0]);
1500     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1501     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1502                                                 indices, bounds);
1503     rewriter.replaceOp(op, comp);
1504     return success();
1505   }
1506 
1507 private:
1508   const bool force32BitVectorIndices;
1509 };
1510 
1511 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1512 public:
1513   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1514 
1515   // Lowering implementation that relies on a small runtime support library,
1516   // which only needs to provide a few printing methods (single value for all
1517   // data types, opening/closing bracket, comma, newline). The lowering splits
1518   // the vector into elementary printing operations. The advantage of this
1519   // approach is that the library can remain unaware of all low-level
1520   // implementation details of vectors while still supporting output of any
1521   // shaped and dimensioned vector.
1522   //
1523   // Note: This lowering only handles scalars, n-D vectors are broken into
1524   // printing scalars in loops in VectorToSCF.
1525   //
1526   // TODO: rely solely on libc in future? something else?
1527   //
1528   LogicalResult
1529   matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1530                   ConversionPatternRewriter &rewriter) const override {
1531     auto parent = printOp->getParentOfType<ModuleOp>();
1532     if (!parent)
1533       return failure();
1534 
1535     auto loc = printOp->getLoc();
1536 
1537     if (auto value = adaptor.getSource()) {
1538       Type printType = printOp.getPrintType();
1539       if (isa<VectorType>(printType)) {
1540         // Vectors should be broken into elementary print ops in VectorToSCF.
1541         return failure();
1542       }
1543       if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1544         return failure();
1545     }
1546 
1547     auto punct = printOp.getPunctuation();
1548     if (auto stringLiteral = printOp.getStringLiteral()) {
1549       auto createResult =
1550           LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1551                                    *stringLiteral, *getTypeConverter(),
1552                                    /*addNewline=*/false);
1553       if (createResult.failed())
1554         return failure();
1555 
1556     } else if (punct != PrintPunctuation::NoPunctuation) {
1557       FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1558         switch (punct) {
1559         case PrintPunctuation::Close:
1560           return LLVM::lookupOrCreatePrintCloseFn(parent);
1561         case PrintPunctuation::Open:
1562           return LLVM::lookupOrCreatePrintOpenFn(parent);
1563         case PrintPunctuation::Comma:
1564           return LLVM::lookupOrCreatePrintCommaFn(parent);
1565         case PrintPunctuation::NewLine:
1566           return LLVM::lookupOrCreatePrintNewlineFn(parent);
1567         default:
1568           llvm_unreachable("unexpected punctuation");
1569         }
1570       }();
1571       if (failed(op))
1572         return failure();
1573       emitCall(rewriter, printOp->getLoc(), op.value());
1574     }
1575 
1576     rewriter.eraseOp(printOp);
1577     return success();
1578   }
1579 
1580 private:
1581   enum class PrintConversion {
1582     // clang-format off
1583     None,
1584     ZeroExt64,
1585     SignExt64,
1586     Bitcast16
1587     // clang-format on
1588   };
1589 
1590   LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1591                                 ModuleOp parent, Location loc, Type printType,
1592                                 Value value) const {
1593     if (typeConverter->convertType(printType) == nullptr)
1594       return failure();
1595 
1596     // Make sure element type has runtime support.
1597     PrintConversion conversion = PrintConversion::None;
1598     FailureOr<Operation *> printer;
1599     if (printType.isF32()) {
1600       printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1601     } else if (printType.isF64()) {
1602       printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1603     } else if (printType.isF16()) {
1604       conversion = PrintConversion::Bitcast16; // bits!
1605       printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1606     } else if (printType.isBF16()) {
1607       conversion = PrintConversion::Bitcast16; // bits!
1608       printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1609     } else if (printType.isIndex()) {
1610       printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1611     } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1612       // Integers need a zero or sign extension on the operand
1613       // (depending on the source type) as well as a signed or
1614       // unsigned print method. Up to 64-bit is supported.
1615       unsigned width = intTy.getWidth();
1616       if (intTy.isUnsigned()) {
1617         if (width <= 64) {
1618           if (width < 64)
1619             conversion = PrintConversion::ZeroExt64;
1620           printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1621         } else {
1622           return failure();
1623         }
1624       } else {
1625         assert(intTy.isSignless() || intTy.isSigned());
1626         if (width <= 64) {
1627           // Note that we *always* zero extend booleans (1-bit integers),
1628           // so that true/false is printed as 1/0 rather than -1/0.
1629           if (width == 1)
1630             conversion = PrintConversion::ZeroExt64;
1631           else if (width < 64)
1632             conversion = PrintConversion::SignExt64;
1633           printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1634         } else {
1635           return failure();
1636         }
1637       }
1638     } else {
1639       return failure();
1640     }
1641     if (failed(printer))
1642       return failure();
1643 
1644     switch (conversion) {
1645     case PrintConversion::ZeroExt64:
1646       value = rewriter.create<arith::ExtUIOp>(
1647           loc, IntegerType::get(rewriter.getContext(), 64), value);
1648       break;
1649     case PrintConversion::SignExt64:
1650       value = rewriter.create<arith::ExtSIOp>(
1651           loc, IntegerType::get(rewriter.getContext(), 64), value);
1652       break;
1653     case PrintConversion::Bitcast16:
1654       value = rewriter.create<LLVM::BitcastOp>(
1655           loc, IntegerType::get(rewriter.getContext(), 16), value);
1656       break;
1657     case PrintConversion::None:
1658       break;
1659     }
1660     emitCall(rewriter, loc, printer.value(), value);
1661     return success();
1662   }
1663 
1664   // Helper to emit a call.
1665   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1666                        Operation *ref, ValueRange params = ValueRange()) {
1667     rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1668                                   params);
1669   }
1670 };
1671 
1672 /// The Splat operation is lowered to an insertelement + a shufflevector
1673 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1674 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1675   using ConvertOpToLLVMPattern<vector::SplatOp>::ConvertOpToLLVMPattern;
1676 
1677   LogicalResult
1678   matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1679                   ConversionPatternRewriter &rewriter) const override {
1680     VectorType resultType = cast<VectorType>(splatOp.getType());
1681     if (resultType.getRank() > 1)
1682       return failure();
1683 
1684     // First insert it into an undef vector so we can shuffle it.
1685     auto vectorType = typeConverter->convertType(splatOp.getType());
1686     Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1687     auto zero = rewriter.create<LLVM::ConstantOp>(
1688         splatOp.getLoc(),
1689         typeConverter->convertType(rewriter.getIntegerType(32)),
1690         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1691 
1692     // For 0-d vector, we simply do `insertelement`.
1693     if (resultType.getRank() == 0) {
1694       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1695           splatOp, vectorType, undef, adaptor.getInput(), zero);
1696       return success();
1697     }
1698 
1699     // For 1-d vector, we additionally do a `vectorshuffle`.
1700     auto v = rewriter.create<LLVM::InsertElementOp>(
1701         splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1702 
1703     int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1704     SmallVector<int32_t> zeroValues(width, 0);
1705 
1706     // Shuffle the value across the desired number of elements.
1707     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1708                                                        zeroValues);
1709     return success();
1710   }
1711 };
1712 
1713 /// The Splat operation is lowered to an insertelement + a shufflevector
1714 /// operation. Splat to only 2+-d vector result types are lowered by the
1715 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1716 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1717   using ConvertOpToLLVMPattern<SplatOp>::ConvertOpToLLVMPattern;
1718 
1719   LogicalResult
1720   matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1721                   ConversionPatternRewriter &rewriter) const override {
1722     VectorType resultType = splatOp.getType();
1723     if (resultType.getRank() <= 1)
1724       return failure();
1725 
1726     // First insert it into an undef vector so we can shuffle it.
1727     auto loc = splatOp.getLoc();
1728     auto vectorTypeInfo =
1729         LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1730     auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1731     auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1732     if (!llvmNDVectorTy || !llvm1DVectorTy)
1733       return failure();
1734 
1735     // Construct returned value.
1736     Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1737 
1738     // Construct a 1-D vector with the splatted value that we insert in all the
1739     // places within the returned descriptor.
1740     Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1741     auto zero = rewriter.create<LLVM::ConstantOp>(
1742         loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1743         rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1744     Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1745                                                      adaptor.getInput(), zero);
1746 
1747     // Shuffle the value across the desired number of elements.
1748     int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1749     SmallVector<int32_t> zeroValues(width, 0);
1750     v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1751 
1752     // Iterate of linear index, convert to coords space and insert splatted 1-D
1753     // vector in each position.
1754     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1755       desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1756     });
1757     rewriter.replaceOp(splatOp, desc);
1758     return success();
1759   }
1760 };
1761 
1762 /// Conversion pattern for a `vector.interleave`.
1763 /// This supports fixed-sized vectors and scalable vectors.
1764 struct VectorInterleaveOpLowering
1765     : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1766   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1767 
1768   LogicalResult
1769   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1770                   ConversionPatternRewriter &rewriter) const override {
1771     VectorType resultType = interleaveOp.getResultVectorType();
1772     // n-D interleaves should have been lowered already.
1773     if (resultType.getRank() != 1)
1774       return rewriter.notifyMatchFailure(interleaveOp,
1775                                          "InterleaveOp not rank 1");
1776     // If the result is rank 1, then this directly maps to LLVM.
1777     if (resultType.isScalable()) {
1778       rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1779           interleaveOp, typeConverter->convertType(resultType),
1780           adaptor.getLhs(), adaptor.getRhs());
1781       return success();
1782     }
1783     // Lower fixed-size interleaves to a shufflevector. While the
1784     // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1785     // langref still recommends fixed-vectors use shufflevector, see:
1786     // https://llvm.org/docs/LangRef.html#id876.
1787     int64_t resultVectorSize = resultType.getNumElements();
1788     SmallVector<int32_t> interleaveShuffleMask;
1789     interleaveShuffleMask.reserve(resultVectorSize);
1790     for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1791       interleaveShuffleMask.push_back(i);
1792       interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1793     }
1794     rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1795         interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1796         interleaveShuffleMask);
1797     return success();
1798   }
1799 };
1800 
1801 /// Conversion pattern for a `vector.deinterleave`.
1802 /// This supports fixed-sized vectors and scalable vectors.
1803 struct VectorDeinterleaveOpLowering
1804     : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1805   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1806 
1807   LogicalResult
1808   matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1809                   ConversionPatternRewriter &rewriter) const override {
1810     VectorType resultType = deinterleaveOp.getResultVectorType();
1811     VectorType sourceType = deinterleaveOp.getSourceVectorType();
1812     auto loc = deinterleaveOp.getLoc();
1813 
1814     // Note: n-D deinterleave operations should be lowered to the 1-D before
1815     // converting to LLVM.
1816     if (resultType.getRank() != 1)
1817       return rewriter.notifyMatchFailure(deinterleaveOp,
1818                                          "DeinterleaveOp not rank 1");
1819 
1820     if (resultType.isScalable()) {
1821       auto llvmTypeConverter = this->getTypeConverter();
1822       auto deinterleaveResults = deinterleaveOp.getResultTypes();
1823       auto packedOpResults =
1824           llvmTypeConverter->packOperationResults(deinterleaveResults);
1825       auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1826           loc, packedOpResults, adaptor.getSource());
1827 
1828       auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1829           loc, intrinsic->getResult(0), 0);
1830       auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1831           loc, intrinsic->getResult(0), 1);
1832 
1833       rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1834       return success();
1835     }
1836     // Lower fixed-size deinterleave to two shufflevectors. While the
1837     // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1838     // langref still recommends fixed-vectors use shufflevector, see:
1839     // https://llvm.org/docs/LangRef.html#id889.
1840     int64_t resultVectorSize = resultType.getNumElements();
1841     SmallVector<int32_t> evenShuffleMask;
1842     SmallVector<int32_t> oddShuffleMask;
1843 
1844     evenShuffleMask.reserve(resultVectorSize);
1845     oddShuffleMask.reserve(resultVectorSize);
1846 
1847     for (int i = 0; i < sourceType.getNumElements(); ++i) {
1848       if (i % 2 == 0)
1849         evenShuffleMask.push_back(i);
1850       else
1851         oddShuffleMask.push_back(i);
1852     }
1853 
1854     auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1855     auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1856         loc, adaptor.getSource(), poison, evenShuffleMask);
1857     auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1858         loc, adaptor.getSource(), poison, oddShuffleMask);
1859 
1860     rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1861     return success();
1862   }
1863 };
1864 
1865 /// Conversion pattern for a `vector.from_elements`.
1866 struct VectorFromElementsLowering
1867     : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1868   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1869 
1870   LogicalResult
1871   matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1872                   ConversionPatternRewriter &rewriter) const override {
1873     Location loc = fromElementsOp.getLoc();
1874     VectorType vectorType = fromElementsOp.getType();
1875     // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1876     // Such ops should be handled in the same way as vector.insert.
1877     if (vectorType.getRank() > 1)
1878       return rewriter.notifyMatchFailure(fromElementsOp,
1879                                          "rank > 1 vectors are not supported");
1880     Type llvmType = typeConverter->convertType(vectorType);
1881     Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1882     for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
1883       result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
1884     rewriter.replaceOp(fromElementsOp, result);
1885     return success();
1886   }
1887 };
1888 
1889 /// Conversion pattern for vector.step.
1890 struct VectorScalableStepOpLowering
1891     : public ConvertOpToLLVMPattern<vector::StepOp> {
1892   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1893 
1894   LogicalResult
1895   matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1896                   ConversionPatternRewriter &rewriter) const override {
1897     auto resultType = cast<VectorType>(stepOp.getType());
1898     if (!resultType.isScalable()) {
1899       return failure();
1900     }
1901     Type llvmType = typeConverter->convertType(stepOp.getType());
1902     rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1903     return success();
1904   }
1905 };
1906 
1907 } // namespace
1908 
1909 void mlir::vector::populateVectorRankReducingFMAPattern(
1910     RewritePatternSet &patterns) {
1911   patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
1912 }
1913 
1914 /// Populate the given list with patterns that convert from Vector to LLVM.
1915 void mlir::populateVectorToLLVMConversionPatterns(
1916     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1917     bool reassociateFPReductions, bool force32BitVectorIndices) {
1918   // This function populates only ConversionPatterns, not RewritePatterns.
1919   MLIRContext *ctx = converter.getDialect()->getContext();
1920   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1921   patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
1922   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1923                VectorExtractElementOpConversion, VectorExtractOpConversion,
1924                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1925                VectorInsertOpConversion, VectorPrintOpConversion,
1926                VectorTypeCastOpConversion, VectorScaleOpConversion,
1927                VectorLoadStoreConversion<vector::LoadOp>,
1928                VectorLoadStoreConversion<vector::MaskedLoadOp>,
1929                VectorLoadStoreConversion<vector::StoreOp>,
1930                VectorLoadStoreConversion<vector::MaskedStoreOp>,
1931                VectorGatherOpConversion, VectorScatterOpConversion,
1932                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1933                VectorSplatOpLowering, VectorSplatNdOpLowering,
1934                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1935                MaskedReductionOpConversion, VectorInterleaveOpLowering,
1936                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1937                VectorScalableStepOpLowering>(converter);
1938 }
1939 
1940 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1941     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1942   patterns.add<VectorMatmulOpConversion>(converter);
1943   patterns.add<VectorFlatTransposeOpConversion>(converter);
1944 }
1945