xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision b6eb26fd0e316b36e3750f7cba7ebdb10219790c)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/Module.h"
20 #include "mlir/IR/StandardTypes.h"
21 #include "mlir/Target/LLVMIR/TypeTranslation.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "mlir/Transforms/Passes.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ErrorHandling.h"
30 
31 using namespace mlir;
32 using namespace mlir::vector;
33 
34 // Helper to reduce vector type by one rank at front.
35 static VectorType reducedVectorTypeFront(VectorType tp) {
36   assert((tp.getRank() > 1) && "unlowerable vector type");
37   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
38 }
39 
40 // Helper to reduce vector type by *all* but one rank at back.
41 static VectorType reducedVectorTypeBack(VectorType tp) {
42   assert((tp.getRank() > 1) && "unlowerable vector type");
43   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
44 }
45 
46 // Helper that picks the proper sequence for inserting.
47 static Value insertOne(ConversionPatternRewriter &rewriter,
48                        LLVMTypeConverter &typeConverter, Location loc,
49                        Value val1, Value val2, Type llvmType, int64_t rank,
50                        int64_t pos) {
51   if (rank == 1) {
52     auto idxType = rewriter.getIndexType();
53     auto constant = rewriter.create<LLVM::ConstantOp>(
54         loc, typeConverter.convertType(idxType),
55         rewriter.getIntegerAttr(idxType, pos));
56     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
57                                                   constant);
58   }
59   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
60                                               rewriter.getI64ArrayAttr(pos));
61 }
62 
63 // Helper that picks the proper sequence for inserting.
64 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
65                        Value into, int64_t offset) {
66   auto vectorType = into.getType().cast<VectorType>();
67   if (vectorType.getRank() > 1)
68     return rewriter.create<InsertOp>(loc, from, into, offset);
69   return rewriter.create<vector::InsertElementOp>(
70       loc, vectorType, from, into,
71       rewriter.create<ConstantIndexOp>(loc, offset));
72 }
73 
74 // Helper that picks the proper sequence for extracting.
75 static Value extractOne(ConversionPatternRewriter &rewriter,
76                         LLVMTypeConverter &typeConverter, Location loc,
77                         Value val, Type llvmType, int64_t rank, int64_t pos) {
78   if (rank == 1) {
79     auto idxType = rewriter.getIndexType();
80     auto constant = rewriter.create<LLVM::ConstantOp>(
81         loc, typeConverter.convertType(idxType),
82         rewriter.getIntegerAttr(idxType, pos));
83     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
84                                                    constant);
85   }
86   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
87                                                rewriter.getI64ArrayAttr(pos));
88 }
89 
90 // Helper that picks the proper sequence for extracting.
91 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
92                         int64_t offset) {
93   auto vectorType = vector.getType().cast<VectorType>();
94   if (vectorType.getRank() > 1)
95     return rewriter.create<ExtractOp>(loc, vector, offset);
96   return rewriter.create<vector::ExtractElementOp>(
97       loc, vectorType.getElementType(), vector,
98       rewriter.create<ConstantIndexOp>(loc, offset));
99 }
100 
101 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
102 // TODO: Better support for attribute subtype forwarding + slicing.
103 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
104                                               unsigned dropFront = 0,
105                                               unsigned dropBack = 0) {
106   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
107   auto range = arrayAttr.getAsRange<IntegerAttr>();
108   SmallVector<int64_t, 4> res;
109   res.reserve(arrayAttr.size() - dropFront - dropBack);
110   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
111        it != eit; ++it)
112     res.push_back((*it).getValue().getSExtValue());
113   return res;
114 }
115 
116 // Helper that returns a vector comparison that constructs a mask:
117 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
118 //
119 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
120 //       much more compact, IR for this operation, but LLVM eventually
121 //       generates more elaborate instructions for this intrinsic since it
122 //       is very conservative on the boundary conditions.
123 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
124                                    Operation *op, bool enableIndexOptimizations,
125                                    int64_t dim, Value b, Value *off = nullptr) {
126   auto loc = op->getLoc();
127   // If we can assume all indices fit in 32-bit, we perform the vector
128   // comparison in 32-bit to get a higher degree of SIMD parallelism.
129   // Otherwise we perform the vector comparison using 64-bit indices.
130   Value indices;
131   Type idxType;
132   if (enableIndexOptimizations) {
133     indices = rewriter.create<ConstantOp>(
134         loc, rewriter.getI32VectorAttr(
135                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
136     idxType = rewriter.getI32Type();
137   } else {
138     indices = rewriter.create<ConstantOp>(
139         loc, rewriter.getI64VectorAttr(
140                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
141     idxType = rewriter.getI64Type();
142   }
143   // Add in an offset if requested.
144   if (off) {
145     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
146     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
147     indices = rewriter.create<AddIOp>(loc, ov, indices);
148   }
149   // Construct the vector comparison.
150   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
151   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
152   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
153 }
154 
155 // Helper that returns data layout alignment of an operation with memref.
156 template <typename T>
157 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
158                                  unsigned &align) {
159   Type elementTy =
160       typeConverter.convertType(op.getMemRefType().getElementType());
161   if (!elementTy)
162     return failure();
163 
164   // TODO: this should use the MLIR data layout when it becomes available and
165   // stop depending on translation.
166   llvm::LLVMContext llvmContext;
167   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
168               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
169                                      typeConverter.getDataLayout());
170   return success();
171 }
172 
173 // Helper that returns the base address of a memref.
174 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
175                              Value memref, MemRefType memRefType, Value &base) {
176   // Inspect stride and offset structure.
177   //
178   // TODO: flat memory only for now, generalize
179   //
180   int64_t offset;
181   SmallVector<int64_t, 4> strides;
182   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
183   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
184       offset != 0 || memRefType.getMemorySpace() != 0)
185     return failure();
186   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
187   return success();
188 }
189 
190 // Helper that returns a pointer given a memref base.
191 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
192                                 Location loc, Value memref,
193                                 MemRefType memRefType, Value &ptr) {
194   Value base;
195   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
196     return failure();
197   auto pType = MemRefDescriptor(memref).getElementPtrType();
198   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
199   return success();
200 }
201 
202 // Helper that returns a bit-casted pointer given a memref base.
203 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
204                                 Location loc, Value memref,
205                                 MemRefType memRefType, Type type, Value &ptr) {
206   Value base;
207   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
208     return failure();
209   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
210   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
211   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
212   return success();
213 }
214 
215 // Helper that returns vector of pointers given a memref base and an index
216 // vector.
217 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
218                                     Location loc, Value memref, Value indices,
219                                     MemRefType memRefType, VectorType vType,
220                                     Type iType, Value &ptrs) {
221   Value base;
222   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
223     return failure();
224   auto pType = MemRefDescriptor(memref).getElementPtrType();
225   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
226   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
227   return success();
228 }
229 
230 static LogicalResult
231 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
232                                  LLVMTypeConverter &typeConverter, Location loc,
233                                  TransferReadOp xferOp,
234                                  ArrayRef<Value> operands, Value dataPtr) {
235   unsigned align;
236   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
237     return failure();
238   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
239   return success();
240 }
241 
242 static LogicalResult
243 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
244                             LLVMTypeConverter &typeConverter, Location loc,
245                             TransferReadOp xferOp, ArrayRef<Value> operands,
246                             Value dataPtr, Value mask) {
247   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
248   VectorType fillType = xferOp.getVectorType();
249   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
250   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
251 
252   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
253   if (!vecTy)
254     return failure();
255 
256   unsigned align;
257   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
258     return failure();
259 
260   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
261       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
262       rewriter.getI32IntegerAttr(align));
263   return success();
264 }
265 
266 static LogicalResult
267 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
268                                  LLVMTypeConverter &typeConverter, Location loc,
269                                  TransferWriteOp xferOp,
270                                  ArrayRef<Value> operands, Value dataPtr) {
271   unsigned align;
272   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
273     return failure();
274   auto adaptor = TransferWriteOpAdaptor(operands);
275   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
276                                              align);
277   return success();
278 }
279 
280 static LogicalResult
281 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
282                             LLVMTypeConverter &typeConverter, Location loc,
283                             TransferWriteOp xferOp, ArrayRef<Value> operands,
284                             Value dataPtr, Value mask) {
285   unsigned align;
286   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
287     return failure();
288 
289   auto adaptor = TransferWriteOpAdaptor(operands);
290   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
291       xferOp, adaptor.vector(), dataPtr, mask,
292       rewriter.getI32IntegerAttr(align));
293   return success();
294 }
295 
296 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
297                                                   ArrayRef<Value> operands) {
298   return TransferReadOpAdaptor(operands);
299 }
300 
301 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
302                                                    ArrayRef<Value> operands) {
303   return TransferWriteOpAdaptor(operands);
304 }
305 
306 namespace {
307 
308 /// Conversion pattern for a vector.matrix_multiply.
309 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
310 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
311 public:
312   explicit VectorMatmulOpConversion(MLIRContext *context,
313                                     LLVMTypeConverter &typeConverter)
314       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
315                              typeConverter) {}
316 
317   LogicalResult
318   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
319                   ConversionPatternRewriter &rewriter) const override {
320     auto matmulOp = cast<vector::MatmulOp>(op);
321     auto adaptor = vector::MatmulOpAdaptor(operands);
322     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
323         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
324         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
325         matmulOp.rhs_columns());
326     return success();
327   }
328 };
329 
330 /// Conversion pattern for a vector.flat_transpose.
331 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
332 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
333 public:
334   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
335                                            LLVMTypeConverter &typeConverter)
336       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
337                              context, typeConverter) {}
338 
339   LogicalResult
340   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
341                   ConversionPatternRewriter &rewriter) const override {
342     auto transOp = cast<vector::FlatTransposeOp>(op);
343     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
344     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
345         transOp, typeConverter.convertType(transOp.res().getType()),
346         adaptor.matrix(), transOp.rows(), transOp.columns());
347     return success();
348   }
349 };
350 
351 /// Conversion pattern for a vector.maskedload.
352 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
353 public:
354   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
355                                         LLVMTypeConverter &typeConverter)
356       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
357                              typeConverter) {}
358 
359   LogicalResult
360   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
361                   ConversionPatternRewriter &rewriter) const override {
362     auto loc = op->getLoc();
363     auto load = cast<vector::MaskedLoadOp>(op);
364     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
365 
366     // Resolve alignment.
367     unsigned align;
368     if (failed(getMemRefAlignment(typeConverter, load, align)))
369       return failure();
370 
371     auto vtype = typeConverter.convertType(load.getResultVectorType());
372     Value ptr;
373     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
374                           vtype, ptr)))
375       return failure();
376 
377     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
378         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
379         rewriter.getI32IntegerAttr(align));
380     return success();
381   }
382 };
383 
384 /// Conversion pattern for a vector.maskedstore.
385 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
386 public:
387   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
388                                          LLVMTypeConverter &typeConverter)
389       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
390                              typeConverter) {}
391 
392   LogicalResult
393   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
394                   ConversionPatternRewriter &rewriter) const override {
395     auto loc = op->getLoc();
396     auto store = cast<vector::MaskedStoreOp>(op);
397     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
398 
399     // Resolve alignment.
400     unsigned align;
401     if (failed(getMemRefAlignment(typeConverter, store, align)))
402       return failure();
403 
404     auto vtype = typeConverter.convertType(store.getValueVectorType());
405     Value ptr;
406     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
407                           vtype, ptr)))
408       return failure();
409 
410     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
411         store, adaptor.value(), ptr, adaptor.mask(),
412         rewriter.getI32IntegerAttr(align));
413     return success();
414   }
415 };
416 
417 /// Conversion pattern for a vector.gather.
418 class VectorGatherOpConversion : public ConvertToLLVMPattern {
419 public:
420   explicit VectorGatherOpConversion(MLIRContext *context,
421                                     LLVMTypeConverter &typeConverter)
422       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
423                              typeConverter) {}
424 
425   LogicalResult
426   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
427                   ConversionPatternRewriter &rewriter) const override {
428     auto loc = op->getLoc();
429     auto gather = cast<vector::GatherOp>(op);
430     auto adaptor = vector::GatherOpAdaptor(operands);
431 
432     // Resolve alignment.
433     unsigned align;
434     if (failed(getMemRefAlignment(typeConverter, gather, align)))
435       return failure();
436 
437     // Get index ptrs.
438     VectorType vType = gather.getResultVectorType();
439     Type iType = gather.getIndicesVectorType().getElementType();
440     Value ptrs;
441     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
442                               gather.getMemRefType(), vType, iType, ptrs)))
443       return failure();
444 
445     // Replace with the gather intrinsic.
446     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
447         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(),
448         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
449     return success();
450   }
451 };
452 
453 /// Conversion pattern for a vector.scatter.
454 class VectorScatterOpConversion : public ConvertToLLVMPattern {
455 public:
456   explicit VectorScatterOpConversion(MLIRContext *context,
457                                      LLVMTypeConverter &typeConverter)
458       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
459                              typeConverter) {}
460 
461   LogicalResult
462   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
463                   ConversionPatternRewriter &rewriter) const override {
464     auto loc = op->getLoc();
465     auto scatter = cast<vector::ScatterOp>(op);
466     auto adaptor = vector::ScatterOpAdaptor(operands);
467 
468     // Resolve alignment.
469     unsigned align;
470     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
471       return failure();
472 
473     // Get index ptrs.
474     VectorType vType = scatter.getValueVectorType();
475     Type iType = scatter.getIndicesVectorType().getElementType();
476     Value ptrs;
477     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
478                               scatter.getMemRefType(), vType, iType, ptrs)))
479       return failure();
480 
481     // Replace with the scatter intrinsic.
482     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
483         scatter, adaptor.value(), ptrs, adaptor.mask(),
484         rewriter.getI32IntegerAttr(align));
485     return success();
486   }
487 };
488 
489 /// Conversion pattern for a vector.expandload.
490 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
491 public:
492   explicit VectorExpandLoadOpConversion(MLIRContext *context,
493                                         LLVMTypeConverter &typeConverter)
494       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
495                              typeConverter) {}
496 
497   LogicalResult
498   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
499                   ConversionPatternRewriter &rewriter) const override {
500     auto loc = op->getLoc();
501     auto expand = cast<vector::ExpandLoadOp>(op);
502     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
503 
504     Value ptr;
505     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
506                           ptr)))
507       return failure();
508 
509     auto vType = expand.getResultVectorType();
510     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
511         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
512         adaptor.pass_thru());
513     return success();
514   }
515 };
516 
517 /// Conversion pattern for a vector.compressstore.
518 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
519 public:
520   explicit VectorCompressStoreOpConversion(MLIRContext *context,
521                                            LLVMTypeConverter &typeConverter)
522       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
523                              context, typeConverter) {}
524 
525   LogicalResult
526   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
527                   ConversionPatternRewriter &rewriter) const override {
528     auto loc = op->getLoc();
529     auto compress = cast<vector::CompressStoreOp>(op);
530     auto adaptor = vector::CompressStoreOpAdaptor(operands);
531 
532     Value ptr;
533     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
534                           compress.getMemRefType(), ptr)))
535       return failure();
536 
537     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
538         op, adaptor.value(), ptr, adaptor.mask());
539     return success();
540   }
541 };
542 
543 /// Conversion pattern for all vector reductions.
544 class VectorReductionOpConversion : public ConvertToLLVMPattern {
545 public:
546   explicit VectorReductionOpConversion(MLIRContext *context,
547                                        LLVMTypeConverter &typeConverter,
548                                        bool reassociateFPRed)
549       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
550                              typeConverter),
551         reassociateFPReductions(reassociateFPRed) {}
552 
553   LogicalResult
554   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
555                   ConversionPatternRewriter &rewriter) const override {
556     auto reductionOp = cast<vector::ReductionOp>(op);
557     auto kind = reductionOp.kind();
558     Type eltType = reductionOp.dest().getType();
559     Type llvmType = typeConverter.convertType(eltType);
560     if (eltType.isIntOrIndex()) {
561       // Integer reductions: add/mul/min/max/and/or/xor.
562       if (kind == "add")
563         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
564             op, llvmType, operands[0]);
565       else if (kind == "mul")
566         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
567             op, llvmType, operands[0]);
568       else if (kind == "min" &&
569                (eltType.isIndex() || eltType.isUnsignedInteger()))
570         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
571             op, llvmType, operands[0]);
572       else if (kind == "min")
573         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
574             op, llvmType, operands[0]);
575       else if (kind == "max" &&
576                (eltType.isIndex() || eltType.isUnsignedInteger()))
577         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
578             op, llvmType, operands[0]);
579       else if (kind == "max")
580         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
581             op, llvmType, operands[0]);
582       else if (kind == "and")
583         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
584             op, llvmType, operands[0]);
585       else if (kind == "or")
586         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
587             op, llvmType, operands[0]);
588       else if (kind == "xor")
589         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
590             op, llvmType, operands[0]);
591       else
592         return failure();
593       return success();
594 
595     } else if (eltType.isa<FloatType>()) {
596       // Floating-point reductions: add/mul/min/max
597       if (kind == "add") {
598         // Optional accumulator (or zero).
599         Value acc = operands.size() > 1 ? operands[1]
600                                         : rewriter.create<LLVM::ConstantOp>(
601                                               op->getLoc(), llvmType,
602                                               rewriter.getZeroAttr(eltType));
603         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
604             op, llvmType, acc, operands[0],
605             rewriter.getBoolAttr(reassociateFPReductions));
606       } else if (kind == "mul") {
607         // Optional accumulator (or one).
608         Value acc = operands.size() > 1
609                         ? operands[1]
610                         : rewriter.create<LLVM::ConstantOp>(
611                               op->getLoc(), llvmType,
612                               rewriter.getFloatAttr(eltType, 1.0));
613         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
614             op, llvmType, acc, operands[0],
615             rewriter.getBoolAttr(reassociateFPReductions));
616       } else if (kind == "min")
617         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
618             op, llvmType, operands[0]);
619       else if (kind == "max")
620         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
621             op, llvmType, operands[0]);
622       else
623         return failure();
624       return success();
625     }
626     return failure();
627   }
628 
629 private:
630   const bool reassociateFPReductions;
631 };
632 
633 /// Conversion pattern for a vector.create_mask (1-D only).
634 class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
635 public:
636   explicit VectorCreateMaskOpConversion(MLIRContext *context,
637                                         LLVMTypeConverter &typeConverter,
638                                         bool enableIndexOpt)
639       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
640                              typeConverter),
641         enableIndexOptimizations(enableIndexOpt) {}
642 
643   LogicalResult
644   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
645                   ConversionPatternRewriter &rewriter) const override {
646     auto dstType = op->getResult(0).getType().cast<VectorType>();
647     int64_t rank = dstType.getRank();
648     if (rank == 1) {
649       rewriter.replaceOp(
650           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
651                                     dstType.getDimSize(0), operands[0]));
652       return success();
653     }
654     return failure();
655   }
656 
657 private:
658   const bool enableIndexOptimizations;
659 };
660 
661 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
662 public:
663   explicit VectorShuffleOpConversion(MLIRContext *context,
664                                      LLVMTypeConverter &typeConverter)
665       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
666                              typeConverter) {}
667 
668   LogicalResult
669   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
670                   ConversionPatternRewriter &rewriter) const override {
671     auto loc = op->getLoc();
672     auto adaptor = vector::ShuffleOpAdaptor(operands);
673     auto shuffleOp = cast<vector::ShuffleOp>(op);
674     auto v1Type = shuffleOp.getV1VectorType();
675     auto v2Type = shuffleOp.getV2VectorType();
676     auto vectorType = shuffleOp.getVectorType();
677     Type llvmType = typeConverter.convertType(vectorType);
678     auto maskArrayAttr = shuffleOp.mask();
679 
680     // Bail if result type cannot be lowered.
681     if (!llvmType)
682       return failure();
683 
684     // Get rank and dimension sizes.
685     int64_t rank = vectorType.getRank();
686     assert(v1Type.getRank() == rank);
687     assert(v2Type.getRank() == rank);
688     int64_t v1Dim = v1Type.getDimSize(0);
689 
690     // For rank 1, where both operands have *exactly* the same vector type,
691     // there is direct shuffle support in LLVM. Use it!
692     if (rank == 1 && v1Type == v2Type) {
693       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
694           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
695       rewriter.replaceOp(op, shuffle);
696       return success();
697     }
698 
699     // For all other cases, insert the individual values individually.
700     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
701     int64_t insPos = 0;
702     for (auto en : llvm::enumerate(maskArrayAttr)) {
703       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
704       Value value = adaptor.v1();
705       if (extPos >= v1Dim) {
706         extPos -= v1Dim;
707         value = adaptor.v2();
708       }
709       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
710                                  rank, extPos);
711       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
712                          llvmType, rank, insPos++);
713     }
714     rewriter.replaceOp(op, insert);
715     return success();
716   }
717 };
718 
719 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
720 public:
721   explicit VectorExtractElementOpConversion(MLIRContext *context,
722                                             LLVMTypeConverter &typeConverter)
723       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
724                              context, typeConverter) {}
725 
726   LogicalResult
727   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
728                   ConversionPatternRewriter &rewriter) const override {
729     auto adaptor = vector::ExtractElementOpAdaptor(operands);
730     auto extractEltOp = cast<vector::ExtractElementOp>(op);
731     auto vectorType = extractEltOp.getVectorType();
732     auto llvmType = typeConverter.convertType(vectorType.getElementType());
733 
734     // Bail if result type cannot be lowered.
735     if (!llvmType)
736       return failure();
737 
738     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
739         op, llvmType, adaptor.vector(), adaptor.position());
740     return success();
741   }
742 };
743 
744 class VectorExtractOpConversion : public ConvertToLLVMPattern {
745 public:
746   explicit VectorExtractOpConversion(MLIRContext *context,
747                                      LLVMTypeConverter &typeConverter)
748       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
749                              typeConverter) {}
750 
751   LogicalResult
752   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
753                   ConversionPatternRewriter &rewriter) const override {
754     auto loc = op->getLoc();
755     auto adaptor = vector::ExtractOpAdaptor(operands);
756     auto extractOp = cast<vector::ExtractOp>(op);
757     auto vectorType = extractOp.getVectorType();
758     auto resultType = extractOp.getResult().getType();
759     auto llvmResultType = typeConverter.convertType(resultType);
760     auto positionArrayAttr = extractOp.position();
761 
762     // Bail if result type cannot be lowered.
763     if (!llvmResultType)
764       return failure();
765 
766     // One-shot extraction of vector from array (only requires extractvalue).
767     if (resultType.isa<VectorType>()) {
768       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
769           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
770       rewriter.replaceOp(op, extracted);
771       return success();
772     }
773 
774     // Potential extraction of 1-D vector from array.
775     auto *context = op->getContext();
776     Value extracted = adaptor.vector();
777     auto positionAttrs = positionArrayAttr.getValue();
778     if (positionAttrs.size() > 1) {
779       auto oneDVectorType = reducedVectorTypeBack(vectorType);
780       auto nMinusOnePositionAttrs =
781           ArrayAttr::get(positionAttrs.drop_back(), context);
782       extracted = rewriter.create<LLVM::ExtractValueOp>(
783           loc, typeConverter.convertType(oneDVectorType), extracted,
784           nMinusOnePositionAttrs);
785     }
786 
787     // Remaining extraction of element from 1-D LLVM vector
788     auto position = positionAttrs.back().cast<IntegerAttr>();
789     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
790     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
791     extracted =
792         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
793     rewriter.replaceOp(op, extracted);
794 
795     return success();
796   }
797 };
798 
799 /// Conversion pattern that turns a vector.fma on a 1-D vector
800 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
801 /// This does not match vectors of n >= 2 rank.
802 ///
803 /// Example:
804 /// ```
805 ///  vector.fma %a, %a, %a : vector<8xf32>
806 /// ```
807 /// is converted to:
808 /// ```
809 ///  llvm.intr.fmuladd %va, %va, %va:
810 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
811 ///    -> !llvm<"<8 x float>">
812 /// ```
813 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
814 public:
815   explicit VectorFMAOp1DConversion(MLIRContext *context,
816                                    LLVMTypeConverter &typeConverter)
817       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
818                              typeConverter) {}
819 
820   LogicalResult
821   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
822                   ConversionPatternRewriter &rewriter) const override {
823     auto adaptor = vector::FMAOpAdaptor(operands);
824     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
825     VectorType vType = fmaOp.getVectorType();
826     if (vType.getRank() != 1)
827       return failure();
828     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
829                                                  adaptor.rhs(), adaptor.acc());
830     return success();
831   }
832 };
833 
834 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
835 public:
836   explicit VectorInsertElementOpConversion(MLIRContext *context,
837                                            LLVMTypeConverter &typeConverter)
838       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
839                              context, typeConverter) {}
840 
841   LogicalResult
842   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
843                   ConversionPatternRewriter &rewriter) const override {
844     auto adaptor = vector::InsertElementOpAdaptor(operands);
845     auto insertEltOp = cast<vector::InsertElementOp>(op);
846     auto vectorType = insertEltOp.getDestVectorType();
847     auto llvmType = typeConverter.convertType(vectorType);
848 
849     // Bail if result type cannot be lowered.
850     if (!llvmType)
851       return failure();
852 
853     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
854         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
855     return success();
856   }
857 };
858 
859 class VectorInsertOpConversion : public ConvertToLLVMPattern {
860 public:
861   explicit VectorInsertOpConversion(MLIRContext *context,
862                                     LLVMTypeConverter &typeConverter)
863       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
864                              typeConverter) {}
865 
866   LogicalResult
867   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
868                   ConversionPatternRewriter &rewriter) const override {
869     auto loc = op->getLoc();
870     auto adaptor = vector::InsertOpAdaptor(operands);
871     auto insertOp = cast<vector::InsertOp>(op);
872     auto sourceType = insertOp.getSourceType();
873     auto destVectorType = insertOp.getDestVectorType();
874     auto llvmResultType = typeConverter.convertType(destVectorType);
875     auto positionArrayAttr = insertOp.position();
876 
877     // Bail if result type cannot be lowered.
878     if (!llvmResultType)
879       return failure();
880 
881     // One-shot insertion of a vector into an array (only requires insertvalue).
882     if (sourceType.isa<VectorType>()) {
883       Value inserted = rewriter.create<LLVM::InsertValueOp>(
884           loc, llvmResultType, adaptor.dest(), adaptor.source(),
885           positionArrayAttr);
886       rewriter.replaceOp(op, inserted);
887       return success();
888     }
889 
890     // Potential extraction of 1-D vector from array.
891     auto *context = op->getContext();
892     Value extracted = adaptor.dest();
893     auto positionAttrs = positionArrayAttr.getValue();
894     auto position = positionAttrs.back().cast<IntegerAttr>();
895     auto oneDVectorType = destVectorType;
896     if (positionAttrs.size() > 1) {
897       oneDVectorType = reducedVectorTypeBack(destVectorType);
898       auto nMinusOnePositionAttrs =
899           ArrayAttr::get(positionAttrs.drop_back(), context);
900       extracted = rewriter.create<LLVM::ExtractValueOp>(
901           loc, typeConverter.convertType(oneDVectorType), extracted,
902           nMinusOnePositionAttrs);
903     }
904 
905     // Insertion of an element into a 1-D LLVM vector.
906     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
907     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
908     Value inserted = rewriter.create<LLVM::InsertElementOp>(
909         loc, typeConverter.convertType(oneDVectorType), extracted,
910         adaptor.source(), constant);
911 
912     // Potential insertion of resulting 1-D vector into array.
913     if (positionAttrs.size() > 1) {
914       auto nMinusOnePositionAttrs =
915           ArrayAttr::get(positionAttrs.drop_back(), context);
916       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
917                                                       adaptor.dest(), inserted,
918                                                       nMinusOnePositionAttrs);
919     }
920 
921     rewriter.replaceOp(op, inserted);
922     return success();
923   }
924 };
925 
926 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
927 ///
928 /// Example:
929 /// ```
930 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
931 /// ```
932 /// is rewritten into:
933 /// ```
934 ///  %r = splat %f0: vector<2x4xf32>
935 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
936 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
937 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
938 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
939 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
940 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
941 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
942 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
943 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
944 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
945 ///  // %r3 holds the final value.
946 /// ```
947 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
948 public:
949   using OpRewritePattern<FMAOp>::OpRewritePattern;
950 
951   LogicalResult matchAndRewrite(FMAOp op,
952                                 PatternRewriter &rewriter) const override {
953     auto vType = op.getVectorType();
954     if (vType.getRank() < 2)
955       return failure();
956 
957     auto loc = op.getLoc();
958     auto elemType = vType.getElementType();
959     Value zero = rewriter.create<ConstantOp>(loc, elemType,
960                                              rewriter.getZeroAttr(elemType));
961     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
962     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
963       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
964       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
965       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
966       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
967       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
968     }
969     rewriter.replaceOp(op, desc);
970     return success();
971   }
972 };
973 
974 // When ranks are different, InsertStridedSlice needs to extract a properly
975 // ranked vector from the destination vector into which to insert. This pattern
976 // only takes care of this part and forwards the rest of the conversion to
977 // another pattern that converts InsertStridedSlice for operands of the same
978 // rank.
979 //
980 // RewritePattern for InsertStridedSliceOp where source and destination vectors
981 // have different ranks. In this case:
982 //   1. the proper subvector is extracted from the destination vector
983 //   2. a new InsertStridedSlice op is created to insert the source in the
984 //   destination subvector
985 //   3. the destination subvector is inserted back in the proper place
986 //   4. the op is replaced by the result of step 3.
987 // The new InsertStridedSlice from step 2. will be picked up by a
988 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
989 class VectorInsertStridedSliceOpDifferentRankRewritePattern
990     : public OpRewritePattern<InsertStridedSliceOp> {
991 public:
992   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
993 
994   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
995                                 PatternRewriter &rewriter) const override {
996     auto srcType = op.getSourceVectorType();
997     auto dstType = op.getDestVectorType();
998 
999     if (op.offsets().getValue().empty())
1000       return failure();
1001 
1002     auto loc = op.getLoc();
1003     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1004     assert(rankDiff >= 0);
1005     if (rankDiff == 0)
1006       return failure();
1007 
1008     int64_t rankRest = dstType.getRank() - rankDiff;
1009     // Extract / insert the subvector of matching rank and InsertStridedSlice
1010     // on it.
1011     Value extracted =
1012         rewriter.create<ExtractOp>(loc, op.dest(),
1013                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
1014                                                   /*dropFront=*/rankRest));
1015     // A different pattern will kick in for InsertStridedSlice with matching
1016     // ranks.
1017     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
1018         loc, op.source(), extracted,
1019         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1020         getI64SubArray(op.strides(), /*dropFront=*/0));
1021     rewriter.replaceOpWithNewOp<InsertOp>(
1022         op, stridedSliceInnerOp.getResult(), op.dest(),
1023         getI64SubArray(op.offsets(), /*dropFront=*/0,
1024                        /*dropFront=*/rankRest));
1025     return success();
1026   }
1027 };
1028 
1029 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1030 // have the same rank. In this case, we reduce
1031 //   1. the proper subvector is extracted from the destination vector
1032 //   2. a new InsertStridedSlice op is created to insert the source in the
1033 //   destination subvector
1034 //   3. the destination subvector is inserted back in the proper place
1035 //   4. the op is replaced by the result of step 3.
1036 // The new InsertStridedSlice from step 2. will be picked up by a
1037 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1038 class VectorInsertStridedSliceOpSameRankRewritePattern
1039     : public OpRewritePattern<InsertStridedSliceOp> {
1040 public:
1041   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1042       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1043     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1044     // bounded as the rank is strictly decreasing.
1045     setHasBoundedRewriteRecursion();
1046   }
1047 
1048   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1049                                 PatternRewriter &rewriter) const override {
1050     auto srcType = op.getSourceVectorType();
1051     auto dstType = op.getDestVectorType();
1052 
1053     if (op.offsets().getValue().empty())
1054       return failure();
1055 
1056     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1057     assert(rankDiff >= 0);
1058     if (rankDiff != 0)
1059       return failure();
1060 
1061     if (srcType == dstType) {
1062       rewriter.replaceOp(op, op.source());
1063       return success();
1064     }
1065 
1066     int64_t offset =
1067         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1068     int64_t size = srcType.getShape().front();
1069     int64_t stride =
1070         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1071 
1072     auto loc = op.getLoc();
1073     Value res = op.dest();
1074     // For each slice of the source vector along the most major dimension.
1075     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1076          off += stride, ++idx) {
1077       // 1. extract the proper subvector (or element) from source
1078       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1079       if (extractedSource.getType().isa<VectorType>()) {
1080         // 2. If we have a vector, extract the proper subvector from destination
1081         // Otherwise we are at the element level and no need to recurse.
1082         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1083         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1084         // smaller rank.
1085         extractedSource = rewriter.create<InsertStridedSliceOp>(
1086             loc, extractedSource, extractedDest,
1087             getI64SubArray(op.offsets(), /* dropFront=*/1),
1088             getI64SubArray(op.strides(), /* dropFront=*/1));
1089       }
1090       // 4. Insert the extractedSource into the res vector.
1091       res = insertOne(rewriter, loc, extractedSource, res, off);
1092     }
1093 
1094     rewriter.replaceOp(op, res);
1095     return success();
1096   }
1097 };
1098 
1099 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1100 /// static layout.
1101 static llvm::Optional<SmallVector<int64_t, 4>>
1102 computeContiguousStrides(MemRefType memRefType) {
1103   int64_t offset;
1104   SmallVector<int64_t, 4> strides;
1105   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1106     return None;
1107   if (!strides.empty() && strides.back() != 1)
1108     return None;
1109   // If no layout or identity layout, this is contiguous by definition.
1110   if (memRefType.getAffineMaps().empty() ||
1111       memRefType.getAffineMaps().front().isIdentity())
1112     return strides;
1113 
1114   // Otherwise, we must determine contiguity form shapes. This can only ever
1115   // work in static cases because MemRefType is underspecified to represent
1116   // contiguous dynamic shapes in other ways than with just empty/identity
1117   // layout.
1118   auto sizes = memRefType.getShape();
1119   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1120     if (ShapedType::isDynamic(sizes[index + 1]) ||
1121         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1122         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1123       return None;
1124     if (strides[index] != strides[index + 1] * sizes[index + 1])
1125       return None;
1126   }
1127   return strides;
1128 }
1129 
1130 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
1131 public:
1132   explicit VectorTypeCastOpConversion(MLIRContext *context,
1133                                       LLVMTypeConverter &typeConverter)
1134       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
1135                              typeConverter) {}
1136 
1137   LogicalResult
1138   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1139                   ConversionPatternRewriter &rewriter) const override {
1140     auto loc = op->getLoc();
1141     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
1142     MemRefType sourceMemRefType =
1143         castOp.getOperand().getType().cast<MemRefType>();
1144     MemRefType targetMemRefType =
1145         castOp.getResult().getType().cast<MemRefType>();
1146 
1147     // Only static shape casts supported atm.
1148     if (!sourceMemRefType.hasStaticShape() ||
1149         !targetMemRefType.hasStaticShape())
1150       return failure();
1151 
1152     auto llvmSourceDescriptorTy =
1153         operands[0].getType().dyn_cast<LLVM::LLVMType>();
1154     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
1155       return failure();
1156     MemRefDescriptor sourceMemRef(operands[0]);
1157 
1158     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
1159                                       .dyn_cast_or_null<LLVM::LLVMType>();
1160     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
1161       return failure();
1162 
1163     // Only contiguous source buffers supported atm.
1164     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1165     if (!sourceStrides)
1166       return failure();
1167     auto targetStrides = computeContiguousStrides(targetMemRefType);
1168     if (!targetStrides)
1169       return failure();
1170     // Only support static strides for now, regardless of contiguity.
1171     if (llvm::any_of(*targetStrides, [](int64_t stride) {
1172           return ShapedType::isDynamicStrideOrOffset(stride);
1173         }))
1174       return failure();
1175 
1176     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
1177 
1178     // Create descriptor.
1179     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1180     Type llvmTargetElementTy = desc.getElementPtrType();
1181     // Set allocated ptr.
1182     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1183     allocated =
1184         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1185     desc.setAllocatedPtr(rewriter, loc, allocated);
1186     // Set aligned ptr.
1187     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1188     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1189     desc.setAlignedPtr(rewriter, loc, ptr);
1190     // Fill offset 0.
1191     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1192     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1193     desc.setOffset(rewriter, loc, zero);
1194 
1195     // Fill size and stride descriptors in memref.
1196     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1197       int64_t index = indexedSize.index();
1198       auto sizeAttr =
1199           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1200       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1201       desc.setSize(rewriter, loc, index, size);
1202       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1203                                                 (*targetStrides)[index]);
1204       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1205       desc.setStride(rewriter, loc, index, stride);
1206     }
1207 
1208     rewriter.replaceOp(op, {desc});
1209     return success();
1210   }
1211 };
1212 
1213 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1214 /// sequence of:
1215 /// 1. Get the source/dst address as an LLVM vector pointer.
1216 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1217 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1218 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1219 /// 5. Rewrite op as a masked read or write.
1220 template <typename ConcreteOp>
1221 class VectorTransferConversion : public ConvertToLLVMPattern {
1222 public:
1223   explicit VectorTransferConversion(MLIRContext *context,
1224                                     LLVMTypeConverter &typeConv,
1225                                     bool enableIndexOpt)
1226       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1227         enableIndexOptimizations(enableIndexOpt) {}
1228 
1229   LogicalResult
1230   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1231                   ConversionPatternRewriter &rewriter) const override {
1232     auto xferOp = cast<ConcreteOp>(op);
1233     auto adaptor = getTransferOpAdapter(xferOp, operands);
1234 
1235     if (xferOp.getVectorType().getRank() > 1 ||
1236         llvm::size(xferOp.indices()) == 0)
1237       return failure();
1238     if (xferOp.permutation_map() !=
1239         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1240                                        xferOp.getVectorType().getRank(),
1241                                        op->getContext()))
1242       return failure();
1243     // Only contiguous source tensors supported atm.
1244     auto strides = computeContiguousStrides(xferOp.getMemRefType());
1245     if (!strides)
1246       return failure();
1247 
1248     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
1249 
1250     Location loc = op->getLoc();
1251     MemRefType memRefType = xferOp.getMemRefType();
1252 
1253     if (auto memrefVectorElementType =
1254             memRefType.getElementType().dyn_cast<VectorType>()) {
1255       // Memref has vector element type.
1256       if (memrefVectorElementType.getElementType() !=
1257           xferOp.getVectorType().getElementType())
1258         return failure();
1259 #ifndef NDEBUG
1260       // Check that memref vector type is a suffix of 'vectorType.
1261       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1262       unsigned resultVecRank = xferOp.getVectorType().getRank();
1263       assert(memrefVecEltRank <= resultVecRank);
1264       // TODO: Move this to isSuffix in Vector/Utils.h.
1265       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1266       auto memrefVecEltShape = memrefVectorElementType.getShape();
1267       auto resultVecShape = xferOp.getVectorType().getShape();
1268       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1269         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1270                "memref vector element shape should match suffix of vector "
1271                "result shape.");
1272 #endif // ifndef NDEBUG
1273     }
1274 
1275     // 1. Get the source/dst address as an LLVM vector pointer.
1276     //    The vector pointer would always be on address space 0, therefore
1277     //    addrspacecast shall be used when source/dst memrefs are not on
1278     //    address space 0.
1279     // TODO: support alignment when possible.
1280     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1281                                adaptor.indices(), rewriter);
1282     auto vecTy =
1283         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1284     Value vectorDataPtr;
1285     if (memRefType.getMemorySpace() == 0)
1286       vectorDataPtr =
1287           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1288     else
1289       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1290           loc, vecTy.getPointerTo(), dataPtr);
1291 
1292     if (!xferOp.isMaskedDim(0))
1293       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
1294                                               xferOp, operands, vectorDataPtr);
1295 
1296     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1297     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1298     // 4. Let dim the memref dimension, compute the vector comparison mask:
1299     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1300     //
1301     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1302     //       dimensions here.
1303     unsigned vecWidth = vecTy.getVectorNumElements();
1304     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1305     Value off = xferOp.indices()[lastIndex];
1306     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1307     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1308                                        vecWidth, dim, &off);
1309 
1310     // 5. Rewrite as a masked read / write.
1311     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1312                                        operands, vectorDataPtr, mask);
1313   }
1314 
1315 private:
1316   const bool enableIndexOptimizations;
1317 };
1318 
1319 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1320 public:
1321   explicit VectorPrintOpConversion(MLIRContext *context,
1322                                    LLVMTypeConverter &typeConverter)
1323       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1324                              typeConverter) {}
1325 
1326   // Proof-of-concept lowering implementation that relies on a small
1327   // runtime support library, which only needs to provide a few
1328   // printing methods (single value for all data types, opening/closing
1329   // bracket, comma, newline). The lowering fully unrolls a vector
1330   // in terms of these elementary printing operations. The advantage
1331   // of this approach is that the library can remain unaware of all
1332   // low-level implementation details of vectors while still supporting
1333   // output of any shaped and dimensioned vector. Due to full unrolling,
1334   // this approach is less suited for very large vectors though.
1335   //
1336   // TODO: rely solely on libc in future? something else?
1337   //
1338   LogicalResult
1339   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1340                   ConversionPatternRewriter &rewriter) const override {
1341     auto printOp = cast<vector::PrintOp>(op);
1342     auto adaptor = vector::PrintOpAdaptor(operands);
1343     Type printType = printOp.getPrintType();
1344 
1345     if (typeConverter.convertType(printType) == nullptr)
1346       return failure();
1347 
1348     // Make sure element type has runtime support.
1349     PrintConversion conversion = PrintConversion::None;
1350     VectorType vectorType = printType.dyn_cast<VectorType>();
1351     Type eltType = vectorType ? vectorType.getElementType() : printType;
1352     Operation *printer;
1353     if (eltType.isF32()) {
1354       printer = getPrintFloat(op);
1355     } else if (eltType.isF64()) {
1356       printer = getPrintDouble(op);
1357     } else if (eltType.isIndex()) {
1358       printer = getPrintU64(op);
1359     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1360       // Integers need a zero or sign extension on the operand
1361       // (depending on the source type) as well as a signed or
1362       // unsigned print method. Up to 64-bit is supported.
1363       unsigned width = intTy.getWidth();
1364       if (intTy.isUnsigned()) {
1365         if (width <= 64) {
1366           if (width < 64)
1367             conversion = PrintConversion::ZeroExt64;
1368           printer = getPrintU64(op);
1369         } else {
1370           return failure();
1371         }
1372       } else {
1373         assert(intTy.isSignless() || intTy.isSigned());
1374         if (width <= 64) {
1375           // Note that we *always* zero extend booleans (1-bit integers),
1376           // so that true/false is printed as 1/0 rather than -1/0.
1377           if (width == 1)
1378             conversion = PrintConversion::ZeroExt64;
1379           else if (width < 64)
1380             conversion = PrintConversion::SignExt64;
1381           printer = getPrintI64(op);
1382         } else {
1383           return failure();
1384         }
1385       }
1386     } else {
1387       return failure();
1388     }
1389 
1390     // Unroll vector into elementary print calls.
1391     int64_t rank = vectorType ? vectorType.getRank() : 0;
1392     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1393               conversion);
1394     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1395     rewriter.eraseOp(op);
1396     return success();
1397   }
1398 
1399 private:
1400   enum class PrintConversion {
1401     // clang-format off
1402     None,
1403     ZeroExt64,
1404     SignExt64
1405     // clang-format on
1406   };
1407 
1408   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1409                  Value value, VectorType vectorType, Operation *printer,
1410                  int64_t rank, PrintConversion conversion) const {
1411     Location loc = op->getLoc();
1412     if (rank == 0) {
1413       switch (conversion) {
1414       case PrintConversion::ZeroExt64:
1415         value = rewriter.create<ZeroExtendIOp>(
1416             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1417         break;
1418       case PrintConversion::SignExt64:
1419         value = rewriter.create<SignExtendIOp>(
1420             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1421         break;
1422       case PrintConversion::None:
1423         break;
1424       }
1425       emitCall(rewriter, loc, printer, value);
1426       return;
1427     }
1428 
1429     emitCall(rewriter, loc, getPrintOpen(op));
1430     Operation *printComma = getPrintComma(op);
1431     int64_t dim = vectorType.getDimSize(0);
1432     for (int64_t d = 0; d < dim; ++d) {
1433       auto reducedType =
1434           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1435       auto llvmType = typeConverter.convertType(
1436           rank > 1 ? reducedType : vectorType.getElementType());
1437       Value nestedVal =
1438           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1439       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1440                 conversion);
1441       if (d != dim - 1)
1442         emitCall(rewriter, loc, printComma);
1443     }
1444     emitCall(rewriter, loc, getPrintClose(op));
1445   }
1446 
1447   // Helper to emit a call.
1448   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1449                        Operation *ref, ValueRange params = ValueRange()) {
1450     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1451                                   rewriter.getSymbolRefAttr(ref), params);
1452   }
1453 
1454   // Helper for printer method declaration (first hit) and lookup.
1455   static Operation *getPrint(Operation *op, StringRef name,
1456                              ArrayRef<LLVM::LLVMType> params) {
1457     auto module = op->getParentOfType<ModuleOp>();
1458     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1459     if (func)
1460       return func;
1461     OpBuilder moduleBuilder(module.getBodyRegion());
1462     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1463         op->getLoc(), name,
1464         LLVM::LLVMType::getFunctionTy(
1465             LLVM::LLVMType::getVoidTy(op->getContext()), params,
1466             /*isVarArg=*/false));
1467   }
1468 
1469   // Helpers for method names.
1470   Operation *getPrintI64(Operation *op) const {
1471     return getPrint(op, "printI64",
1472                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1473   }
1474   Operation *getPrintU64(Operation *op) const {
1475     return getPrint(op, "printU64",
1476                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1477   }
1478   Operation *getPrintFloat(Operation *op) const {
1479     return getPrint(op, "printF32",
1480                     LLVM::LLVMType::getFloatTy(op->getContext()));
1481   }
1482   Operation *getPrintDouble(Operation *op) const {
1483     return getPrint(op, "printF64",
1484                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1485   }
1486   Operation *getPrintOpen(Operation *op) const {
1487     return getPrint(op, "printOpen", {});
1488   }
1489   Operation *getPrintClose(Operation *op) const {
1490     return getPrint(op, "printClose", {});
1491   }
1492   Operation *getPrintComma(Operation *op) const {
1493     return getPrint(op, "printComma", {});
1494   }
1495   Operation *getPrintNewline(Operation *op) const {
1496     return getPrint(op, "printNewline", {});
1497   }
1498 };
1499 
1500 /// Progressive lowering of ExtractStridedSliceOp to either:
1501 ///   1. express single offset extract as a direct shuffle.
1502 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1503 class VectorExtractStridedSliceOpConversion
1504     : public OpRewritePattern<ExtractStridedSliceOp> {
1505 public:
1506   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1507       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1508     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1509     // is bounded as the rank is strictly decreasing.
1510     setHasBoundedRewriteRecursion();
1511   }
1512 
1513   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1514                                 PatternRewriter &rewriter) const override {
1515     auto dstType = op.getResult().getType().cast<VectorType>();
1516 
1517     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1518 
1519     int64_t offset =
1520         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1521     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1522     int64_t stride =
1523         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1524 
1525     auto loc = op.getLoc();
1526     auto elemType = dstType.getElementType();
1527     assert(elemType.isSignlessIntOrIndexOrFloat());
1528 
1529     // Single offset can be more efficiently shuffled.
1530     if (op.offsets().getValue().size() == 1) {
1531       SmallVector<int64_t, 4> offsets;
1532       offsets.reserve(size);
1533       for (int64_t off = offset, e = offset + size * stride; off < e;
1534            off += stride)
1535         offsets.push_back(off);
1536       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1537                                              op.vector(),
1538                                              rewriter.getI64ArrayAttr(offsets));
1539       return success();
1540     }
1541 
1542     // Extract/insert on a lower ranked extract strided slice op.
1543     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1544                                              rewriter.getZeroAttr(elemType));
1545     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1546     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1547          off += stride, ++idx) {
1548       Value one = extractOne(rewriter, loc, op.vector(), off);
1549       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1550           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1551           getI64SubArray(op.sizes(), /* dropFront=*/1),
1552           getI64SubArray(op.strides(), /* dropFront=*/1));
1553       res = insertOne(rewriter, loc, extracted, res, idx);
1554     }
1555     rewriter.replaceOp(op, res);
1556     return success();
1557   }
1558 };
1559 
1560 } // namespace
1561 
1562 /// Populate the given list with patterns that convert from Vector to LLVM.
1563 void mlir::populateVectorToLLVMConversionPatterns(
1564     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1565     bool reassociateFPReductions, bool enableIndexOptimizations) {
1566   MLIRContext *ctx = converter.getDialect()->getContext();
1567   // clang-format off
1568   patterns.insert<VectorFMAOpNDRewritePattern,
1569                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1570                   VectorInsertStridedSliceOpSameRankRewritePattern,
1571                   VectorExtractStridedSliceOpConversion>(ctx);
1572   patterns.insert<VectorReductionOpConversion>(
1573       ctx, converter, reassociateFPReductions);
1574   patterns.insert<VectorCreateMaskOpConversion,
1575                   VectorTransferConversion<TransferReadOp>,
1576                   VectorTransferConversion<TransferWriteOp>>(
1577       ctx, converter, enableIndexOptimizations);
1578   patterns
1579       .insert<VectorShuffleOpConversion,
1580               VectorExtractElementOpConversion,
1581               VectorExtractOpConversion,
1582               VectorFMAOp1DConversion,
1583               VectorInsertElementOpConversion,
1584               VectorInsertOpConversion,
1585               VectorPrintOpConversion,
1586               VectorTypeCastOpConversion,
1587               VectorMaskedLoadOpConversion,
1588               VectorMaskedStoreOpConversion,
1589               VectorGatherOpConversion,
1590               VectorScatterOpConversion,
1591               VectorExpandLoadOpConversion,
1592               VectorCompressStoreOpConversion>(ctx, converter);
1593   // clang-format on
1594 }
1595 
1596 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1597     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1598   MLIRContext *ctx = converter.getDialect()->getContext();
1599   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1600   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1601 }
1602 
1603 namespace {
1604 struct LowerVectorToLLVMPass
1605     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
1606   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1607     this->reassociateFPReductions = options.reassociateFPReductions;
1608     this->enableIndexOptimizations = options.enableIndexOptimizations;
1609   }
1610   void runOnOperation() override;
1611 };
1612 } // namespace
1613 
1614 void LowerVectorToLLVMPass::runOnOperation() {
1615   // Perform progressive lowering of operations on slices and
1616   // all contraction operations. Also applies folding and DCE.
1617   {
1618     OwningRewritePatternList patterns;
1619     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
1620     populateVectorSlicesLoweringPatterns(patterns, &getContext());
1621     populateVectorContractLoweringPatterns(patterns, &getContext());
1622     applyPatternsAndFoldGreedily(getOperation(), patterns);
1623   }
1624 
1625   // Convert to the LLVM IR dialect.
1626   LLVMTypeConverter converter(&getContext());
1627   OwningRewritePatternList patterns;
1628   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1629   populateVectorToLLVMConversionPatterns(
1630       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
1631   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
1632   populateStdToLLVMConversionPatterns(converter, patterns);
1633 
1634   LLVMConversionTarget target(getContext());
1635   if (failed(applyPartialConversion(getOperation(), target, patterns)))
1636     signalPassFailure();
1637 }
1638 
1639 std::unique_ptr<OperationPass<ModuleOp>>
1640 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
1641   return std::make_unique<LowerVectorToLLVMPass>(options);
1642 }
1643