xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision a23f190213e16ec0f9075e1a813a046730f73458)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 template <typename T>
38 static LLVM::LLVMType getPtrToElementType(T containerType,
39                                           LLVMTypeConverter &typeConverter) {
40   return typeConverter.convertType(containerType.getElementType())
41       .template cast<LLVM::LLVMType>()
42       .getPointerTo();
43 }
44 
45 // Helper to reduce vector type by one rank at front.
46 static VectorType reducedVectorTypeFront(VectorType tp) {
47   assert((tp.getRank() > 1) && "unlowerable vector type");
48   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
49 }
50 
51 // Helper to reduce vector type by *all* but one rank at back.
52 static VectorType reducedVectorTypeBack(VectorType tp) {
53   assert((tp.getRank() > 1) && "unlowerable vector type");
54   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
55 }
56 
57 // Helper that picks the proper sequence for inserting.
58 static Value insertOne(ConversionPatternRewriter &rewriter,
59                        LLVMTypeConverter &typeConverter, Location loc,
60                        Value val1, Value val2, Type llvmType, int64_t rank,
61                        int64_t pos) {
62   if (rank == 1) {
63     auto idxType = rewriter.getIndexType();
64     auto constant = rewriter.create<LLVM::ConstantOp>(
65         loc, typeConverter.convertType(idxType),
66         rewriter.getIntegerAttr(idxType, pos));
67     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
68                                                   constant);
69   }
70   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
71                                               rewriter.getI64ArrayAttr(pos));
72 }
73 
74 // Helper that picks the proper sequence for inserting.
75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
76                        Value into, int64_t offset) {
77   auto vectorType = into.getType().cast<VectorType>();
78   if (vectorType.getRank() > 1)
79     return rewriter.create<InsertOp>(loc, from, into, offset);
80   return rewriter.create<vector::InsertElementOp>(
81       loc, vectorType, from, into,
82       rewriter.create<ConstantIndexOp>(loc, offset));
83 }
84 
85 // Helper that picks the proper sequence for extracting.
86 static Value extractOne(ConversionPatternRewriter &rewriter,
87                         LLVMTypeConverter &typeConverter, Location loc,
88                         Value val, Type llvmType, int64_t rank, int64_t pos) {
89   if (rank == 1) {
90     auto idxType = rewriter.getIndexType();
91     auto constant = rewriter.create<LLVM::ConstantOp>(
92         loc, typeConverter.convertType(idxType),
93         rewriter.getIntegerAttr(idxType, pos));
94     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
95                                                    constant);
96   }
97   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
98                                                rewriter.getI64ArrayAttr(pos));
99 }
100 
101 // Helper that picks the proper sequence for extracting.
102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
103                         int64_t offset) {
104   auto vectorType = vector.getType().cast<VectorType>();
105   if (vectorType.getRank() > 1)
106     return rewriter.create<ExtractOp>(loc, vector, offset);
107   return rewriter.create<vector::ExtractElementOp>(
108       loc, vectorType.getElementType(), vector,
109       rewriter.create<ConstantIndexOp>(loc, offset));
110 }
111 
112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing.
114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
115                                               unsigned dropFront = 0,
116                                               unsigned dropBack = 0) {
117   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
118   auto range = arrayAttr.getAsRange<IntegerAttr>();
119   SmallVector<int64_t, 4> res;
120   res.reserve(arrayAttr.size() - dropFront - dropBack);
121   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
122        it != eit; ++it)
123     res.push_back((*it).getValue().getSExtValue());
124   return res;
125 }
126 
127 namespace {
128 
129 /// Conversion pattern for a vector.matrix_multiply.
130 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
131 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
132 public:
133   explicit VectorMatmulOpConversion(MLIRContext *context,
134                                     LLVMTypeConverter &typeConverter)
135       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
136                              typeConverter) {}
137 
138   LogicalResult
139   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
140                   ConversionPatternRewriter &rewriter) const override {
141     auto matmulOp = cast<vector::MatmulOp>(op);
142     auto adaptor = vector::MatmulOpOperandAdaptor(operands);
143     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
144         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
145         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
146         matmulOp.rhs_columns());
147     return success();
148   }
149 };
150 
151 class VectorReductionOpConversion : public ConvertToLLVMPattern {
152 public:
153   explicit VectorReductionOpConversion(MLIRContext *context,
154                                        LLVMTypeConverter &typeConverter)
155       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
156                              typeConverter) {}
157 
158   LogicalResult
159   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
160                   ConversionPatternRewriter &rewriter) const override {
161     auto reductionOp = cast<vector::ReductionOp>(op);
162     auto kind = reductionOp.kind();
163     Type eltType = reductionOp.dest().getType();
164     Type llvmType = typeConverter.convertType(eltType);
165     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
166       // Integer reductions: add/mul/min/max/and/or/xor.
167       if (kind == "add")
168         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
169             op, llvmType, operands[0]);
170       else if (kind == "mul")
171         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
172             op, llvmType, operands[0]);
173       else if (kind == "min")
174         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
175             op, llvmType, operands[0]);
176       else if (kind == "max")
177         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
178             op, llvmType, operands[0]);
179       else if (kind == "and")
180         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
181             op, llvmType, operands[0]);
182       else if (kind == "or")
183         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
184             op, llvmType, operands[0]);
185       else if (kind == "xor")
186         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
187             op, llvmType, operands[0]);
188       else
189         return failure();
190       return success();
191 
192     } else if (eltType.isF32() || eltType.isF64()) {
193       // Floating-point reductions: add/mul/min/max
194       if (kind == "add") {
195         // Optional accumulator (or zero).
196         Value acc = operands.size() > 1 ? operands[1]
197                                         : rewriter.create<LLVM::ConstantOp>(
198                                               op->getLoc(), llvmType,
199                                               rewriter.getZeroAttr(eltType));
200         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
201             op, llvmType, acc, operands[0]);
202       } else if (kind == "mul") {
203         // Optional accumulator (or one).
204         Value acc = operands.size() > 1
205                         ? operands[1]
206                         : rewriter.create<LLVM::ConstantOp>(
207                               op->getLoc(), llvmType,
208                               rewriter.getFloatAttr(eltType, 1.0));
209         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
210             op, llvmType, acc, operands[0]);
211       } else if (kind == "min")
212         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
213             op, llvmType, operands[0]);
214       else if (kind == "max")
215         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
216             op, llvmType, operands[0]);
217       else
218         return failure();
219       return success();
220     }
221     return failure();
222   }
223 };
224 
225 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
226 public:
227   explicit VectorShuffleOpConversion(MLIRContext *context,
228                                      LLVMTypeConverter &typeConverter)
229       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
230                              typeConverter) {}
231 
232   LogicalResult
233   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
234                   ConversionPatternRewriter &rewriter) const override {
235     auto loc = op->getLoc();
236     auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
237     auto shuffleOp = cast<vector::ShuffleOp>(op);
238     auto v1Type = shuffleOp.getV1VectorType();
239     auto v2Type = shuffleOp.getV2VectorType();
240     auto vectorType = shuffleOp.getVectorType();
241     Type llvmType = typeConverter.convertType(vectorType);
242     auto maskArrayAttr = shuffleOp.mask();
243 
244     // Bail if result type cannot be lowered.
245     if (!llvmType)
246       return failure();
247 
248     // Get rank and dimension sizes.
249     int64_t rank = vectorType.getRank();
250     assert(v1Type.getRank() == rank);
251     assert(v2Type.getRank() == rank);
252     int64_t v1Dim = v1Type.getDimSize(0);
253 
254     // For rank 1, where both operands have *exactly* the same vector type,
255     // there is direct shuffle support in LLVM. Use it!
256     if (rank == 1 && v1Type == v2Type) {
257       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
258           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
259       rewriter.replaceOp(op, shuffle);
260       return success();
261     }
262 
263     // For all other cases, insert the individual values individually.
264     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
265     int64_t insPos = 0;
266     for (auto en : llvm::enumerate(maskArrayAttr)) {
267       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
268       Value value = adaptor.v1();
269       if (extPos >= v1Dim) {
270         extPos -= v1Dim;
271         value = adaptor.v2();
272       }
273       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
274                                  rank, extPos);
275       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
276                          llvmType, rank, insPos++);
277     }
278     rewriter.replaceOp(op, insert);
279     return success();
280   }
281 };
282 
283 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
284 public:
285   explicit VectorExtractElementOpConversion(MLIRContext *context,
286                                             LLVMTypeConverter &typeConverter)
287       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
288                              context, typeConverter) {}
289 
290   LogicalResult
291   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
292                   ConversionPatternRewriter &rewriter) const override {
293     auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
294     auto extractEltOp = cast<vector::ExtractElementOp>(op);
295     auto vectorType = extractEltOp.getVectorType();
296     auto llvmType = typeConverter.convertType(vectorType.getElementType());
297 
298     // Bail if result type cannot be lowered.
299     if (!llvmType)
300       return failure();
301 
302     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
303         op, llvmType, adaptor.vector(), adaptor.position());
304     return success();
305   }
306 };
307 
308 class VectorExtractOpConversion : public ConvertToLLVMPattern {
309 public:
310   explicit VectorExtractOpConversion(MLIRContext *context,
311                                      LLVMTypeConverter &typeConverter)
312       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
313                              typeConverter) {}
314 
315   LogicalResult
316   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
317                   ConversionPatternRewriter &rewriter) const override {
318     auto loc = op->getLoc();
319     auto adaptor = vector::ExtractOpOperandAdaptor(operands);
320     auto extractOp = cast<vector::ExtractOp>(op);
321     auto vectorType = extractOp.getVectorType();
322     auto resultType = extractOp.getResult().getType();
323     auto llvmResultType = typeConverter.convertType(resultType);
324     auto positionArrayAttr = extractOp.position();
325 
326     // Bail if result type cannot be lowered.
327     if (!llvmResultType)
328       return failure();
329 
330     // One-shot extraction of vector from array (only requires extractvalue).
331     if (resultType.isa<VectorType>()) {
332       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
333           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
334       rewriter.replaceOp(op, extracted);
335       return success();
336     }
337 
338     // Potential extraction of 1-D vector from array.
339     auto *context = op->getContext();
340     Value extracted = adaptor.vector();
341     auto positionAttrs = positionArrayAttr.getValue();
342     if (positionAttrs.size() > 1) {
343       auto oneDVectorType = reducedVectorTypeBack(vectorType);
344       auto nMinusOnePositionAttrs =
345           ArrayAttr::get(positionAttrs.drop_back(), context);
346       extracted = rewriter.create<LLVM::ExtractValueOp>(
347           loc, typeConverter.convertType(oneDVectorType), extracted,
348           nMinusOnePositionAttrs);
349     }
350 
351     // Remaining extraction of element from 1-D LLVM vector
352     auto position = positionAttrs.back().cast<IntegerAttr>();
353     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
354     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
355     extracted =
356         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
357     rewriter.replaceOp(op, extracted);
358 
359     return success();
360   }
361 };
362 
363 /// Conversion pattern that turns a vector.fma on a 1-D vector
364 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
365 /// This does not match vectors of n >= 2 rank.
366 ///
367 /// Example:
368 /// ```
369 ///  vector.fma %a, %a, %a : vector<8xf32>
370 /// ```
371 /// is converted to:
372 /// ```
373 ///  llvm.intr.fma %va, %va, %va:
374 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
375 ///    -> !llvm<"<8 x float>">
376 /// ```
377 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
378 public:
379   explicit VectorFMAOp1DConversion(MLIRContext *context,
380                                    LLVMTypeConverter &typeConverter)
381       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
382                              typeConverter) {}
383 
384   LogicalResult
385   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
386                   ConversionPatternRewriter &rewriter) const override {
387     auto adaptor = vector::FMAOpOperandAdaptor(operands);
388     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
389     VectorType vType = fmaOp.getVectorType();
390     if (vType.getRank() != 1)
391       return failure();
392     rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
393                                              adaptor.acc());
394     return success();
395   }
396 };
397 
398 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
399 public:
400   explicit VectorInsertElementOpConversion(MLIRContext *context,
401                                            LLVMTypeConverter &typeConverter)
402       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
403                              context, typeConverter) {}
404 
405   LogicalResult
406   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
407                   ConversionPatternRewriter &rewriter) const override {
408     auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
409     auto insertEltOp = cast<vector::InsertElementOp>(op);
410     auto vectorType = insertEltOp.getDestVectorType();
411     auto llvmType = typeConverter.convertType(vectorType);
412 
413     // Bail if result type cannot be lowered.
414     if (!llvmType)
415       return failure();
416 
417     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
418         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
419     return success();
420   }
421 };
422 
423 class VectorInsertOpConversion : public ConvertToLLVMPattern {
424 public:
425   explicit VectorInsertOpConversion(MLIRContext *context,
426                                     LLVMTypeConverter &typeConverter)
427       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
428                              typeConverter) {}
429 
430   LogicalResult
431   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
432                   ConversionPatternRewriter &rewriter) const override {
433     auto loc = op->getLoc();
434     auto adaptor = vector::InsertOpOperandAdaptor(operands);
435     auto insertOp = cast<vector::InsertOp>(op);
436     auto sourceType = insertOp.getSourceType();
437     auto destVectorType = insertOp.getDestVectorType();
438     auto llvmResultType = typeConverter.convertType(destVectorType);
439     auto positionArrayAttr = insertOp.position();
440 
441     // Bail if result type cannot be lowered.
442     if (!llvmResultType)
443       return failure();
444 
445     // One-shot insertion of a vector into an array (only requires insertvalue).
446     if (sourceType.isa<VectorType>()) {
447       Value inserted = rewriter.create<LLVM::InsertValueOp>(
448           loc, llvmResultType, adaptor.dest(), adaptor.source(),
449           positionArrayAttr);
450       rewriter.replaceOp(op, inserted);
451       return success();
452     }
453 
454     // Potential extraction of 1-D vector from array.
455     auto *context = op->getContext();
456     Value extracted = adaptor.dest();
457     auto positionAttrs = positionArrayAttr.getValue();
458     auto position = positionAttrs.back().cast<IntegerAttr>();
459     auto oneDVectorType = destVectorType;
460     if (positionAttrs.size() > 1) {
461       oneDVectorType = reducedVectorTypeBack(destVectorType);
462       auto nMinusOnePositionAttrs =
463           ArrayAttr::get(positionAttrs.drop_back(), context);
464       extracted = rewriter.create<LLVM::ExtractValueOp>(
465           loc, typeConverter.convertType(oneDVectorType), extracted,
466           nMinusOnePositionAttrs);
467     }
468 
469     // Insertion of an element into a 1-D LLVM vector.
470     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
471     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
472     Value inserted = rewriter.create<LLVM::InsertElementOp>(
473         loc, typeConverter.convertType(oneDVectorType), extracted,
474         adaptor.source(), constant);
475 
476     // Potential insertion of resulting 1-D vector into array.
477     if (positionAttrs.size() > 1) {
478       auto nMinusOnePositionAttrs =
479           ArrayAttr::get(positionAttrs.drop_back(), context);
480       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
481                                                       adaptor.dest(), inserted,
482                                                       nMinusOnePositionAttrs);
483     }
484 
485     rewriter.replaceOp(op, inserted);
486     return success();
487   }
488 };
489 
490 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
491 ///
492 /// Example:
493 /// ```
494 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
495 /// ```
496 /// is rewritten into:
497 /// ```
498 ///  %r = splat %f0: vector<2x4xf32>
499 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
500 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
501 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
502 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
503 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
504 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
505 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
506 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
507 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
508 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
509 ///  // %r3 holds the final value.
510 /// ```
511 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
512 public:
513   using OpRewritePattern<FMAOp>::OpRewritePattern;
514 
515   LogicalResult matchAndRewrite(FMAOp op,
516                                 PatternRewriter &rewriter) const override {
517     auto vType = op.getVectorType();
518     if (vType.getRank() < 2)
519       return failure();
520 
521     auto loc = op.getLoc();
522     auto elemType = vType.getElementType();
523     Value zero = rewriter.create<ConstantOp>(loc, elemType,
524                                              rewriter.getZeroAttr(elemType));
525     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
526     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
527       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
528       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
529       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
530       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
531       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
532     }
533     rewriter.replaceOp(op, desc);
534     return success();
535   }
536 };
537 
538 // When ranks are different, InsertStridedSlice needs to extract a properly
539 // ranked vector from the destination vector into which to insert. This pattern
540 // only takes care of this part and forwards the rest of the conversion to
541 // another pattern that converts InsertStridedSlice for operands of the same
542 // rank.
543 //
544 // RewritePattern for InsertStridedSliceOp where source and destination vectors
545 // have different ranks. In this case:
546 //   1. the proper subvector is extracted from the destination vector
547 //   2. a new InsertStridedSlice op is created to insert the source in the
548 //   destination subvector
549 //   3. the destination subvector is inserted back in the proper place
550 //   4. the op is replaced by the result of step 3.
551 // The new InsertStridedSlice from step 2. will be picked up by a
552 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
553 class VectorInsertStridedSliceOpDifferentRankRewritePattern
554     : public OpRewritePattern<InsertStridedSliceOp> {
555 public:
556   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
557 
558   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
559                                 PatternRewriter &rewriter) const override {
560     auto srcType = op.getSourceVectorType();
561     auto dstType = op.getDestVectorType();
562 
563     if (op.offsets().getValue().empty())
564       return failure();
565 
566     auto loc = op.getLoc();
567     int64_t rankDiff = dstType.getRank() - srcType.getRank();
568     assert(rankDiff >= 0);
569     if (rankDiff == 0)
570       return failure();
571 
572     int64_t rankRest = dstType.getRank() - rankDiff;
573     // Extract / insert the subvector of matching rank and InsertStridedSlice
574     // on it.
575     Value extracted =
576         rewriter.create<ExtractOp>(loc, op.dest(),
577                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
578                                                   /*dropFront=*/rankRest));
579     // A different pattern will kick in for InsertStridedSlice with matching
580     // ranks.
581     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
582         loc, op.source(), extracted,
583         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
584         getI64SubArray(op.strides(), /*dropFront=*/0));
585     rewriter.replaceOpWithNewOp<InsertOp>(
586         op, stridedSliceInnerOp.getResult(), op.dest(),
587         getI64SubArray(op.offsets(), /*dropFront=*/0,
588                        /*dropFront=*/rankRest));
589     return success();
590   }
591 };
592 
593 // RewritePattern for InsertStridedSliceOp where source and destination vectors
594 // have the same rank. In this case, we reduce
595 //   1. the proper subvector is extracted from the destination vector
596 //   2. a new InsertStridedSlice op is created to insert the source in the
597 //   destination subvector
598 //   3. the destination subvector is inserted back in the proper place
599 //   4. the op is replaced by the result of step 3.
600 // The new InsertStridedSlice from step 2. will be picked up by a
601 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
602 class VectorInsertStridedSliceOpSameRankRewritePattern
603     : public OpRewritePattern<InsertStridedSliceOp> {
604 public:
605   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
606 
607   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
608                                 PatternRewriter &rewriter) const override {
609     auto srcType = op.getSourceVectorType();
610     auto dstType = op.getDestVectorType();
611 
612     if (op.offsets().getValue().empty())
613       return failure();
614 
615     int64_t rankDiff = dstType.getRank() - srcType.getRank();
616     assert(rankDiff >= 0);
617     if (rankDiff != 0)
618       return failure();
619 
620     if (srcType == dstType) {
621       rewriter.replaceOp(op, op.source());
622       return success();
623     }
624 
625     int64_t offset =
626         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
627     int64_t size = srcType.getShape().front();
628     int64_t stride =
629         op.strides().getValue().front().cast<IntegerAttr>().getInt();
630 
631     auto loc = op.getLoc();
632     Value res = op.dest();
633     // For each slice of the source vector along the most major dimension.
634     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
635          off += stride, ++idx) {
636       // 1. extract the proper subvector (or element) from source
637       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
638       if (extractedSource.getType().isa<VectorType>()) {
639         // 2. If we have a vector, extract the proper subvector from destination
640         // Otherwise we are at the element level and no need to recurse.
641         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
642         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
643         // smaller rank.
644         extractedSource = rewriter.create<InsertStridedSliceOp>(
645             loc, extractedSource, extractedDest,
646             getI64SubArray(op.offsets(), /* dropFront=*/1),
647             getI64SubArray(op.strides(), /* dropFront=*/1));
648       }
649       // 4. Insert the extractedSource into the res vector.
650       res = insertOne(rewriter, loc, extractedSource, res, off);
651     }
652 
653     rewriter.replaceOp(op, res);
654     return success();
655   }
656   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
657   /// bounded as the rank is strictly decreasing.
658   bool hasBoundedRewriteRecursion() const final { return true; }
659 };
660 
661 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
662 public:
663   explicit VectorTypeCastOpConversion(MLIRContext *context,
664                                       LLVMTypeConverter &typeConverter)
665       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
666                              typeConverter) {}
667 
668   LogicalResult
669   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
670                   ConversionPatternRewriter &rewriter) const override {
671     auto loc = op->getLoc();
672     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
673     MemRefType sourceMemRefType =
674         castOp.getOperand().getType().cast<MemRefType>();
675     MemRefType targetMemRefType =
676         castOp.getResult().getType().cast<MemRefType>();
677 
678     // Only static shape casts supported atm.
679     if (!sourceMemRefType.hasStaticShape() ||
680         !targetMemRefType.hasStaticShape())
681       return failure();
682 
683     auto llvmSourceDescriptorTy =
684         operands[0].getType().dyn_cast<LLVM::LLVMType>();
685     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
686       return failure();
687     MemRefDescriptor sourceMemRef(operands[0]);
688 
689     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
690                                       .dyn_cast_or_null<LLVM::LLVMType>();
691     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
692       return failure();
693 
694     int64_t offset;
695     SmallVector<int64_t, 4> strides;
696     auto successStrides =
697         getStridesAndOffset(sourceMemRefType, strides, offset);
698     bool isContiguous = (strides.back() == 1);
699     if (isContiguous) {
700       auto sizes = sourceMemRefType.getShape();
701       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
702         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
703           isContiguous = false;
704           break;
705         }
706       }
707     }
708     // Only contiguous source tensors supported atm.
709     if (failed(successStrides) || !isContiguous)
710       return failure();
711 
712     auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
713 
714     // Create descriptor.
715     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
716     Type llvmTargetElementTy = desc.getElementType();
717     // Set allocated ptr.
718     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
719     allocated =
720         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
721     desc.setAllocatedPtr(rewriter, loc, allocated);
722     // Set aligned ptr.
723     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
724     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
725     desc.setAlignedPtr(rewriter, loc, ptr);
726     // Fill offset 0.
727     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
728     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
729     desc.setOffset(rewriter, loc, zero);
730 
731     // Fill size and stride descriptors in memref.
732     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
733       int64_t index = indexedSize.index();
734       auto sizeAttr =
735           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
736       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
737       desc.setSize(rewriter, loc, index, size);
738       auto strideAttr =
739           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
740       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
741       desc.setStride(rewriter, loc, index, stride);
742     }
743 
744     rewriter.replaceOp(op, {desc});
745     return success();
746   }
747 };
748 
749 template <typename ConcreteOp>
750 void replaceTransferOp(ConversionPatternRewriter &rewriter,
751                        LLVMTypeConverter &typeConverter, Location loc,
752                        Operation *op, ArrayRef<Value> operands, Value dataPtr,
753                        Value mask);
754 
755 LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
756                                       Type type, LLVM::LLVMType &llvmType,
757                                       unsigned &align) {
758   auto convertedType = typeConverter.convertType(type);
759   if (!convertedType)
760     return failure();
761 
762   llvmType = convertedType.template cast<LLVM::LLVMType>();
763   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
764   align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
765   return success();
766 }
767 
768 template <>
769 void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
770                                        LLVMTypeConverter &typeConverter,
771                                        Location loc, Operation *op,
772                                        ArrayRef<Value> operands, Value dataPtr,
773                                        Value mask) {
774   auto xferOp = cast<TransferReadOp>(op);
775   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
776   VectorType fillType = xferOp.getVectorType();
777   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
778   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
779 
780   LLVM::LLVMType vecTy;
781   unsigned align;
782   if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
783                                         vecTy, align)))
784     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
785         op, vecTy, dataPtr, mask, ValueRange{fill},
786         rewriter.getI32IntegerAttr(align));
787 }
788 
789 template <>
790 void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
791                                         LLVMTypeConverter &typeConverter,
792                                         Location loc, Operation *op,
793                                         ArrayRef<Value> operands, Value dataPtr,
794                                         Value mask) {
795   auto adaptor = TransferWriteOpOperandAdaptor(operands);
796 
797   auto xferOp = cast<TransferWriteOp>(op);
798   LLVM::LLVMType vecTy;
799   unsigned align;
800   if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
801                                         vecTy, align)))
802     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
803         op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
804 }
805 
806 static TransferReadOpOperandAdaptor
807 getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
808   return TransferReadOpOperandAdaptor(operands);
809 }
810 
811 static TransferWriteOpOperandAdaptor
812 getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
813   return TransferWriteOpOperandAdaptor(operands);
814 }
815 
816 bool isMinorIdentity(AffineMap map, unsigned rank) {
817   if (map.getNumResults() < rank)
818     return false;
819   unsigned startDim = map.getNumDims() - rank;
820   for (unsigned i = 0; i < rank; ++i)
821     if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext()))
822       return false;
823   return true;
824 }
825 
826 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
827 /// sequence of:
828 /// 1. Bitcast or addrspacecast to vector form.
829 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
830 /// 3. Create a mask where offsetVector is compared against memref upper bound.
831 /// 4. Rewrite op as a masked read or write.
832 template <typename ConcreteOp>
833 class VectorTransferConversion : public ConvertToLLVMPattern {
834 public:
835   explicit VectorTransferConversion(MLIRContext *context,
836                                     LLVMTypeConverter &typeConv)
837       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
838                              typeConv) {}
839 
840   LogicalResult
841   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
842                   ConversionPatternRewriter &rewriter) const override {
843     auto xferOp = cast<ConcreteOp>(op);
844     auto adaptor = getTransferOpAdapter(xferOp, operands);
845 
846     if (xferOp.getVectorType().getRank() > 1 ||
847         llvm::size(xferOp.indices()) == 0)
848       return failure();
849     if (!isMinorIdentity(xferOp.permutation_map(),
850                          xferOp.getVectorType().getRank()))
851       return failure();
852 
853     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
854 
855     Location loc = op->getLoc();
856     Type i64Type = rewriter.getIntegerType(64);
857     MemRefType memRefType = xferOp.getMemRefType();
858 
859     // 1. Get the source/dst address as an LLVM vector pointer.
860     //    The vector pointer would always be on address space 0, therefore
861     //    addrspacecast shall be used when source/dst memrefs are not on
862     //    address space 0.
863     // TODO: support alignment when possible.
864     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
865                                adaptor.indices(), rewriter, getModule());
866     auto vecTy =
867         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
868     Value vectorDataPtr;
869     if (memRefType.getMemorySpace() == 0)
870       vectorDataPtr =
871           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
872     else
873       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
874           loc, vecTy.getPointerTo(), dataPtr);
875 
876     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
877     unsigned vecWidth = vecTy.getVectorNumElements();
878     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
879     SmallVector<int64_t, 8> indices;
880     indices.reserve(vecWidth);
881     for (unsigned i = 0; i < vecWidth; ++i)
882       indices.push_back(i);
883     Value linearIndices = rewriter.create<ConstantOp>(
884         loc, vectorCmpType,
885         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
886     linearIndices = rewriter.create<LLVM::DialectCastOp>(
887         loc, toLLVMTy(vectorCmpType), linearIndices);
888 
889     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
890     // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last
891     // `k` dimensions here.
892     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
893     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
894     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
895     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
896     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
897 
898     // 4. Let dim the memref dimension, compute the vector comparison mask:
899     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
900     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
901     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
902     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
903     Value mask =
904         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
905     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
906                                                 mask);
907 
908     // 5. Rewrite as a masked read / write.
909     replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
910                                   vectorDataPtr, mask);
911 
912     return success();
913   }
914 };
915 
916 class VectorPrintOpConversion : public ConvertToLLVMPattern {
917 public:
918   explicit VectorPrintOpConversion(MLIRContext *context,
919                                    LLVMTypeConverter &typeConverter)
920       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
921                              typeConverter) {}
922 
923   // Proof-of-concept lowering implementation that relies on a small
924   // runtime support library, which only needs to provide a few
925   // printing methods (single value for all data types, opening/closing
926   // bracket, comma, newline). The lowering fully unrolls a vector
927   // in terms of these elementary printing operations. The advantage
928   // of this approach is that the library can remain unaware of all
929   // low-level implementation details of vectors while still supporting
930   // output of any shaped and dimensioned vector. Due to full unrolling,
931   // this approach is less suited for very large vectors though.
932   //
933   // TODO(ajcbik): rely solely on libc in future? something else?
934   //
935   LogicalResult
936   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
937                   ConversionPatternRewriter &rewriter) const override {
938     auto printOp = cast<vector::PrintOp>(op);
939     auto adaptor = vector::PrintOpOperandAdaptor(operands);
940     Type printType = printOp.getPrintType();
941 
942     if (typeConverter.convertType(printType) == nullptr)
943       return failure();
944 
945     // Make sure element type has runtime support (currently just Float/Double).
946     VectorType vectorType = printType.dyn_cast<VectorType>();
947     Type eltType = vectorType ? vectorType.getElementType() : printType;
948     int64_t rank = vectorType ? vectorType.getRank() : 0;
949     Operation *printer;
950     if (eltType.isSignlessInteger(1))
951       printer = getPrintI1(op);
952     else if (eltType.isSignlessInteger(32))
953       printer = getPrintI32(op);
954     else if (eltType.isSignlessInteger(64))
955       printer = getPrintI64(op);
956     else if (eltType.isF32())
957       printer = getPrintFloat(op);
958     else if (eltType.isF64())
959       printer = getPrintDouble(op);
960     else
961       return failure();
962 
963     // Unroll vector into elementary print calls.
964     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
965     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
966     rewriter.eraseOp(op);
967     return success();
968   }
969 
970 private:
971   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
972                  Value value, VectorType vectorType, Operation *printer,
973                  int64_t rank) const {
974     Location loc = op->getLoc();
975     if (rank == 0) {
976       emitCall(rewriter, loc, printer, value);
977       return;
978     }
979 
980     emitCall(rewriter, loc, getPrintOpen(op));
981     Operation *printComma = getPrintComma(op);
982     int64_t dim = vectorType.getDimSize(0);
983     for (int64_t d = 0; d < dim; ++d) {
984       auto reducedType =
985           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
986       auto llvmType = typeConverter.convertType(
987           rank > 1 ? reducedType : vectorType.getElementType());
988       Value nestedVal =
989           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
990       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
991       if (d != dim - 1)
992         emitCall(rewriter, loc, printComma);
993     }
994     emitCall(rewriter, loc, getPrintClose(op));
995   }
996 
997   // Helper to emit a call.
998   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
999                        Operation *ref, ValueRange params = ValueRange()) {
1000     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1001                                   rewriter.getSymbolRefAttr(ref), params);
1002   }
1003 
1004   // Helper for printer method declaration (first hit) and lookup.
1005   static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
1006                              StringRef name, ArrayRef<LLVM::LLVMType> params) {
1007     auto module = op->getParentOfType<ModuleOp>();
1008     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1009     if (func)
1010       return func;
1011     OpBuilder moduleBuilder(module.getBodyRegion());
1012     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1013         op->getLoc(), name,
1014         LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
1015                                       params, /*isVarArg=*/false));
1016   }
1017 
1018   // Helpers for method names.
1019   Operation *getPrintI1(Operation *op) const {
1020     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1021     return getPrint(op, dialect, "print_i1",
1022                     LLVM::LLVMType::getInt1Ty(dialect));
1023   }
1024   Operation *getPrintI32(Operation *op) const {
1025     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1026     return getPrint(op, dialect, "print_i32",
1027                     LLVM::LLVMType::getInt32Ty(dialect));
1028   }
1029   Operation *getPrintI64(Operation *op) const {
1030     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1031     return getPrint(op, dialect, "print_i64",
1032                     LLVM::LLVMType::getInt64Ty(dialect));
1033   }
1034   Operation *getPrintFloat(Operation *op) const {
1035     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1036     return getPrint(op, dialect, "print_f32",
1037                     LLVM::LLVMType::getFloatTy(dialect));
1038   }
1039   Operation *getPrintDouble(Operation *op) const {
1040     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1041     return getPrint(op, dialect, "print_f64",
1042                     LLVM::LLVMType::getDoubleTy(dialect));
1043   }
1044   Operation *getPrintOpen(Operation *op) const {
1045     return getPrint(op, typeConverter.getDialect(), "print_open", {});
1046   }
1047   Operation *getPrintClose(Operation *op) const {
1048     return getPrint(op, typeConverter.getDialect(), "print_close", {});
1049   }
1050   Operation *getPrintComma(Operation *op) const {
1051     return getPrint(op, typeConverter.getDialect(), "print_comma", {});
1052   }
1053   Operation *getPrintNewline(Operation *op) const {
1054     return getPrint(op, typeConverter.getDialect(), "print_newline", {});
1055   }
1056 };
1057 
1058 /// Progressive lowering of StridedSliceOp to either:
1059 ///   1. extractelement + insertelement for the 1-D case
1060 ///   2. extract + optional strided_slice + insert for the n-D case.
1061 class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
1062 public:
1063   using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
1064 
1065   LogicalResult matchAndRewrite(StridedSliceOp op,
1066                                 PatternRewriter &rewriter) const override {
1067     auto dstType = op.getResult().getType().cast<VectorType>();
1068 
1069     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1070 
1071     int64_t offset =
1072         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1073     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1074     int64_t stride =
1075         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1076 
1077     auto loc = op.getLoc();
1078     auto elemType = dstType.getElementType();
1079     assert(elemType.isSignlessIntOrIndexOrFloat());
1080     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1081                                              rewriter.getZeroAttr(elemType));
1082     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1083     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1084          off += stride, ++idx) {
1085       Value extracted = extractOne(rewriter, loc, op.vector(), off);
1086       if (op.offsets().getValue().size() > 1) {
1087         extracted = rewriter.create<StridedSliceOp>(
1088             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
1089             getI64SubArray(op.sizes(), /* dropFront=*/1),
1090             getI64SubArray(op.strides(), /* dropFront=*/1));
1091       }
1092       res = insertOne(rewriter, loc, extracted, res, idx);
1093     }
1094     rewriter.replaceOp(op, {res});
1095     return success();
1096   }
1097   /// This pattern creates recursive StridedSliceOp, but the recursion is
1098   /// bounded as the rank is strictly decreasing.
1099   bool hasBoundedRewriteRecursion() const final { return true; }
1100 };
1101 
1102 } // namespace
1103 
1104 /// Populate the given list with patterns that convert from Vector to LLVM.
1105 void mlir::populateVectorToLLVMConversionPatterns(
1106     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1107   MLIRContext *ctx = converter.getDialect()->getContext();
1108   // clang-format off
1109   patterns.insert<VectorFMAOpNDRewritePattern,
1110                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1111                   VectorInsertStridedSliceOpSameRankRewritePattern,
1112                   VectorStridedSliceOpConversion>(ctx);
1113   patterns
1114       .insert<VectorReductionOpConversion,
1115               VectorShuffleOpConversion,
1116               VectorExtractElementOpConversion,
1117               VectorExtractOpConversion,
1118               VectorFMAOp1DConversion,
1119               VectorInsertElementOpConversion,
1120               VectorInsertOpConversion,
1121               VectorPrintOpConversion,
1122               VectorTransferConversion<TransferReadOp>,
1123               VectorTransferConversion<TransferWriteOp>,
1124               VectorTypeCastOpConversion>(ctx, converter);
1125   // clang-format on
1126 }
1127 
1128 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1129     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1130   MLIRContext *ctx = converter.getDialect()->getContext();
1131   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1132 }
1133 
1134 namespace {
1135 struct LowerVectorToLLVMPass
1136     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1137   void runOnOperation() override;
1138 };
1139 } // namespace
1140 
1141 void LowerVectorToLLVMPass::runOnOperation() {
1142   // Perform progressive lowering of operations on slices and
1143   // all contraction operations. Also applies folding and DCE.
1144   {
1145     OwningRewritePatternList patterns;
1146     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1147     populateVectorContractLoweringPatterns(patterns, &getContext());
1148     applyPatternsAndFoldGreedily(getOperation(), patterns);
1149   }
1150 
1151   // Convert to the LLVM IR dialect.
1152   LLVMTypeConverter converter(&getContext());
1153   OwningRewritePatternList patterns;
1154   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1155   populateVectorToLLVMConversionPatterns(converter, patterns);
1156   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1157   populateStdToLLVMConversionPatterns(converter, patterns);
1158 
1159   LLVMConversionTarget target(getContext());
1160   target.addDynamicallyLegalOp<FuncOp>(
1161       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1162   if (failed(applyPartialConversion(getOperation(), target, patterns,
1163                                     &converter))) {
1164     signalPassFailure();
1165   }
1166 }
1167 
1168 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
1169   return std::make_unique<LowerVectorToLLVMPass>();
1170 }
1171