xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 1870e787af961d1b409e18a18ddf297f02333a78)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 template <typename T>
38 static LLVM::LLVMType getPtrToElementType(T containerType,
39                                           LLVMTypeConverter &typeConverter) {
40   return typeConverter.convertType(containerType.getElementType())
41       .template cast<LLVM::LLVMType>()
42       .getPointerTo();
43 }
44 
45 // Helper to reduce vector type by one rank at front.
46 static VectorType reducedVectorTypeFront(VectorType tp) {
47   assert((tp.getRank() > 1) && "unlowerable vector type");
48   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
49 }
50 
51 // Helper to reduce vector type by *all* but one rank at back.
52 static VectorType reducedVectorTypeBack(VectorType tp) {
53   assert((tp.getRank() > 1) && "unlowerable vector type");
54   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
55 }
56 
57 // Helper that picks the proper sequence for inserting.
58 static Value insertOne(ConversionPatternRewriter &rewriter,
59                        LLVMTypeConverter &typeConverter, Location loc,
60                        Value val1, Value val2, Type llvmType, int64_t rank,
61                        int64_t pos) {
62   if (rank == 1) {
63     auto idxType = rewriter.getIndexType();
64     auto constant = rewriter.create<LLVM::ConstantOp>(
65         loc, typeConverter.convertType(idxType),
66         rewriter.getIntegerAttr(idxType, pos));
67     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
68                                                   constant);
69   }
70   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
71                                               rewriter.getI64ArrayAttr(pos));
72 }
73 
74 // Helper that picks the proper sequence for inserting.
75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
76                        Value into, int64_t offset) {
77   auto vectorType = into.getType().cast<VectorType>();
78   if (vectorType.getRank() > 1)
79     return rewriter.create<InsertOp>(loc, from, into, offset);
80   return rewriter.create<vector::InsertElementOp>(
81       loc, vectorType, from, into,
82       rewriter.create<ConstantIndexOp>(loc, offset));
83 }
84 
85 // Helper that picks the proper sequence for extracting.
86 static Value extractOne(ConversionPatternRewriter &rewriter,
87                         LLVMTypeConverter &typeConverter, Location loc,
88                         Value val, Type llvmType, int64_t rank, int64_t pos) {
89   if (rank == 1) {
90     auto idxType = rewriter.getIndexType();
91     auto constant = rewriter.create<LLVM::ConstantOp>(
92         loc, typeConverter.convertType(idxType),
93         rewriter.getIntegerAttr(idxType, pos));
94     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
95                                                    constant);
96   }
97   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
98                                                rewriter.getI64ArrayAttr(pos));
99 }
100 
101 // Helper that picks the proper sequence for extracting.
102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
103                         int64_t offset) {
104   auto vectorType = vector.getType().cast<VectorType>();
105   if (vectorType.getRank() > 1)
106     return rewriter.create<ExtractOp>(loc, vector, offset);
107   return rewriter.create<vector::ExtractElementOp>(
108       loc, vectorType.getElementType(), vector,
109       rewriter.create<ConstantIndexOp>(loc, offset));
110 }
111 
112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing.
114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
115                                               unsigned dropFront = 0,
116                                               unsigned dropBack = 0) {
117   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
118   auto range = arrayAttr.getAsRange<IntegerAttr>();
119   SmallVector<int64_t, 4> res;
120   res.reserve(arrayAttr.size() - dropFront - dropBack);
121   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
122        it != eit; ++it)
123     res.push_back((*it).getValue().getSExtValue());
124   return res;
125 }
126 
127 namespace {
128 
129 /// Conversion pattern for a vector.matrix_multiply.
130 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
131 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
132 public:
133   explicit VectorMatmulOpConversion(MLIRContext *context,
134                                     LLVMTypeConverter &typeConverter)
135       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
136                              typeConverter) {}
137 
138   LogicalResult
139   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
140                   ConversionPatternRewriter &rewriter) const override {
141     auto matmulOp = cast<vector::MatmulOp>(op);
142     auto adaptor = vector::MatmulOpOperandAdaptor(operands);
143     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
144         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
145         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
146         matmulOp.rhs_columns());
147     return success();
148   }
149 };
150 
151 class VectorReductionOpConversion : public ConvertToLLVMPattern {
152 public:
153   explicit VectorReductionOpConversion(MLIRContext *context,
154                                        LLVMTypeConverter &typeConverter)
155       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
156                              typeConverter) {}
157 
158   LogicalResult
159   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
160                   ConversionPatternRewriter &rewriter) const override {
161     auto reductionOp = cast<vector::ReductionOp>(op);
162     auto kind = reductionOp.kind();
163     Type eltType = reductionOp.dest().getType();
164     Type llvmType = typeConverter.convertType(eltType);
165     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
166       // Integer reductions: add/mul/min/max/and/or/xor.
167       if (kind == "add")
168         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
169             op, llvmType, operands[0]);
170       else if (kind == "mul")
171         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
172             op, llvmType, operands[0]);
173       else if (kind == "min")
174         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
175             op, llvmType, operands[0]);
176       else if (kind == "max")
177         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
178             op, llvmType, operands[0]);
179       else if (kind == "and")
180         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
181             op, llvmType, operands[0]);
182       else if (kind == "or")
183         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
184             op, llvmType, operands[0]);
185       else if (kind == "xor")
186         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
187             op, llvmType, operands[0]);
188       else
189         return failure();
190       return success();
191 
192     } else if (eltType.isF32() || eltType.isF64()) {
193       // Floating-point reductions: add/mul/min/max
194       if (kind == "add") {
195         // Optional accumulator (or zero).
196         Value acc = operands.size() > 1 ? operands[1]
197                                         : rewriter.create<LLVM::ConstantOp>(
198                                               op->getLoc(), llvmType,
199                                               rewriter.getZeroAttr(eltType));
200         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
201             op, llvmType, acc, operands[0]);
202       } else if (kind == "mul") {
203         // Optional accumulator (or one).
204         Value acc = operands.size() > 1
205                         ? operands[1]
206                         : rewriter.create<LLVM::ConstantOp>(
207                               op->getLoc(), llvmType,
208                               rewriter.getFloatAttr(eltType, 1.0));
209         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
210             op, llvmType, acc, operands[0]);
211       } else if (kind == "min")
212         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
213             op, llvmType, operands[0]);
214       else if (kind == "max")
215         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
216             op, llvmType, operands[0]);
217       else
218         return failure();
219       return success();
220     }
221     return failure();
222   }
223 };
224 
225 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
226 public:
227   explicit VectorShuffleOpConversion(MLIRContext *context,
228                                      LLVMTypeConverter &typeConverter)
229       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
230                              typeConverter) {}
231 
232   LogicalResult
233   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
234                   ConversionPatternRewriter &rewriter) const override {
235     auto loc = op->getLoc();
236     auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
237     auto shuffleOp = cast<vector::ShuffleOp>(op);
238     auto v1Type = shuffleOp.getV1VectorType();
239     auto v2Type = shuffleOp.getV2VectorType();
240     auto vectorType = shuffleOp.getVectorType();
241     Type llvmType = typeConverter.convertType(vectorType);
242     auto maskArrayAttr = shuffleOp.mask();
243 
244     // Bail if result type cannot be lowered.
245     if (!llvmType)
246       return failure();
247 
248     // Get rank and dimension sizes.
249     int64_t rank = vectorType.getRank();
250     assert(v1Type.getRank() == rank);
251     assert(v2Type.getRank() == rank);
252     int64_t v1Dim = v1Type.getDimSize(0);
253 
254     // For rank 1, where both operands have *exactly* the same vector type,
255     // there is direct shuffle support in LLVM. Use it!
256     if (rank == 1 && v1Type == v2Type) {
257       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
258           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
259       rewriter.replaceOp(op, shuffle);
260       return success();
261     }
262 
263     // For all other cases, insert the individual values individually.
264     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
265     int64_t insPos = 0;
266     for (auto en : llvm::enumerate(maskArrayAttr)) {
267       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
268       Value value = adaptor.v1();
269       if (extPos >= v1Dim) {
270         extPos -= v1Dim;
271         value = adaptor.v2();
272       }
273       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
274                                  rank, extPos);
275       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
276                          llvmType, rank, insPos++);
277     }
278     rewriter.replaceOp(op, insert);
279     return success();
280   }
281 };
282 
283 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
284 public:
285   explicit VectorExtractElementOpConversion(MLIRContext *context,
286                                             LLVMTypeConverter &typeConverter)
287       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
288                              context, typeConverter) {}
289 
290   LogicalResult
291   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
292                   ConversionPatternRewriter &rewriter) const override {
293     auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
294     auto extractEltOp = cast<vector::ExtractElementOp>(op);
295     auto vectorType = extractEltOp.getVectorType();
296     auto llvmType = typeConverter.convertType(vectorType.getElementType());
297 
298     // Bail if result type cannot be lowered.
299     if (!llvmType)
300       return failure();
301 
302     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
303         op, llvmType, adaptor.vector(), adaptor.position());
304     return success();
305   }
306 };
307 
308 class VectorExtractOpConversion : public ConvertToLLVMPattern {
309 public:
310   explicit VectorExtractOpConversion(MLIRContext *context,
311                                      LLVMTypeConverter &typeConverter)
312       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
313                              typeConverter) {}
314 
315   LogicalResult
316   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
317                   ConversionPatternRewriter &rewriter) const override {
318     auto loc = op->getLoc();
319     auto adaptor = vector::ExtractOpOperandAdaptor(operands);
320     auto extractOp = cast<vector::ExtractOp>(op);
321     auto vectorType = extractOp.getVectorType();
322     auto resultType = extractOp.getResult().getType();
323     auto llvmResultType = typeConverter.convertType(resultType);
324     auto positionArrayAttr = extractOp.position();
325 
326     // Bail if result type cannot be lowered.
327     if (!llvmResultType)
328       return failure();
329 
330     // One-shot extraction of vector from array (only requires extractvalue).
331     if (resultType.isa<VectorType>()) {
332       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
333           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
334       rewriter.replaceOp(op, extracted);
335       return success();
336     }
337 
338     // Potential extraction of 1-D vector from array.
339     auto *context = op->getContext();
340     Value extracted = adaptor.vector();
341     auto positionAttrs = positionArrayAttr.getValue();
342     if (positionAttrs.size() > 1) {
343       auto oneDVectorType = reducedVectorTypeBack(vectorType);
344       auto nMinusOnePositionAttrs =
345           ArrayAttr::get(positionAttrs.drop_back(), context);
346       extracted = rewriter.create<LLVM::ExtractValueOp>(
347           loc, typeConverter.convertType(oneDVectorType), extracted,
348           nMinusOnePositionAttrs);
349     }
350 
351     // Remaining extraction of element from 1-D LLVM vector
352     auto position = positionAttrs.back().cast<IntegerAttr>();
353     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
354     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
355     extracted =
356         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
357     rewriter.replaceOp(op, extracted);
358 
359     return success();
360   }
361 };
362 
363 /// Conversion pattern that turns a vector.fma on a 1-D vector
364 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
365 /// This does not match vectors of n >= 2 rank.
366 ///
367 /// Example:
368 /// ```
369 ///  vector.fma %a, %a, %a : vector<8xf32>
370 /// ```
371 /// is converted to:
372 /// ```
373 ///  llvm.intr.fma %va, %va, %va:
374 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
375 ///    -> !llvm<"<8 x float>">
376 /// ```
377 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
378 public:
379   explicit VectorFMAOp1DConversion(MLIRContext *context,
380                                    LLVMTypeConverter &typeConverter)
381       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
382                              typeConverter) {}
383 
384   LogicalResult
385   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
386                   ConversionPatternRewriter &rewriter) const override {
387     auto adaptor = vector::FMAOpOperandAdaptor(operands);
388     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
389     VectorType vType = fmaOp.getVectorType();
390     if (vType.getRank() != 1)
391       return failure();
392     rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
393                                              adaptor.acc());
394     return success();
395   }
396 };
397 
398 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
399 public:
400   explicit VectorInsertElementOpConversion(MLIRContext *context,
401                                            LLVMTypeConverter &typeConverter)
402       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
403                              context, typeConverter) {}
404 
405   LogicalResult
406   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
407                   ConversionPatternRewriter &rewriter) const override {
408     auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
409     auto insertEltOp = cast<vector::InsertElementOp>(op);
410     auto vectorType = insertEltOp.getDestVectorType();
411     auto llvmType = typeConverter.convertType(vectorType);
412 
413     // Bail if result type cannot be lowered.
414     if (!llvmType)
415       return failure();
416 
417     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
418         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
419     return success();
420   }
421 };
422 
423 class VectorInsertOpConversion : public ConvertToLLVMPattern {
424 public:
425   explicit VectorInsertOpConversion(MLIRContext *context,
426                                     LLVMTypeConverter &typeConverter)
427       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
428                              typeConverter) {}
429 
430   LogicalResult
431   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
432                   ConversionPatternRewriter &rewriter) const override {
433     auto loc = op->getLoc();
434     auto adaptor = vector::InsertOpOperandAdaptor(operands);
435     auto insertOp = cast<vector::InsertOp>(op);
436     auto sourceType = insertOp.getSourceType();
437     auto destVectorType = insertOp.getDestVectorType();
438     auto llvmResultType = typeConverter.convertType(destVectorType);
439     auto positionArrayAttr = insertOp.position();
440 
441     // Bail if result type cannot be lowered.
442     if (!llvmResultType)
443       return failure();
444 
445     // One-shot insertion of a vector into an array (only requires insertvalue).
446     if (sourceType.isa<VectorType>()) {
447       Value inserted = rewriter.create<LLVM::InsertValueOp>(
448           loc, llvmResultType, adaptor.dest(), adaptor.source(),
449           positionArrayAttr);
450       rewriter.replaceOp(op, inserted);
451       return success();
452     }
453 
454     // Potential extraction of 1-D vector from array.
455     auto *context = op->getContext();
456     Value extracted = adaptor.dest();
457     auto positionAttrs = positionArrayAttr.getValue();
458     auto position = positionAttrs.back().cast<IntegerAttr>();
459     auto oneDVectorType = destVectorType;
460     if (positionAttrs.size() > 1) {
461       oneDVectorType = reducedVectorTypeBack(destVectorType);
462       auto nMinusOnePositionAttrs =
463           ArrayAttr::get(positionAttrs.drop_back(), context);
464       extracted = rewriter.create<LLVM::ExtractValueOp>(
465           loc, typeConverter.convertType(oneDVectorType), extracted,
466           nMinusOnePositionAttrs);
467     }
468 
469     // Insertion of an element into a 1-D LLVM vector.
470     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
471     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
472     Value inserted = rewriter.create<LLVM::InsertElementOp>(
473         loc, typeConverter.convertType(oneDVectorType), extracted,
474         adaptor.source(), constant);
475 
476     // Potential insertion of resulting 1-D vector into array.
477     if (positionAttrs.size() > 1) {
478       auto nMinusOnePositionAttrs =
479           ArrayAttr::get(positionAttrs.drop_back(), context);
480       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
481                                                       adaptor.dest(), inserted,
482                                                       nMinusOnePositionAttrs);
483     }
484 
485     rewriter.replaceOp(op, inserted);
486     return success();
487   }
488 };
489 
490 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
491 ///
492 /// Example:
493 /// ```
494 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
495 /// ```
496 /// is rewritten into:
497 /// ```
498 ///  %r = splat %f0: vector<2x4xf32>
499 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
500 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
501 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
502 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
503 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
504 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
505 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
506 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
507 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
508 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
509 ///  // %r3 holds the final value.
510 /// ```
511 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
512 public:
513   using OpRewritePattern<FMAOp>::OpRewritePattern;
514 
515   LogicalResult matchAndRewrite(FMAOp op,
516                                 PatternRewriter &rewriter) const override {
517     auto vType = op.getVectorType();
518     if (vType.getRank() < 2)
519       return failure();
520 
521     auto loc = op.getLoc();
522     auto elemType = vType.getElementType();
523     Value zero = rewriter.create<ConstantOp>(loc, elemType,
524                                              rewriter.getZeroAttr(elemType));
525     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
526     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
527       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
528       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
529       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
530       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
531       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
532     }
533     rewriter.replaceOp(op, desc);
534     return success();
535   }
536 };
537 
538 // When ranks are different, InsertStridedSlice needs to extract a properly
539 // ranked vector from the destination vector into which to insert. This pattern
540 // only takes care of this part and forwards the rest of the conversion to
541 // another pattern that converts InsertStridedSlice for operands of the same
542 // rank.
543 //
544 // RewritePattern for InsertStridedSliceOp where source and destination vectors
545 // have different ranks. In this case:
546 //   1. the proper subvector is extracted from the destination vector
547 //   2. a new InsertStridedSlice op is created to insert the source in the
548 //   destination subvector
549 //   3. the destination subvector is inserted back in the proper place
550 //   4. the op is replaced by the result of step 3.
551 // The new InsertStridedSlice from step 2. will be picked up by a
552 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
553 class VectorInsertStridedSliceOpDifferentRankRewritePattern
554     : public OpRewritePattern<InsertStridedSliceOp> {
555 public:
556   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
557 
558   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
559                                 PatternRewriter &rewriter) const override {
560     auto srcType = op.getSourceVectorType();
561     auto dstType = op.getDestVectorType();
562 
563     if (op.offsets().getValue().empty())
564       return failure();
565 
566     auto loc = op.getLoc();
567     int64_t rankDiff = dstType.getRank() - srcType.getRank();
568     assert(rankDiff >= 0);
569     if (rankDiff == 0)
570       return failure();
571 
572     int64_t rankRest = dstType.getRank() - rankDiff;
573     // Extract / insert the subvector of matching rank and InsertStridedSlice
574     // on it.
575     Value extracted =
576         rewriter.create<ExtractOp>(loc, op.dest(),
577                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
578                                                   /*dropFront=*/rankRest));
579     // A different pattern will kick in for InsertStridedSlice with matching
580     // ranks.
581     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
582         loc, op.source(), extracted,
583         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
584         getI64SubArray(op.strides(), /*dropFront=*/0));
585     rewriter.replaceOpWithNewOp<InsertOp>(
586         op, stridedSliceInnerOp.getResult(), op.dest(),
587         getI64SubArray(op.offsets(), /*dropFront=*/0,
588                        /*dropFront=*/rankRest));
589     return success();
590   }
591 };
592 
593 // RewritePattern for InsertStridedSliceOp where source and destination vectors
594 // have the same rank. In this case, we reduce
595 //   1. the proper subvector is extracted from the destination vector
596 //   2. a new InsertStridedSlice op is created to insert the source in the
597 //   destination subvector
598 //   3. the destination subvector is inserted back in the proper place
599 //   4. the op is replaced by the result of step 3.
600 // The new InsertStridedSlice from step 2. will be picked up by a
601 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
602 class VectorInsertStridedSliceOpSameRankRewritePattern
603     : public OpRewritePattern<InsertStridedSliceOp> {
604 public:
605   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
606 
607   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
608                                 PatternRewriter &rewriter) const override {
609     auto srcType = op.getSourceVectorType();
610     auto dstType = op.getDestVectorType();
611 
612     if (op.offsets().getValue().empty())
613       return failure();
614 
615     int64_t rankDiff = dstType.getRank() - srcType.getRank();
616     assert(rankDiff >= 0);
617     if (rankDiff != 0)
618       return failure();
619 
620     if (srcType == dstType) {
621       rewriter.replaceOp(op, op.source());
622       return success();
623     }
624 
625     int64_t offset =
626         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
627     int64_t size = srcType.getShape().front();
628     int64_t stride =
629         op.strides().getValue().front().cast<IntegerAttr>().getInt();
630 
631     auto loc = op.getLoc();
632     Value res = op.dest();
633     // For each slice of the source vector along the most major dimension.
634     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
635          off += stride, ++idx) {
636       // 1. extract the proper subvector (or element) from source
637       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
638       if (extractedSource.getType().isa<VectorType>()) {
639         // 2. If we have a vector, extract the proper subvector from destination
640         // Otherwise we are at the element level and no need to recurse.
641         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
642         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
643         // smaller rank.
644         extractedSource = rewriter.create<InsertStridedSliceOp>(
645             loc, extractedSource, extractedDest,
646             getI64SubArray(op.offsets(), /* dropFront=*/1),
647             getI64SubArray(op.strides(), /* dropFront=*/1));
648       }
649       // 4. Insert the extractedSource into the res vector.
650       res = insertOne(rewriter, loc, extractedSource, res, off);
651     }
652 
653     rewriter.replaceOp(op, res);
654     return success();
655   }
656   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
657   /// bounded as the rank is strictly decreasing.
658   bool hasBoundedRewriteRecursion() const final { return true; }
659 };
660 
661 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
662 public:
663   explicit VectorTypeCastOpConversion(MLIRContext *context,
664                                       LLVMTypeConverter &typeConverter)
665       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
666                              typeConverter) {}
667 
668   LogicalResult
669   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
670                   ConversionPatternRewriter &rewriter) const override {
671     auto loc = op->getLoc();
672     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
673     MemRefType sourceMemRefType =
674         castOp.getOperand().getType().cast<MemRefType>();
675     MemRefType targetMemRefType =
676         castOp.getResult().getType().cast<MemRefType>();
677 
678     // Only static shape casts supported atm.
679     if (!sourceMemRefType.hasStaticShape() ||
680         !targetMemRefType.hasStaticShape())
681       return failure();
682 
683     auto llvmSourceDescriptorTy =
684         operands[0].getType().dyn_cast<LLVM::LLVMType>();
685     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
686       return failure();
687     MemRefDescriptor sourceMemRef(operands[0]);
688 
689     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
690                                       .dyn_cast_or_null<LLVM::LLVMType>();
691     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
692       return failure();
693 
694     int64_t offset;
695     SmallVector<int64_t, 4> strides;
696     auto successStrides =
697         getStridesAndOffset(sourceMemRefType, strides, offset);
698     bool isContiguous = (strides.back() == 1);
699     if (isContiguous) {
700       auto sizes = sourceMemRefType.getShape();
701       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
702         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
703           isContiguous = false;
704           break;
705         }
706       }
707     }
708     // Only contiguous source tensors supported atm.
709     if (failed(successStrides) || !isContiguous)
710       return failure();
711 
712     auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
713 
714     // Create descriptor.
715     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
716     Type llvmTargetElementTy = desc.getElementType();
717     // Set allocated ptr.
718     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
719     allocated =
720         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
721     desc.setAllocatedPtr(rewriter, loc, allocated);
722     // Set aligned ptr.
723     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
724     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
725     desc.setAlignedPtr(rewriter, loc, ptr);
726     // Fill offset 0.
727     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
728     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
729     desc.setOffset(rewriter, loc, zero);
730 
731     // Fill size and stride descriptors in memref.
732     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
733       int64_t index = indexedSize.index();
734       auto sizeAttr =
735           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
736       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
737       desc.setSize(rewriter, loc, index, size);
738       auto strideAttr =
739           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
740       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
741       desc.setStride(rewriter, loc, index, stride);
742     }
743 
744     rewriter.replaceOp(op, {desc});
745     return success();
746   }
747 };
748 
749 LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
750                                       Type type, LLVM::LLVMType &llvmType,
751                                       unsigned &align) {
752   auto convertedType = typeConverter.convertType(type);
753   if (!convertedType)
754     return failure();
755 
756   llvmType = convertedType.template cast<LLVM::LLVMType>();
757   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
758   align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
759   return success();
760 }
761 
762 LogicalResult
763 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
764                                  LLVMTypeConverter &typeConverter, Location loc,
765                                  TransferReadOp xferOp,
766                                  ArrayRef<Value> operands, Value dataPtr) {
767   LLVM::LLVMType vecTy;
768   unsigned align;
769   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
770                                      vecTy, align)))
771     return failure();
772   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
773   return success();
774 }
775 
776 LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
777                                           LLVMTypeConverter &typeConverter,
778                                           Location loc, TransferReadOp xferOp,
779                                           ArrayRef<Value> operands,
780                                           Value dataPtr, Value mask) {
781   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
782   VectorType fillType = xferOp.getVectorType();
783   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
784   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
785 
786   LLVM::LLVMType vecTy;
787   unsigned align;
788   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
789                                      vecTy, align)))
790     return failure();
791 
792   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
793       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
794       rewriter.getI32IntegerAttr(align));
795   return success();
796 }
797 
798 LogicalResult
799 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
800                                  LLVMTypeConverter &typeConverter, Location loc,
801                                  TransferWriteOp xferOp,
802                                  ArrayRef<Value> operands, Value dataPtr) {
803   auto adaptor = TransferWriteOpOperandAdaptor(operands);
804   LLVM::LLVMType vecTy;
805   unsigned align;
806   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
807                                      vecTy, align)))
808     return failure();
809   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
810   return success();
811 }
812 
813 LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
814                                           LLVMTypeConverter &typeConverter,
815                                           Location loc, TransferWriteOp xferOp,
816                                           ArrayRef<Value> operands,
817                                           Value dataPtr, Value mask) {
818   auto adaptor = TransferWriteOpOperandAdaptor(operands);
819   LLVM::LLVMType vecTy;
820   unsigned align;
821   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
822                                      vecTy, align)))
823     return failure();
824 
825   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
826       xferOp, adaptor.vector(), dataPtr, mask,
827       rewriter.getI32IntegerAttr(align));
828   return success();
829 }
830 
831 static TransferReadOpOperandAdaptor
832 getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
833   return TransferReadOpOperandAdaptor(operands);
834 }
835 
836 static TransferWriteOpOperandAdaptor
837 getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
838   return TransferWriteOpOperandAdaptor(operands);
839 }
840 
841 bool isMinorIdentity(AffineMap map, unsigned rank) {
842   if (map.getNumResults() < rank)
843     return false;
844   unsigned startDim = map.getNumDims() - rank;
845   for (unsigned i = 0; i < rank; ++i)
846     if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext()))
847       return false;
848   return true;
849 }
850 
851 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
852 /// sequence of:
853 /// 1. Bitcast or addrspacecast to vector form.
854 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
855 /// 3. Create a mask where offsetVector is compared against memref upper bound.
856 /// 4. Rewrite op as a masked read or write.
857 template <typename ConcreteOp>
858 class VectorTransferConversion : public ConvertToLLVMPattern {
859 public:
860   explicit VectorTransferConversion(MLIRContext *context,
861                                     LLVMTypeConverter &typeConv)
862       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
863                              typeConv) {}
864 
865   LogicalResult
866   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
867                   ConversionPatternRewriter &rewriter) const override {
868     auto xferOp = cast<ConcreteOp>(op);
869     auto adaptor = getTransferOpAdapter(xferOp, operands);
870 
871     if (xferOp.getVectorType().getRank() > 1 ||
872         llvm::size(xferOp.indices()) == 0)
873       return failure();
874     if (!isMinorIdentity(xferOp.permutation_map(),
875                          xferOp.getVectorType().getRank()))
876       return failure();
877 
878     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
879 
880     Location loc = op->getLoc();
881     Type i64Type = rewriter.getIntegerType(64);
882     MemRefType memRefType = xferOp.getMemRefType();
883 
884     // 1. Get the source/dst address as an LLVM vector pointer.
885     //    The vector pointer would always be on address space 0, therefore
886     //    addrspacecast shall be used when source/dst memrefs are not on
887     //    address space 0.
888     // TODO: support alignment when possible.
889     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
890                                adaptor.indices(), rewriter, getModule());
891     auto vecTy =
892         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
893     Value vectorDataPtr;
894     if (memRefType.getMemorySpace() == 0)
895       vectorDataPtr =
896           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
897     else
898       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
899           loc, vecTy.getPointerTo(), dataPtr);
900 
901     if (!xferOp.isMaskedDim(0))
902       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
903                                               xferOp, operands, vectorDataPtr);
904 
905     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
906     unsigned vecWidth = vecTy.getVectorNumElements();
907     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
908     SmallVector<int64_t, 8> indices;
909     indices.reserve(vecWidth);
910     for (unsigned i = 0; i < vecWidth; ++i)
911       indices.push_back(i);
912     Value linearIndices = rewriter.create<ConstantOp>(
913         loc, vectorCmpType,
914         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
915     linearIndices = rewriter.create<LLVM::DialectCastOp>(
916         loc, toLLVMTy(vectorCmpType), linearIndices);
917 
918     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
919     // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last
920     // `k` dimensions here.
921     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
922     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
923     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
924     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
925     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
926 
927     // 4. Let dim the memref dimension, compute the vector comparison mask:
928     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
929     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
930     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
931     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
932     Value mask =
933         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
934     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
935                                                 mask);
936 
937     // 5. Rewrite as a masked read / write.
938     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
939                                        operands, vectorDataPtr, mask);
940   }
941 };
942 
943 class VectorPrintOpConversion : public ConvertToLLVMPattern {
944 public:
945   explicit VectorPrintOpConversion(MLIRContext *context,
946                                    LLVMTypeConverter &typeConverter)
947       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
948                              typeConverter) {}
949 
950   // Proof-of-concept lowering implementation that relies on a small
951   // runtime support library, which only needs to provide a few
952   // printing methods (single value for all data types, opening/closing
953   // bracket, comma, newline). The lowering fully unrolls a vector
954   // in terms of these elementary printing operations. The advantage
955   // of this approach is that the library can remain unaware of all
956   // low-level implementation details of vectors while still supporting
957   // output of any shaped and dimensioned vector. Due to full unrolling,
958   // this approach is less suited for very large vectors though.
959   //
960   // TODO(ajcbik): rely solely on libc in future? something else?
961   //
962   LogicalResult
963   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
964                   ConversionPatternRewriter &rewriter) const override {
965     auto printOp = cast<vector::PrintOp>(op);
966     auto adaptor = vector::PrintOpOperandAdaptor(operands);
967     Type printType = printOp.getPrintType();
968 
969     if (typeConverter.convertType(printType) == nullptr)
970       return failure();
971 
972     // Make sure element type has runtime support (currently just Float/Double).
973     VectorType vectorType = printType.dyn_cast<VectorType>();
974     Type eltType = vectorType ? vectorType.getElementType() : printType;
975     int64_t rank = vectorType ? vectorType.getRank() : 0;
976     Operation *printer;
977     if (eltType.isSignlessInteger(1))
978       printer = getPrintI1(op);
979     else if (eltType.isSignlessInteger(32))
980       printer = getPrintI32(op);
981     else if (eltType.isSignlessInteger(64))
982       printer = getPrintI64(op);
983     else if (eltType.isF32())
984       printer = getPrintFloat(op);
985     else if (eltType.isF64())
986       printer = getPrintDouble(op);
987     else
988       return failure();
989 
990     // Unroll vector into elementary print calls.
991     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
992     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
993     rewriter.eraseOp(op);
994     return success();
995   }
996 
997 private:
998   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
999                  Value value, VectorType vectorType, Operation *printer,
1000                  int64_t rank) const {
1001     Location loc = op->getLoc();
1002     if (rank == 0) {
1003       emitCall(rewriter, loc, printer, value);
1004       return;
1005     }
1006 
1007     emitCall(rewriter, loc, getPrintOpen(op));
1008     Operation *printComma = getPrintComma(op);
1009     int64_t dim = vectorType.getDimSize(0);
1010     for (int64_t d = 0; d < dim; ++d) {
1011       auto reducedType =
1012           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1013       auto llvmType = typeConverter.convertType(
1014           rank > 1 ? reducedType : vectorType.getElementType());
1015       Value nestedVal =
1016           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1017       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1018       if (d != dim - 1)
1019         emitCall(rewriter, loc, printComma);
1020     }
1021     emitCall(rewriter, loc, getPrintClose(op));
1022   }
1023 
1024   // Helper to emit a call.
1025   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1026                        Operation *ref, ValueRange params = ValueRange()) {
1027     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1028                                   rewriter.getSymbolRefAttr(ref), params);
1029   }
1030 
1031   // Helper for printer method declaration (first hit) and lookup.
1032   static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
1033                              StringRef name, ArrayRef<LLVM::LLVMType> params) {
1034     auto module = op->getParentOfType<ModuleOp>();
1035     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1036     if (func)
1037       return func;
1038     OpBuilder moduleBuilder(module.getBodyRegion());
1039     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1040         op->getLoc(), name,
1041         LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
1042                                       params, /*isVarArg=*/false));
1043   }
1044 
1045   // Helpers for method names.
1046   Operation *getPrintI1(Operation *op) const {
1047     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1048     return getPrint(op, dialect, "print_i1",
1049                     LLVM::LLVMType::getInt1Ty(dialect));
1050   }
1051   Operation *getPrintI32(Operation *op) const {
1052     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1053     return getPrint(op, dialect, "print_i32",
1054                     LLVM::LLVMType::getInt32Ty(dialect));
1055   }
1056   Operation *getPrintI64(Operation *op) const {
1057     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1058     return getPrint(op, dialect, "print_i64",
1059                     LLVM::LLVMType::getInt64Ty(dialect));
1060   }
1061   Operation *getPrintFloat(Operation *op) const {
1062     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1063     return getPrint(op, dialect, "print_f32",
1064                     LLVM::LLVMType::getFloatTy(dialect));
1065   }
1066   Operation *getPrintDouble(Operation *op) const {
1067     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1068     return getPrint(op, dialect, "print_f64",
1069                     LLVM::LLVMType::getDoubleTy(dialect));
1070   }
1071   Operation *getPrintOpen(Operation *op) const {
1072     return getPrint(op, typeConverter.getDialect(), "print_open", {});
1073   }
1074   Operation *getPrintClose(Operation *op) const {
1075     return getPrint(op, typeConverter.getDialect(), "print_close", {});
1076   }
1077   Operation *getPrintComma(Operation *op) const {
1078     return getPrint(op, typeConverter.getDialect(), "print_comma", {});
1079   }
1080   Operation *getPrintNewline(Operation *op) const {
1081     return getPrint(op, typeConverter.getDialect(), "print_newline", {});
1082   }
1083 };
1084 
1085 /// Progressive lowering of ExtractStridedSliceOp to either:
1086 ///   1. extractelement + insertelement for the 1-D case
1087 ///   2. extract + optional strided_slice + insert for the n-D case.
1088 class VectorStridedSliceOpConversion
1089     : public OpRewritePattern<ExtractStridedSliceOp> {
1090 public:
1091   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1092 
1093   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1094                                 PatternRewriter &rewriter) const override {
1095     auto dstType = op.getResult().getType().cast<VectorType>();
1096 
1097     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1098 
1099     int64_t offset =
1100         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1101     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1102     int64_t stride =
1103         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1104 
1105     auto loc = op.getLoc();
1106     auto elemType = dstType.getElementType();
1107     assert(elemType.isSignlessIntOrIndexOrFloat());
1108     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1109                                              rewriter.getZeroAttr(elemType));
1110     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1111     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1112          off += stride, ++idx) {
1113       Value extracted = extractOne(rewriter, loc, op.vector(), off);
1114       if (op.offsets().getValue().size() > 1) {
1115         extracted = rewriter.create<ExtractStridedSliceOp>(
1116             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
1117             getI64SubArray(op.sizes(), /* dropFront=*/1),
1118             getI64SubArray(op.strides(), /* dropFront=*/1));
1119       }
1120       res = insertOne(rewriter, loc, extracted, res, idx);
1121     }
1122     rewriter.replaceOp(op, {res});
1123     return success();
1124   }
1125   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1126   /// bounded as the rank is strictly decreasing.
1127   bool hasBoundedRewriteRecursion() const final { return true; }
1128 };
1129 
1130 } // namespace
1131 
1132 /// Populate the given list with patterns that convert from Vector to LLVM.
1133 void mlir::populateVectorToLLVMConversionPatterns(
1134     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1135   MLIRContext *ctx = converter.getDialect()->getContext();
1136   // clang-format off
1137   patterns.insert<VectorFMAOpNDRewritePattern,
1138                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1139                   VectorInsertStridedSliceOpSameRankRewritePattern,
1140                   VectorStridedSliceOpConversion>(ctx);
1141   patterns
1142       .insert<VectorReductionOpConversion,
1143               VectorShuffleOpConversion,
1144               VectorExtractElementOpConversion,
1145               VectorExtractOpConversion,
1146               VectorFMAOp1DConversion,
1147               VectorInsertElementOpConversion,
1148               VectorInsertOpConversion,
1149               VectorPrintOpConversion,
1150               VectorTransferConversion<TransferReadOp>,
1151               VectorTransferConversion<TransferWriteOp>,
1152               VectorTypeCastOpConversion>(ctx, converter);
1153   // clang-format on
1154 }
1155 
1156 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1157     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1158   MLIRContext *ctx = converter.getDialect()->getContext();
1159   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1160 }
1161 
1162 namespace {
1163 struct LowerVectorToLLVMPass
1164     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1165   void runOnOperation() override;
1166 };
1167 } // namespace
1168 
1169 void LowerVectorToLLVMPass::runOnOperation() {
1170   // Perform progressive lowering of operations on slices and
1171   // all contraction operations. Also applies folding and DCE.
1172   {
1173     OwningRewritePatternList patterns;
1174     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1175     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1176     populateVectorContractLoweringPatterns(patterns, &getContext());
1177     applyPatternsAndFoldGreedily(getOperation(), patterns);
1178   }
1179 
1180   // Convert to the LLVM IR dialect.
1181   LLVMTypeConverter converter(&getContext());
1182   OwningRewritePatternList patterns;
1183   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1184   populateVectorToLLVMConversionPatterns(converter, patterns);
1185   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1186   populateStdToLLVMConversionPatterns(converter, patterns);
1187 
1188   LLVMConversionTarget target(getContext());
1189   target.addDynamicallyLegalOp<FuncOp>(
1190       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1191   if (failed(applyPartialConversion(getOperation(), target, patterns,
1192                                     &converter))) {
1193     signalPassFailure();
1194   }
1195 }
1196 
1197 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
1198   return std::make_unique<LowerVectorToLLVMPass>();
1199 }
1200