xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 2d2c73c5cfdedf55f6e1a8ae7f59ee54e2d89c99)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 template <typename T>
38 static LLVM::LLVMType getPtrToElementType(T containerType,
39                                           LLVMTypeConverter &typeConverter) {
40   return typeConverter.convertType(containerType.getElementType())
41       .template cast<LLVM::LLVMType>()
42       .getPointerTo();
43 }
44 
45 // Helper to reduce vector type by one rank at front.
46 static VectorType reducedVectorTypeFront(VectorType tp) {
47   assert((tp.getRank() > 1) && "unlowerable vector type");
48   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
49 }
50 
51 // Helper to reduce vector type by *all* but one rank at back.
52 static VectorType reducedVectorTypeBack(VectorType tp) {
53   assert((tp.getRank() > 1) && "unlowerable vector type");
54   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
55 }
56 
57 // Helper that picks the proper sequence for inserting.
58 static Value insertOne(ConversionPatternRewriter &rewriter,
59                        LLVMTypeConverter &typeConverter, Location loc,
60                        Value val1, Value val2, Type llvmType, int64_t rank,
61                        int64_t pos) {
62   if (rank == 1) {
63     auto idxType = rewriter.getIndexType();
64     auto constant = rewriter.create<LLVM::ConstantOp>(
65         loc, typeConverter.convertType(idxType),
66         rewriter.getIntegerAttr(idxType, pos));
67     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
68                                                   constant);
69   }
70   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
71                                               rewriter.getI64ArrayAttr(pos));
72 }
73 
74 // Helper that picks the proper sequence for inserting.
75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
76                        Value into, int64_t offset) {
77   auto vectorType = into.getType().cast<VectorType>();
78   if (vectorType.getRank() > 1)
79     return rewriter.create<InsertOp>(loc, from, into, offset);
80   return rewriter.create<vector::InsertElementOp>(
81       loc, vectorType, from, into,
82       rewriter.create<ConstantIndexOp>(loc, offset));
83 }
84 
85 // Helper that picks the proper sequence for extracting.
86 static Value extractOne(ConversionPatternRewriter &rewriter,
87                         LLVMTypeConverter &typeConverter, Location loc,
88                         Value val, Type llvmType, int64_t rank, int64_t pos) {
89   if (rank == 1) {
90     auto idxType = rewriter.getIndexType();
91     auto constant = rewriter.create<LLVM::ConstantOp>(
92         loc, typeConverter.convertType(idxType),
93         rewriter.getIntegerAttr(idxType, pos));
94     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
95                                                    constant);
96   }
97   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
98                                                rewriter.getI64ArrayAttr(pos));
99 }
100 
101 // Helper that picks the proper sequence for extracting.
102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
103                         int64_t offset) {
104   auto vectorType = vector.getType().cast<VectorType>();
105   if (vectorType.getRank() > 1)
106     return rewriter.create<ExtractOp>(loc, vector, offset);
107   return rewriter.create<vector::ExtractElementOp>(
108       loc, vectorType.getElementType(), vector,
109       rewriter.create<ConstantIndexOp>(loc, offset));
110 }
111 
112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing.
114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
115                                               unsigned dropFront = 0,
116                                               unsigned dropBack = 0) {
117   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
118   auto range = arrayAttr.getAsRange<IntegerAttr>();
119   SmallVector<int64_t, 4> res;
120   res.reserve(arrayAttr.size() - dropFront - dropBack);
121   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
122        it != eit; ++it)
123     res.push_back((*it).getValue().getSExtValue());
124   return res;
125 }
126 
127 template <typename TransferOp>
128 LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter,
129                                          TransferOp xferOp, unsigned &align) {
130   Type elementTy =
131       typeConverter.convertType(xferOp.getMemRefType().getElementType());
132   if (!elementTy)
133     return failure();
134 
135   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
136   align = dataLayout.getPrefTypeAlignment(
137       elementTy.cast<LLVM::LLVMType>().getUnderlyingType());
138   return success();
139 }
140 
141 static LogicalResult
142 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
143                                  LLVMTypeConverter &typeConverter, Location loc,
144                                  TransferReadOp xferOp,
145                                  ArrayRef<Value> operands, Value dataPtr) {
146   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
147   return success();
148 }
149 
150 static LogicalResult
151 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
152                             LLVMTypeConverter &typeConverter, Location loc,
153                             TransferReadOp xferOp, ArrayRef<Value> operands,
154                             Value dataPtr, Value mask) {
155   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
156   VectorType fillType = xferOp.getVectorType();
157   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
158   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
159 
160   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
161   if (!vecTy)
162     return failure();
163 
164   unsigned align;
165   if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
166     return failure();
167 
168   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
169       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
170       rewriter.getI32IntegerAttr(align));
171   return success();
172 }
173 
174 static LogicalResult
175 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
176                                  LLVMTypeConverter &typeConverter, Location loc,
177                                  TransferWriteOp xferOp,
178                                  ArrayRef<Value> operands, Value dataPtr) {
179   auto adaptor = TransferWriteOpAdaptor(operands);
180   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
181   return success();
182 }
183 
184 static LogicalResult
185 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
186                             LLVMTypeConverter &typeConverter, Location loc,
187                             TransferWriteOp xferOp, ArrayRef<Value> operands,
188                             Value dataPtr, Value mask) {
189   unsigned align;
190   if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
191     return failure();
192 
193   auto adaptor = TransferWriteOpAdaptor(operands);
194   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
195       xferOp, adaptor.vector(), dataPtr, mask,
196       rewriter.getI32IntegerAttr(align));
197   return success();
198 }
199 
200 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
201                                                   ArrayRef<Value> operands) {
202   return TransferReadOpAdaptor(operands);
203 }
204 
205 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
206                                                    ArrayRef<Value> operands) {
207   return TransferWriteOpAdaptor(operands);
208 }
209 
210 namespace {
211 
212 /// Conversion pattern for a vector.matrix_multiply.
213 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
214 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
215 public:
216   explicit VectorMatmulOpConversion(MLIRContext *context,
217                                     LLVMTypeConverter &typeConverter)
218       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
219                              typeConverter) {}
220 
221   LogicalResult
222   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
223                   ConversionPatternRewriter &rewriter) const override {
224     auto matmulOp = cast<vector::MatmulOp>(op);
225     auto adaptor = vector::MatmulOpAdaptor(operands);
226     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
227         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
228         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
229         matmulOp.rhs_columns());
230     return success();
231   }
232 };
233 
234 /// Conversion pattern for a vector.flat_transpose.
235 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
236 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
237 public:
238   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
239                                            LLVMTypeConverter &typeConverter)
240       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
241                              context, typeConverter) {}
242 
243   LogicalResult
244   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
245                   ConversionPatternRewriter &rewriter) const override {
246     auto transOp = cast<vector::FlatTransposeOp>(op);
247     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
248     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
249         transOp, typeConverter.convertType(transOp.res().getType()),
250         adaptor.matrix(), transOp.rows(), transOp.columns());
251     return success();
252   }
253 };
254 
255 class VectorReductionOpConversion : public ConvertToLLVMPattern {
256 public:
257   explicit VectorReductionOpConversion(MLIRContext *context,
258                                        LLVMTypeConverter &typeConverter)
259       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
260                              typeConverter) {}
261 
262   LogicalResult
263   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
264                   ConversionPatternRewriter &rewriter) const override {
265     auto reductionOp = cast<vector::ReductionOp>(op);
266     auto kind = reductionOp.kind();
267     Type eltType = reductionOp.dest().getType();
268     Type llvmType = typeConverter.convertType(eltType);
269     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
270       // Integer reductions: add/mul/min/max/and/or/xor.
271       if (kind == "add")
272         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
273             op, llvmType, operands[0]);
274       else if (kind == "mul")
275         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
276             op, llvmType, operands[0]);
277       else if (kind == "min")
278         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
279             op, llvmType, operands[0]);
280       else if (kind == "max")
281         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
282             op, llvmType, operands[0]);
283       else if (kind == "and")
284         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
285             op, llvmType, operands[0]);
286       else if (kind == "or")
287         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
288             op, llvmType, operands[0]);
289       else if (kind == "xor")
290         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
291             op, llvmType, operands[0]);
292       else
293         return failure();
294       return success();
295 
296     } else if (eltType.isF32() || eltType.isF64()) {
297       // Floating-point reductions: add/mul/min/max
298       if (kind == "add") {
299         // Optional accumulator (or zero).
300         Value acc = operands.size() > 1 ? operands[1]
301                                         : rewriter.create<LLVM::ConstantOp>(
302                                               op->getLoc(), llvmType,
303                                               rewriter.getZeroAttr(eltType));
304         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
305             op, llvmType, acc, operands[0]);
306       } else if (kind == "mul") {
307         // Optional accumulator (or one).
308         Value acc = operands.size() > 1
309                         ? operands[1]
310                         : rewriter.create<LLVM::ConstantOp>(
311                               op->getLoc(), llvmType,
312                               rewriter.getFloatAttr(eltType, 1.0));
313         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
314             op, llvmType, acc, operands[0]);
315       } else if (kind == "min")
316         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
317             op, llvmType, operands[0]);
318       else if (kind == "max")
319         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
320             op, llvmType, operands[0]);
321       else
322         return failure();
323       return success();
324     }
325     return failure();
326   }
327 };
328 
329 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
330 public:
331   explicit VectorShuffleOpConversion(MLIRContext *context,
332                                      LLVMTypeConverter &typeConverter)
333       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
334                              typeConverter) {}
335 
336   LogicalResult
337   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
338                   ConversionPatternRewriter &rewriter) const override {
339     auto loc = op->getLoc();
340     auto adaptor = vector::ShuffleOpAdaptor(operands);
341     auto shuffleOp = cast<vector::ShuffleOp>(op);
342     auto v1Type = shuffleOp.getV1VectorType();
343     auto v2Type = shuffleOp.getV2VectorType();
344     auto vectorType = shuffleOp.getVectorType();
345     Type llvmType = typeConverter.convertType(vectorType);
346     auto maskArrayAttr = shuffleOp.mask();
347 
348     // Bail if result type cannot be lowered.
349     if (!llvmType)
350       return failure();
351 
352     // Get rank and dimension sizes.
353     int64_t rank = vectorType.getRank();
354     assert(v1Type.getRank() == rank);
355     assert(v2Type.getRank() == rank);
356     int64_t v1Dim = v1Type.getDimSize(0);
357 
358     // For rank 1, where both operands have *exactly* the same vector type,
359     // there is direct shuffle support in LLVM. Use it!
360     if (rank == 1 && v1Type == v2Type) {
361       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
362           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
363       rewriter.replaceOp(op, shuffle);
364       return success();
365     }
366 
367     // For all other cases, insert the individual values individually.
368     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
369     int64_t insPos = 0;
370     for (auto en : llvm::enumerate(maskArrayAttr)) {
371       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
372       Value value = adaptor.v1();
373       if (extPos >= v1Dim) {
374         extPos -= v1Dim;
375         value = adaptor.v2();
376       }
377       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
378                                  rank, extPos);
379       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
380                          llvmType, rank, insPos++);
381     }
382     rewriter.replaceOp(op, insert);
383     return success();
384   }
385 };
386 
387 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
388 public:
389   explicit VectorExtractElementOpConversion(MLIRContext *context,
390                                             LLVMTypeConverter &typeConverter)
391       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
392                              context, typeConverter) {}
393 
394   LogicalResult
395   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
396                   ConversionPatternRewriter &rewriter) const override {
397     auto adaptor = vector::ExtractElementOpAdaptor(operands);
398     auto extractEltOp = cast<vector::ExtractElementOp>(op);
399     auto vectorType = extractEltOp.getVectorType();
400     auto llvmType = typeConverter.convertType(vectorType.getElementType());
401 
402     // Bail if result type cannot be lowered.
403     if (!llvmType)
404       return failure();
405 
406     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
407         op, llvmType, adaptor.vector(), adaptor.position());
408     return success();
409   }
410 };
411 
412 class VectorExtractOpConversion : public ConvertToLLVMPattern {
413 public:
414   explicit VectorExtractOpConversion(MLIRContext *context,
415                                      LLVMTypeConverter &typeConverter)
416       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
417                              typeConverter) {}
418 
419   LogicalResult
420   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
421                   ConversionPatternRewriter &rewriter) const override {
422     auto loc = op->getLoc();
423     auto adaptor = vector::ExtractOpAdaptor(operands);
424     auto extractOp = cast<vector::ExtractOp>(op);
425     auto vectorType = extractOp.getVectorType();
426     auto resultType = extractOp.getResult().getType();
427     auto llvmResultType = typeConverter.convertType(resultType);
428     auto positionArrayAttr = extractOp.position();
429 
430     // Bail if result type cannot be lowered.
431     if (!llvmResultType)
432       return failure();
433 
434     // One-shot extraction of vector from array (only requires extractvalue).
435     if (resultType.isa<VectorType>()) {
436       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
437           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
438       rewriter.replaceOp(op, extracted);
439       return success();
440     }
441 
442     // Potential extraction of 1-D vector from array.
443     auto *context = op->getContext();
444     Value extracted = adaptor.vector();
445     auto positionAttrs = positionArrayAttr.getValue();
446     if (positionAttrs.size() > 1) {
447       auto oneDVectorType = reducedVectorTypeBack(vectorType);
448       auto nMinusOnePositionAttrs =
449           ArrayAttr::get(positionAttrs.drop_back(), context);
450       extracted = rewriter.create<LLVM::ExtractValueOp>(
451           loc, typeConverter.convertType(oneDVectorType), extracted,
452           nMinusOnePositionAttrs);
453     }
454 
455     // Remaining extraction of element from 1-D LLVM vector
456     auto position = positionAttrs.back().cast<IntegerAttr>();
457     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
458     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
459     extracted =
460         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
461     rewriter.replaceOp(op, extracted);
462 
463     return success();
464   }
465 };
466 
467 /// Conversion pattern that turns a vector.fma on a 1-D vector
468 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
469 /// This does not match vectors of n >= 2 rank.
470 ///
471 /// Example:
472 /// ```
473 ///  vector.fma %a, %a, %a : vector<8xf32>
474 /// ```
475 /// is converted to:
476 /// ```
477 ///  llvm.intr.fma %va, %va, %va:
478 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
479 ///    -> !llvm<"<8 x float>">
480 /// ```
481 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
482 public:
483   explicit VectorFMAOp1DConversion(MLIRContext *context,
484                                    LLVMTypeConverter &typeConverter)
485       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
486                              typeConverter) {}
487 
488   LogicalResult
489   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
490                   ConversionPatternRewriter &rewriter) const override {
491     auto adaptor = vector::FMAOpAdaptor(operands);
492     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
493     VectorType vType = fmaOp.getVectorType();
494     if (vType.getRank() != 1)
495       return failure();
496     rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
497                                              adaptor.acc());
498     return success();
499   }
500 };
501 
502 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
503 public:
504   explicit VectorInsertElementOpConversion(MLIRContext *context,
505                                            LLVMTypeConverter &typeConverter)
506       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
507                              context, typeConverter) {}
508 
509   LogicalResult
510   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
511                   ConversionPatternRewriter &rewriter) const override {
512     auto adaptor = vector::InsertElementOpAdaptor(operands);
513     auto insertEltOp = cast<vector::InsertElementOp>(op);
514     auto vectorType = insertEltOp.getDestVectorType();
515     auto llvmType = typeConverter.convertType(vectorType);
516 
517     // Bail if result type cannot be lowered.
518     if (!llvmType)
519       return failure();
520 
521     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
522         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
523     return success();
524   }
525 };
526 
527 class VectorInsertOpConversion : public ConvertToLLVMPattern {
528 public:
529   explicit VectorInsertOpConversion(MLIRContext *context,
530                                     LLVMTypeConverter &typeConverter)
531       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
532                              typeConverter) {}
533 
534   LogicalResult
535   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
536                   ConversionPatternRewriter &rewriter) const override {
537     auto loc = op->getLoc();
538     auto adaptor = vector::InsertOpAdaptor(operands);
539     auto insertOp = cast<vector::InsertOp>(op);
540     auto sourceType = insertOp.getSourceType();
541     auto destVectorType = insertOp.getDestVectorType();
542     auto llvmResultType = typeConverter.convertType(destVectorType);
543     auto positionArrayAttr = insertOp.position();
544 
545     // Bail if result type cannot be lowered.
546     if (!llvmResultType)
547       return failure();
548 
549     // One-shot insertion of a vector into an array (only requires insertvalue).
550     if (sourceType.isa<VectorType>()) {
551       Value inserted = rewriter.create<LLVM::InsertValueOp>(
552           loc, llvmResultType, adaptor.dest(), adaptor.source(),
553           positionArrayAttr);
554       rewriter.replaceOp(op, inserted);
555       return success();
556     }
557 
558     // Potential extraction of 1-D vector from array.
559     auto *context = op->getContext();
560     Value extracted = adaptor.dest();
561     auto positionAttrs = positionArrayAttr.getValue();
562     auto position = positionAttrs.back().cast<IntegerAttr>();
563     auto oneDVectorType = destVectorType;
564     if (positionAttrs.size() > 1) {
565       oneDVectorType = reducedVectorTypeBack(destVectorType);
566       auto nMinusOnePositionAttrs =
567           ArrayAttr::get(positionAttrs.drop_back(), context);
568       extracted = rewriter.create<LLVM::ExtractValueOp>(
569           loc, typeConverter.convertType(oneDVectorType), extracted,
570           nMinusOnePositionAttrs);
571     }
572 
573     // Insertion of an element into a 1-D LLVM vector.
574     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
575     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
576     Value inserted = rewriter.create<LLVM::InsertElementOp>(
577         loc, typeConverter.convertType(oneDVectorType), extracted,
578         adaptor.source(), constant);
579 
580     // Potential insertion of resulting 1-D vector into array.
581     if (positionAttrs.size() > 1) {
582       auto nMinusOnePositionAttrs =
583           ArrayAttr::get(positionAttrs.drop_back(), context);
584       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
585                                                       adaptor.dest(), inserted,
586                                                       nMinusOnePositionAttrs);
587     }
588 
589     rewriter.replaceOp(op, inserted);
590     return success();
591   }
592 };
593 
594 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
595 ///
596 /// Example:
597 /// ```
598 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
599 /// ```
600 /// is rewritten into:
601 /// ```
602 ///  %r = splat %f0: vector<2x4xf32>
603 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
604 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
605 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
606 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
607 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
608 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
609 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
610 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
611 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
612 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
613 ///  // %r3 holds the final value.
614 /// ```
615 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
616 public:
617   using OpRewritePattern<FMAOp>::OpRewritePattern;
618 
619   LogicalResult matchAndRewrite(FMAOp op,
620                                 PatternRewriter &rewriter) const override {
621     auto vType = op.getVectorType();
622     if (vType.getRank() < 2)
623       return failure();
624 
625     auto loc = op.getLoc();
626     auto elemType = vType.getElementType();
627     Value zero = rewriter.create<ConstantOp>(loc, elemType,
628                                              rewriter.getZeroAttr(elemType));
629     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
630     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
631       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
632       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
633       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
634       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
635       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
636     }
637     rewriter.replaceOp(op, desc);
638     return success();
639   }
640 };
641 
642 // When ranks are different, InsertStridedSlice needs to extract a properly
643 // ranked vector from the destination vector into which to insert. This pattern
644 // only takes care of this part and forwards the rest of the conversion to
645 // another pattern that converts InsertStridedSlice for operands of the same
646 // rank.
647 //
648 // RewritePattern for InsertStridedSliceOp where source and destination vectors
649 // have different ranks. In this case:
650 //   1. the proper subvector is extracted from the destination vector
651 //   2. a new InsertStridedSlice op is created to insert the source in the
652 //   destination subvector
653 //   3. the destination subvector is inserted back in the proper place
654 //   4. the op is replaced by the result of step 3.
655 // The new InsertStridedSlice from step 2. will be picked up by a
656 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
657 class VectorInsertStridedSliceOpDifferentRankRewritePattern
658     : public OpRewritePattern<InsertStridedSliceOp> {
659 public:
660   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
661 
662   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
663                                 PatternRewriter &rewriter) const override {
664     auto srcType = op.getSourceVectorType();
665     auto dstType = op.getDestVectorType();
666 
667     if (op.offsets().getValue().empty())
668       return failure();
669 
670     auto loc = op.getLoc();
671     int64_t rankDiff = dstType.getRank() - srcType.getRank();
672     assert(rankDiff >= 0);
673     if (rankDiff == 0)
674       return failure();
675 
676     int64_t rankRest = dstType.getRank() - rankDiff;
677     // Extract / insert the subvector of matching rank and InsertStridedSlice
678     // on it.
679     Value extracted =
680         rewriter.create<ExtractOp>(loc, op.dest(),
681                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
682                                                   /*dropFront=*/rankRest));
683     // A different pattern will kick in for InsertStridedSlice with matching
684     // ranks.
685     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
686         loc, op.source(), extracted,
687         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
688         getI64SubArray(op.strides(), /*dropFront=*/0));
689     rewriter.replaceOpWithNewOp<InsertOp>(
690         op, stridedSliceInnerOp.getResult(), op.dest(),
691         getI64SubArray(op.offsets(), /*dropFront=*/0,
692                        /*dropFront=*/rankRest));
693     return success();
694   }
695 };
696 
697 // RewritePattern for InsertStridedSliceOp where source and destination vectors
698 // have the same rank. In this case, we reduce
699 //   1. the proper subvector is extracted from the destination vector
700 //   2. a new InsertStridedSlice op is created to insert the source in the
701 //   destination subvector
702 //   3. the destination subvector is inserted back in the proper place
703 //   4. the op is replaced by the result of step 3.
704 // The new InsertStridedSlice from step 2. will be picked up by a
705 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
706 class VectorInsertStridedSliceOpSameRankRewritePattern
707     : public OpRewritePattern<InsertStridedSliceOp> {
708 public:
709   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
710 
711   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
712                                 PatternRewriter &rewriter) const override {
713     auto srcType = op.getSourceVectorType();
714     auto dstType = op.getDestVectorType();
715 
716     if (op.offsets().getValue().empty())
717       return failure();
718 
719     int64_t rankDiff = dstType.getRank() - srcType.getRank();
720     assert(rankDiff >= 0);
721     if (rankDiff != 0)
722       return failure();
723 
724     if (srcType == dstType) {
725       rewriter.replaceOp(op, op.source());
726       return success();
727     }
728 
729     int64_t offset =
730         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
731     int64_t size = srcType.getShape().front();
732     int64_t stride =
733         op.strides().getValue().front().cast<IntegerAttr>().getInt();
734 
735     auto loc = op.getLoc();
736     Value res = op.dest();
737     // For each slice of the source vector along the most major dimension.
738     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
739          off += stride, ++idx) {
740       // 1. extract the proper subvector (or element) from source
741       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
742       if (extractedSource.getType().isa<VectorType>()) {
743         // 2. If we have a vector, extract the proper subvector from destination
744         // Otherwise we are at the element level and no need to recurse.
745         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
746         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
747         // smaller rank.
748         extractedSource = rewriter.create<InsertStridedSliceOp>(
749             loc, extractedSource, extractedDest,
750             getI64SubArray(op.offsets(), /* dropFront=*/1),
751             getI64SubArray(op.strides(), /* dropFront=*/1));
752       }
753       // 4. Insert the extractedSource into the res vector.
754       res = insertOne(rewriter, loc, extractedSource, res, off);
755     }
756 
757     rewriter.replaceOp(op, res);
758     return success();
759   }
760   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
761   /// bounded as the rank is strictly decreasing.
762   bool hasBoundedRewriteRecursion() const final { return true; }
763 };
764 
765 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
766 public:
767   explicit VectorTypeCastOpConversion(MLIRContext *context,
768                                       LLVMTypeConverter &typeConverter)
769       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
770                              typeConverter) {}
771 
772   LogicalResult
773   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
774                   ConversionPatternRewriter &rewriter) const override {
775     auto loc = op->getLoc();
776     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
777     MemRefType sourceMemRefType =
778         castOp.getOperand().getType().cast<MemRefType>();
779     MemRefType targetMemRefType =
780         castOp.getResult().getType().cast<MemRefType>();
781 
782     // Only static shape casts supported atm.
783     if (!sourceMemRefType.hasStaticShape() ||
784         !targetMemRefType.hasStaticShape())
785       return failure();
786 
787     auto llvmSourceDescriptorTy =
788         operands[0].getType().dyn_cast<LLVM::LLVMType>();
789     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
790       return failure();
791     MemRefDescriptor sourceMemRef(operands[0]);
792 
793     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
794                                       .dyn_cast_or_null<LLVM::LLVMType>();
795     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
796       return failure();
797 
798     int64_t offset;
799     SmallVector<int64_t, 4> strides;
800     auto successStrides =
801         getStridesAndOffset(sourceMemRefType, strides, offset);
802     bool isContiguous = (strides.back() == 1);
803     if (isContiguous) {
804       auto sizes = sourceMemRefType.getShape();
805       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
806         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
807           isContiguous = false;
808           break;
809         }
810       }
811     }
812     // Only contiguous source tensors supported atm.
813     if (failed(successStrides) || !isContiguous)
814       return failure();
815 
816     auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
817 
818     // Create descriptor.
819     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
820     Type llvmTargetElementTy = desc.getElementType();
821     // Set allocated ptr.
822     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
823     allocated =
824         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
825     desc.setAllocatedPtr(rewriter, loc, allocated);
826     // Set aligned ptr.
827     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
828     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
829     desc.setAlignedPtr(rewriter, loc, ptr);
830     // Fill offset 0.
831     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
832     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
833     desc.setOffset(rewriter, loc, zero);
834 
835     // Fill size and stride descriptors in memref.
836     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
837       int64_t index = indexedSize.index();
838       auto sizeAttr =
839           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
840       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
841       desc.setSize(rewriter, loc, index, size);
842       auto strideAttr =
843           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
844       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
845       desc.setStride(rewriter, loc, index, stride);
846     }
847 
848     rewriter.replaceOp(op, {desc});
849     return success();
850   }
851 };
852 
853 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
854 /// sequence of:
855 /// 1. Bitcast or addrspacecast to vector form.
856 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
857 /// 3. Create a mask where offsetVector is compared against memref upper bound.
858 /// 4. Rewrite op as a masked read or write.
859 template <typename ConcreteOp>
860 class VectorTransferConversion : public ConvertToLLVMPattern {
861 public:
862   explicit VectorTransferConversion(MLIRContext *context,
863                                     LLVMTypeConverter &typeConv)
864       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
865                              typeConv) {}
866 
867   LogicalResult
868   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
869                   ConversionPatternRewriter &rewriter) const override {
870     auto xferOp = cast<ConcreteOp>(op);
871     auto adaptor = getTransferOpAdapter(xferOp, operands);
872 
873     if (xferOp.getVectorType().getRank() > 1 ||
874         llvm::size(xferOp.indices()) == 0)
875       return failure();
876     if (xferOp.permutation_map() !=
877         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
878                                        xferOp.getVectorType().getRank(),
879                                        op->getContext()))
880       return failure();
881 
882     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
883 
884     Location loc = op->getLoc();
885     Type i64Type = rewriter.getIntegerType(64);
886     MemRefType memRefType = xferOp.getMemRefType();
887 
888     // 1. Get the source/dst address as an LLVM vector pointer.
889     //    The vector pointer would always be on address space 0, therefore
890     //    addrspacecast shall be used when source/dst memrefs are not on
891     //    address space 0.
892     // TODO: support alignment when possible.
893     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
894                                adaptor.indices(), rewriter, getModule());
895     auto vecTy =
896         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
897     Value vectorDataPtr;
898     if (memRefType.getMemorySpace() == 0)
899       vectorDataPtr =
900           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
901     else
902       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
903           loc, vecTy.getPointerTo(), dataPtr);
904 
905     if (!xferOp.isMaskedDim(0))
906       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
907                                               xferOp, operands, vectorDataPtr);
908 
909     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
910     unsigned vecWidth = vecTy.getVectorNumElements();
911     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
912     SmallVector<int64_t, 8> indices;
913     indices.reserve(vecWidth);
914     for (unsigned i = 0; i < vecWidth; ++i)
915       indices.push_back(i);
916     Value linearIndices = rewriter.create<ConstantOp>(
917         loc, vectorCmpType,
918         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
919     linearIndices = rewriter.create<LLVM::DialectCastOp>(
920         loc, toLLVMTy(vectorCmpType), linearIndices);
921 
922     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
923     // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last
924     // `k` dimensions here.
925     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
926     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
927     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
928     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
929     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
930 
931     // 4. Let dim the memref dimension, compute the vector comparison mask:
932     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
933     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
934     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
935     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
936     Value mask =
937         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
938     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
939                                                 mask);
940 
941     // 5. Rewrite as a masked read / write.
942     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
943                                        operands, vectorDataPtr, mask);
944   }
945 };
946 
947 class VectorPrintOpConversion : public ConvertToLLVMPattern {
948 public:
949   explicit VectorPrintOpConversion(MLIRContext *context,
950                                    LLVMTypeConverter &typeConverter)
951       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
952                              typeConverter) {}
953 
954   // Proof-of-concept lowering implementation that relies on a small
955   // runtime support library, which only needs to provide a few
956   // printing methods (single value for all data types, opening/closing
957   // bracket, comma, newline). The lowering fully unrolls a vector
958   // in terms of these elementary printing operations. The advantage
959   // of this approach is that the library can remain unaware of all
960   // low-level implementation details of vectors while still supporting
961   // output of any shaped and dimensioned vector. Due to full unrolling,
962   // this approach is less suited for very large vectors though.
963   //
964   // TODO(ajcbik): rely solely on libc in future? something else?
965   //
966   LogicalResult
967   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
968                   ConversionPatternRewriter &rewriter) const override {
969     auto printOp = cast<vector::PrintOp>(op);
970     auto adaptor = vector::PrintOpAdaptor(operands);
971     Type printType = printOp.getPrintType();
972 
973     if (typeConverter.convertType(printType) == nullptr)
974       return failure();
975 
976     // Make sure element type has runtime support (currently just Float/Double).
977     VectorType vectorType = printType.dyn_cast<VectorType>();
978     Type eltType = vectorType ? vectorType.getElementType() : printType;
979     int64_t rank = vectorType ? vectorType.getRank() : 0;
980     Operation *printer;
981     if (eltType.isSignlessInteger(1))
982       printer = getPrintI1(op);
983     else if (eltType.isSignlessInteger(32))
984       printer = getPrintI32(op);
985     else if (eltType.isSignlessInteger(64))
986       printer = getPrintI64(op);
987     else if (eltType.isF32())
988       printer = getPrintFloat(op);
989     else if (eltType.isF64())
990       printer = getPrintDouble(op);
991     else
992       return failure();
993 
994     // Unroll vector into elementary print calls.
995     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
996     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
997     rewriter.eraseOp(op);
998     return success();
999   }
1000 
1001 private:
1002   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1003                  Value value, VectorType vectorType, Operation *printer,
1004                  int64_t rank) const {
1005     Location loc = op->getLoc();
1006     if (rank == 0) {
1007       emitCall(rewriter, loc, printer, value);
1008       return;
1009     }
1010 
1011     emitCall(rewriter, loc, getPrintOpen(op));
1012     Operation *printComma = getPrintComma(op);
1013     int64_t dim = vectorType.getDimSize(0);
1014     for (int64_t d = 0; d < dim; ++d) {
1015       auto reducedType =
1016           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1017       auto llvmType = typeConverter.convertType(
1018           rank > 1 ? reducedType : vectorType.getElementType());
1019       Value nestedVal =
1020           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1021       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1022       if (d != dim - 1)
1023         emitCall(rewriter, loc, printComma);
1024     }
1025     emitCall(rewriter, loc, getPrintClose(op));
1026   }
1027 
1028   // Helper to emit a call.
1029   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1030                        Operation *ref, ValueRange params = ValueRange()) {
1031     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1032                                   rewriter.getSymbolRefAttr(ref), params);
1033   }
1034 
1035   // Helper for printer method declaration (first hit) and lookup.
1036   static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
1037                              StringRef name, ArrayRef<LLVM::LLVMType> params) {
1038     auto module = op->getParentOfType<ModuleOp>();
1039     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1040     if (func)
1041       return func;
1042     OpBuilder moduleBuilder(module.getBodyRegion());
1043     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1044         op->getLoc(), name,
1045         LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
1046                                       params, /*isVarArg=*/false));
1047   }
1048 
1049   // Helpers for method names.
1050   Operation *getPrintI1(Operation *op) const {
1051     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1052     return getPrint(op, dialect, "print_i1",
1053                     LLVM::LLVMType::getInt1Ty(dialect));
1054   }
1055   Operation *getPrintI32(Operation *op) const {
1056     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1057     return getPrint(op, dialect, "print_i32",
1058                     LLVM::LLVMType::getInt32Ty(dialect));
1059   }
1060   Operation *getPrintI64(Operation *op) const {
1061     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1062     return getPrint(op, dialect, "print_i64",
1063                     LLVM::LLVMType::getInt64Ty(dialect));
1064   }
1065   Operation *getPrintFloat(Operation *op) const {
1066     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1067     return getPrint(op, dialect, "print_f32",
1068                     LLVM::LLVMType::getFloatTy(dialect));
1069   }
1070   Operation *getPrintDouble(Operation *op) const {
1071     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1072     return getPrint(op, dialect, "print_f64",
1073                     LLVM::LLVMType::getDoubleTy(dialect));
1074   }
1075   Operation *getPrintOpen(Operation *op) const {
1076     return getPrint(op, typeConverter.getDialect(), "print_open", {});
1077   }
1078   Operation *getPrintClose(Operation *op) const {
1079     return getPrint(op, typeConverter.getDialect(), "print_close", {});
1080   }
1081   Operation *getPrintComma(Operation *op) const {
1082     return getPrint(op, typeConverter.getDialect(), "print_comma", {});
1083   }
1084   Operation *getPrintNewline(Operation *op) const {
1085     return getPrint(op, typeConverter.getDialect(), "print_newline", {});
1086   }
1087 };
1088 
1089 /// Progressive lowering of ExtractStridedSliceOp to either:
1090 ///   1. extractelement + insertelement for the 1-D case
1091 ///   2. extract + optional strided_slice + insert for the n-D case.
1092 class VectorStridedSliceOpConversion
1093     : public OpRewritePattern<ExtractStridedSliceOp> {
1094 public:
1095   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1096 
1097   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1098                                 PatternRewriter &rewriter) const override {
1099     auto dstType = op.getResult().getType().cast<VectorType>();
1100 
1101     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1102 
1103     int64_t offset =
1104         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1105     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1106     int64_t stride =
1107         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1108 
1109     auto loc = op.getLoc();
1110     auto elemType = dstType.getElementType();
1111     assert(elemType.isSignlessIntOrIndexOrFloat());
1112     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1113                                              rewriter.getZeroAttr(elemType));
1114     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1115     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1116          off += stride, ++idx) {
1117       Value extracted = extractOne(rewriter, loc, op.vector(), off);
1118       if (op.offsets().getValue().size() > 1) {
1119         extracted = rewriter.create<ExtractStridedSliceOp>(
1120             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
1121             getI64SubArray(op.sizes(), /* dropFront=*/1),
1122             getI64SubArray(op.strides(), /* dropFront=*/1));
1123       }
1124       res = insertOne(rewriter, loc, extracted, res, idx);
1125     }
1126     rewriter.replaceOp(op, {res});
1127     return success();
1128   }
1129   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1130   /// bounded as the rank is strictly decreasing.
1131   bool hasBoundedRewriteRecursion() const final { return true; }
1132 };
1133 
1134 } // namespace
1135 
1136 /// Populate the given list with patterns that convert from Vector to LLVM.
1137 void mlir::populateVectorToLLVMConversionPatterns(
1138     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1139   MLIRContext *ctx = converter.getDialect()->getContext();
1140   // clang-format off
1141   patterns.insert<VectorFMAOpNDRewritePattern,
1142                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1143                   VectorInsertStridedSliceOpSameRankRewritePattern,
1144                   VectorStridedSliceOpConversion>(ctx);
1145   patterns
1146       .insert<VectorReductionOpConversion,
1147               VectorShuffleOpConversion,
1148               VectorExtractElementOpConversion,
1149               VectorExtractOpConversion,
1150               VectorFMAOp1DConversion,
1151               VectorInsertElementOpConversion,
1152               VectorInsertOpConversion,
1153               VectorPrintOpConversion,
1154               VectorTransferConversion<TransferReadOp>,
1155               VectorTransferConversion<TransferWriteOp>,
1156               VectorTypeCastOpConversion>(ctx, converter);
1157   // clang-format on
1158 }
1159 
1160 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1161     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1162   MLIRContext *ctx = converter.getDialect()->getContext();
1163   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1164   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1165 }
1166 
1167 namespace {
1168 struct LowerVectorToLLVMPass
1169     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1170   void runOnOperation() override;
1171 };
1172 } // namespace
1173 
1174 void LowerVectorToLLVMPass::runOnOperation() {
1175   // Perform progressive lowering of operations on slices and
1176   // all contraction operations. Also applies folding and DCE.
1177   {
1178     OwningRewritePatternList patterns;
1179     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1180     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1181     populateVectorContractLoweringPatterns(patterns, &getContext());
1182     applyPatternsAndFoldGreedily(getOperation(), patterns);
1183   }
1184 
1185   // Convert to the LLVM IR dialect.
1186   LLVMTypeConverter converter(&getContext());
1187   OwningRewritePatternList patterns;
1188   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1189   populateVectorToLLVMConversionPatterns(converter, patterns);
1190   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1191   populateStdToLLVMConversionPatterns(converter, patterns);
1192 
1193   LLVMConversionTarget target(getContext());
1194   target.addDynamicallyLegalOp<FuncOp>(
1195       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1196   if (failed(applyPartialConversion(getOperation(), target, patterns,
1197                                     &converter))) {
1198     signalPassFailure();
1199   }
1200 }
1201 
1202 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
1203   return std::make_unique<LowerVectorToLLVMPass>();
1204 }
1205