xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision be16075bfca4fa62083044b990c760c5d14a2bac)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 template <typename T>
38 static LLVM::LLVMType getPtrToElementType(T containerType,
39                                           LLVMTypeConverter &typeConverter) {
40   return typeConverter.convertType(containerType.getElementType())
41       .template cast<LLVM::LLVMType>()
42       .getPointerTo();
43 }
44 
45 // Helper to reduce vector type by one rank at front.
46 static VectorType reducedVectorTypeFront(VectorType tp) {
47   assert((tp.getRank() > 1) && "unlowerable vector type");
48   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
49 }
50 
51 // Helper to reduce vector type by *all* but one rank at back.
52 static VectorType reducedVectorTypeBack(VectorType tp) {
53   assert((tp.getRank() > 1) && "unlowerable vector type");
54   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
55 }
56 
57 // Helper that picks the proper sequence for inserting.
58 static Value insertOne(ConversionPatternRewriter &rewriter,
59                        LLVMTypeConverter &typeConverter, Location loc,
60                        Value val1, Value val2, Type llvmType, int64_t rank,
61                        int64_t pos) {
62   if (rank == 1) {
63     auto idxType = rewriter.getIndexType();
64     auto constant = rewriter.create<LLVM::ConstantOp>(
65         loc, typeConverter.convertType(idxType),
66         rewriter.getIntegerAttr(idxType, pos));
67     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
68                                                   constant);
69   }
70   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
71                                               rewriter.getI64ArrayAttr(pos));
72 }
73 
74 // Helper that picks the proper sequence for inserting.
75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
76                        Value into, int64_t offset) {
77   auto vectorType = into.getType().cast<VectorType>();
78   if (vectorType.getRank() > 1)
79     return rewriter.create<InsertOp>(loc, from, into, offset);
80   return rewriter.create<vector::InsertElementOp>(
81       loc, vectorType, from, into,
82       rewriter.create<ConstantIndexOp>(loc, offset));
83 }
84 
85 // Helper that picks the proper sequence for extracting.
86 static Value extractOne(ConversionPatternRewriter &rewriter,
87                         LLVMTypeConverter &typeConverter, Location loc,
88                         Value val, Type llvmType, int64_t rank, int64_t pos) {
89   if (rank == 1) {
90     auto idxType = rewriter.getIndexType();
91     auto constant = rewriter.create<LLVM::ConstantOp>(
92         loc, typeConverter.convertType(idxType),
93         rewriter.getIntegerAttr(idxType, pos));
94     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
95                                                    constant);
96   }
97   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
98                                                rewriter.getI64ArrayAttr(pos));
99 }
100 
101 // Helper that picks the proper sequence for extracting.
102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
103                         int64_t offset) {
104   auto vectorType = vector.getType().cast<VectorType>();
105   if (vectorType.getRank() > 1)
106     return rewriter.create<ExtractOp>(loc, vector, offset);
107   return rewriter.create<vector::ExtractElementOp>(
108       loc, vectorType.getElementType(), vector,
109       rewriter.create<ConstantIndexOp>(loc, offset));
110 }
111 
112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing.
114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
115                                               unsigned dropFront = 0,
116                                               unsigned dropBack = 0) {
117   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
118   auto range = arrayAttr.getAsRange<IntegerAttr>();
119   SmallVector<int64_t, 4> res;
120   res.reserve(arrayAttr.size() - dropFront - dropBack);
121   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
122        it != eit; ++it)
123     res.push_back((*it).getValue().getSExtValue());
124   return res;
125 }
126 
127 namespace {
128 
129 /// Conversion pattern for a vector.matrix_multiply.
130 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
131 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
132 public:
133   explicit VectorMatmulOpConversion(MLIRContext *context,
134                                     LLVMTypeConverter &typeConverter)
135       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
136                              typeConverter) {}
137 
138   LogicalResult
139   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
140                   ConversionPatternRewriter &rewriter) const override {
141     auto matmulOp = cast<vector::MatmulOp>(op);
142     auto adaptor = vector::MatmulOpOperandAdaptor(operands);
143     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
144         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
145         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
146         matmulOp.rhs_columns());
147     return success();
148   }
149 };
150 
151 class VectorReductionOpConversion : public ConvertToLLVMPattern {
152 public:
153   explicit VectorReductionOpConversion(MLIRContext *context,
154                                        LLVMTypeConverter &typeConverter)
155       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
156                              typeConverter) {}
157 
158   LogicalResult
159   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
160                   ConversionPatternRewriter &rewriter) const override {
161     auto reductionOp = cast<vector::ReductionOp>(op);
162     auto kind = reductionOp.kind();
163     Type eltType = reductionOp.dest().getType();
164     Type llvmType = typeConverter.convertType(eltType);
165     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
166       // Integer reductions: add/mul/min/max/and/or/xor.
167       if (kind == "add")
168         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
169             op, llvmType, operands[0]);
170       else if (kind == "mul")
171         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
172             op, llvmType, operands[0]);
173       else if (kind == "min")
174         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
175             op, llvmType, operands[0]);
176       else if (kind == "max")
177         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
178             op, llvmType, operands[0]);
179       else if (kind == "and")
180         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
181             op, llvmType, operands[0]);
182       else if (kind == "or")
183         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
184             op, llvmType, operands[0]);
185       else if (kind == "xor")
186         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
187             op, llvmType, operands[0]);
188       else
189         return failure();
190       return success();
191 
192     } else if (eltType.isF32() || eltType.isF64()) {
193       // Floating-point reductions: add/mul/min/max
194       if (kind == "add") {
195         // Optional accumulator (or zero).
196         Value acc = operands.size() > 1 ? operands[1]
197                                         : rewriter.create<LLVM::ConstantOp>(
198                                               op->getLoc(), llvmType,
199                                               rewriter.getZeroAttr(eltType));
200         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
201             op, llvmType, acc, operands[0]);
202       } else if (kind == "mul") {
203         // Optional accumulator (or one).
204         Value acc = operands.size() > 1
205                         ? operands[1]
206                         : rewriter.create<LLVM::ConstantOp>(
207                               op->getLoc(), llvmType,
208                               rewriter.getFloatAttr(eltType, 1.0));
209         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
210             op, llvmType, acc, operands[0]);
211       } else if (kind == "min")
212         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
213             op, llvmType, operands[0]);
214       else if (kind == "max")
215         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
216             op, llvmType, operands[0]);
217       else
218         return failure();
219       return success();
220     }
221     return failure();
222   }
223 };
224 
225 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
226 public:
227   explicit VectorShuffleOpConversion(MLIRContext *context,
228                                      LLVMTypeConverter &typeConverter)
229       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
230                              typeConverter) {}
231 
232   LogicalResult
233   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
234                   ConversionPatternRewriter &rewriter) const override {
235     auto loc = op->getLoc();
236     auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
237     auto shuffleOp = cast<vector::ShuffleOp>(op);
238     auto v1Type = shuffleOp.getV1VectorType();
239     auto v2Type = shuffleOp.getV2VectorType();
240     auto vectorType = shuffleOp.getVectorType();
241     Type llvmType = typeConverter.convertType(vectorType);
242     auto maskArrayAttr = shuffleOp.mask();
243 
244     // Bail if result type cannot be lowered.
245     if (!llvmType)
246       return failure();
247 
248     // Get rank and dimension sizes.
249     int64_t rank = vectorType.getRank();
250     assert(v1Type.getRank() == rank);
251     assert(v2Type.getRank() == rank);
252     int64_t v1Dim = v1Type.getDimSize(0);
253 
254     // For rank 1, where both operands have *exactly* the same vector type,
255     // there is direct shuffle support in LLVM. Use it!
256     if (rank == 1 && v1Type == v2Type) {
257       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
258           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
259       rewriter.replaceOp(op, shuffle);
260       return success();
261     }
262 
263     // For all other cases, insert the individual values individually.
264     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
265     int64_t insPos = 0;
266     for (auto en : llvm::enumerate(maskArrayAttr)) {
267       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
268       Value value = adaptor.v1();
269       if (extPos >= v1Dim) {
270         extPos -= v1Dim;
271         value = adaptor.v2();
272       }
273       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
274                                  rank, extPos);
275       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
276                          llvmType, rank, insPos++);
277     }
278     rewriter.replaceOp(op, insert);
279     return success();
280   }
281 };
282 
283 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
284 public:
285   explicit VectorExtractElementOpConversion(MLIRContext *context,
286                                             LLVMTypeConverter &typeConverter)
287       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
288                              context, typeConverter) {}
289 
290   LogicalResult
291   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
292                   ConversionPatternRewriter &rewriter) const override {
293     auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
294     auto extractEltOp = cast<vector::ExtractElementOp>(op);
295     auto vectorType = extractEltOp.getVectorType();
296     auto llvmType = typeConverter.convertType(vectorType.getElementType());
297 
298     // Bail if result type cannot be lowered.
299     if (!llvmType)
300       return failure();
301 
302     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
303         op, llvmType, adaptor.vector(), adaptor.position());
304     return success();
305   }
306 };
307 
308 class VectorExtractOpConversion : public ConvertToLLVMPattern {
309 public:
310   explicit VectorExtractOpConversion(MLIRContext *context,
311                                      LLVMTypeConverter &typeConverter)
312       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
313                              typeConverter) {}
314 
315   LogicalResult
316   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
317                   ConversionPatternRewriter &rewriter) const override {
318     auto loc = op->getLoc();
319     auto adaptor = vector::ExtractOpOperandAdaptor(operands);
320     auto extractOp = cast<vector::ExtractOp>(op);
321     auto vectorType = extractOp.getVectorType();
322     auto resultType = extractOp.getResult().getType();
323     auto llvmResultType = typeConverter.convertType(resultType);
324     auto positionArrayAttr = extractOp.position();
325 
326     // Bail if result type cannot be lowered.
327     if (!llvmResultType)
328       return failure();
329 
330     // One-shot extraction of vector from array (only requires extractvalue).
331     if (resultType.isa<VectorType>()) {
332       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
333           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
334       rewriter.replaceOp(op, extracted);
335       return success();
336     }
337 
338     // Potential extraction of 1-D vector from array.
339     auto *context = op->getContext();
340     Value extracted = adaptor.vector();
341     auto positionAttrs = positionArrayAttr.getValue();
342     if (positionAttrs.size() > 1) {
343       auto oneDVectorType = reducedVectorTypeBack(vectorType);
344       auto nMinusOnePositionAttrs =
345           ArrayAttr::get(positionAttrs.drop_back(), context);
346       extracted = rewriter.create<LLVM::ExtractValueOp>(
347           loc, typeConverter.convertType(oneDVectorType), extracted,
348           nMinusOnePositionAttrs);
349     }
350 
351     // Remaining extraction of element from 1-D LLVM vector
352     auto position = positionAttrs.back().cast<IntegerAttr>();
353     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
354     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
355     extracted =
356         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
357     rewriter.replaceOp(op, extracted);
358 
359     return success();
360   }
361 };
362 
363 /// Conversion pattern that turns a vector.fma on a 1-D vector
364 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
365 /// This does not match vectors of n >= 2 rank.
366 ///
367 /// Example:
368 /// ```
369 ///  vector.fma %a, %a, %a : vector<8xf32>
370 /// ```
371 /// is converted to:
372 /// ```
373 ///  llvm.intr.fma %va, %va, %va:
374 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
375 ///    -> !llvm<"<8 x float>">
376 /// ```
377 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
378 public:
379   explicit VectorFMAOp1DConversion(MLIRContext *context,
380                                    LLVMTypeConverter &typeConverter)
381       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
382                              typeConverter) {}
383 
384   LogicalResult
385   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
386                   ConversionPatternRewriter &rewriter) const override {
387     auto adaptor = vector::FMAOpOperandAdaptor(operands);
388     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
389     VectorType vType = fmaOp.getVectorType();
390     if (vType.getRank() != 1)
391       return failure();
392     rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
393                                              adaptor.acc());
394     return success();
395   }
396 };
397 
398 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
399 public:
400   explicit VectorInsertElementOpConversion(MLIRContext *context,
401                                            LLVMTypeConverter &typeConverter)
402       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
403                              context, typeConverter) {}
404 
405   LogicalResult
406   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
407                   ConversionPatternRewriter &rewriter) const override {
408     auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
409     auto insertEltOp = cast<vector::InsertElementOp>(op);
410     auto vectorType = insertEltOp.getDestVectorType();
411     auto llvmType = typeConverter.convertType(vectorType);
412 
413     // Bail if result type cannot be lowered.
414     if (!llvmType)
415       return failure();
416 
417     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
418         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
419     return success();
420   }
421 };
422 
423 class VectorInsertOpConversion : public ConvertToLLVMPattern {
424 public:
425   explicit VectorInsertOpConversion(MLIRContext *context,
426                                     LLVMTypeConverter &typeConverter)
427       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
428                              typeConverter) {}
429 
430   LogicalResult
431   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
432                   ConversionPatternRewriter &rewriter) const override {
433     auto loc = op->getLoc();
434     auto adaptor = vector::InsertOpOperandAdaptor(operands);
435     auto insertOp = cast<vector::InsertOp>(op);
436     auto sourceType = insertOp.getSourceType();
437     auto destVectorType = insertOp.getDestVectorType();
438     auto llvmResultType = typeConverter.convertType(destVectorType);
439     auto positionArrayAttr = insertOp.position();
440 
441     // Bail if result type cannot be lowered.
442     if (!llvmResultType)
443       return failure();
444 
445     // One-shot insertion of a vector into an array (only requires insertvalue).
446     if (sourceType.isa<VectorType>()) {
447       Value inserted = rewriter.create<LLVM::InsertValueOp>(
448           loc, llvmResultType, adaptor.dest(), adaptor.source(),
449           positionArrayAttr);
450       rewriter.replaceOp(op, inserted);
451       return success();
452     }
453 
454     // Potential extraction of 1-D vector from array.
455     auto *context = op->getContext();
456     Value extracted = adaptor.dest();
457     auto positionAttrs = positionArrayAttr.getValue();
458     auto position = positionAttrs.back().cast<IntegerAttr>();
459     auto oneDVectorType = destVectorType;
460     if (positionAttrs.size() > 1) {
461       oneDVectorType = reducedVectorTypeBack(destVectorType);
462       auto nMinusOnePositionAttrs =
463           ArrayAttr::get(positionAttrs.drop_back(), context);
464       extracted = rewriter.create<LLVM::ExtractValueOp>(
465           loc, typeConverter.convertType(oneDVectorType), extracted,
466           nMinusOnePositionAttrs);
467     }
468 
469     // Insertion of an element into a 1-D LLVM vector.
470     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
471     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
472     Value inserted = rewriter.create<LLVM::InsertElementOp>(
473         loc, typeConverter.convertType(oneDVectorType), extracted,
474         adaptor.source(), constant);
475 
476     // Potential insertion of resulting 1-D vector into array.
477     if (positionAttrs.size() > 1) {
478       auto nMinusOnePositionAttrs =
479           ArrayAttr::get(positionAttrs.drop_back(), context);
480       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
481                                                       adaptor.dest(), inserted,
482                                                       nMinusOnePositionAttrs);
483     }
484 
485     rewriter.replaceOp(op, inserted);
486     return success();
487   }
488 };
489 
490 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
491 ///
492 /// Example:
493 /// ```
494 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
495 /// ```
496 /// is rewritten into:
497 /// ```
498 ///  %r = splat %f0: vector<2x4xf32>
499 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
500 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
501 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
502 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
503 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
504 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
505 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
506 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
507 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
508 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
509 ///  // %r3 holds the final value.
510 /// ```
511 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
512 public:
513   using OpRewritePattern<FMAOp>::OpRewritePattern;
514 
515   LogicalResult matchAndRewrite(FMAOp op,
516                                 PatternRewriter &rewriter) const override {
517     auto vType = op.getVectorType();
518     if (vType.getRank() < 2)
519       return failure();
520 
521     auto loc = op.getLoc();
522     auto elemType = vType.getElementType();
523     Value zero = rewriter.create<ConstantOp>(loc, elemType,
524                                              rewriter.getZeroAttr(elemType));
525     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
526     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
527       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
528       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
529       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
530       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
531       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
532     }
533     rewriter.replaceOp(op, desc);
534     return success();
535   }
536 };
537 
538 // When ranks are different, InsertStridedSlice needs to extract a properly
539 // ranked vector from the destination vector into which to insert. This pattern
540 // only takes care of this part and forwards the rest of the conversion to
541 // another pattern that converts InsertStridedSlice for operands of the same
542 // rank.
543 //
544 // RewritePattern for InsertStridedSliceOp where source and destination vectors
545 // have different ranks. In this case:
546 //   1. the proper subvector is extracted from the destination vector
547 //   2. a new InsertStridedSlice op is created to insert the source in the
548 //   destination subvector
549 //   3. the destination subvector is inserted back in the proper place
550 //   4. the op is replaced by the result of step 3.
551 // The new InsertStridedSlice from step 2. will be picked up by a
552 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
553 class VectorInsertStridedSliceOpDifferentRankRewritePattern
554     : public OpRewritePattern<InsertStridedSliceOp> {
555 public:
556   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
557 
558   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
559                                 PatternRewriter &rewriter) const override {
560     auto srcType = op.getSourceVectorType();
561     auto dstType = op.getDestVectorType();
562 
563     if (op.offsets().getValue().empty())
564       return failure();
565 
566     auto loc = op.getLoc();
567     int64_t rankDiff = dstType.getRank() - srcType.getRank();
568     assert(rankDiff >= 0);
569     if (rankDiff == 0)
570       return failure();
571 
572     int64_t rankRest = dstType.getRank() - rankDiff;
573     // Extract / insert the subvector of matching rank and InsertStridedSlice
574     // on it.
575     Value extracted =
576         rewriter.create<ExtractOp>(loc, op.dest(),
577                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
578                                                   /*dropFront=*/rankRest));
579     // A different pattern will kick in for InsertStridedSlice with matching
580     // ranks.
581     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
582         loc, op.source(), extracted,
583         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
584         getI64SubArray(op.strides(), /*dropFront=*/0));
585     rewriter.replaceOpWithNewOp<InsertOp>(
586         op, stridedSliceInnerOp.getResult(), op.dest(),
587         getI64SubArray(op.offsets(), /*dropFront=*/0,
588                        /*dropFront=*/rankRest));
589     return success();
590   }
591 };
592 
593 // RewritePattern for InsertStridedSliceOp where source and destination vectors
594 // have the same rank. In this case, we reduce
595 //   1. the proper subvector is extracted from the destination vector
596 //   2. a new InsertStridedSlice op is created to insert the source in the
597 //   destination subvector
598 //   3. the destination subvector is inserted back in the proper place
599 //   4. the op is replaced by the result of step 3.
600 // The new InsertStridedSlice from step 2. will be picked up by a
601 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
602 class VectorInsertStridedSliceOpSameRankRewritePattern
603     : public OpRewritePattern<InsertStridedSliceOp> {
604 public:
605   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
606 
607   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
608                                 PatternRewriter &rewriter) const override {
609     auto srcType = op.getSourceVectorType();
610     auto dstType = op.getDestVectorType();
611 
612     if (op.offsets().getValue().empty())
613       return failure();
614 
615     int64_t rankDiff = dstType.getRank() - srcType.getRank();
616     assert(rankDiff >= 0);
617     if (rankDiff != 0)
618       return failure();
619 
620     if (srcType == dstType) {
621       rewriter.replaceOp(op, op.source());
622       return success();
623     }
624 
625     int64_t offset =
626         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
627     int64_t size = srcType.getShape().front();
628     int64_t stride =
629         op.strides().getValue().front().cast<IntegerAttr>().getInt();
630 
631     auto loc = op.getLoc();
632     Value res = op.dest();
633     // For each slice of the source vector along the most major dimension.
634     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
635          off += stride, ++idx) {
636       // 1. extract the proper subvector (or element) from source
637       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
638       if (extractedSource.getType().isa<VectorType>()) {
639         // 2. If we have a vector, extract the proper subvector from destination
640         // Otherwise we are at the element level and no need to recurse.
641         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
642         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
643         // smaller rank.
644         extractedSource = rewriter.create<InsertStridedSliceOp>(
645             loc, extractedSource, extractedDest,
646             getI64SubArray(op.offsets(), /* dropFront=*/1),
647             getI64SubArray(op.strides(), /* dropFront=*/1));
648       }
649       // 4. Insert the extractedSource into the res vector.
650       res = insertOne(rewriter, loc, extractedSource, res, off);
651     }
652 
653     rewriter.replaceOp(op, res);
654     return success();
655   }
656   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
657   /// bounded as the rank is strictly decreasing.
658   bool hasBoundedRewriteRecursion() const final { return true; }
659 };
660 
661 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
662 public:
663   explicit VectorTypeCastOpConversion(MLIRContext *context,
664                                       LLVMTypeConverter &typeConverter)
665       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
666                              typeConverter) {}
667 
668   LogicalResult
669   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
670                   ConversionPatternRewriter &rewriter) const override {
671     auto loc = op->getLoc();
672     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
673     MemRefType sourceMemRefType =
674         castOp.getOperand().getType().cast<MemRefType>();
675     MemRefType targetMemRefType =
676         castOp.getResult().getType().cast<MemRefType>();
677 
678     // Only static shape casts supported atm.
679     if (!sourceMemRefType.hasStaticShape() ||
680         !targetMemRefType.hasStaticShape())
681       return failure();
682 
683     auto llvmSourceDescriptorTy =
684         operands[0].getType().dyn_cast<LLVM::LLVMType>();
685     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
686       return failure();
687     MemRefDescriptor sourceMemRef(operands[0]);
688 
689     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
690                                       .dyn_cast_or_null<LLVM::LLVMType>();
691     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
692       return failure();
693 
694     int64_t offset;
695     SmallVector<int64_t, 4> strides;
696     auto successStrides =
697         getStridesAndOffset(sourceMemRefType, strides, offset);
698     bool isContiguous = (strides.back() == 1);
699     if (isContiguous) {
700       auto sizes = sourceMemRefType.getShape();
701       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
702         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
703           isContiguous = false;
704           break;
705         }
706       }
707     }
708     // Only contiguous source tensors supported atm.
709     if (failed(successStrides) || !isContiguous)
710       return failure();
711 
712     auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
713 
714     // Create descriptor.
715     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
716     Type llvmTargetElementTy = desc.getElementType();
717     // Set allocated ptr.
718     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
719     allocated =
720         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
721     desc.setAllocatedPtr(rewriter, loc, allocated);
722     // Set aligned ptr.
723     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
724     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
725     desc.setAlignedPtr(rewriter, loc, ptr);
726     // Fill offset 0.
727     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
728     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
729     desc.setOffset(rewriter, loc, zero);
730 
731     // Fill size and stride descriptors in memref.
732     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
733       int64_t index = indexedSize.index();
734       auto sizeAttr =
735           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
736       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
737       desc.setSize(rewriter, loc, index, size);
738       auto strideAttr =
739           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
740       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
741       desc.setStride(rewriter, loc, index, stride);
742     }
743 
744     rewriter.replaceOp(op, {desc});
745     return success();
746   }
747 };
748 
749 template <typename ConcreteOp>
750 void replaceTransferOp(ConversionPatternRewriter &rewriter,
751                        LLVMTypeConverter &typeConverter, Location loc,
752                        Operation *op, ArrayRef<Value> operands, Value dataPtr,
753                        Value mask);
754 
755 template <>
756 void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
757                                        LLVMTypeConverter &typeConverter,
758                                        Location loc, Operation *op,
759                                        ArrayRef<Value> operands, Value dataPtr,
760                                        Value mask) {
761   auto xferOp = cast<TransferReadOp>(op);
762   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
763   VectorType fillType = xferOp.getVectorType();
764   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
765   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
766 
767   auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
768   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
769       op, vecTy, dataPtr, mask, ValueRange{fill},
770       rewriter.getI32IntegerAttr(1));
771 }
772 
773 template <>
774 void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
775                                         LLVMTypeConverter &typeConverter,
776                                         Location loc, Operation *op,
777                                         ArrayRef<Value> operands, Value dataPtr,
778                                         Value mask) {
779   auto adaptor = TransferWriteOpOperandAdaptor(operands);
780   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
781       op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
782 }
783 
784 static TransferReadOpOperandAdaptor
785 getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
786   return TransferReadOpOperandAdaptor(operands);
787 }
788 
789 static TransferWriteOpOperandAdaptor
790 getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
791   return TransferWriteOpOperandAdaptor(operands);
792 }
793 
794 bool isMinorIdentity(AffineMap map, unsigned rank) {
795   if (map.getNumResults() < rank)
796     return false;
797   unsigned startDim = map.getNumDims() - rank;
798   for (unsigned i = 0; i < rank; ++i)
799     if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext()))
800       return false;
801   return true;
802 }
803 
804 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
805 /// sequence of:
806 /// 1. Bitcast or addrspacecast to vector form.
807 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
808 /// 3. Create a mask where offsetVector is compared against memref upper bound.
809 /// 4. Rewrite op as a masked read or write.
810 template <typename ConcreteOp>
811 class VectorTransferConversion : public ConvertToLLVMPattern {
812 public:
813   explicit VectorTransferConversion(MLIRContext *context,
814                                     LLVMTypeConverter &typeConv)
815       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
816                              typeConv) {}
817 
818   LogicalResult
819   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
820                   ConversionPatternRewriter &rewriter) const override {
821     auto xferOp = cast<ConcreteOp>(op);
822     auto adaptor = getTransferOpAdapter(xferOp, operands);
823 
824     if (xferOp.getVectorType().getRank() > 1 ||
825         llvm::size(xferOp.indices()) == 0)
826       return failure();
827     if (!isMinorIdentity(xferOp.permutation_map(),
828                          xferOp.getVectorType().getRank()))
829       return failure();
830 
831     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
832 
833     Location loc = op->getLoc();
834     Type i64Type = rewriter.getIntegerType(64);
835     MemRefType memRefType = xferOp.getMemRefType();
836 
837     // 1. Get the source/dst address as an LLVM vector pointer.
838     //    The vector pointer would always be on address space 0, therefore
839     //    addrspacecast shall be used when source/dst memrefs are not on
840     //    address space 0.
841     // TODO: support alignment when possible.
842     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
843                                adaptor.indices(), rewriter, getModule());
844     auto vecTy =
845         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
846     Value vectorDataPtr;
847     if (memRefType.getMemorySpace() == 0)
848       vectorDataPtr =
849           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
850     else
851       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
852           loc, vecTy.getPointerTo(), dataPtr);
853 
854     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
855     unsigned vecWidth = vecTy.getVectorNumElements();
856     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
857     SmallVector<int64_t, 8> indices;
858     indices.reserve(vecWidth);
859     for (unsigned i = 0; i < vecWidth; ++i)
860       indices.push_back(i);
861     Value linearIndices = rewriter.create<ConstantOp>(
862         loc, vectorCmpType,
863         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
864     linearIndices = rewriter.create<LLVM::DialectCastOp>(
865         loc, toLLVMTy(vectorCmpType), linearIndices);
866 
867     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
868     // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last
869     // `k` dimensions here.
870     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
871     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
872     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
873     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
874     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
875 
876     // 4. Let dim the memref dimension, compute the vector comparison mask:
877     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
878     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
879     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
880     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
881     Value mask =
882         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
883     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
884                                                 mask);
885 
886     // 5. Rewrite as a masked read / write.
887     replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
888                                   vectorDataPtr, mask);
889 
890     return success();
891   }
892 };
893 
894 class VectorPrintOpConversion : public ConvertToLLVMPattern {
895 public:
896   explicit VectorPrintOpConversion(MLIRContext *context,
897                                    LLVMTypeConverter &typeConverter)
898       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
899                              typeConverter) {}
900 
901   // Proof-of-concept lowering implementation that relies on a small
902   // runtime support library, which only needs to provide a few
903   // printing methods (single value for all data types, opening/closing
904   // bracket, comma, newline). The lowering fully unrolls a vector
905   // in terms of these elementary printing operations. The advantage
906   // of this approach is that the library can remain unaware of all
907   // low-level implementation details of vectors while still supporting
908   // output of any shaped and dimensioned vector. Due to full unrolling,
909   // this approach is less suited for very large vectors though.
910   //
911   // TODO(ajcbik): rely solely on libc in future? something else?
912   //
913   LogicalResult
914   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
915                   ConversionPatternRewriter &rewriter) const override {
916     auto printOp = cast<vector::PrintOp>(op);
917     auto adaptor = vector::PrintOpOperandAdaptor(operands);
918     Type printType = printOp.getPrintType();
919 
920     if (typeConverter.convertType(printType) == nullptr)
921       return failure();
922 
923     // Make sure element type has runtime support (currently just Float/Double).
924     VectorType vectorType = printType.dyn_cast<VectorType>();
925     Type eltType = vectorType ? vectorType.getElementType() : printType;
926     int64_t rank = vectorType ? vectorType.getRank() : 0;
927     Operation *printer;
928     if (eltType.isSignlessInteger(32))
929       printer = getPrintI32(op);
930     else if (eltType.isSignlessInteger(64))
931       printer = getPrintI64(op);
932     else if (eltType.isF32())
933       printer = getPrintFloat(op);
934     else if (eltType.isF64())
935       printer = getPrintDouble(op);
936     else
937       return failure();
938 
939     // Unroll vector into elementary print calls.
940     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
941     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
942     rewriter.eraseOp(op);
943     return success();
944   }
945 
946 private:
947   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
948                  Value value, VectorType vectorType, Operation *printer,
949                  int64_t rank) const {
950     Location loc = op->getLoc();
951     if (rank == 0) {
952       emitCall(rewriter, loc, printer, value);
953       return;
954     }
955 
956     emitCall(rewriter, loc, getPrintOpen(op));
957     Operation *printComma = getPrintComma(op);
958     int64_t dim = vectorType.getDimSize(0);
959     for (int64_t d = 0; d < dim; ++d) {
960       auto reducedType =
961           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
962       auto llvmType = typeConverter.convertType(
963           rank > 1 ? reducedType : vectorType.getElementType());
964       Value nestedVal =
965           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
966       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
967       if (d != dim - 1)
968         emitCall(rewriter, loc, printComma);
969     }
970     emitCall(rewriter, loc, getPrintClose(op));
971   }
972 
973   // Helper to emit a call.
974   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
975                        Operation *ref, ValueRange params = ValueRange()) {
976     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
977                                   rewriter.getSymbolRefAttr(ref), params);
978   }
979 
980   // Helper for printer method declaration (first hit) and lookup.
981   static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
982                              StringRef name, ArrayRef<LLVM::LLVMType> params) {
983     auto module = op->getParentOfType<ModuleOp>();
984     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
985     if (func)
986       return func;
987     OpBuilder moduleBuilder(module.getBodyRegion());
988     return moduleBuilder.create<LLVM::LLVMFuncOp>(
989         op->getLoc(), name,
990         LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
991                                       params, /*isVarArg=*/false));
992   }
993 
994   // Helpers for method names.
995   Operation *getPrintI32(Operation *op) const {
996     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
997     return getPrint(op, dialect, "print_i32",
998                     LLVM::LLVMType::getInt32Ty(dialect));
999   }
1000   Operation *getPrintI64(Operation *op) const {
1001     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1002     return getPrint(op, dialect, "print_i64",
1003                     LLVM::LLVMType::getInt64Ty(dialect));
1004   }
1005   Operation *getPrintFloat(Operation *op) const {
1006     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1007     return getPrint(op, dialect, "print_f32",
1008                     LLVM::LLVMType::getFloatTy(dialect));
1009   }
1010   Operation *getPrintDouble(Operation *op) const {
1011     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1012     return getPrint(op, dialect, "print_f64",
1013                     LLVM::LLVMType::getDoubleTy(dialect));
1014   }
1015   Operation *getPrintOpen(Operation *op) const {
1016     return getPrint(op, typeConverter.getDialect(), "print_open", {});
1017   }
1018   Operation *getPrintClose(Operation *op) const {
1019     return getPrint(op, typeConverter.getDialect(), "print_close", {});
1020   }
1021   Operation *getPrintComma(Operation *op) const {
1022     return getPrint(op, typeConverter.getDialect(), "print_comma", {});
1023   }
1024   Operation *getPrintNewline(Operation *op) const {
1025     return getPrint(op, typeConverter.getDialect(), "print_newline", {});
1026   }
1027 };
1028 
1029 /// Progressive lowering of StridedSliceOp to either:
1030 ///   1. extractelement + insertelement for the 1-D case
1031 ///   2. extract + optional strided_slice + insert for the n-D case.
1032 class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
1033 public:
1034   using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
1035 
1036   LogicalResult matchAndRewrite(StridedSliceOp op,
1037                                 PatternRewriter &rewriter) const override {
1038     auto dstType = op.getResult().getType().cast<VectorType>();
1039 
1040     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1041 
1042     int64_t offset =
1043         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1044     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1045     int64_t stride =
1046         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1047 
1048     auto loc = op.getLoc();
1049     auto elemType = dstType.getElementType();
1050     assert(elemType.isSignlessIntOrIndexOrFloat());
1051     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1052                                              rewriter.getZeroAttr(elemType));
1053     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1054     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1055          off += stride, ++idx) {
1056       Value extracted = extractOne(rewriter, loc, op.vector(), off);
1057       if (op.offsets().getValue().size() > 1) {
1058         extracted = rewriter.create<StridedSliceOp>(
1059             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
1060             getI64SubArray(op.sizes(), /* dropFront=*/1),
1061             getI64SubArray(op.strides(), /* dropFront=*/1));
1062       }
1063       res = insertOne(rewriter, loc, extracted, res, idx);
1064     }
1065     rewriter.replaceOp(op, {res});
1066     return success();
1067   }
1068   /// This pattern creates recursive StridedSliceOp, but the recursion is
1069   /// bounded as the rank is strictly decreasing.
1070   bool hasBoundedRewriteRecursion() const final { return true; }
1071 };
1072 
1073 } // namespace
1074 
1075 /// Populate the given list with patterns that convert from Vector to LLVM.
1076 void mlir::populateVectorToLLVMConversionPatterns(
1077     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1078   MLIRContext *ctx = converter.getDialect()->getContext();
1079   // clang-format off
1080   patterns.insert<VectorFMAOpNDRewritePattern,
1081                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1082                   VectorInsertStridedSliceOpSameRankRewritePattern,
1083                   VectorStridedSliceOpConversion>(ctx);
1084   patterns
1085       .insert<VectorReductionOpConversion,
1086               VectorShuffleOpConversion,
1087               VectorExtractElementOpConversion,
1088               VectorExtractOpConversion,
1089               VectorFMAOp1DConversion,
1090               VectorInsertElementOpConversion,
1091               VectorInsertOpConversion,
1092               VectorPrintOpConversion,
1093               VectorTransferConversion<TransferReadOp>,
1094               VectorTransferConversion<TransferWriteOp>,
1095               VectorTypeCastOpConversion>(ctx, converter);
1096   // clang-format on
1097 }
1098 
1099 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1100     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1101   MLIRContext *ctx = converter.getDialect()->getContext();
1102   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1103 }
1104 
1105 namespace {
1106 struct LowerVectorToLLVMPass
1107     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1108   void runOnOperation() override;
1109 };
1110 } // namespace
1111 
1112 void LowerVectorToLLVMPass::runOnOperation() {
1113   // Perform progressive lowering of operations on slices and
1114   // all contraction operations. Also applies folding and DCE.
1115   {
1116     OwningRewritePatternList patterns;
1117     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1118     populateVectorContractLoweringPatterns(patterns, &getContext());
1119     applyPatternsAndFoldGreedily(getOperation(), patterns);
1120   }
1121 
1122   // Convert to the LLVM IR dialect.
1123   LLVMTypeConverter converter(&getContext());
1124   OwningRewritePatternList patterns;
1125   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1126   populateVectorToLLVMConversionPatterns(converter, patterns);
1127   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1128   populateStdToLLVMConversionPatterns(converter, patterns);
1129 
1130   LLVMConversionTarget target(getContext());
1131   target.addDynamicallyLegalOp<FuncOp>(
1132       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1133   if (failed(applyPartialConversion(getOperation(), target, patterns,
1134                                     &converter))) {
1135     signalPassFailure();
1136   }
1137 }
1138 
1139 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
1140   return std::make_unique<LowerVectorToLLVMPass>();
1141 }
1142