xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 0de60b550b727fa3a0202a9ab5ca30520e291dd5)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Target/LLVMIR/TypeTranslation.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/Passes.h"
29 #include "llvm/IR/DerivedTypes.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/ErrorHandling.h"
34 
35 using namespace mlir;
36 using namespace mlir::vector;
37 
38 // Helper to reduce vector type by one rank at front.
39 static VectorType reducedVectorTypeFront(VectorType tp) {
40   assert((tp.getRank() > 1) && "unlowerable vector type");
41   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
42 }
43 
44 // Helper to reduce vector type by *all* but one rank at back.
45 static VectorType reducedVectorTypeBack(VectorType tp) {
46   assert((tp.getRank() > 1) && "unlowerable vector type");
47   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
48 }
49 
50 // Helper that picks the proper sequence for inserting.
51 static Value insertOne(ConversionPatternRewriter &rewriter,
52                        LLVMTypeConverter &typeConverter, Location loc,
53                        Value val1, Value val2, Type llvmType, int64_t rank,
54                        int64_t pos) {
55   if (rank == 1) {
56     auto idxType = rewriter.getIndexType();
57     auto constant = rewriter.create<LLVM::ConstantOp>(
58         loc, typeConverter.convertType(idxType),
59         rewriter.getIntegerAttr(idxType, pos));
60     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
61                                                   constant);
62   }
63   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
64                                               rewriter.getI64ArrayAttr(pos));
65 }
66 
67 // Helper that picks the proper sequence for inserting.
68 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
69                        Value into, int64_t offset) {
70   auto vectorType = into.getType().cast<VectorType>();
71   if (vectorType.getRank() > 1)
72     return rewriter.create<InsertOp>(loc, from, into, offset);
73   return rewriter.create<vector::InsertElementOp>(
74       loc, vectorType, from, into,
75       rewriter.create<ConstantIndexOp>(loc, offset));
76 }
77 
78 // Helper that picks the proper sequence for extracting.
79 static Value extractOne(ConversionPatternRewriter &rewriter,
80                         LLVMTypeConverter &typeConverter, Location loc,
81                         Value val, Type llvmType, int64_t rank, int64_t pos) {
82   if (rank == 1) {
83     auto idxType = rewriter.getIndexType();
84     auto constant = rewriter.create<LLVM::ConstantOp>(
85         loc, typeConverter.convertType(idxType),
86         rewriter.getIntegerAttr(idxType, pos));
87     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
88                                                    constant);
89   }
90   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
91                                                rewriter.getI64ArrayAttr(pos));
92 }
93 
94 // Helper that picks the proper sequence for extracting.
95 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
96                         int64_t offset) {
97   auto vectorType = vector.getType().cast<VectorType>();
98   if (vectorType.getRank() > 1)
99     return rewriter.create<ExtractOp>(loc, vector, offset);
100   return rewriter.create<vector::ExtractElementOp>(
101       loc, vectorType.getElementType(), vector,
102       rewriter.create<ConstantIndexOp>(loc, offset));
103 }
104 
105 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
106 // TODO: Better support for attribute subtype forwarding + slicing.
107 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
108                                               unsigned dropFront = 0,
109                                               unsigned dropBack = 0) {
110   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
111   auto range = arrayAttr.getAsRange<IntegerAttr>();
112   SmallVector<int64_t, 4> res;
113   res.reserve(arrayAttr.size() - dropFront - dropBack);
114   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
115        it != eit; ++it)
116     res.push_back((*it).getValue().getSExtValue());
117   return res;
118 }
119 
120 // Helper that returns data layout alignment of an operation with memref.
121 template <typename T>
122 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
123                                  unsigned &align) {
124   Type elementTy =
125       typeConverter.convertType(op.getMemRefType().getElementType());
126   if (!elementTy)
127     return failure();
128 
129   // TODO: this should use the MLIR data layout when it becomes available and
130   // stop depending on translation.
131   LLVM::LLVMDialect *dialect = typeConverter.getDialect();
132   llvm::LLVMContext llvmContext;
133   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
134               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
135                                      dialect->getDataLayout());
136   return success();
137 }
138 
139 // Helper that returns the base address of a memref.
140 LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
141                       Value memref, MemRefType memRefType, Value &base) {
142   // Inspect stride and offset structure.
143   //
144   // TODO: flat memory only for now, generalize
145   //
146   int64_t offset;
147   SmallVector<int64_t, 4> strides;
148   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
149   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
150       offset != 0 || memRefType.getMemorySpace() != 0)
151     return failure();
152   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
153   return success();
154 }
155 
156 // Helper that returns a pointer given a memref base.
157 LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
158                          Value memref, MemRefType memRefType, Value &ptr) {
159   Value base;
160   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
161     return failure();
162   auto pType = MemRefDescriptor(memref).getElementType();
163   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
164   return success();
165 }
166 
167 // Helper that returns a bit-casted pointer given a memref base.
168 LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
169                          Value memref, MemRefType memRefType, Type type,
170                          Value &ptr) {
171   Value base;
172   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
173     return failure();
174   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
175   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
176   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
177   return success();
178 }
179 
180 // Helper that returns vector of pointers given a memref base and an index
181 // vector.
182 LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
183                              Value memref, Value indices, MemRefType memRefType,
184                              VectorType vType, Type iType, Value &ptrs) {
185   Value base;
186   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
187     return failure();
188   auto pType = MemRefDescriptor(memref).getElementType();
189   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
190   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
191   return success();
192 }
193 
194 static LogicalResult
195 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
196                                  LLVMTypeConverter &typeConverter, Location loc,
197                                  TransferReadOp xferOp,
198                                  ArrayRef<Value> operands, Value dataPtr) {
199   unsigned align;
200   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
201     return failure();
202   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
203   return success();
204 }
205 
206 static LogicalResult
207 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
208                             LLVMTypeConverter &typeConverter, Location loc,
209                             TransferReadOp xferOp, ArrayRef<Value> operands,
210                             Value dataPtr, Value mask) {
211   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
212   VectorType fillType = xferOp.getVectorType();
213   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
214   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
215 
216   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
217   if (!vecTy)
218     return failure();
219 
220   unsigned align;
221   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
222     return failure();
223 
224   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
225       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
226       rewriter.getI32IntegerAttr(align));
227   return success();
228 }
229 
230 static LogicalResult
231 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
232                                  LLVMTypeConverter &typeConverter, Location loc,
233                                  TransferWriteOp xferOp,
234                                  ArrayRef<Value> operands, Value dataPtr) {
235   unsigned align;
236   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
237     return failure();
238   auto adaptor = TransferWriteOpAdaptor(operands);
239   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
240                                              align);
241   return success();
242 }
243 
244 static LogicalResult
245 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
246                             LLVMTypeConverter &typeConverter, Location loc,
247                             TransferWriteOp xferOp, ArrayRef<Value> operands,
248                             Value dataPtr, Value mask) {
249   unsigned align;
250   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
251     return failure();
252 
253   auto adaptor = TransferWriteOpAdaptor(operands);
254   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
255       xferOp, adaptor.vector(), dataPtr, mask,
256       rewriter.getI32IntegerAttr(align));
257   return success();
258 }
259 
260 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
261                                                   ArrayRef<Value> operands) {
262   return TransferReadOpAdaptor(operands);
263 }
264 
265 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
266                                                    ArrayRef<Value> operands) {
267   return TransferWriteOpAdaptor(operands);
268 }
269 
270 namespace {
271 
272 /// Conversion pattern for a vector.matrix_multiply.
273 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
274 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
275 public:
276   explicit VectorMatmulOpConversion(MLIRContext *context,
277                                     LLVMTypeConverter &typeConverter)
278       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
279                              typeConverter) {}
280 
281   LogicalResult
282   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
283                   ConversionPatternRewriter &rewriter) const override {
284     auto matmulOp = cast<vector::MatmulOp>(op);
285     auto adaptor = vector::MatmulOpAdaptor(operands);
286     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
287         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
288         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
289         matmulOp.rhs_columns());
290     return success();
291   }
292 };
293 
294 /// Conversion pattern for a vector.flat_transpose.
295 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
296 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
297 public:
298   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
299                                            LLVMTypeConverter &typeConverter)
300       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
301                              context, typeConverter) {}
302 
303   LogicalResult
304   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
305                   ConversionPatternRewriter &rewriter) const override {
306     auto transOp = cast<vector::FlatTransposeOp>(op);
307     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
308     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
309         transOp, typeConverter.convertType(transOp.res().getType()),
310         adaptor.matrix(), transOp.rows(), transOp.columns());
311     return success();
312   }
313 };
314 
315 /// Conversion pattern for a vector.maskedload.
316 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
317 public:
318   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
319                                         LLVMTypeConverter &typeConverter)
320       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
321                              typeConverter) {}
322 
323   LogicalResult
324   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
325                   ConversionPatternRewriter &rewriter) const override {
326     auto loc = op->getLoc();
327     auto load = cast<vector::MaskedLoadOp>(op);
328     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
329 
330     // Resolve alignment.
331     unsigned align;
332     if (failed(getMemRefAlignment(typeConverter, load, align)))
333       return failure();
334 
335     auto vtype = typeConverter.convertType(load.getResultVectorType());
336     Value ptr;
337     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
338                           vtype, ptr)))
339       return failure();
340 
341     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
342         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
343         rewriter.getI32IntegerAttr(align));
344     return success();
345   }
346 };
347 
348 /// Conversion pattern for a vector.maskedstore.
349 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
350 public:
351   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
352                                          LLVMTypeConverter &typeConverter)
353       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
354                              typeConverter) {}
355 
356   LogicalResult
357   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
358                   ConversionPatternRewriter &rewriter) const override {
359     auto loc = op->getLoc();
360     auto store = cast<vector::MaskedStoreOp>(op);
361     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
362 
363     // Resolve alignment.
364     unsigned align;
365     if (failed(getMemRefAlignment(typeConverter, store, align)))
366       return failure();
367 
368     auto vtype = typeConverter.convertType(store.getValueVectorType());
369     Value ptr;
370     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
371                           vtype, ptr)))
372       return failure();
373 
374     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
375         store, adaptor.value(), ptr, adaptor.mask(),
376         rewriter.getI32IntegerAttr(align));
377     return success();
378   }
379 };
380 
381 /// Conversion pattern for a vector.gather.
382 class VectorGatherOpConversion : public ConvertToLLVMPattern {
383 public:
384   explicit VectorGatherOpConversion(MLIRContext *context,
385                                     LLVMTypeConverter &typeConverter)
386       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
387                              typeConverter) {}
388 
389   LogicalResult
390   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
391                   ConversionPatternRewriter &rewriter) const override {
392     auto loc = op->getLoc();
393     auto gather = cast<vector::GatherOp>(op);
394     auto adaptor = vector::GatherOpAdaptor(operands);
395 
396     // Resolve alignment.
397     unsigned align;
398     if (failed(getMemRefAlignment(typeConverter, gather, align)))
399       return failure();
400 
401     // Get index ptrs.
402     VectorType vType = gather.getResultVectorType();
403     Type iType = gather.getIndicesVectorType().getElementType();
404     Value ptrs;
405     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
406                               gather.getMemRefType(), vType, iType, ptrs)))
407       return failure();
408 
409     // Replace with the gather intrinsic.
410     ValueRange v = (llvm::size(adaptor.pass_thru()) == 0) ? ValueRange({})
411                                                           : adaptor.pass_thru();
412     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
413         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), v,
414         rewriter.getI32IntegerAttr(align));
415     return success();
416   }
417 };
418 
419 /// Conversion pattern for a vector.scatter.
420 class VectorScatterOpConversion : public ConvertToLLVMPattern {
421 public:
422   explicit VectorScatterOpConversion(MLIRContext *context,
423                                      LLVMTypeConverter &typeConverter)
424       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
425                              typeConverter) {}
426 
427   LogicalResult
428   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
429                   ConversionPatternRewriter &rewriter) const override {
430     auto loc = op->getLoc();
431     auto scatter = cast<vector::ScatterOp>(op);
432     auto adaptor = vector::ScatterOpAdaptor(operands);
433 
434     // Resolve alignment.
435     unsigned align;
436     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
437       return failure();
438 
439     // Get index ptrs.
440     VectorType vType = scatter.getValueVectorType();
441     Type iType = scatter.getIndicesVectorType().getElementType();
442     Value ptrs;
443     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
444                               scatter.getMemRefType(), vType, iType, ptrs)))
445       return failure();
446 
447     // Replace with the scatter intrinsic.
448     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
449         scatter, adaptor.value(), ptrs, adaptor.mask(),
450         rewriter.getI32IntegerAttr(align));
451     return success();
452   }
453 };
454 
455 /// Conversion pattern for a vector.expandload.
456 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
457 public:
458   explicit VectorExpandLoadOpConversion(MLIRContext *context,
459                                         LLVMTypeConverter &typeConverter)
460       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
461                              typeConverter) {}
462 
463   LogicalResult
464   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
465                   ConversionPatternRewriter &rewriter) const override {
466     auto loc = op->getLoc();
467     auto expand = cast<vector::ExpandLoadOp>(op);
468     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
469 
470     Value ptr;
471     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
472                           ptr)))
473       return failure();
474 
475     auto vType = expand.getResultVectorType();
476     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
477         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
478         adaptor.pass_thru());
479     return success();
480   }
481 };
482 
483 /// Conversion pattern for a vector.compressstore.
484 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
485 public:
486   explicit VectorCompressStoreOpConversion(MLIRContext *context,
487                                            LLVMTypeConverter &typeConverter)
488       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
489                              context, typeConverter) {}
490 
491   LogicalResult
492   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
493                   ConversionPatternRewriter &rewriter) const override {
494     auto loc = op->getLoc();
495     auto compress = cast<vector::CompressStoreOp>(op);
496     auto adaptor = vector::CompressStoreOpAdaptor(operands);
497 
498     Value ptr;
499     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
500                           compress.getMemRefType(), ptr)))
501       return failure();
502 
503     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
504         op, adaptor.value(), ptr, adaptor.mask());
505     return success();
506   }
507 };
508 
509 /// Conversion pattern for all vector reductions.
510 class VectorReductionOpConversion : public ConvertToLLVMPattern {
511 public:
512   explicit VectorReductionOpConversion(MLIRContext *context,
513                                        LLVMTypeConverter &typeConverter,
514                                        bool reassociateFP)
515       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
516                              typeConverter),
517         reassociateFPReductions(reassociateFP) {}
518 
519   LogicalResult
520   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
521                   ConversionPatternRewriter &rewriter) const override {
522     auto reductionOp = cast<vector::ReductionOp>(op);
523     auto kind = reductionOp.kind();
524     Type eltType = reductionOp.dest().getType();
525     Type llvmType = typeConverter.convertType(eltType);
526     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
527       // Integer reductions: add/mul/min/max/and/or/xor.
528       if (kind == "add")
529         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
530             op, llvmType, operands[0]);
531       else if (kind == "mul")
532         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
533             op, llvmType, operands[0]);
534       else if (kind == "min")
535         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
536             op, llvmType, operands[0]);
537       else if (kind == "max")
538         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
539             op, llvmType, operands[0]);
540       else if (kind == "and")
541         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
542             op, llvmType, operands[0]);
543       else if (kind == "or")
544         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
545             op, llvmType, operands[0]);
546       else if (kind == "xor")
547         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
548             op, llvmType, operands[0]);
549       else
550         return failure();
551       return success();
552 
553     } else if (eltType.isF32() || eltType.isF64()) {
554       // Floating-point reductions: add/mul/min/max
555       if (kind == "add") {
556         // Optional accumulator (or zero).
557         Value acc = operands.size() > 1 ? operands[1]
558                                         : rewriter.create<LLVM::ConstantOp>(
559                                               op->getLoc(), llvmType,
560                                               rewriter.getZeroAttr(eltType));
561         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
562             op, llvmType, acc, operands[0],
563             rewriter.getBoolAttr(reassociateFPReductions));
564       } else if (kind == "mul") {
565         // Optional accumulator (or one).
566         Value acc = operands.size() > 1
567                         ? operands[1]
568                         : rewriter.create<LLVM::ConstantOp>(
569                               op->getLoc(), llvmType,
570                               rewriter.getFloatAttr(eltType, 1.0));
571         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
572             op, llvmType, acc, operands[0],
573             rewriter.getBoolAttr(reassociateFPReductions));
574       } else if (kind == "min")
575         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
576             op, llvmType, operands[0]);
577       else if (kind == "max")
578         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
579             op, llvmType, operands[0]);
580       else
581         return failure();
582       return success();
583     }
584     return failure();
585   }
586 
587 private:
588   const bool reassociateFPReductions;
589 };
590 
591 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
592 public:
593   explicit VectorShuffleOpConversion(MLIRContext *context,
594                                      LLVMTypeConverter &typeConverter)
595       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
596                              typeConverter) {}
597 
598   LogicalResult
599   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
600                   ConversionPatternRewriter &rewriter) const override {
601     auto loc = op->getLoc();
602     auto adaptor = vector::ShuffleOpAdaptor(operands);
603     auto shuffleOp = cast<vector::ShuffleOp>(op);
604     auto v1Type = shuffleOp.getV1VectorType();
605     auto v2Type = shuffleOp.getV2VectorType();
606     auto vectorType = shuffleOp.getVectorType();
607     Type llvmType = typeConverter.convertType(vectorType);
608     auto maskArrayAttr = shuffleOp.mask();
609 
610     // Bail if result type cannot be lowered.
611     if (!llvmType)
612       return failure();
613 
614     // Get rank and dimension sizes.
615     int64_t rank = vectorType.getRank();
616     assert(v1Type.getRank() == rank);
617     assert(v2Type.getRank() == rank);
618     int64_t v1Dim = v1Type.getDimSize(0);
619 
620     // For rank 1, where both operands have *exactly* the same vector type,
621     // there is direct shuffle support in LLVM. Use it!
622     if (rank == 1 && v1Type == v2Type) {
623       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
624           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
625       rewriter.replaceOp(op, shuffle);
626       return success();
627     }
628 
629     // For all other cases, insert the individual values individually.
630     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
631     int64_t insPos = 0;
632     for (auto en : llvm::enumerate(maskArrayAttr)) {
633       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
634       Value value = adaptor.v1();
635       if (extPos >= v1Dim) {
636         extPos -= v1Dim;
637         value = adaptor.v2();
638       }
639       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
640                                  rank, extPos);
641       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
642                          llvmType, rank, insPos++);
643     }
644     rewriter.replaceOp(op, insert);
645     return success();
646   }
647 };
648 
649 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
650 public:
651   explicit VectorExtractElementOpConversion(MLIRContext *context,
652                                             LLVMTypeConverter &typeConverter)
653       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
654                              context, typeConverter) {}
655 
656   LogicalResult
657   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
658                   ConversionPatternRewriter &rewriter) const override {
659     auto adaptor = vector::ExtractElementOpAdaptor(operands);
660     auto extractEltOp = cast<vector::ExtractElementOp>(op);
661     auto vectorType = extractEltOp.getVectorType();
662     auto llvmType = typeConverter.convertType(vectorType.getElementType());
663 
664     // Bail if result type cannot be lowered.
665     if (!llvmType)
666       return failure();
667 
668     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
669         op, llvmType, adaptor.vector(), adaptor.position());
670     return success();
671   }
672 };
673 
674 class VectorExtractOpConversion : public ConvertToLLVMPattern {
675 public:
676   explicit VectorExtractOpConversion(MLIRContext *context,
677                                      LLVMTypeConverter &typeConverter)
678       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
679                              typeConverter) {}
680 
681   LogicalResult
682   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
683                   ConversionPatternRewriter &rewriter) const override {
684     auto loc = op->getLoc();
685     auto adaptor = vector::ExtractOpAdaptor(operands);
686     auto extractOp = cast<vector::ExtractOp>(op);
687     auto vectorType = extractOp.getVectorType();
688     auto resultType = extractOp.getResult().getType();
689     auto llvmResultType = typeConverter.convertType(resultType);
690     auto positionArrayAttr = extractOp.position();
691 
692     // Bail if result type cannot be lowered.
693     if (!llvmResultType)
694       return failure();
695 
696     // One-shot extraction of vector from array (only requires extractvalue).
697     if (resultType.isa<VectorType>()) {
698       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
699           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
700       rewriter.replaceOp(op, extracted);
701       return success();
702     }
703 
704     // Potential extraction of 1-D vector from array.
705     auto *context = op->getContext();
706     Value extracted = adaptor.vector();
707     auto positionAttrs = positionArrayAttr.getValue();
708     if (positionAttrs.size() > 1) {
709       auto oneDVectorType = reducedVectorTypeBack(vectorType);
710       auto nMinusOnePositionAttrs =
711           ArrayAttr::get(positionAttrs.drop_back(), context);
712       extracted = rewriter.create<LLVM::ExtractValueOp>(
713           loc, typeConverter.convertType(oneDVectorType), extracted,
714           nMinusOnePositionAttrs);
715     }
716 
717     // Remaining extraction of element from 1-D LLVM vector
718     auto position = positionAttrs.back().cast<IntegerAttr>();
719     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
720     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
721     extracted =
722         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
723     rewriter.replaceOp(op, extracted);
724 
725     return success();
726   }
727 };
728 
729 /// Conversion pattern that turns a vector.fma on a 1-D vector
730 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
731 /// This does not match vectors of n >= 2 rank.
732 ///
733 /// Example:
734 /// ```
735 ///  vector.fma %a, %a, %a : vector<8xf32>
736 /// ```
737 /// is converted to:
738 /// ```
739 ///  llvm.intr.fmuladd %va, %va, %va:
740 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
741 ///    -> !llvm<"<8 x float>">
742 /// ```
743 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
744 public:
745   explicit VectorFMAOp1DConversion(MLIRContext *context,
746                                    LLVMTypeConverter &typeConverter)
747       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
748                              typeConverter) {}
749 
750   LogicalResult
751   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
752                   ConversionPatternRewriter &rewriter) const override {
753     auto adaptor = vector::FMAOpAdaptor(operands);
754     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
755     VectorType vType = fmaOp.getVectorType();
756     if (vType.getRank() != 1)
757       return failure();
758     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
759                                                  adaptor.rhs(), adaptor.acc());
760     return success();
761   }
762 };
763 
764 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
765 public:
766   explicit VectorInsertElementOpConversion(MLIRContext *context,
767                                            LLVMTypeConverter &typeConverter)
768       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
769                              context, typeConverter) {}
770 
771   LogicalResult
772   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
773                   ConversionPatternRewriter &rewriter) const override {
774     auto adaptor = vector::InsertElementOpAdaptor(operands);
775     auto insertEltOp = cast<vector::InsertElementOp>(op);
776     auto vectorType = insertEltOp.getDestVectorType();
777     auto llvmType = typeConverter.convertType(vectorType);
778 
779     // Bail if result type cannot be lowered.
780     if (!llvmType)
781       return failure();
782 
783     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
784         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
785     return success();
786   }
787 };
788 
789 class VectorInsertOpConversion : public ConvertToLLVMPattern {
790 public:
791   explicit VectorInsertOpConversion(MLIRContext *context,
792                                     LLVMTypeConverter &typeConverter)
793       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
794                              typeConverter) {}
795 
796   LogicalResult
797   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
798                   ConversionPatternRewriter &rewriter) const override {
799     auto loc = op->getLoc();
800     auto adaptor = vector::InsertOpAdaptor(operands);
801     auto insertOp = cast<vector::InsertOp>(op);
802     auto sourceType = insertOp.getSourceType();
803     auto destVectorType = insertOp.getDestVectorType();
804     auto llvmResultType = typeConverter.convertType(destVectorType);
805     auto positionArrayAttr = insertOp.position();
806 
807     // Bail if result type cannot be lowered.
808     if (!llvmResultType)
809       return failure();
810 
811     // One-shot insertion of a vector into an array (only requires insertvalue).
812     if (sourceType.isa<VectorType>()) {
813       Value inserted = rewriter.create<LLVM::InsertValueOp>(
814           loc, llvmResultType, adaptor.dest(), adaptor.source(),
815           positionArrayAttr);
816       rewriter.replaceOp(op, inserted);
817       return success();
818     }
819 
820     // Potential extraction of 1-D vector from array.
821     auto *context = op->getContext();
822     Value extracted = adaptor.dest();
823     auto positionAttrs = positionArrayAttr.getValue();
824     auto position = positionAttrs.back().cast<IntegerAttr>();
825     auto oneDVectorType = destVectorType;
826     if (positionAttrs.size() > 1) {
827       oneDVectorType = reducedVectorTypeBack(destVectorType);
828       auto nMinusOnePositionAttrs =
829           ArrayAttr::get(positionAttrs.drop_back(), context);
830       extracted = rewriter.create<LLVM::ExtractValueOp>(
831           loc, typeConverter.convertType(oneDVectorType), extracted,
832           nMinusOnePositionAttrs);
833     }
834 
835     // Insertion of an element into a 1-D LLVM vector.
836     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
837     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
838     Value inserted = rewriter.create<LLVM::InsertElementOp>(
839         loc, typeConverter.convertType(oneDVectorType), extracted,
840         adaptor.source(), constant);
841 
842     // Potential insertion of resulting 1-D vector into array.
843     if (positionAttrs.size() > 1) {
844       auto nMinusOnePositionAttrs =
845           ArrayAttr::get(positionAttrs.drop_back(), context);
846       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
847                                                       adaptor.dest(), inserted,
848                                                       nMinusOnePositionAttrs);
849     }
850 
851     rewriter.replaceOp(op, inserted);
852     return success();
853   }
854 };
855 
856 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
857 ///
858 /// Example:
859 /// ```
860 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
861 /// ```
862 /// is rewritten into:
863 /// ```
864 ///  %r = splat %f0: vector<2x4xf32>
865 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
866 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
867 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
868 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
869 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
870 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
871 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
872 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
873 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
874 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
875 ///  // %r3 holds the final value.
876 /// ```
877 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
878 public:
879   using OpRewritePattern<FMAOp>::OpRewritePattern;
880 
881   LogicalResult matchAndRewrite(FMAOp op,
882                                 PatternRewriter &rewriter) const override {
883     auto vType = op.getVectorType();
884     if (vType.getRank() < 2)
885       return failure();
886 
887     auto loc = op.getLoc();
888     auto elemType = vType.getElementType();
889     Value zero = rewriter.create<ConstantOp>(loc, elemType,
890                                              rewriter.getZeroAttr(elemType));
891     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
892     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
893       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
894       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
895       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
896       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
897       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
898     }
899     rewriter.replaceOp(op, desc);
900     return success();
901   }
902 };
903 
904 // When ranks are different, InsertStridedSlice needs to extract a properly
905 // ranked vector from the destination vector into which to insert. This pattern
906 // only takes care of this part and forwards the rest of the conversion to
907 // another pattern that converts InsertStridedSlice for operands of the same
908 // rank.
909 //
910 // RewritePattern for InsertStridedSliceOp where source and destination vectors
911 // have different ranks. In this case:
912 //   1. the proper subvector is extracted from the destination vector
913 //   2. a new InsertStridedSlice op is created to insert the source in the
914 //   destination subvector
915 //   3. the destination subvector is inserted back in the proper place
916 //   4. the op is replaced by the result of step 3.
917 // The new InsertStridedSlice from step 2. will be picked up by a
918 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
919 class VectorInsertStridedSliceOpDifferentRankRewritePattern
920     : public OpRewritePattern<InsertStridedSliceOp> {
921 public:
922   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
923 
924   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
925                                 PatternRewriter &rewriter) const override {
926     auto srcType = op.getSourceVectorType();
927     auto dstType = op.getDestVectorType();
928 
929     if (op.offsets().getValue().empty())
930       return failure();
931 
932     auto loc = op.getLoc();
933     int64_t rankDiff = dstType.getRank() - srcType.getRank();
934     assert(rankDiff >= 0);
935     if (rankDiff == 0)
936       return failure();
937 
938     int64_t rankRest = dstType.getRank() - rankDiff;
939     // Extract / insert the subvector of matching rank and InsertStridedSlice
940     // on it.
941     Value extracted =
942         rewriter.create<ExtractOp>(loc, op.dest(),
943                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
944                                                   /*dropFront=*/rankRest));
945     // A different pattern will kick in for InsertStridedSlice with matching
946     // ranks.
947     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
948         loc, op.source(), extracted,
949         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
950         getI64SubArray(op.strides(), /*dropFront=*/0));
951     rewriter.replaceOpWithNewOp<InsertOp>(
952         op, stridedSliceInnerOp.getResult(), op.dest(),
953         getI64SubArray(op.offsets(), /*dropFront=*/0,
954                        /*dropFront=*/rankRest));
955     return success();
956   }
957 };
958 
959 // RewritePattern for InsertStridedSliceOp where source and destination vectors
960 // have the same rank. In this case, we reduce
961 //   1. the proper subvector is extracted from the destination vector
962 //   2. a new InsertStridedSlice op is created to insert the source in the
963 //   destination subvector
964 //   3. the destination subvector is inserted back in the proper place
965 //   4. the op is replaced by the result of step 3.
966 // The new InsertStridedSlice from step 2. will be picked up by a
967 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
968 class VectorInsertStridedSliceOpSameRankRewritePattern
969     : public OpRewritePattern<InsertStridedSliceOp> {
970 public:
971   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
972 
973   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
974                                 PatternRewriter &rewriter) const override {
975     auto srcType = op.getSourceVectorType();
976     auto dstType = op.getDestVectorType();
977 
978     if (op.offsets().getValue().empty())
979       return failure();
980 
981     int64_t rankDiff = dstType.getRank() - srcType.getRank();
982     assert(rankDiff >= 0);
983     if (rankDiff != 0)
984       return failure();
985 
986     if (srcType == dstType) {
987       rewriter.replaceOp(op, op.source());
988       return success();
989     }
990 
991     int64_t offset =
992         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
993     int64_t size = srcType.getShape().front();
994     int64_t stride =
995         op.strides().getValue().front().cast<IntegerAttr>().getInt();
996 
997     auto loc = op.getLoc();
998     Value res = op.dest();
999     // For each slice of the source vector along the most major dimension.
1000     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1001          off += stride, ++idx) {
1002       // 1. extract the proper subvector (or element) from source
1003       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1004       if (extractedSource.getType().isa<VectorType>()) {
1005         // 2. If we have a vector, extract the proper subvector from destination
1006         // Otherwise we are at the element level and no need to recurse.
1007         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1008         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1009         // smaller rank.
1010         extractedSource = rewriter.create<InsertStridedSliceOp>(
1011             loc, extractedSource, extractedDest,
1012             getI64SubArray(op.offsets(), /* dropFront=*/1),
1013             getI64SubArray(op.strides(), /* dropFront=*/1));
1014       }
1015       // 4. Insert the extractedSource into the res vector.
1016       res = insertOne(rewriter, loc, extractedSource, res, off);
1017     }
1018 
1019     rewriter.replaceOp(op, res);
1020     return success();
1021   }
1022   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
1023   /// bounded as the rank is strictly decreasing.
1024   bool hasBoundedRewriteRecursion() const final { return true; }
1025 };
1026 
1027 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
1028 public:
1029   explicit VectorTypeCastOpConversion(MLIRContext *context,
1030                                       LLVMTypeConverter &typeConverter)
1031       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
1032                              typeConverter) {}
1033 
1034   LogicalResult
1035   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1036                   ConversionPatternRewriter &rewriter) const override {
1037     auto loc = op->getLoc();
1038     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
1039     MemRefType sourceMemRefType =
1040         castOp.getOperand().getType().cast<MemRefType>();
1041     MemRefType targetMemRefType =
1042         castOp.getResult().getType().cast<MemRefType>();
1043 
1044     // Only static shape casts supported atm.
1045     if (!sourceMemRefType.hasStaticShape() ||
1046         !targetMemRefType.hasStaticShape())
1047       return failure();
1048 
1049     auto llvmSourceDescriptorTy =
1050         operands[0].getType().dyn_cast<LLVM::LLVMType>();
1051     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
1052       return failure();
1053     MemRefDescriptor sourceMemRef(operands[0]);
1054 
1055     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
1056                                       .dyn_cast_or_null<LLVM::LLVMType>();
1057     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
1058       return failure();
1059 
1060     int64_t offset;
1061     SmallVector<int64_t, 4> strides;
1062     auto successStrides =
1063         getStridesAndOffset(sourceMemRefType, strides, offset);
1064     bool isContiguous = (strides.back() == 1);
1065     if (isContiguous) {
1066       auto sizes = sourceMemRefType.getShape();
1067       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1068         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
1069           isContiguous = false;
1070           break;
1071         }
1072       }
1073     }
1074     // Only contiguous source tensors supported atm.
1075     if (failed(successStrides) || !isContiguous)
1076       return failure();
1077 
1078     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
1079 
1080     // Create descriptor.
1081     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1082     Type llvmTargetElementTy = desc.getElementType();
1083     // Set allocated ptr.
1084     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1085     allocated =
1086         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1087     desc.setAllocatedPtr(rewriter, loc, allocated);
1088     // Set aligned ptr.
1089     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1090     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1091     desc.setAlignedPtr(rewriter, loc, ptr);
1092     // Fill offset 0.
1093     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1094     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1095     desc.setOffset(rewriter, loc, zero);
1096 
1097     // Fill size and stride descriptors in memref.
1098     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1099       int64_t index = indexedSize.index();
1100       auto sizeAttr =
1101           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1102       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1103       desc.setSize(rewriter, loc, index, size);
1104       auto strideAttr =
1105           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
1106       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1107       desc.setStride(rewriter, loc, index, stride);
1108     }
1109 
1110     rewriter.replaceOp(op, {desc});
1111     return success();
1112   }
1113 };
1114 
1115 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1116 /// sequence of:
1117 /// 1. Bitcast or addrspacecast to vector form.
1118 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1119 /// 3. Create a mask where offsetVector is compared against memref upper bound.
1120 /// 4. Rewrite op as a masked read or write.
1121 template <typename ConcreteOp>
1122 class VectorTransferConversion : public ConvertToLLVMPattern {
1123 public:
1124   explicit VectorTransferConversion(MLIRContext *context,
1125                                     LLVMTypeConverter &typeConv)
1126       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
1127                              typeConv) {}
1128 
1129   LogicalResult
1130   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1131                   ConversionPatternRewriter &rewriter) const override {
1132     auto xferOp = cast<ConcreteOp>(op);
1133     auto adaptor = getTransferOpAdapter(xferOp, operands);
1134 
1135     if (xferOp.getVectorType().getRank() > 1 ||
1136         llvm::size(xferOp.indices()) == 0)
1137       return failure();
1138     if (xferOp.permutation_map() !=
1139         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1140                                        xferOp.getVectorType().getRank(),
1141                                        op->getContext()))
1142       return failure();
1143 
1144     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
1145 
1146     Location loc = op->getLoc();
1147     Type i64Type = rewriter.getIntegerType(64);
1148     MemRefType memRefType = xferOp.getMemRefType();
1149 
1150     if (auto memrefVectorElementType =
1151             memRefType.getElementType().dyn_cast<VectorType>()) {
1152       // Memref has vector element type.
1153       if (memrefVectorElementType.getElementType() !=
1154           xferOp.getVectorType().getElementType())
1155         return failure();
1156 #ifndef NDEBUG
1157       // Check that memref vector type is a suffix of 'vectorType.
1158       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1159       unsigned resultVecRank = xferOp.getVectorType().getRank();
1160       assert(memrefVecEltRank <= resultVecRank);
1161       // TODO: Move this to isSuffix in Vector/Utils.h.
1162       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1163       auto memrefVecEltShape = memrefVectorElementType.getShape();
1164       auto resultVecShape = xferOp.getVectorType().getShape();
1165       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1166         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1167                "memref vector element shape should match suffix of vector "
1168                "result shape.");
1169 #endif // ifndef NDEBUG
1170     }
1171 
1172     // 1. Get the source/dst address as an LLVM vector pointer.
1173     //    The vector pointer would always be on address space 0, therefore
1174     //    addrspacecast shall be used when source/dst memrefs are not on
1175     //    address space 0.
1176     // TODO: support alignment when possible.
1177     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1178                                adaptor.indices(), rewriter);
1179     auto vecTy =
1180         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1181     Value vectorDataPtr;
1182     if (memRefType.getMemorySpace() == 0)
1183       vectorDataPtr =
1184           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1185     else
1186       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1187           loc, vecTy.getPointerTo(), dataPtr);
1188 
1189     if (!xferOp.isMaskedDim(0))
1190       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
1191                                               xferOp, operands, vectorDataPtr);
1192 
1193     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1194     unsigned vecWidth = vecTy.getVectorNumElements();
1195     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
1196     SmallVector<int64_t, 8> indices;
1197     indices.reserve(vecWidth);
1198     for (unsigned i = 0; i < vecWidth; ++i)
1199       indices.push_back(i);
1200     Value linearIndices = rewriter.create<ConstantOp>(
1201         loc, vectorCmpType,
1202         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
1203     linearIndices = rewriter.create<LLVM::DialectCastOp>(
1204         loc, toLLVMTy(vectorCmpType), linearIndices);
1205 
1206     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1207     // TODO: when the leaf transfer rank is k > 1 we need the last
1208     // `k` dimensions here.
1209     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1210     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
1211     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
1212     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
1213     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
1214 
1215     // 4. Let dim the memref dimension, compute the vector comparison mask:
1216     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1217     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1218     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
1219     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
1220     Value mask =
1221         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
1222     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
1223                                                 mask);
1224 
1225     // 5. Rewrite as a masked read / write.
1226     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1227                                        operands, vectorDataPtr, mask);
1228   }
1229 };
1230 
1231 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1232 public:
1233   explicit VectorPrintOpConversion(MLIRContext *context,
1234                                    LLVMTypeConverter &typeConverter)
1235       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1236                              typeConverter) {}
1237 
1238   // Proof-of-concept lowering implementation that relies on a small
1239   // runtime support library, which only needs to provide a few
1240   // printing methods (single value for all data types, opening/closing
1241   // bracket, comma, newline). The lowering fully unrolls a vector
1242   // in terms of these elementary printing operations. The advantage
1243   // of this approach is that the library can remain unaware of all
1244   // low-level implementation details of vectors while still supporting
1245   // output of any shaped and dimensioned vector. Due to full unrolling,
1246   // this approach is less suited for very large vectors though.
1247   //
1248   // TODO: rely solely on libc in future? something else?
1249   //
1250   LogicalResult
1251   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1252                   ConversionPatternRewriter &rewriter) const override {
1253     auto printOp = cast<vector::PrintOp>(op);
1254     auto adaptor = vector::PrintOpAdaptor(operands);
1255     Type printType = printOp.getPrintType();
1256 
1257     if (typeConverter.convertType(printType) == nullptr)
1258       return failure();
1259 
1260     // Make sure element type has runtime support (currently just Float/Double).
1261     VectorType vectorType = printType.dyn_cast<VectorType>();
1262     Type eltType = vectorType ? vectorType.getElementType() : printType;
1263     int64_t rank = vectorType ? vectorType.getRank() : 0;
1264     Operation *printer;
1265     if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
1266       printer = getPrintI32(op);
1267     else if (eltType.isSignlessInteger(64))
1268       printer = getPrintI64(op);
1269     else if (eltType.isF32())
1270       printer = getPrintFloat(op);
1271     else if (eltType.isF64())
1272       printer = getPrintDouble(op);
1273     else
1274       return failure();
1275 
1276     // Unroll vector into elementary print calls.
1277     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
1278     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1279     rewriter.eraseOp(op);
1280     return success();
1281   }
1282 
1283 private:
1284   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1285                  Value value, VectorType vectorType, Operation *printer,
1286                  int64_t rank) const {
1287     Location loc = op->getLoc();
1288     if (rank == 0) {
1289       if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
1290         // Convert i1 (bool) to i32 so we can use the print_i32 method.
1291         // This avoids the need for a print_i1 method with an unclear ABI.
1292         auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
1293         auto trueVal = rewriter.create<ConstantOp>(
1294             loc, i32Type, rewriter.getI32IntegerAttr(1));
1295         auto falseVal = rewriter.create<ConstantOp>(
1296             loc, i32Type, rewriter.getI32IntegerAttr(0));
1297         value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
1298       }
1299       emitCall(rewriter, loc, printer, value);
1300       return;
1301     }
1302 
1303     emitCall(rewriter, loc, getPrintOpen(op));
1304     Operation *printComma = getPrintComma(op);
1305     int64_t dim = vectorType.getDimSize(0);
1306     for (int64_t d = 0; d < dim; ++d) {
1307       auto reducedType =
1308           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1309       auto llvmType = typeConverter.convertType(
1310           rank > 1 ? reducedType : vectorType.getElementType());
1311       Value nestedVal =
1312           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1313       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1314       if (d != dim - 1)
1315         emitCall(rewriter, loc, printComma);
1316     }
1317     emitCall(rewriter, loc, getPrintClose(op));
1318   }
1319 
1320   // Helper to emit a call.
1321   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1322                        Operation *ref, ValueRange params = ValueRange()) {
1323     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1324                                   rewriter.getSymbolRefAttr(ref), params);
1325   }
1326 
1327   // Helper for printer method declaration (first hit) and lookup.
1328   static Operation *getPrint(Operation *op, StringRef name,
1329                              ArrayRef<LLVM::LLVMType> params) {
1330     auto module = op->getParentOfType<ModuleOp>();
1331     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1332     if (func)
1333       return func;
1334     OpBuilder moduleBuilder(module.getBodyRegion());
1335     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1336         op->getLoc(), name,
1337         LLVM::LLVMType::getFunctionTy(
1338             LLVM::LLVMType::getVoidTy(op->getContext()), params,
1339             /*isVarArg=*/false));
1340   }
1341 
1342   // Helpers for method names.
1343   Operation *getPrintI32(Operation *op) const {
1344     return getPrint(op, "print_i32",
1345                     LLVM::LLVMType::getInt32Ty(op->getContext()));
1346   }
1347   Operation *getPrintI64(Operation *op) const {
1348     return getPrint(op, "print_i64",
1349                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1350   }
1351   Operation *getPrintFloat(Operation *op) const {
1352     return getPrint(op, "print_f32",
1353                     LLVM::LLVMType::getFloatTy(op->getContext()));
1354   }
1355   Operation *getPrintDouble(Operation *op) const {
1356     return getPrint(op, "print_f64",
1357                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1358   }
1359   Operation *getPrintOpen(Operation *op) const {
1360     return getPrint(op, "print_open", {});
1361   }
1362   Operation *getPrintClose(Operation *op) const {
1363     return getPrint(op, "print_close", {});
1364   }
1365   Operation *getPrintComma(Operation *op) const {
1366     return getPrint(op, "print_comma", {});
1367   }
1368   Operation *getPrintNewline(Operation *op) const {
1369     return getPrint(op, "print_newline", {});
1370   }
1371 };
1372 
1373 /// Progressive lowering of ExtractStridedSliceOp to either:
1374 ///   1. express single offset extract as a direct shuffle.
1375 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1376 class VectorExtractStridedSliceOpConversion
1377     : public OpRewritePattern<ExtractStridedSliceOp> {
1378 public:
1379   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1380 
1381   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1382                                 PatternRewriter &rewriter) const override {
1383     auto dstType = op.getResult().getType().cast<VectorType>();
1384 
1385     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1386 
1387     int64_t offset =
1388         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1389     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1390     int64_t stride =
1391         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1392 
1393     auto loc = op.getLoc();
1394     auto elemType = dstType.getElementType();
1395     assert(elemType.isSignlessIntOrIndexOrFloat());
1396 
1397     // Single offset can be more efficiently shuffled.
1398     if (op.offsets().getValue().size() == 1) {
1399       SmallVector<int64_t, 4> offsets;
1400       offsets.reserve(size);
1401       for (int64_t off = offset, e = offset + size * stride; off < e;
1402            off += stride)
1403         offsets.push_back(off);
1404       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1405                                              op.vector(),
1406                                              rewriter.getI64ArrayAttr(offsets));
1407       return success();
1408     }
1409 
1410     // Extract/insert on a lower ranked extract strided slice op.
1411     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1412                                              rewriter.getZeroAttr(elemType));
1413     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1414     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1415          off += stride, ++idx) {
1416       Value one = extractOne(rewriter, loc, op.vector(), off);
1417       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1418           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1419           getI64SubArray(op.sizes(), /* dropFront=*/1),
1420           getI64SubArray(op.strides(), /* dropFront=*/1));
1421       res = insertOne(rewriter, loc, extracted, res, idx);
1422     }
1423     rewriter.replaceOp(op, res);
1424     return success();
1425   }
1426   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1427   /// bounded as the rank is strictly decreasing.
1428   bool hasBoundedRewriteRecursion() const final { return true; }
1429 };
1430 
1431 } // namespace
1432 
1433 /// Populate the given list with patterns that convert from Vector to LLVM.
1434 void mlir::populateVectorToLLVMConversionPatterns(
1435     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1436     bool reassociateFPReductions) {
1437   MLIRContext *ctx = converter.getDialect()->getContext();
1438   // clang-format off
1439   patterns.insert<VectorFMAOpNDRewritePattern,
1440                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1441                   VectorInsertStridedSliceOpSameRankRewritePattern,
1442                   VectorExtractStridedSliceOpConversion>(ctx);
1443   patterns.insert<VectorReductionOpConversion>(
1444       ctx, converter, reassociateFPReductions);
1445   patterns
1446       .insert<VectorShuffleOpConversion,
1447               VectorExtractElementOpConversion,
1448               VectorExtractOpConversion,
1449               VectorFMAOp1DConversion,
1450               VectorInsertElementOpConversion,
1451               VectorInsertOpConversion,
1452               VectorPrintOpConversion,
1453               VectorTransferConversion<TransferReadOp>,
1454               VectorTransferConversion<TransferWriteOp>,
1455               VectorTypeCastOpConversion,
1456               VectorMaskedLoadOpConversion,
1457               VectorMaskedStoreOpConversion,
1458               VectorGatherOpConversion,
1459               VectorScatterOpConversion,
1460               VectorExpandLoadOpConversion,
1461               VectorCompressStoreOpConversion>(ctx, converter);
1462   // clang-format on
1463 }
1464 
1465 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1466     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1467   MLIRContext *ctx = converter.getDialect()->getContext();
1468   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1469   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1470 }
1471 
1472 namespace {
1473 struct LowerVectorToLLVMPass
1474     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1475   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1476     this->reassociateFPReductions = options.reassociateFPReductions;
1477   }
1478   void runOnOperation() override;
1479 };
1480 } // namespace
1481 
1482 void LowerVectorToLLVMPass::runOnOperation() {
1483   // Perform progressive lowering of operations on slices and
1484   // all contraction operations. Also applies folding and DCE.
1485   {
1486     OwningRewritePatternList patterns;
1487     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1488     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1489     populateVectorContractLoweringPatterns(patterns, &getContext());
1490     applyPatternsAndFoldGreedily(getOperation(), patterns);
1491   }
1492 
1493   // Convert to the LLVM IR dialect.
1494   LLVMTypeConverter converter(&getContext());
1495   OwningRewritePatternList patterns;
1496   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1497   populateVectorToLLVMConversionPatterns(converter, patterns,
1498                                          reassociateFPReductions);
1499   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1500   populateStdToLLVMConversionPatterns(converter, patterns);
1501 
1502   LLVMConversionTarget target(getContext());
1503   if (failed(applyPartialConversion(getOperation(), target, patterns))) {
1504     signalPassFailure();
1505   }
1506 }
1507 
1508 std::unique_ptr<OperationPass<ModuleOp>>
1509 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1510   return std::make_unique<LowerVectorToLLVMPass>(options);
1511 }
1512