xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision c295a65da496f5e982402e8f83e417659c7dd166)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 template <typename T>
38 static LLVM::LLVMType getPtrToElementType(T containerType,
39                                           LLVMTypeConverter &typeConverter) {
40   return typeConverter.convertType(containerType.getElementType())
41       .template cast<LLVM::LLVMType>()
42       .getPointerTo();
43 }
44 
45 // Helper to reduce vector type by one rank at front.
46 static VectorType reducedVectorTypeFront(VectorType tp) {
47   assert((tp.getRank() > 1) && "unlowerable vector type");
48   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
49 }
50 
51 // Helper to reduce vector type by *all* but one rank at back.
52 static VectorType reducedVectorTypeBack(VectorType tp) {
53   assert((tp.getRank() > 1) && "unlowerable vector type");
54   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
55 }
56 
57 // Helper that picks the proper sequence for inserting.
58 static Value insertOne(ConversionPatternRewriter &rewriter,
59                        LLVMTypeConverter &typeConverter, Location loc,
60                        Value val1, Value val2, Type llvmType, int64_t rank,
61                        int64_t pos) {
62   if (rank == 1) {
63     auto idxType = rewriter.getIndexType();
64     auto constant = rewriter.create<LLVM::ConstantOp>(
65         loc, typeConverter.convertType(idxType),
66         rewriter.getIntegerAttr(idxType, pos));
67     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
68                                                   constant);
69   }
70   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
71                                               rewriter.getI64ArrayAttr(pos));
72 }
73 
74 // Helper that picks the proper sequence for inserting.
75 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
76                        Value into, int64_t offset) {
77   auto vectorType = into.getType().cast<VectorType>();
78   if (vectorType.getRank() > 1)
79     return rewriter.create<InsertOp>(loc, from, into, offset);
80   return rewriter.create<vector::InsertElementOp>(
81       loc, vectorType, from, into,
82       rewriter.create<ConstantIndexOp>(loc, offset));
83 }
84 
85 // Helper that picks the proper sequence for extracting.
86 static Value extractOne(ConversionPatternRewriter &rewriter,
87                         LLVMTypeConverter &typeConverter, Location loc,
88                         Value val, Type llvmType, int64_t rank, int64_t pos) {
89   if (rank == 1) {
90     auto idxType = rewriter.getIndexType();
91     auto constant = rewriter.create<LLVM::ConstantOp>(
92         loc, typeConverter.convertType(idxType),
93         rewriter.getIntegerAttr(idxType, pos));
94     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
95                                                    constant);
96   }
97   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
98                                                rewriter.getI64ArrayAttr(pos));
99 }
100 
101 // Helper that picks the proper sequence for extracting.
102 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
103                         int64_t offset) {
104   auto vectorType = vector.getType().cast<VectorType>();
105   if (vectorType.getRank() > 1)
106     return rewriter.create<ExtractOp>(loc, vector, offset);
107   return rewriter.create<vector::ExtractElementOp>(
108       loc, vectorType.getElementType(), vector,
109       rewriter.create<ConstantIndexOp>(loc, offset));
110 }
111 
112 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
113 // TODO(rriddle): Better support for attribute subtype forwarding + slicing.
114 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
115                                               unsigned dropFront = 0,
116                                               unsigned dropBack = 0) {
117   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
118   auto range = arrayAttr.getAsRange<IntegerAttr>();
119   SmallVector<int64_t, 4> res;
120   res.reserve(arrayAttr.size() - dropFront - dropBack);
121   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
122        it != eit; ++it)
123     res.push_back((*it).getValue().getSExtValue());
124   return res;
125 }
126 
127 namespace {
128 
129 /// Conversion pattern for a vector.matrix_multiply.
130 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
131 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
132 public:
133   explicit VectorMatmulOpConversion(MLIRContext *context,
134                                     LLVMTypeConverter &typeConverter)
135       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
136                              typeConverter) {}
137 
138   LogicalResult
139   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
140                   ConversionPatternRewriter &rewriter) const override {
141     auto matmulOp = cast<vector::MatmulOp>(op);
142     auto adaptor = vector::MatmulOpOperandAdaptor(operands);
143     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
144         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
145         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
146         matmulOp.rhs_columns());
147     return success();
148   }
149 };
150 
151 /// Conversion pattern for a vector.flat_transpose.
152 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
153 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
154 public:
155   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
156                                            LLVMTypeConverter &typeConverter)
157       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
158                              context, typeConverter) {}
159 
160   LogicalResult
161   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
162                   ConversionPatternRewriter &rewriter) const override {
163     auto transOp = cast<vector::FlatTransposeOp>(op);
164     auto adaptor = vector::FlatTransposeOpOperandAdaptor(operands);
165     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
166         transOp, typeConverter.convertType(transOp.res().getType()),
167         adaptor.matrix(), transOp.rows(), transOp.columns());
168     return success();
169   }
170 };
171 
172 class VectorReductionOpConversion : public ConvertToLLVMPattern {
173 public:
174   explicit VectorReductionOpConversion(MLIRContext *context,
175                                        LLVMTypeConverter &typeConverter)
176       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
177                              typeConverter) {}
178 
179   LogicalResult
180   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
181                   ConversionPatternRewriter &rewriter) const override {
182     auto reductionOp = cast<vector::ReductionOp>(op);
183     auto kind = reductionOp.kind();
184     Type eltType = reductionOp.dest().getType();
185     Type llvmType = typeConverter.convertType(eltType);
186     if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64)) {
187       // Integer reductions: add/mul/min/max/and/or/xor.
188       if (kind == "add")
189         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
190             op, llvmType, operands[0]);
191       else if (kind == "mul")
192         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
193             op, llvmType, operands[0]);
194       else if (kind == "min")
195         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
196             op, llvmType, operands[0]);
197       else if (kind == "max")
198         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
199             op, llvmType, operands[0]);
200       else if (kind == "and")
201         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
202             op, llvmType, operands[0]);
203       else if (kind == "or")
204         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
205             op, llvmType, operands[0]);
206       else if (kind == "xor")
207         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
208             op, llvmType, operands[0]);
209       else
210         return failure();
211       return success();
212 
213     } else if (eltType.isF32() || eltType.isF64()) {
214       // Floating-point reductions: add/mul/min/max
215       if (kind == "add") {
216         // Optional accumulator (or zero).
217         Value acc = operands.size() > 1 ? operands[1]
218                                         : rewriter.create<LLVM::ConstantOp>(
219                                               op->getLoc(), llvmType,
220                                               rewriter.getZeroAttr(eltType));
221         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
222             op, llvmType, acc, operands[0]);
223       } else if (kind == "mul") {
224         // Optional accumulator (or one).
225         Value acc = operands.size() > 1
226                         ? operands[1]
227                         : rewriter.create<LLVM::ConstantOp>(
228                               op->getLoc(), llvmType,
229                               rewriter.getFloatAttr(eltType, 1.0));
230         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
231             op, llvmType, acc, operands[0]);
232       } else if (kind == "min")
233         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
234             op, llvmType, operands[0]);
235       else if (kind == "max")
236         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
237             op, llvmType, operands[0]);
238       else
239         return failure();
240       return success();
241     }
242     return failure();
243   }
244 };
245 
246 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
247 public:
248   explicit VectorShuffleOpConversion(MLIRContext *context,
249                                      LLVMTypeConverter &typeConverter)
250       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
251                              typeConverter) {}
252 
253   LogicalResult
254   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
255                   ConversionPatternRewriter &rewriter) const override {
256     auto loc = op->getLoc();
257     auto adaptor = vector::ShuffleOpOperandAdaptor(operands);
258     auto shuffleOp = cast<vector::ShuffleOp>(op);
259     auto v1Type = shuffleOp.getV1VectorType();
260     auto v2Type = shuffleOp.getV2VectorType();
261     auto vectorType = shuffleOp.getVectorType();
262     Type llvmType = typeConverter.convertType(vectorType);
263     auto maskArrayAttr = shuffleOp.mask();
264 
265     // Bail if result type cannot be lowered.
266     if (!llvmType)
267       return failure();
268 
269     // Get rank and dimension sizes.
270     int64_t rank = vectorType.getRank();
271     assert(v1Type.getRank() == rank);
272     assert(v2Type.getRank() == rank);
273     int64_t v1Dim = v1Type.getDimSize(0);
274 
275     // For rank 1, where both operands have *exactly* the same vector type,
276     // there is direct shuffle support in LLVM. Use it!
277     if (rank == 1 && v1Type == v2Type) {
278       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
279           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
280       rewriter.replaceOp(op, shuffle);
281       return success();
282     }
283 
284     // For all other cases, insert the individual values individually.
285     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
286     int64_t insPos = 0;
287     for (auto en : llvm::enumerate(maskArrayAttr)) {
288       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
289       Value value = adaptor.v1();
290       if (extPos >= v1Dim) {
291         extPos -= v1Dim;
292         value = adaptor.v2();
293       }
294       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
295                                  rank, extPos);
296       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
297                          llvmType, rank, insPos++);
298     }
299     rewriter.replaceOp(op, insert);
300     return success();
301   }
302 };
303 
304 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
305 public:
306   explicit VectorExtractElementOpConversion(MLIRContext *context,
307                                             LLVMTypeConverter &typeConverter)
308       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
309                              context, typeConverter) {}
310 
311   LogicalResult
312   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
313                   ConversionPatternRewriter &rewriter) const override {
314     auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
315     auto extractEltOp = cast<vector::ExtractElementOp>(op);
316     auto vectorType = extractEltOp.getVectorType();
317     auto llvmType = typeConverter.convertType(vectorType.getElementType());
318 
319     // Bail if result type cannot be lowered.
320     if (!llvmType)
321       return failure();
322 
323     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
324         op, llvmType, adaptor.vector(), adaptor.position());
325     return success();
326   }
327 };
328 
329 class VectorExtractOpConversion : public ConvertToLLVMPattern {
330 public:
331   explicit VectorExtractOpConversion(MLIRContext *context,
332                                      LLVMTypeConverter &typeConverter)
333       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
334                              typeConverter) {}
335 
336   LogicalResult
337   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
338                   ConversionPatternRewriter &rewriter) const override {
339     auto loc = op->getLoc();
340     auto adaptor = vector::ExtractOpOperandAdaptor(operands);
341     auto extractOp = cast<vector::ExtractOp>(op);
342     auto vectorType = extractOp.getVectorType();
343     auto resultType = extractOp.getResult().getType();
344     auto llvmResultType = typeConverter.convertType(resultType);
345     auto positionArrayAttr = extractOp.position();
346 
347     // Bail if result type cannot be lowered.
348     if (!llvmResultType)
349       return failure();
350 
351     // One-shot extraction of vector from array (only requires extractvalue).
352     if (resultType.isa<VectorType>()) {
353       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
354           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
355       rewriter.replaceOp(op, extracted);
356       return success();
357     }
358 
359     // Potential extraction of 1-D vector from array.
360     auto *context = op->getContext();
361     Value extracted = adaptor.vector();
362     auto positionAttrs = positionArrayAttr.getValue();
363     if (positionAttrs.size() > 1) {
364       auto oneDVectorType = reducedVectorTypeBack(vectorType);
365       auto nMinusOnePositionAttrs =
366           ArrayAttr::get(positionAttrs.drop_back(), context);
367       extracted = rewriter.create<LLVM::ExtractValueOp>(
368           loc, typeConverter.convertType(oneDVectorType), extracted,
369           nMinusOnePositionAttrs);
370     }
371 
372     // Remaining extraction of element from 1-D LLVM vector
373     auto position = positionAttrs.back().cast<IntegerAttr>();
374     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
375     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
376     extracted =
377         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
378     rewriter.replaceOp(op, extracted);
379 
380     return success();
381   }
382 };
383 
384 /// Conversion pattern that turns a vector.fma on a 1-D vector
385 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
386 /// This does not match vectors of n >= 2 rank.
387 ///
388 /// Example:
389 /// ```
390 ///  vector.fma %a, %a, %a : vector<8xf32>
391 /// ```
392 /// is converted to:
393 /// ```
394 ///  llvm.intr.fma %va, %va, %va:
395 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
396 ///    -> !llvm<"<8 x float>">
397 /// ```
398 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
399 public:
400   explicit VectorFMAOp1DConversion(MLIRContext *context,
401                                    LLVMTypeConverter &typeConverter)
402       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
403                              typeConverter) {}
404 
405   LogicalResult
406   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
407                   ConversionPatternRewriter &rewriter) const override {
408     auto adaptor = vector::FMAOpOperandAdaptor(operands);
409     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
410     VectorType vType = fmaOp.getVectorType();
411     if (vType.getRank() != 1)
412       return failure();
413     rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
414                                              adaptor.acc());
415     return success();
416   }
417 };
418 
419 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
420 public:
421   explicit VectorInsertElementOpConversion(MLIRContext *context,
422                                            LLVMTypeConverter &typeConverter)
423       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
424                              context, typeConverter) {}
425 
426   LogicalResult
427   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
428                   ConversionPatternRewriter &rewriter) const override {
429     auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
430     auto insertEltOp = cast<vector::InsertElementOp>(op);
431     auto vectorType = insertEltOp.getDestVectorType();
432     auto llvmType = typeConverter.convertType(vectorType);
433 
434     // Bail if result type cannot be lowered.
435     if (!llvmType)
436       return failure();
437 
438     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
439         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
440     return success();
441   }
442 };
443 
444 class VectorInsertOpConversion : public ConvertToLLVMPattern {
445 public:
446   explicit VectorInsertOpConversion(MLIRContext *context,
447                                     LLVMTypeConverter &typeConverter)
448       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
449                              typeConverter) {}
450 
451   LogicalResult
452   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
453                   ConversionPatternRewriter &rewriter) const override {
454     auto loc = op->getLoc();
455     auto adaptor = vector::InsertOpOperandAdaptor(operands);
456     auto insertOp = cast<vector::InsertOp>(op);
457     auto sourceType = insertOp.getSourceType();
458     auto destVectorType = insertOp.getDestVectorType();
459     auto llvmResultType = typeConverter.convertType(destVectorType);
460     auto positionArrayAttr = insertOp.position();
461 
462     // Bail if result type cannot be lowered.
463     if (!llvmResultType)
464       return failure();
465 
466     // One-shot insertion of a vector into an array (only requires insertvalue).
467     if (sourceType.isa<VectorType>()) {
468       Value inserted = rewriter.create<LLVM::InsertValueOp>(
469           loc, llvmResultType, adaptor.dest(), adaptor.source(),
470           positionArrayAttr);
471       rewriter.replaceOp(op, inserted);
472       return success();
473     }
474 
475     // Potential extraction of 1-D vector from array.
476     auto *context = op->getContext();
477     Value extracted = adaptor.dest();
478     auto positionAttrs = positionArrayAttr.getValue();
479     auto position = positionAttrs.back().cast<IntegerAttr>();
480     auto oneDVectorType = destVectorType;
481     if (positionAttrs.size() > 1) {
482       oneDVectorType = reducedVectorTypeBack(destVectorType);
483       auto nMinusOnePositionAttrs =
484           ArrayAttr::get(positionAttrs.drop_back(), context);
485       extracted = rewriter.create<LLVM::ExtractValueOp>(
486           loc, typeConverter.convertType(oneDVectorType), extracted,
487           nMinusOnePositionAttrs);
488     }
489 
490     // Insertion of an element into a 1-D LLVM vector.
491     auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
492     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
493     Value inserted = rewriter.create<LLVM::InsertElementOp>(
494         loc, typeConverter.convertType(oneDVectorType), extracted,
495         adaptor.source(), constant);
496 
497     // Potential insertion of resulting 1-D vector into array.
498     if (positionAttrs.size() > 1) {
499       auto nMinusOnePositionAttrs =
500           ArrayAttr::get(positionAttrs.drop_back(), context);
501       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
502                                                       adaptor.dest(), inserted,
503                                                       nMinusOnePositionAttrs);
504     }
505 
506     rewriter.replaceOp(op, inserted);
507     return success();
508   }
509 };
510 
511 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
512 ///
513 /// Example:
514 /// ```
515 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
516 /// ```
517 /// is rewritten into:
518 /// ```
519 ///  %r = splat %f0: vector<2x4xf32>
520 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
521 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
522 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
523 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
524 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
525 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
526 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
527 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
528 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
529 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
530 ///  // %r3 holds the final value.
531 /// ```
532 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
533 public:
534   using OpRewritePattern<FMAOp>::OpRewritePattern;
535 
536   LogicalResult matchAndRewrite(FMAOp op,
537                                 PatternRewriter &rewriter) const override {
538     auto vType = op.getVectorType();
539     if (vType.getRank() < 2)
540       return failure();
541 
542     auto loc = op.getLoc();
543     auto elemType = vType.getElementType();
544     Value zero = rewriter.create<ConstantOp>(loc, elemType,
545                                              rewriter.getZeroAttr(elemType));
546     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
547     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
548       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
549       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
550       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
551       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
552       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
553     }
554     rewriter.replaceOp(op, desc);
555     return success();
556   }
557 };
558 
559 // When ranks are different, InsertStridedSlice needs to extract a properly
560 // ranked vector from the destination vector into which to insert. This pattern
561 // only takes care of this part and forwards the rest of the conversion to
562 // another pattern that converts InsertStridedSlice for operands of the same
563 // rank.
564 //
565 // RewritePattern for InsertStridedSliceOp where source and destination vectors
566 // have different ranks. In this case:
567 //   1. the proper subvector is extracted from the destination vector
568 //   2. a new InsertStridedSlice op is created to insert the source in the
569 //   destination subvector
570 //   3. the destination subvector is inserted back in the proper place
571 //   4. the op is replaced by the result of step 3.
572 // The new InsertStridedSlice from step 2. will be picked up by a
573 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
574 class VectorInsertStridedSliceOpDifferentRankRewritePattern
575     : public OpRewritePattern<InsertStridedSliceOp> {
576 public:
577   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
578 
579   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
580                                 PatternRewriter &rewriter) const override {
581     auto srcType = op.getSourceVectorType();
582     auto dstType = op.getDestVectorType();
583 
584     if (op.offsets().getValue().empty())
585       return failure();
586 
587     auto loc = op.getLoc();
588     int64_t rankDiff = dstType.getRank() - srcType.getRank();
589     assert(rankDiff >= 0);
590     if (rankDiff == 0)
591       return failure();
592 
593     int64_t rankRest = dstType.getRank() - rankDiff;
594     // Extract / insert the subvector of matching rank and InsertStridedSlice
595     // on it.
596     Value extracted =
597         rewriter.create<ExtractOp>(loc, op.dest(),
598                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
599                                                   /*dropFront=*/rankRest));
600     // A different pattern will kick in for InsertStridedSlice with matching
601     // ranks.
602     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
603         loc, op.source(), extracted,
604         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
605         getI64SubArray(op.strides(), /*dropFront=*/0));
606     rewriter.replaceOpWithNewOp<InsertOp>(
607         op, stridedSliceInnerOp.getResult(), op.dest(),
608         getI64SubArray(op.offsets(), /*dropFront=*/0,
609                        /*dropFront=*/rankRest));
610     return success();
611   }
612 };
613 
614 // RewritePattern for InsertStridedSliceOp where source and destination vectors
615 // have the same rank. In this case, we reduce
616 //   1. the proper subvector is extracted from the destination vector
617 //   2. a new InsertStridedSlice op is created to insert the source in the
618 //   destination subvector
619 //   3. the destination subvector is inserted back in the proper place
620 //   4. the op is replaced by the result of step 3.
621 // The new InsertStridedSlice from step 2. will be picked up by a
622 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
623 class VectorInsertStridedSliceOpSameRankRewritePattern
624     : public OpRewritePattern<InsertStridedSliceOp> {
625 public:
626   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
627 
628   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
629                                 PatternRewriter &rewriter) const override {
630     auto srcType = op.getSourceVectorType();
631     auto dstType = op.getDestVectorType();
632 
633     if (op.offsets().getValue().empty())
634       return failure();
635 
636     int64_t rankDiff = dstType.getRank() - srcType.getRank();
637     assert(rankDiff >= 0);
638     if (rankDiff != 0)
639       return failure();
640 
641     if (srcType == dstType) {
642       rewriter.replaceOp(op, op.source());
643       return success();
644     }
645 
646     int64_t offset =
647         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
648     int64_t size = srcType.getShape().front();
649     int64_t stride =
650         op.strides().getValue().front().cast<IntegerAttr>().getInt();
651 
652     auto loc = op.getLoc();
653     Value res = op.dest();
654     // For each slice of the source vector along the most major dimension.
655     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
656          off += stride, ++idx) {
657       // 1. extract the proper subvector (or element) from source
658       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
659       if (extractedSource.getType().isa<VectorType>()) {
660         // 2. If we have a vector, extract the proper subvector from destination
661         // Otherwise we are at the element level and no need to recurse.
662         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
663         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
664         // smaller rank.
665         extractedSource = rewriter.create<InsertStridedSliceOp>(
666             loc, extractedSource, extractedDest,
667             getI64SubArray(op.offsets(), /* dropFront=*/1),
668             getI64SubArray(op.strides(), /* dropFront=*/1));
669       }
670       // 4. Insert the extractedSource into the res vector.
671       res = insertOne(rewriter, loc, extractedSource, res, off);
672     }
673 
674     rewriter.replaceOp(op, res);
675     return success();
676   }
677   /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
678   /// bounded as the rank is strictly decreasing.
679   bool hasBoundedRewriteRecursion() const final { return true; }
680 };
681 
682 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
683 public:
684   explicit VectorTypeCastOpConversion(MLIRContext *context,
685                                       LLVMTypeConverter &typeConverter)
686       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
687                              typeConverter) {}
688 
689   LogicalResult
690   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
691                   ConversionPatternRewriter &rewriter) const override {
692     auto loc = op->getLoc();
693     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
694     MemRefType sourceMemRefType =
695         castOp.getOperand().getType().cast<MemRefType>();
696     MemRefType targetMemRefType =
697         castOp.getResult().getType().cast<MemRefType>();
698 
699     // Only static shape casts supported atm.
700     if (!sourceMemRefType.hasStaticShape() ||
701         !targetMemRefType.hasStaticShape())
702       return failure();
703 
704     auto llvmSourceDescriptorTy =
705         operands[0].getType().dyn_cast<LLVM::LLVMType>();
706     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
707       return failure();
708     MemRefDescriptor sourceMemRef(operands[0]);
709 
710     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
711                                       .dyn_cast_or_null<LLVM::LLVMType>();
712     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
713       return failure();
714 
715     int64_t offset;
716     SmallVector<int64_t, 4> strides;
717     auto successStrides =
718         getStridesAndOffset(sourceMemRefType, strides, offset);
719     bool isContiguous = (strides.back() == 1);
720     if (isContiguous) {
721       auto sizes = sourceMemRefType.getShape();
722       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
723         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
724           isContiguous = false;
725           break;
726         }
727       }
728     }
729     // Only contiguous source tensors supported atm.
730     if (failed(successStrides) || !isContiguous)
731       return failure();
732 
733     auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
734 
735     // Create descriptor.
736     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
737     Type llvmTargetElementTy = desc.getElementType();
738     // Set allocated ptr.
739     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
740     allocated =
741         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
742     desc.setAllocatedPtr(rewriter, loc, allocated);
743     // Set aligned ptr.
744     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
745     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
746     desc.setAlignedPtr(rewriter, loc, ptr);
747     // Fill offset 0.
748     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
749     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
750     desc.setOffset(rewriter, loc, zero);
751 
752     // Fill size and stride descriptors in memref.
753     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
754       int64_t index = indexedSize.index();
755       auto sizeAttr =
756           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
757       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
758       desc.setSize(rewriter, loc, index, size);
759       auto strideAttr =
760           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
761       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
762       desc.setStride(rewriter, loc, index, stride);
763     }
764 
765     rewriter.replaceOp(op, {desc});
766     return success();
767   }
768 };
769 
770 LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
771                                       Type type, LLVM::LLVMType &llvmType,
772                                       unsigned &align) {
773   auto convertedType = typeConverter.convertType(type);
774   if (!convertedType)
775     return failure();
776 
777   llvmType = convertedType.template cast<LLVM::LLVMType>();
778   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
779   align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
780   return success();
781 }
782 
783 LogicalResult
784 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
785                                  LLVMTypeConverter &typeConverter, Location loc,
786                                  TransferReadOp xferOp,
787                                  ArrayRef<Value> operands, Value dataPtr) {
788   LLVM::LLVMType vecTy;
789   unsigned align;
790   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
791                                      vecTy, align)))
792     return failure();
793   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
794   return success();
795 }
796 
797 LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
798                                           LLVMTypeConverter &typeConverter,
799                                           Location loc, TransferReadOp xferOp,
800                                           ArrayRef<Value> operands,
801                                           Value dataPtr, Value mask) {
802   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
803   VectorType fillType = xferOp.getVectorType();
804   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
805   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
806 
807   LLVM::LLVMType vecTy;
808   unsigned align;
809   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
810                                      vecTy, align)))
811     return failure();
812 
813   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
814       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
815       rewriter.getI32IntegerAttr(align));
816   return success();
817 }
818 
819 LogicalResult
820 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
821                                  LLVMTypeConverter &typeConverter, Location loc,
822                                  TransferWriteOp xferOp,
823                                  ArrayRef<Value> operands, Value dataPtr) {
824   auto adaptor = TransferWriteOpOperandAdaptor(operands);
825   LLVM::LLVMType vecTy;
826   unsigned align;
827   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
828                                      vecTy, align)))
829     return failure();
830   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
831   return success();
832 }
833 
834 LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
835                                           LLVMTypeConverter &typeConverter,
836                                           Location loc, TransferWriteOp xferOp,
837                                           ArrayRef<Value> operands,
838                                           Value dataPtr, Value mask) {
839   auto adaptor = TransferWriteOpOperandAdaptor(operands);
840   LLVM::LLVMType vecTy;
841   unsigned align;
842   if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
843                                      vecTy, align)))
844     return failure();
845 
846   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
847       xferOp, adaptor.vector(), dataPtr, mask,
848       rewriter.getI32IntegerAttr(align));
849   return success();
850 }
851 
852 static TransferReadOpOperandAdaptor
853 getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
854   return TransferReadOpOperandAdaptor(operands);
855 }
856 
857 static TransferWriteOpOperandAdaptor
858 getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
859   return TransferWriteOpOperandAdaptor(operands);
860 }
861 
862 bool isMinorIdentity(AffineMap map, unsigned rank) {
863   if (map.getNumResults() < rank)
864     return false;
865   unsigned startDim = map.getNumDims() - rank;
866   for (unsigned i = 0; i < rank; ++i)
867     if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext()))
868       return false;
869   return true;
870 }
871 
872 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
873 /// sequence of:
874 /// 1. Bitcast or addrspacecast to vector form.
875 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
876 /// 3. Create a mask where offsetVector is compared against memref upper bound.
877 /// 4. Rewrite op as a masked read or write.
878 template <typename ConcreteOp>
879 class VectorTransferConversion : public ConvertToLLVMPattern {
880 public:
881   explicit VectorTransferConversion(MLIRContext *context,
882                                     LLVMTypeConverter &typeConv)
883       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
884                              typeConv) {}
885 
886   LogicalResult
887   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
888                   ConversionPatternRewriter &rewriter) const override {
889     auto xferOp = cast<ConcreteOp>(op);
890     auto adaptor = getTransferOpAdapter(xferOp, operands);
891 
892     if (xferOp.getVectorType().getRank() > 1 ||
893         llvm::size(xferOp.indices()) == 0)
894       return failure();
895     if (!isMinorIdentity(xferOp.permutation_map(),
896                          xferOp.getVectorType().getRank()))
897       return failure();
898 
899     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
900 
901     Location loc = op->getLoc();
902     Type i64Type = rewriter.getIntegerType(64);
903     MemRefType memRefType = xferOp.getMemRefType();
904 
905     // 1. Get the source/dst address as an LLVM vector pointer.
906     //    The vector pointer would always be on address space 0, therefore
907     //    addrspacecast shall be used when source/dst memrefs are not on
908     //    address space 0.
909     // TODO: support alignment when possible.
910     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
911                                adaptor.indices(), rewriter, getModule());
912     auto vecTy =
913         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
914     Value vectorDataPtr;
915     if (memRefType.getMemorySpace() == 0)
916       vectorDataPtr =
917           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
918     else
919       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
920           loc, vecTy.getPointerTo(), dataPtr);
921 
922     if (!xferOp.isMaskedDim(0))
923       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
924                                               xferOp, operands, vectorDataPtr);
925 
926     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
927     unsigned vecWidth = vecTy.getVectorNumElements();
928     VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
929     SmallVector<int64_t, 8> indices;
930     indices.reserve(vecWidth);
931     for (unsigned i = 0; i < vecWidth; ++i)
932       indices.push_back(i);
933     Value linearIndices = rewriter.create<ConstantOp>(
934         loc, vectorCmpType,
935         DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
936     linearIndices = rewriter.create<LLVM::DialectCastOp>(
937         loc, toLLVMTy(vectorCmpType), linearIndices);
938 
939     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
940     // TODO(ntv, ajcbik): when the leaf transfer rank is k > 1 we need the last
941     // `k` dimensions here.
942     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
943     Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
944     offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
945     Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
946     Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
947 
948     // 4. Let dim the memref dimension, compute the vector comparison mask:
949     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
950     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
951     dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
952     dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
953     Value mask =
954         rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
955     mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
956                                                 mask);
957 
958     // 5. Rewrite as a masked read / write.
959     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
960                                        operands, vectorDataPtr, mask);
961   }
962 };
963 
964 class VectorPrintOpConversion : public ConvertToLLVMPattern {
965 public:
966   explicit VectorPrintOpConversion(MLIRContext *context,
967                                    LLVMTypeConverter &typeConverter)
968       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
969                              typeConverter) {}
970 
971   // Proof-of-concept lowering implementation that relies on a small
972   // runtime support library, which only needs to provide a few
973   // printing methods (single value for all data types, opening/closing
974   // bracket, comma, newline). The lowering fully unrolls a vector
975   // in terms of these elementary printing operations. The advantage
976   // of this approach is that the library can remain unaware of all
977   // low-level implementation details of vectors while still supporting
978   // output of any shaped and dimensioned vector. Due to full unrolling,
979   // this approach is less suited for very large vectors though.
980   //
981   // TODO(ajcbik): rely solely on libc in future? something else?
982   //
983   LogicalResult
984   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
985                   ConversionPatternRewriter &rewriter) const override {
986     auto printOp = cast<vector::PrintOp>(op);
987     auto adaptor = vector::PrintOpOperandAdaptor(operands);
988     Type printType = printOp.getPrintType();
989 
990     if (typeConverter.convertType(printType) == nullptr)
991       return failure();
992 
993     // Make sure element type has runtime support (currently just Float/Double).
994     VectorType vectorType = printType.dyn_cast<VectorType>();
995     Type eltType = vectorType ? vectorType.getElementType() : printType;
996     int64_t rank = vectorType ? vectorType.getRank() : 0;
997     Operation *printer;
998     if (eltType.isSignlessInteger(1))
999       printer = getPrintI1(op);
1000     else if (eltType.isSignlessInteger(32))
1001       printer = getPrintI32(op);
1002     else if (eltType.isSignlessInteger(64))
1003       printer = getPrintI64(op);
1004     else if (eltType.isF32())
1005       printer = getPrintFloat(op);
1006     else if (eltType.isF64())
1007       printer = getPrintDouble(op);
1008     else
1009       return failure();
1010 
1011     // Unroll vector into elementary print calls.
1012     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
1013     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1014     rewriter.eraseOp(op);
1015     return success();
1016   }
1017 
1018 private:
1019   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1020                  Value value, VectorType vectorType, Operation *printer,
1021                  int64_t rank) const {
1022     Location loc = op->getLoc();
1023     if (rank == 0) {
1024       emitCall(rewriter, loc, printer, value);
1025       return;
1026     }
1027 
1028     emitCall(rewriter, loc, getPrintOpen(op));
1029     Operation *printComma = getPrintComma(op);
1030     int64_t dim = vectorType.getDimSize(0);
1031     for (int64_t d = 0; d < dim; ++d) {
1032       auto reducedType =
1033           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1034       auto llvmType = typeConverter.convertType(
1035           rank > 1 ? reducedType : vectorType.getElementType());
1036       Value nestedVal =
1037           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1038       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
1039       if (d != dim - 1)
1040         emitCall(rewriter, loc, printComma);
1041     }
1042     emitCall(rewriter, loc, getPrintClose(op));
1043   }
1044 
1045   // Helper to emit a call.
1046   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1047                        Operation *ref, ValueRange params = ValueRange()) {
1048     rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
1049                                   rewriter.getSymbolRefAttr(ref), params);
1050   }
1051 
1052   // Helper for printer method declaration (first hit) and lookup.
1053   static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
1054                              StringRef name, ArrayRef<LLVM::LLVMType> params) {
1055     auto module = op->getParentOfType<ModuleOp>();
1056     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1057     if (func)
1058       return func;
1059     OpBuilder moduleBuilder(module.getBodyRegion());
1060     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1061         op->getLoc(), name,
1062         LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
1063                                       params, /*isVarArg=*/false));
1064   }
1065 
1066   // Helpers for method names.
1067   Operation *getPrintI1(Operation *op) const {
1068     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1069     return getPrint(op, dialect, "print_i1",
1070                     LLVM::LLVMType::getInt1Ty(dialect));
1071   }
1072   Operation *getPrintI32(Operation *op) const {
1073     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1074     return getPrint(op, dialect, "print_i32",
1075                     LLVM::LLVMType::getInt32Ty(dialect));
1076   }
1077   Operation *getPrintI64(Operation *op) const {
1078     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1079     return getPrint(op, dialect, "print_i64",
1080                     LLVM::LLVMType::getInt64Ty(dialect));
1081   }
1082   Operation *getPrintFloat(Operation *op) const {
1083     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1084     return getPrint(op, dialect, "print_f32",
1085                     LLVM::LLVMType::getFloatTy(dialect));
1086   }
1087   Operation *getPrintDouble(Operation *op) const {
1088     LLVM::LLVMDialect *dialect = typeConverter.getDialect();
1089     return getPrint(op, dialect, "print_f64",
1090                     LLVM::LLVMType::getDoubleTy(dialect));
1091   }
1092   Operation *getPrintOpen(Operation *op) const {
1093     return getPrint(op, typeConverter.getDialect(), "print_open", {});
1094   }
1095   Operation *getPrintClose(Operation *op) const {
1096     return getPrint(op, typeConverter.getDialect(), "print_close", {});
1097   }
1098   Operation *getPrintComma(Operation *op) const {
1099     return getPrint(op, typeConverter.getDialect(), "print_comma", {});
1100   }
1101   Operation *getPrintNewline(Operation *op) const {
1102     return getPrint(op, typeConverter.getDialect(), "print_newline", {});
1103   }
1104 };
1105 
1106 /// Progressive lowering of ExtractStridedSliceOp to either:
1107 ///   1. extractelement + insertelement for the 1-D case
1108 ///   2. extract + optional strided_slice + insert for the n-D case.
1109 class VectorStridedSliceOpConversion
1110     : public OpRewritePattern<ExtractStridedSliceOp> {
1111 public:
1112   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1113 
1114   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1115                                 PatternRewriter &rewriter) const override {
1116     auto dstType = op.getResult().getType().cast<VectorType>();
1117 
1118     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1119 
1120     int64_t offset =
1121         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1122     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1123     int64_t stride =
1124         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1125 
1126     auto loc = op.getLoc();
1127     auto elemType = dstType.getElementType();
1128     assert(elemType.isSignlessIntOrIndexOrFloat());
1129     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1130                                              rewriter.getZeroAttr(elemType));
1131     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1132     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1133          off += stride, ++idx) {
1134       Value extracted = extractOne(rewriter, loc, op.vector(), off);
1135       if (op.offsets().getValue().size() > 1) {
1136         extracted = rewriter.create<ExtractStridedSliceOp>(
1137             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
1138             getI64SubArray(op.sizes(), /* dropFront=*/1),
1139             getI64SubArray(op.strides(), /* dropFront=*/1));
1140       }
1141       res = insertOne(rewriter, loc, extracted, res, idx);
1142     }
1143     rewriter.replaceOp(op, {res});
1144     return success();
1145   }
1146   /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
1147   /// bounded as the rank is strictly decreasing.
1148   bool hasBoundedRewriteRecursion() const final { return true; }
1149 };
1150 
1151 } // namespace
1152 
1153 /// Populate the given list with patterns that convert from Vector to LLVM.
1154 void mlir::populateVectorToLLVMConversionPatterns(
1155     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1156   MLIRContext *ctx = converter.getDialect()->getContext();
1157   // clang-format off
1158   patterns.insert<VectorFMAOpNDRewritePattern,
1159                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1160                   VectorInsertStridedSliceOpSameRankRewritePattern,
1161                   VectorStridedSliceOpConversion>(ctx);
1162   patterns
1163       .insert<VectorReductionOpConversion,
1164               VectorShuffleOpConversion,
1165               VectorExtractElementOpConversion,
1166               VectorExtractOpConversion,
1167               VectorFMAOp1DConversion,
1168               VectorInsertElementOpConversion,
1169               VectorInsertOpConversion,
1170               VectorPrintOpConversion,
1171               VectorTransferConversion<TransferReadOp>,
1172               VectorTransferConversion<TransferWriteOp>,
1173               VectorTypeCastOpConversion>(ctx, converter);
1174   // clang-format on
1175 }
1176 
1177 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1178     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1179   MLIRContext *ctx = converter.getDialect()->getContext();
1180   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1181   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1182 }
1183 
1184 namespace {
1185 struct LowerVectorToLLVMPass
1186     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1187   void runOnOperation() override;
1188 };
1189 } // namespace
1190 
1191 void LowerVectorToLLVMPass::runOnOperation() {
1192   // Perform progressive lowering of operations on slices and
1193   // all contraction operations. Also applies folding and DCE.
1194   {
1195     OwningRewritePatternList patterns;
1196     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1197     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1198     populateVectorContractLoweringPatterns(patterns, &getContext());
1199     applyPatternsAndFoldGreedily(getOperation(), patterns);
1200   }
1201 
1202   // Convert to the LLVM IR dialect.
1203   LLVMTypeConverter converter(&getContext());
1204   OwningRewritePatternList patterns;
1205   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1206   populateVectorToLLVMConversionPatterns(converter, patterns);
1207   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1208   populateStdToLLVMConversionPatterns(converter, patterns);
1209 
1210   LLVMConversionTarget target(getContext());
1211   target.addDynamicallyLegalOp<FuncOp>(
1212       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
1213   if (failed(applyPartialConversion(getOperation(), target, patterns,
1214                                     &converter))) {
1215     signalPassFailure();
1216   }
1217 }
1218 
1219 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertVectorToLLVMPass() {
1220   return std::make_unique<LowerVectorToLLVMPass>();
1221 }
1222