xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 168213f91c571352c56f432573513cba3f9ba61b)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Target/LLVMIR/TypeTranslation.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/Passes.h"
29 #include "llvm/IR/DerivedTypes.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/ErrorHandling.h"
34 
35 using namespace mlir;
36 using namespace mlir::vector;
37 
38 // Helper to reduce vector type by one rank at front.
39 static VectorType reducedVectorTypeFront(VectorType tp) {
40   assert((tp.getRank() > 1) && "unlowerable vector type");
41   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
42 }
43 
44 // Helper to reduce vector type by *all* but one rank at back.
45 static VectorType reducedVectorTypeBack(VectorType tp) {
46   assert((tp.getRank() > 1) && "unlowerable vector type");
47   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
48 }
49 
50 // Helper that picks the proper sequence for inserting.
51 static Value insertOne(ConversionPatternRewriter &rewriter,
52                        LLVMTypeConverter &typeConverter, Location loc,
53                        Value val1, Value val2, Type llvmType, int64_t rank,
54                        int64_t pos) {
55   if (rank == 1) {
56     auto idxType = rewriter.getIndexType();
57     auto constant = rewriter.create<LLVM::ConstantOp>(
58         loc, typeConverter.convertType(idxType),
59         rewriter.getIntegerAttr(idxType, pos));
60     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
61                                                   constant);
62   }
63   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
64                                               rewriter.getI64ArrayAttr(pos));
65 }
66 
67 // Helper that picks the proper sequence for inserting.
68 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
69                        Value into, int64_t offset) {
70   auto vectorType = into.getType().cast<VectorType>();
71   if (vectorType.getRank() > 1)
72     return rewriter.create<InsertOp>(loc, from, into, offset);
73   return rewriter.create<vector::InsertElementOp>(
74       loc, vectorType, from, into,
75       rewriter.create<ConstantIndexOp>(loc, offset));
76 }
77 
78 // Helper that picks the proper sequence for extracting.
79 static Value extractOne(ConversionPatternRewriter &rewriter,
80                         LLVMTypeConverter &typeConverter, Location loc,
81                         Value val, Type llvmType, int64_t rank, int64_t pos) {
82   if (rank == 1) {
83     auto idxType = rewriter.getIndexType();
84     auto constant = rewriter.create<LLVM::ConstantOp>(
85         loc, typeConverter.convertType(idxType),
86         rewriter.getIntegerAttr(idxType, pos));
87     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
88                                                    constant);
89   }
90   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
91                                                rewriter.getI64ArrayAttr(pos));
92 }
93 
94 // Helper that picks the proper sequence for extracting.
95 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
96                         int64_t offset) {
97   auto vectorType = vector.getType().cast<VectorType>();
98   if (vectorType.getRank() > 1)
99     return rewriter.create<ExtractOp>(loc, vector, offset);
100   return rewriter.create<vector::ExtractElementOp>(
101       loc, vectorType.getElementType(), vector,
102       rewriter.create<ConstantIndexOp>(loc, offset));
103 }
104 
105 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
106 // TODO: Better support for attribute subtype forwarding + slicing.
107 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
108                                               unsigned dropFront = 0,
109                                               unsigned dropBack = 0) {
110   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
111   auto range = arrayAttr.getAsRange<IntegerAttr>();
112   SmallVector<int64_t, 4> res;
113   res.reserve(arrayAttr.size() - dropFront - dropBack);
114   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
115        it != eit; ++it)
116     res.push_back((*it).getValue().getSExtValue());
117   return res;
118 }
119 
120 // Helper that returns data layout alignment of an operation with memref.
121 template <typename T>
122 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
123                                  unsigned &align) {
124   Type elementTy =
125       typeConverter.convertType(op.getMemRefType().getElementType());
126   if (!elementTy)
127     return failure();
128 
129   // TODO: this should use the MLIR data layout when it becomes available and
130   // stop depending on translation.
131   llvm::LLVMContext llvmContext;
132   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
133               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
134                                      typeConverter.getDataLayout());
135   return success();
136 }
137 
138 // Helper that returns the base address of a memref.
139 LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
140                       Value memref, MemRefType memRefType, Value &base) {
141   // Inspect stride and offset structure.
142   //
143   // TODO: flat memory only for now, generalize
144   //
145   int64_t offset;
146   SmallVector<int64_t, 4> strides;
147   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
148   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
149       offset != 0 || memRefType.getMemorySpace() != 0)
150     return failure();
151   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
152   return success();
153 }
154 
155 // Helper that returns a pointer given a memref base.
156 LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
157                          Value memref, MemRefType memRefType, Value &ptr) {
158   Value base;
159   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
160     return failure();
161   auto pType = MemRefDescriptor(memref).getElementType();
162   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
163   return success();
164 }
165 
166 // Helper that returns a bit-casted pointer given a memref base.
167 LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
168                          Value memref, MemRefType memRefType, Type type,
169                          Value &ptr) {
170   Value base;
171   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
172     return failure();
173   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
174   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
175   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
176   return success();
177 }
178 
179 // Helper that returns vector of pointers given a memref base and an index
180 // vector.
181 LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
182                              Value memref, Value indices, MemRefType memRefType,
183                              VectorType vType, Type iType, Value &ptrs) {
184   Value base;
185   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
186     return failure();
187   auto pType = MemRefDescriptor(memref).getElementType();
188   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
189   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
190   return success();
191 }
192 
193 static LogicalResult
194 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
195                                  LLVMTypeConverter &typeConverter, Location loc,
196                                  TransferReadOp xferOp,
197                                  ArrayRef<Value> operands, Value dataPtr) {
198   unsigned align;
199   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
200     return failure();
201   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
202   return success();
203 }
204 
205 static LogicalResult
206 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
207                             LLVMTypeConverter &typeConverter, Location loc,
208                             TransferReadOp xferOp, ArrayRef<Value> operands,
209                             Value dataPtr, Value mask) {
210   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
211   VectorType fillType = xferOp.getVectorType();
212   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
213   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
214 
215   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
216   if (!vecTy)
217     return failure();
218 
219   unsigned align;
220   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
221     return failure();
222 
223   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
224       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
225       rewriter.getI32IntegerAttr(align));
226   return success();
227 }
228 
229 static LogicalResult
230 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
231                                  LLVMTypeConverter &typeConverter, Location loc,
232                                  TransferWriteOp xferOp,
233                                  ArrayRef<Value> operands, Value dataPtr) {
234   unsigned align;
235   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
236     return failure();
237   auto adaptor = TransferWriteOpAdaptor(operands);
238   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
239                                              align);
240   return success();
241 }
242 
243 static LogicalResult
244 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
245                             LLVMTypeConverter &typeConverter, Location loc,
246                             TransferWriteOp xferOp, ArrayRef<Value> operands,
247                             Value dataPtr, Value mask) {
248   unsigned align;
249   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
250     return failure();
251 
252   auto adaptor = TransferWriteOpAdaptor(operands);
253   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
254       xferOp, adaptor.vector(), dataPtr, mask,
255       rewriter.getI32IntegerAttr(align));
256   return success();
257 }
258 
259 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
260                                                   ArrayRef<Value> operands) {
261   return TransferReadOpAdaptor(operands);
262 }
263 
264 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
265                                                    ArrayRef<Value> operands) {
266   return TransferWriteOpAdaptor(operands);
267 }
268 
269 namespace {
270 
271 /// Conversion pattern for a vector.matrix_multiply.
272 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
273 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
274 public:
275   explicit VectorMatmulOpConversion(MLIRContext *context,
276                                     LLVMTypeConverter &typeConverter)
277       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
278                              typeConverter) {}
279 
280   LogicalResult
281   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
282                   ConversionPatternRewriter &rewriter) const override {
283     auto matmulOp = cast<vector::MatmulOp>(op);
284     auto adaptor = vector::MatmulOpAdaptor(operands);
285     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
286         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
287         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
288         matmulOp.rhs_columns());
289     return success();
290   }
291 };
292 
293 /// Conversion pattern for a vector.flat_transpose.
294 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
295 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
296 public:
297   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
298                                            LLVMTypeConverter &typeConverter)
299       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
300                              context, typeConverter) {}
301 
302   LogicalResult
303   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
304                   ConversionPatternRewriter &rewriter) const override {
305     auto transOp = cast<vector::FlatTransposeOp>(op);
306     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
307     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
308         transOp, typeConverter.convertType(transOp.res().getType()),
309         adaptor.matrix(), transOp.rows(), transOp.columns());
310     return success();
311   }
312 };
313 
314 /// Conversion pattern for a vector.maskedload.
315 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
316 public:
317   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
318                                         LLVMTypeConverter &typeConverter)
319       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
320                              typeConverter) {}
321 
322   LogicalResult
323   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
324                   ConversionPatternRewriter &rewriter) const override {
325     auto loc = op->getLoc();
326     auto load = cast<vector::MaskedLoadOp>(op);
327     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
328 
329     // Resolve alignment.
330     unsigned align;
331     if (failed(getMemRefAlignment(typeConverter, load, align)))
332       return failure();
333 
334     auto vtype = typeConverter.convertType(load.getResultVectorType());
335     Value ptr;
336     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
337                           vtype, ptr)))
338       return failure();
339 
340     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
341         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
342         rewriter.getI32IntegerAttr(align));
343     return success();
344   }
345 };
346 
347 /// Conversion pattern for a vector.maskedstore.
348 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
349 public:
350   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
351                                          LLVMTypeConverter &typeConverter)
352       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
353                              typeConverter) {}
354 
355   LogicalResult
356   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
357                   ConversionPatternRewriter &rewriter) const override {
358     auto loc = op->getLoc();
359     auto store = cast<vector::MaskedStoreOp>(op);
360     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
361 
362     // Resolve alignment.
363     unsigned align;
364     if (failed(getMemRefAlignment(typeConverter, store, align)))
365       return failure();
366 
367     auto vtype = typeConverter.convertType(store.getValueVectorType());
368     Value ptr;
369     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
370                           vtype, ptr)))
371       return failure();
372 
373     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
374         store, adaptor.value(), ptr, adaptor.mask(),
375         rewriter.getI32IntegerAttr(align));
376     return success();
377   }
378 };
379 
380 /// Conversion pattern for a vector.gather.
381 class VectorGatherOpConversion : public ConvertToLLVMPattern {
382 public:
383   explicit VectorGatherOpConversion(MLIRContext *context,
384                                     LLVMTypeConverter &typeConverter)
385       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
386                              typeConverter) {}
387 
388   LogicalResult
389   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
390                   ConversionPatternRewriter &rewriter) const override {
391     auto loc = op->getLoc();
392     auto gather = cast<vector::GatherOp>(op);
393     auto adaptor = vector::GatherOpAdaptor(operands);
394 
395     // Resolve alignment.
396     unsigned align;
397     if (failed(getMemRefAlignment(typeConverter, gather, align)))
398       return failure();
399 
400     // Get index ptrs.
401     VectorType vType = gather.getResultVectorType();
402     Type iType = gather.getIndicesVectorType().getElementType();
403     Value ptrs;
404     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
405                               gather.getMemRefType(), vType, iType, ptrs)))
406       return failure();
407 
408     // Replace with the gather intrinsic.
409     ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
410                                                           : adaptor.pass_thru();
411     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
412         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
413         rewriter.getI32IntegerAttr(align));
414     return success();
415   }
416 };
417 
418 /// Conversion pattern for a vector.scatter.
419 class VectorScatterOpConversion : public ConvertToLLVMPattern {
420 public:
421   explicit VectorScatterOpConversion(MLIRContext *context,
422                                      LLVMTypeConverter &typeConverter)
423       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
424                              typeConverter) {}
425 
426   LogicalResult
427   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
428                   ConversionPatternRewriter &rewriter) const override {
429     auto loc = op->getLoc();
430     auto scatter = cast<vector::ScatterOp>(op);
431     auto adaptor = vector::ScatterOpAdaptor(operands);
432 
433     // Resolve alignment.
434     unsigned align;
435     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
436       return failure();
437 
438     // Get index ptrs.
439     VectorType vType = scatter.getValueVectorType();
440     Type iType = scatter.getIndicesVectorType().getElementType();
441     Value ptrs;
442     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
443                               scatter.getMemRefType(), vType, iType, ptrs)))
444       return failure();
445 
446     // Replace with the scatter intrinsic.
447     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
448         scatter, adaptor.value(), ptrs, adaptor.mask(),
449         rewriter.getI32IntegerAttr(align));
450     return success();
451   }
452 };
453 
454 /// Conversion pattern for a vector.expandload.
455 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
456 public:
457   explicit VectorExpandLoadOpConversion(MLIRContext *context,
458                                         LLVMTypeConverter &typeConverter)
459       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
460                              typeConverter) {}
461 
462   LogicalResult
463   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
464                   ConversionPatternRewriter &rewriter) const override {
465     auto loc = op->getLoc();
466     auto expand = cast<vector::ExpandLoadOp>(op);
467     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
468 
469     Value ptr;
470     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
471                           ptr)))
472       return failure();
473 
474     auto vType = expand.getResultVectorType();
475     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
476         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
477         adaptor.pass_thru());
478     return success();
479   }
480 };
481 
482 /// Conversion pattern for a vector.compressstore.
483 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
484 public:
485   explicit VectorCompressStoreOpConversion(MLIRContext *context,
486                                            LLVMTypeConverter &typeConverter)
487       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
488                              context, typeConverter) {}
489 
490   LogicalResult
491   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
492                   ConversionPatternRewriter &rewriter) const override {
493     auto loc = op->getLoc();
494     auto compress = cast<vector::CompressStoreOp>(op);
495     auto adaptor = vector::CompressStoreOpAdaptor(operands);
496 
497     Value ptr;
498     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
499                           compress.getMemRefType(), ptr)))
500       return failure();
501 
502     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
503         op, adaptor.value(), ptr, adaptor.mask());
504     return success();
505   }
506 };
507 
508 /// Conversion pattern for all vector reductions.
509 class VectorReductionOpConversion : public ConvertToLLVMPattern {
510 public:
511   explicit VectorReductionOpConversion(MLIRContext *context,
512                                        LLVMTypeConverter &typeConverter,
513                                        bool reassociateFP)
514       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
515                              typeConverter),
516         reassociateFPReductions(reassociateFP) {}
517 
518   LogicalResult
519   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
520                   ConversionPatternRewriter &rewriter) const override {
521     auto reductionOp = cast<vector::ReductionOp>(op);
522     auto kind = reductionOp.kind();
523     Type eltType = reductionOp.dest().getType();
524     Type llvmType = typeConverter.convertType(eltType);
525     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
526       // Integer reductions: add/mul/min/max/and/or/xor.
527       if (kind == "add")
528         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
529             op, llvmType, operands[0]);
530       else if (kind == "mul")
531         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
532             op, llvmType, operands[0]);
533       else if (kind == "min")
534         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
535             op, llvmType, operands[0]);
536       else if (kind == "max")
537         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
538             op, llvmType, operands[0]);
539       else if (kind == "and")
540         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
541             op, llvmType, operands[0]);
542       else if (kind == "or")
543         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
544             op, llvmType, operands[0]);
545       else if (kind == "xor")
546         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
547             op, llvmType, operands[0]);
548       else
549         return failure();
550       return success();
551 
552     } else if (eltType.isF32() || eltType.isF64()) {
553       // Floating-point reductions: add/mul/min/max
554       if (kind == "add") {
555         // Optional accumulator (or zero).
556         Value acc = operands.size() > 1 ? operands[1]
557                                         : rewriter.create<LLVM::ConstantOp>(
558                                               op->getLoc(), llvmType,
559                                               rewriter.getZeroAttr(eltType));
560         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
561             op, llvmType, acc, operands[0],
562             rewriter.getBoolAttr(reassociateFPReductions));
563       } else if (kind == "mul") {
564         // Optional accumulator (or one).
565         Value acc = operands.size() > 1
566                         ? operands[1]
567                         : rewriter.create<LLVM::ConstantOp>(
568                               op->getLoc(), llvmType,
569                               rewriter.getFloatAttr(eltType, 1.0));
570         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
571             op, llvmType, acc, operands[0],
572             rewriter.getBoolAttr(reassociateFPReductions));
573       } else if (kind == "min")
574         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
575             op, llvmType, operands[0]);
576       else if (kind == "max")
577         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
578             op, llvmType, operands[0]);
579       else
580         return failure();
581       return success();
582     }
583     return failure();
584   }
585 
586 private:
587   const bool reassociateFPReductions;
588 };
589 
590 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
591 public:
592   explicit VectorShuffleOpConversion(MLIRContext *context,
593                                      LLVMTypeConverter &typeConverter)
594       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
595                              typeConverter) {}
596 
597   LogicalResult
598   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
599                   ConversionPatternRewriter &rewriter) const override {
600     auto loc = op->getLoc();
601     auto adaptor = vector::ShuffleOpAdaptor(operands);
602     auto shuffleOp = cast<vector::ShuffleOp>(op);
603     auto v1Type = shuffleOp.getV1VectorType();
604     auto v2Type = shuffleOp.getV2VectorType();
605     auto vectorType = shuffleOp.getVectorType();
606     Type llvmType = typeConverter.convertType(vectorType);
607     auto maskArrayAttr = shuffleOp.mask();
608 
609     // Bail if result type cannot be lowered.
610     if (!llvmType)
611       return failure();
612 
613     // Get rank and dimension sizes.
614     int64_t rank = vectorType.getRank();
615     assert(v1Type.getRank() == rank);
616     assert(v2Type.getRank() == rank);
617     int64_t v1Dim = v1Type.getDimSize(0);
618 
619     // For rank 1, where both operands have *exactly* the same vector type,
620     // there is direct shuffle support in LLVM. Use it!
621     if (rank == 1 && v1Type == v2Type) {
622       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
623           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
624       rewriter.replaceOp(op, shuffle);
625       return success();
626     }
627 
628     // For all other cases, insert the individual values individually.
629     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
630     int64_t insPos = 0;
631     for (auto en : llvm::enumerate(maskArrayAttr)) {
632       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
633       Value value = adaptor.v1();
634       if (extPos >= v1Dim) {
635         extPos -= v1Dim;
636         value = adaptor.v2();
637       }
638       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
639                                  rank, extPos);
640       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
641                          llvmType, rank, insPos++);
642     }
643     rewriter.replaceOp(op, insert);
644     return success();
645   }
646 };
647 
648 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
649 public:
650   explicit VectorExtractElementOpConversion(MLIRContext *context,
651                                             LLVMTypeConverter &typeConverter)
652       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
653                              context, typeConverter) {}
654 
655   LogicalResult
656   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
657                   ConversionPatternRewriter &rewriter) const override {
658     auto adaptor = vector::ExtractElementOpAdaptor(operands);
659     auto extractEltOp = cast<vector::ExtractElementOp>(op);
660     auto vectorType = extractEltOp.getVectorType();
661     auto llvmType = typeConverter.convertType(vectorType.getElementType());
662 
663     // Bail if result type cannot be lowered.
664     if (!llvmType)
665       return failure();
666 
667     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
668         op, llvmType, adaptor.vector(), adaptor.position());
669     return success();
670   }
671 };
672 
673 class VectorExtractOpConversion : public ConvertToLLVMPattern {
674 public:
675   explicit VectorExtractOpConversion(MLIRContext *context,
676                                      LLVMTypeConverter &typeConverter)
677       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
678                              typeConverter) {}
679 
680   LogicalResult
681   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
682                   ConversionPatternRewriter &rewriter) const override {
683     auto loc = op->getLoc();
684     auto adaptor = vector::ExtractOpAdaptor(operands);
685     auto extractOp = cast<vector::ExtractOp>(op);
686     auto vectorType = extractOp.getVectorType();
687     auto resultType = extractOp.getResult().getType();
688     auto llvmResultType = typeConverter.convertType(resultType);
689     auto positionArrayAttr = extractOp.position();
690 
691     // Bail if result type cannot be lowered.
692     if (!llvmResultType)
693       return failure();
694 
695     // One-shot extraction of vector from array (only requires extractvalue).
696     if (resultType.isa<VectorType>()) {
697       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
698           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
699       rewriter.replaceOp(op, extracted);
700       return success();
701     }
702 
703     // Potential extraction of 1-D vector from array.
704     auto *context = op->getContext();
705     Value extracted = adaptor.vector();
706     auto positionAttrs = positionArrayAttr.getValue();
707     if (positionAttrs.size() > 1) {
708       auto oneDVectorType = reducedVectorTypeBack(vectorType);
709       auto nMinusOnePositionAttrs =
710           ArrayAttr::get(positionAttrs.drop_back(), context);
711       extracted = rewriter.create<LLVM::ExtractValueOp>(
712           loc, typeConverter.convertType(oneDVectorType), extracted,
713           nMinusOnePositionAttrs);
714     }
715 
716     // Remaining extraction of element from 1-D LLVM vector
717     auto position = positionAttrs.back().cast<IntegerAttr>();
718     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
719     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
720     extracted =
721         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
722     rewriter.replaceOp(op, extracted);
723 
724     return success();
725   }
726 };
727 
728 /// Conversion pattern that turns a vector.fma on a 1-D vector
729 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
730 /// This does not match vectors of n >= 2 rank.
731 ///
732 /// Example:
733 /// ```
734 ///  vector.fma %a, %a, %a : vector<8xf32>
735 /// ```
736 /// is converted to:
737 /// ```
738 ///  llvm.intr.fmuladd %va, %va, %va:
739 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
740 ///    -> !llvm<"<8 x float>">
741 /// ```
742 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
743 public:
744   explicit VectorFMAOp1DConversion(MLIRContext *context,
745                                    LLVMTypeConverter &typeConverter)
746       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
747                              typeConverter) {}
748 
749   LogicalResult
750   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
751                   ConversionPatternRewriter &rewriter) const override {
752     auto adaptor = vector::FMAOpAdaptor(operands);
753     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
754     VectorType vType = fmaOp.getVectorType();
755     if (vType.getRank() != 1)
756       return failure();
757     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
758                                                  adaptor.rhs(), adaptor.acc());
759     return success();
760   }
761 };
762 
763 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
764 public:
765   explicit VectorInsertElementOpConversion(MLIRContext *context,
766                                            LLVMTypeConverter &typeConverter)
767       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
768                              context, typeConverter) {}
769 
770   LogicalResult
771   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
772                   ConversionPatternRewriter &rewriter) const override {
773     auto adaptor = vector::InsertElementOpAdaptor(operands);
774     auto insertEltOp = cast<vector::InsertElementOp>(op);
775     auto vectorType = insertEltOp.getDestVectorType();
776     auto llvmType = typeConverter.convertType(vectorType);
777 
778     // Bail if result type cannot be lowered.
779     if (!llvmType)
780       return failure();
781 
782     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
783         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
784     return success();
785   }
786 };
787 
788 class VectorInsertOpConversion : public ConvertToLLVMPattern {
789 public:
790   explicit VectorInsertOpConversion(MLIRContext *context,
791                                     LLVMTypeConverter &typeConverter)
792       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
793                              typeConverter) {}
794 
795   LogicalResult
796   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
797                   ConversionPatternRewriter &rewriter) const override {
798     auto loc = op->getLoc();
799     auto adaptor = vector::InsertOpAdaptor(operands);
800     auto insertOp = cast<vector::InsertOp>(op);
801     auto sourceType = insertOp.getSourceType();
802     auto destVectorType = insertOp.getDestVectorType();
803     auto llvmResultType = typeConverter.convertType(destVectorType);
804     auto positionArrayAttr = insertOp.position();
805 
806     // Bail if result type cannot be lowered.
807     if (!llvmResultType)
808       return failure();
809 
810     // One-shot insertion of a vector into an array (only requires insertvalue).
811     if (sourceType.isa<VectorType>()) {
812       Value inserted = rewriter.create<LLVM::InsertValueOp>(
813           loc, llvmResultType, adaptor.dest(), adaptor.source(),
814           positionArrayAttr);
815       rewriter.replaceOp(op, inserted);
816       return success();
817     }
818 
819     // Potential extraction of 1-D vector from array.
820     auto *context = op->getContext();
821     Value extracted = adaptor.dest();
822     auto positionAttrs = positionArrayAttr.getValue();
823     auto position = positionAttrs.back().cast<IntegerAttr>();
824     auto oneDVectorType = destVectorType;
825     if (positionAttrs.size() > 1) {
826       oneDVectorType = reducedVectorTypeBack(destVectorType);
827       auto nMinusOnePositionAttrs =
828           ArrayAttr::get(positionAttrs.drop_back(), context);
829       extracted = rewriter.create<LLVM::ExtractValueOp>(
830           loc, typeConverter.convertType(oneDVectorType), extracted,
831           nMinusOnePositionAttrs);
832     }
833 
834     // Insertion of an element into a 1-D LLVM vector.
835     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
836     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
837     Value inserted = rewriter.create<LLVM::InsertElementOp>(
838         loc, typeConverter.convertType(oneDVectorType), extracted,
839         adaptor.source(), constant);
840 
841     // Potential insertion of resulting 1-D vector into array.
842     if (positionAttrs.size() > 1) {
843       auto nMinusOnePositionAttrs =
844           ArrayAttr::get(positionAttrs.drop_back(), context);
845       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
846                                                       adaptor.dest(), inserted,
847                                                       nMinusOnePositionAttrs);
848     }
849 
850     rewriter.replaceOp(op, inserted);
851     return success();
852   }
853 };
854 
855 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
856 ///
857 /// Example:
858 /// ```
859 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
860 /// ```
861 /// is rewritten into:
862 /// ```
863 ///  %r = splat %f0: vector<2x4xf32>
864 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
865 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
866 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
867 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
868 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
869 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
870 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
871 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
872 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
873 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
874 ///  // %r3 holds the final value.
875 /// ```
876 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
877 public:
878   using OpRewritePattern<FMAOp>::OpRewritePattern;
879 
880   LogicalResult matchAndRewrite(FMAOp op,
881                                 PatternRewriter &rewriter) const override {
882     auto vType = op.getVectorType();
883     if (vType.getRank() < 2)
884       return failure();
885 
886     auto loc = op.getLoc();
887     auto elemType = vType.getElementType();
888     Value zero = rewriter.create<ConstantOp>(loc, elemType,
889                                              rewriter.getZeroAttr(elemType));
890     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
891     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
892       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
893       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
894       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
895       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
896       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
897     }
898     rewriter.replaceOp(op, desc);
899     return success();
900   }
901 };
902 
903 // When ranks are different, InsertStridedSlice needs to extract a properly
904 // ranked vector from the destination vector into which to insert. This pattern
905 // only takes care of this part and forwards the rest of the conversion to
906 // another pattern that converts InsertStridedSlice for operands of the same
907 // rank.
908 //
909 // RewritePattern for InsertStridedSliceOp where source and destination vectors
910 // have different ranks. In this case:
911 //   1. the proper subvector is extracted from the destination vector
912 //   2. a new InsertStridedSlice op is created to insert the source in the
913 //   destination subvector
914 //   3. the destination subvector is inserted back in the proper place
915 //   4. the op is replaced by the result of step 3.
916 // The new InsertStridedSlice from step 2. will be picked up by a
917 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
918 class VectorInsertStridedSliceOpDifferentRankRewritePattern
919     : public OpRewritePattern<InsertStridedSliceOp> {
920 public:
921   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
922 
923   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
924                                 PatternRewriter &rewriter) const override {
925     auto srcType = op.getSourceVectorType();
926     auto dstType = op.getDestVectorType();
927 
928     if (op.offsets().getValue().empty())
929       return failure();
930 
931     auto loc = op.getLoc();
932     int64_t rankDiff = dstType.getRank() - srcType.getRank();
933     assert(rankDiff >= 0);
934     if (rankDiff == 0)
935       return failure();
936 
937     int64_t rankRest = dstType.getRank() - rankDiff;
938     // Extract / insert the subvector of matching rank and InsertStridedSlice
939     // on it.
940     Value extracted =
941         rewriter.create<ExtractOp>(loc, op.dest(),
942                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
943                                                   /*dropFront=*/rankRest));
944     // A different pattern will kick in for InsertStridedSlice with matching
945     // ranks.
946     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
947         loc, op.source(), extracted,
948         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
949         getI64SubArray(op.strides(), /*dropFront=*/0));
950     rewriter.replaceOpWithNewOp<InsertOp>(
951         op, stridedSliceInnerOp.getResult(), op.dest(),
952         getI64SubArray(op.offsets(), /*dropFront=*/0,
953                        /*dropFront=*/rankRest));
954     return success();
955   }
956 };
957 
958 // RewritePattern for InsertStridedSliceOp where source and destination vectors
959 // have the same rank. In this case, we reduce
960 //   1. the proper subvector is extracted from the destination vector
961 //   2. a new InsertStridedSlice op is created to insert the source in the
962 //   destination subvector
963 //   3. the destination subvector is inserted back in the proper place
964 //   4. the op is replaced by the result of step 3.
965 // The new InsertStridedSlice from step 2. will be picked up by a
966 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
967 class VectorInsertStridedSliceOpSameRankRewritePattern
968     : public OpRewritePattern<InsertStridedSliceOp> {
969 public:
970   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
971 
972   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
973                                 PatternRewriter &rewriter) const override {
974     auto srcType = op.getSourceVectorType();
975     auto dstType = op.getDestVectorType();
976 
977     if (op.offsets().getValue().empty())
978       return failure();
979 
980     int64_t rankDiff = dstType.getRank() - srcType.getRank();
981     assert(rankDiff >= 0);
982     if (rankDiff != 0)
983       return failure();
984 
985     if (srcType == dstType) {
986       rewriter.replaceOp(op, op.source());
987       return success();
988     }
989 
990     int64_t offset =
991         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
992     int64_t size = srcType.getShape().front();
993     int64_t stride =
994         op.strides().getValue().front().cast<IntegerAttr>().getInt();
995 
996     auto loc = op.getLoc();
997     Value res = op.dest();
998     // For each slice of the source vector along the most major dimension.
999     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1000          off += stride, ++idx) {
1001       // 1. extract the proper subvector (or element) from source
1002       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1003       if (extractedSource.getType().isa<VectorType>()) {
1004         // 2. If we have a vector, extract the proper subvector from destination
1005         // Otherwise we are at the element level and no need to recurse.
1006         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1007         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1008         // smaller rank.
1009         extractedSource = rewriter.create<InsertStridedSliceOp>(
1010             loc, extractedSource, extractedDest,
1011             getI64SubArray(op.offsets(), /* dropFront=*/1),
1012             getI64SubArray(op.strides(), /* dropFront=*/1));
1013       }
1014       // 4. Insert the extractedSource into the res vector.
1015       res = insertOne(rewriter, loc, extractedSource, res, off);
1016     }
1017 
1018     rewriter.replaceOp(op, res);
1019     return success();
1020   }
1021   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
1022   /// bounded as the rank is strictly decreasing.
1023   bool hasBoundedRewriteRecursion() const final { return true; }
1024 };
1025 
1026 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
1027 public:
1028   explicit VectorTypeCastOpConversion(MLIRContext *context,
1029                                       LLVMTypeConverter &typeConverter)
1030       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
1031                              typeConverter) {}
1032 
1033   LogicalResult
1034   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1035                   ConversionPatternRewriter &rewriter) const override {
1036     auto loc = op->getLoc();
1037     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
1038     MemRefType sourceMemRefType =
1039         castOp.getOperand().getType().cast<MemRefType>();
1040     MemRefType targetMemRefType =
1041         castOp.getResult().getType().cast<MemRefType>();
1042 
1043     // Only static shape casts supported atm.
1044     if (!sourceMemRefType.hasStaticShape() ||
1045         !targetMemRefType.hasStaticShape())
1046       return failure();
1047 
1048     auto llvmSourceDescriptorTy =
1049         operands[0].getType().dyn_cast<LLVM::LLVMType>();
1050     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
1051       return failure();
1052     MemRefDescriptor sourceMemRef(operands[0]);
1053 
1054     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
1055                                       .dyn_cast_or_null<LLVM::LLVMType>();
1056     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
1057       return failure();
1058 
1059     int64_t offset;
1060     SmallVector<int64_t, 4> strides;
1061     auto successStrides =
1062         getStridesAndOffset(sourceMemRefType, strides, offset);
1063     bool isContiguous = (strides.back() == 1);
1064     if (isContiguous) {
1065       auto sizes = sourceMemRefType.getShape();
1066       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1067         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
1068           isContiguous = false;
1069           break;
1070         }
1071       }
1072     }
1073     // Only contiguous source tensors supported atm.
1074     if (failed(successStrides) || !isContiguous)
1075       return failure();
1076 
1077     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
1078 
1079     // Create descriptor.
1080     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1081     Type llvmTargetElementTy = desc.getElementType();
1082     // Set allocated ptr.
1083     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1084     allocated =
1085         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1086     desc.setAllocatedPtr(rewriter, loc, allocated);
1087     // Set aligned ptr.
1088     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1089     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1090     desc.setAlignedPtr(rewriter, loc, ptr);
1091     // Fill offset 0.
1092     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1093     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1094     desc.setOffset(rewriter, loc, zero);
1095 
1096     // Fill size and stride descriptors in memref.
1097     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1098       int64_t index = indexedSize.index();
1099       auto sizeAttr =
1100           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1101       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1102       desc.setSize(rewriter, loc, index, size);
1103       auto strideAttr =
1104           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
1105       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1106       desc.setStride(rewriter, loc, index, stride);
1107     }
1108 
1109     rewriter.replaceOp(op, {desc});
1110     return success();
1111   }
1112 };
1113 
1114 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1115 /// sequence of:
1116 /// 1. Bitcast or addrspacecast to vector form.
1117 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1118 /// 3. Create a mask where offsetVector is compared against memref upper bound.
1119 /// 4. Rewrite op as a masked read or write.
1120 template <typename ConcreteOp>
1121 class VectorTransferConversion : public ConvertToLLVMPattern {
1122 public:
1123   explicit VectorTransferConversion(MLIRContext *context,
1124                                     LLVMTypeConverter &typeConv)
1125       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
1126                              typeConv) {}
1127 
1128   LogicalResult
1129   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1130                   ConversionPatternRewriter &rewriter) const override {
1131     auto xferOp = cast<ConcreteOp>(op);
1132     auto adaptor = getTransferOpAdapter(xferOp, operands);
1133 
1134     if (xferOp.getVectorType().getRank() > 1 ||
1135         llvm::size(xferOp.indices()) == 0)
1136       return failure();
1137     if (xferOp.permutation_map() !=
1138         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1139                                        xferOp.getVectorType().getRank(),
1140                                        op->getContext()))
1141       return failure();
1142 
1143     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
1144 
1145     Location loc = op->getLoc();
1146     Type i64Type = rewriter.getIntegerType(64);
1147     MemRefType memRefType = xferOp.getMemRefType();
1148 
1149     if (auto memrefVectorElementType =
1150             memRefType.getElementType().dyn_cast<VectorType>()) {
1151       // Memref has vector element type.
1152       if (memrefVectorElementType.getElementType() !=
1153           xferOp.getVectorType().getElementType())
1154         return failure();
1155 #ifndef NDEBUG
1156       // Check that memref vector type is a suffix of 'vectorType.
1157       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1158       unsigned resultVecRank = xferOp.getVectorType().getRank();
1159       assert(memrefVecEltRank <= resultVecRank);
1160       // TODO: Move this to isSuffix in Vector/Utils.h.
1161       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1162       auto memrefVecEltShape = memrefVectorElementType.getShape();
1163       auto resultVecShape = xferOp.getVectorType().getShape();
1164       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1165         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1166                "memref vector element shape should match suffix of vector "
1167                "result shape.");
1168 #endif // ifndef NDEBUG
1169     }
1170 
1171     // 1. Get the source/dst address as an LLVM vector pointer.
1172     //    The vector pointer would always be on address space 0, therefore
1173     //    addrspacecast shall be used when source/dst memrefs are not on
1174     //    address space 0.
1175     // TODO: support alignment when possible.
1176     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1177                                adaptor.indices(), rewriter);
1178     auto vecTy =
1179         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1180     Value vectorDataPtr;
1181     if (memRefType.getMemorySpace() == 0)
1182       vectorDataPtr =
1183           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1184     else
1185       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1186           loc, vecTy.getPointerTo(), dataPtr);
1187 
1188     if (!xferOp.isMaskedDim(0))
1189       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
1190                                               xferOp, operands, vectorDataPtr);
1191 
1192     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1193     unsigned vecWidth = vecTy.getVectorNumElements();
1194     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
1195     SmallVector<int64_t, 8> indices;
1196     indices.reserve(vecWidth);
1197     for (unsigned i = 0; i < vecWidth; ++i)
1198       indices.push_back(i);
1199     Value linearIndices = rewriter.create<ConstantOp>(
1200         loc, vectorCmpType,
1201         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
1202     linearIndices = rewriter.create<LLVM::DialectCastOp>(
1203         loc, toLLVMTy(vectorCmpType), linearIndices);
1204 
1205     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1206     // TODO: when the leaf transfer rank is k > 1 we need the last
1207     // `k` dimensions here.
1208     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1209     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
1210     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
1211     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
1212     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
1213 
1214     // 4. Let dim the memref dimension, compute the vector comparison mask:
1215     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1216     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1217     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
1218     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
1219     Value mask =
1220         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
1221     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
1222                                                 mask);
1223 
1224     // 5. Rewrite as a masked read / write.
1225     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1226                                        operands, vectorDataPtr, mask);
1227   }
1228 };
1229 
1230 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1231 public:
1232   explicit VectorPrintOpConversion(MLIRContext *context,
1233                                    LLVMTypeConverter &typeConverter)
1234       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1235                              typeConverter) {}
1236 
1237   // Proof-of-concept lowering implementation that relies on a small
1238   // runtime support library, which only needs to provide a few
1239   // printing methods (single value for all data types, opening/closing
1240   // bracket, comma, newline). The lowering fully unrolls a vector
1241   // in terms of these elementary printing operations. The advantage
1242   // of this approach is that the library can remain unaware of all
1243   // low-level implementation details of vectors while still supporting
1244   // output of any shaped and dimensioned vector. Due to full unrolling,
1245   // this approach is less suited for very large vectors though.
1246   //
1247   // TODO: rely solely on libc in future? something else?
1248   //
1249   LogicalResult
1250   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1251                   ConversionPatternRewriter &rewriter) const override {
1252     auto printOp = cast<vector::PrintOp>(op);
1253     auto adaptor = vector::PrintOpAdaptor(operands);
1254     Type printType = printOp.getPrintType();
1255 
1256     if (typeConverter.convertType(printType) == nullptr)
1257       return failure();
1258 
1259     // Make sure element type has runtime support (currently just Float/Double).
1260     VectorType vectorType = printType.dyn_cast<VectorType>();
1261     Type eltType = vectorType ? vectorType.getElementType() : printType;
1262     int64_t rank = vectorType ? vectorType.getRank() : 0;
1263     Operation *printer;
1264     if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
1265       printer = getPrintI32(op);
1266     else if (eltType.isSignlessInteger(64))
1267       printer = getPrintI64(op);
1268     else if (eltType.isF32())
1269       printer = getPrintFloat(op);
1270     else if (eltType.isF64())
1271       printer = getPrintDouble(op);
1272     else
1273       return failure();
1274 
1275     // Unroll vector into elementary print calls.
1276     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
1277     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1278     rewriter.eraseOp(op);
1279     return success();
1280   }
1281 
1282 private:
1283   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1284                  Value value, VectorType vectorType, Operation *printer,
1285                  int64_t rank) const {
1286     Location loc = op->getLoc();
1287     if (rank == 0) {
1288       if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
1289         // Convert i1 (bool) to i32 so we can use the print_i32 method.
1290         // This avoids the need for a print_i1 method with an unclear ABI.
1291         auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
1292         auto trueVal = rewriter.create<ConstantOp>(
1293             loc, i32Type, rewriter.getI32IntegerAttr(1));
1294         auto falseVal = rewriter.create<ConstantOp>(
1295             loc, i32Type, rewriter.getI32IntegerAttr(0));
1296         value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
1297       }
1298       emitCall(rewriter, loc, printer, value);
1299       return;
1300     }
1301 
1302     emitCall(rewriter, loc, getPrintOpen(op));
1303     Operation *printComma = getPrintComma(op);
1304     int64_t dim = vectorType.getDimSize(0);
1305     for (int64_t d = 0; d < dim; ++d) {
1306       auto reducedType =
1307           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1308       auto llvmType = typeConverter.convertType(
1309           rank > 1 ? reducedType : vectorType.getElementType());
1310       Value nestedVal =
1311           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1312       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1313       if (d != dim - 1)
1314         emitCall(rewriter, loc, printComma);
1315     }
1316     emitCall(rewriter, loc, getPrintClose(op));
1317   }
1318 
1319   // Helper to emit a call.
1320   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1321                        Operation *ref, ValueRange params = ValueRange()) {
1322     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1323                                   rewriter.getSymbolRefAttr(ref), params);
1324   }
1325 
1326   // Helper for printer method declaration (first hit) and lookup.
1327   static Operation *getPrint(Operation *op, StringRef name,
1328                              ArrayRef<LLVM::LLVMType> params) {
1329     auto module = op->getParentOfType<ModuleOp>();
1330     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1331     if (func)
1332       return func;
1333     OpBuilder moduleBuilder(module.getBodyRegion());
1334     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1335         op->getLoc(), name,
1336         LLVM::LLVMType::getFunctionTy(
1337             LLVM::LLVMType::getVoidTy(op->getContext()), params,
1338             /*isVarArg=*/false));
1339   }
1340 
1341   // Helpers for method names.
1342   Operation *getPrintI32(Operation *op) const {
1343     return getPrint(op, "print_i32",
1344                     LLVM::LLVMType::getInt32Ty(op->getContext()));
1345   }
1346   Operation *getPrintI64(Operation *op) const {
1347     return getPrint(op, "print_i64",
1348                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1349   }
1350   Operation *getPrintFloat(Operation *op) const {
1351     return getPrint(op, "print_f32",
1352                     LLVM::LLVMType::getFloatTy(op->getContext()));
1353   }
1354   Operation *getPrintDouble(Operation *op) const {
1355     return getPrint(op, "print_f64",
1356                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1357   }
1358   Operation *getPrintOpen(Operation *op) const {
1359     return getPrint(op, "print_open", {});
1360   }
1361   Operation *getPrintClose(Operation *op) const {
1362     return getPrint(op, "print_close", {});
1363   }
1364   Operation *getPrintComma(Operation *op) const {
1365     return getPrint(op, "print_comma", {});
1366   }
1367   Operation *getPrintNewline(Operation *op) const {
1368     return getPrint(op, "print_newline", {});
1369   }
1370 };
1371 
1372 /// Progressive lowering of ExtractStridedSliceOp to either:
1373 ///   1. express single offset extract as a direct shuffle.
1374 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1375 class VectorExtractStridedSliceOpConversion
1376     : public OpRewritePattern<ExtractStridedSliceOp> {
1377 public:
1378   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1379 
1380   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1381                                 PatternRewriter &rewriter) const override {
1382     auto dstType = op.getResult().getType().cast<VectorType>();
1383 
1384     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1385 
1386     int64_t offset =
1387         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1388     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1389     int64_t stride =
1390         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1391 
1392     auto loc = op.getLoc();
1393     auto elemType = dstType.getElementType();
1394     assert(elemType.isSignlessIntOrIndexOrFloat());
1395 
1396     // Single offset can be more efficiently shuffled.
1397     if (op.offsets().getValue().size() == 1) {
1398       SmallVector<int64_t, 4> offsets;
1399       offsets.reserve(size);
1400       for (int64_t off = offset, e = offset + size * stride; off < e;
1401            off += stride)
1402         offsets.push_back(off);
1403       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1404                                              op.vector(),
1405                                              rewriter.getI64ArrayAttr(offsets));
1406       return success();
1407     }
1408 
1409     // Extract/insert on a lower ranked extract strided slice op.
1410     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1411                                              rewriter.getZeroAttr(elemType));
1412     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1413     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1414          off += stride, ++idx) {
1415       Value one = extractOne(rewriter, loc, op.vector(), off);
1416       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1417           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1418           getI64SubArray(op.sizes(), /* dropFront=*/1),
1419           getI64SubArray(op.strides(), /* dropFront=*/1));
1420       res = insertOne(rewriter, loc, extracted, res, idx);
1421     }
1422     rewriter.replaceOp(op, res);
1423     return success();
1424   }
1425   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1426   /// bounded as the rank is strictly decreasing.
1427   bool hasBoundedRewriteRecursion() const final { return true; }
1428 };
1429 
1430 } // namespace
1431 
1432 /// Populate the given list with patterns that convert from Vector to LLVM.
1433 void mlir::populateVectorToLLVMConversionPatterns(
1434     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1435     bool reassociateFPReductions) {
1436   MLIRContext *ctx = converter.getDialect()->getContext();
1437   // clang-format off
1438   patterns.insert<VectorFMAOpNDRewritePattern,
1439                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1440                   VectorInsertStridedSliceOpSameRankRewritePattern,
1441                   VectorExtractStridedSliceOpConversion>(ctx);
1442   patterns.insert<VectorReductionOpConversion>(
1443       ctx, converter, reassociateFPReductions);
1444   patterns
1445       .insert<VectorShuffleOpConversion,
1446               VectorExtractElementOpConversion,
1447               VectorExtractOpConversion,
1448               VectorFMAOp1DConversion,
1449               VectorInsertElementOpConversion,
1450               VectorInsertOpConversion,
1451               VectorPrintOpConversion,
1452               VectorTransferConversion<TransferReadOp>,
1453               VectorTransferConversion<TransferWriteOp>,
1454               VectorTypeCastOpConversion,
1455               VectorMaskedLoadOpConversion,
1456               VectorMaskedStoreOpConversion,
1457               VectorGatherOpConversion,
1458               VectorScatterOpConversion,
1459               VectorExpandLoadOpConversion,
1460               VectorCompressStoreOpConversion>(ctx, converter);
1461   // clang-format on
1462 }
1463 
1464 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1465     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1466   MLIRContext *ctx = converter.getDialect()->getContext();
1467   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1468   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1469 }
1470 
1471 namespace {
1472 struct LowerVectorToLLVMPass
1473     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1474   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1475     this->reassociateFPReductions = options.reassociateFPReductions;
1476   }
1477   void runOnOperation() override;
1478 };
1479 } // namespace
1480 
1481 void LowerVectorToLLVMPass::runOnOperation() {
1482   // Perform progressive lowering of operations on slices and
1483   // all contraction operations. Also applies folding and DCE.
1484   {
1485     OwningRewritePatternList patterns;
1486     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1487     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1488     populateVectorContractLoweringPatterns(patterns, &getContext());
1489     applyPatternsAndFoldGreedily(getOperation(), patterns);
1490   }
1491 
1492   // Convert to the LLVM IR dialect.
1493   LLVMTypeConverter converter(&getContext());
1494   OwningRewritePatternList patterns;
1495   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1496   populateVectorToLLVMConversionPatterns(converter, patterns,
1497                                          reassociateFPReductions);
1498   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1499   populateStdToLLVMConversionPatterns(converter, patterns);
1500 
1501   LLVMConversionTarget target(getContext());
1502   if (failed(applyPartialConversion(getOperation(), target, patterns))) {
1503     signalPassFailure();
1504   }
1505 }
1506 
1507 std::unique_ptr<OperationPass<ModuleOp>>
1508 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1509   return std::make_unique<LowerVectorToLLVMPass>(options);
1510 }
1511