xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision c95acf052b53e5c18e380b8632e7de24b5e65dbe)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/Target/LLVMIR/TypeTranslation.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 // Helper to reduce vector type by one rank at front.
24 static VectorType reducedVectorTypeFront(VectorType tp) {
25   assert((tp.getRank() > 1) && "unlowerable vector type");
26   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
27 }
28 
29 // Helper to reduce vector type by *all* but one rank at back.
30 static VectorType reducedVectorTypeBack(VectorType tp) {
31   assert((tp.getRank() > 1) && "unlowerable vector type");
32   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
33 }
34 
35 // Helper that picks the proper sequence for inserting.
36 static Value insertOne(ConversionPatternRewriter &rewriter,
37                        LLVMTypeConverter &typeConverter, Location loc,
38                        Value val1, Value val2, Type llvmType, int64_t rank,
39                        int64_t pos) {
40   if (rank == 1) {
41     auto idxType = rewriter.getIndexType();
42     auto constant = rewriter.create<LLVM::ConstantOp>(
43         loc, typeConverter.convertType(idxType),
44         rewriter.getIntegerAttr(idxType, pos));
45     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
46                                                   constant);
47   }
48   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
49                                               rewriter.getI64ArrayAttr(pos));
50 }
51 
52 // Helper that picks the proper sequence for inserting.
53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
54                        Value into, int64_t offset) {
55   auto vectorType = into.getType().cast<VectorType>();
56   if (vectorType.getRank() > 1)
57     return rewriter.create<InsertOp>(loc, from, into, offset);
58   return rewriter.create<vector::InsertElementOp>(
59       loc, vectorType, from, into,
60       rewriter.create<ConstantIndexOp>(loc, offset));
61 }
62 
63 // Helper that picks the proper sequence for extracting.
64 static Value extractOne(ConversionPatternRewriter &rewriter,
65                         LLVMTypeConverter &typeConverter, Location loc,
66                         Value val, Type llvmType, int64_t rank, int64_t pos) {
67   if (rank == 1) {
68     auto idxType = rewriter.getIndexType();
69     auto constant = rewriter.create<LLVM::ConstantOp>(
70         loc, typeConverter.convertType(idxType),
71         rewriter.getIntegerAttr(idxType, pos));
72     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
73                                                    constant);
74   }
75   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
76                                                rewriter.getI64ArrayAttr(pos));
77 }
78 
79 // Helper that picks the proper sequence for extracting.
80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
81                         int64_t offset) {
82   auto vectorType = vector.getType().cast<VectorType>();
83   if (vectorType.getRank() > 1)
84     return rewriter.create<ExtractOp>(loc, vector, offset);
85   return rewriter.create<vector::ExtractElementOp>(
86       loc, vectorType.getElementType(), vector,
87       rewriter.create<ConstantIndexOp>(loc, offset));
88 }
89 
90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
91 // TODO: Better support for attribute subtype forwarding + slicing.
92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
93                                               unsigned dropFront = 0,
94                                               unsigned dropBack = 0) {
95   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
96   auto range = arrayAttr.getAsRange<IntegerAttr>();
97   SmallVector<int64_t, 4> res;
98   res.reserve(arrayAttr.size() - dropFront - dropBack);
99   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
100        it != eit; ++it)
101     res.push_back((*it).getValue().getSExtValue());
102   return res;
103 }
104 
105 // Helper that returns a vector comparison that constructs a mask:
106 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
107 //
108 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
109 //       much more compact, IR for this operation, but LLVM eventually
110 //       generates more elaborate instructions for this intrinsic since it
111 //       is very conservative on the boundary conditions.
112 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
113                                    Operation *op, bool enableIndexOptimizations,
114                                    int64_t dim, Value b, Value *off = nullptr) {
115   auto loc = op->getLoc();
116   // If we can assume all indices fit in 32-bit, we perform the vector
117   // comparison in 32-bit to get a higher degree of SIMD parallelism.
118   // Otherwise we perform the vector comparison using 64-bit indices.
119   Value indices;
120   Type idxType;
121   if (enableIndexOptimizations) {
122     indices = rewriter.create<ConstantOp>(
123         loc, rewriter.getI32VectorAttr(
124                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
125     idxType = rewriter.getI32Type();
126   } else {
127     indices = rewriter.create<ConstantOp>(
128         loc, rewriter.getI64VectorAttr(
129                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
130     idxType = rewriter.getI64Type();
131   }
132   // Add in an offset if requested.
133   if (off) {
134     Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
135     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
136     indices = rewriter.create<AddIOp>(loc, ov, indices);
137   }
138   // Construct the vector comparison.
139   Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
140   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
141   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
142 }
143 
144 // Helper that returns data layout alignment of an operation with memref.
145 template <typename T>
146 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
147                                  unsigned &align) {
148   Type elementTy =
149       typeConverter.convertType(op.getMemRefType().getElementType());
150   if (!elementTy)
151     return failure();
152 
153   // TODO: this should use the MLIR data layout when it becomes available and
154   // stop depending on translation.
155   llvm::LLVMContext llvmContext;
156   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
157               .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
158                                      typeConverter.getDataLayout());
159   return success();
160 }
161 
162 // Helper that returns the base address of a memref.
163 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
164                              Value memref, MemRefType memRefType, Value &base) {
165   // Inspect stride and offset structure.
166   //
167   // TODO: flat memory only for now, generalize
168   //
169   int64_t offset;
170   SmallVector<int64_t, 4> strides;
171   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
172   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
173       offset != 0 || memRefType.getMemorySpace() != 0)
174     return failure();
175   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
176   return success();
177 }
178 
179 // Helper that returns a pointer given a memref base.
180 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
181                                 Location loc, Value memref,
182                                 MemRefType memRefType, Value &ptr) {
183   Value base;
184   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
185     return failure();
186   auto pType = MemRefDescriptor(memref).getElementPtrType();
187   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
188   return success();
189 }
190 
191 // Helper that returns a bit-casted pointer given a memref base.
192 static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
193                                 Location loc, Value memref,
194                                 MemRefType memRefType, Type type, Value &ptr) {
195   Value base;
196   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
197     return failure();
198   auto pType = type.template cast<LLVM::LLVMType>().getPointerTo();
199   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
200   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
201   return success();
202 }
203 
204 // Helper that returns vector of pointers given a memref base and an index
205 // vector.
206 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
207                                     Location loc, Value memref, Value indices,
208                                     MemRefType memRefType, VectorType vType,
209                                     Type iType, Value &ptrs) {
210   Value base;
211   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
212     return failure();
213   auto pType = MemRefDescriptor(memref).getElementPtrType();
214   auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
215   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
216   return success();
217 }
218 
219 static LogicalResult
220 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
221                                  LLVMTypeConverter &typeConverter, Location loc,
222                                  TransferReadOp xferOp,
223                                  ArrayRef<Value> operands, Value dataPtr) {
224   unsigned align;
225   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
226     return failure();
227   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
228   return success();
229 }
230 
231 static LogicalResult
232 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
233                             LLVMTypeConverter &typeConverter, Location loc,
234                             TransferReadOp xferOp, ArrayRef<Value> operands,
235                             Value dataPtr, Value mask) {
236   auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
237   VectorType fillType = xferOp.getVectorType();
238   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
239   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
240 
241   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
242   if (!vecTy)
243     return failure();
244 
245   unsigned align;
246   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
247     return failure();
248 
249   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
250       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
251       rewriter.getI32IntegerAttr(align));
252   return success();
253 }
254 
255 static LogicalResult
256 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
257                                  LLVMTypeConverter &typeConverter, Location loc,
258                                  TransferWriteOp xferOp,
259                                  ArrayRef<Value> operands, Value dataPtr) {
260   unsigned align;
261   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
262     return failure();
263   auto adaptor = TransferWriteOpAdaptor(operands);
264   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
265                                              align);
266   return success();
267 }
268 
269 static LogicalResult
270 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
271                             LLVMTypeConverter &typeConverter, Location loc,
272                             TransferWriteOp xferOp, ArrayRef<Value> operands,
273                             Value dataPtr, Value mask) {
274   unsigned align;
275   if (failed(getMemRefAlignment(typeConverter, xferOp, align)))
276     return failure();
277 
278   auto adaptor = TransferWriteOpAdaptor(operands);
279   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
280       xferOp, adaptor.vector(), dataPtr, mask,
281       rewriter.getI32IntegerAttr(align));
282   return success();
283 }
284 
285 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
286                                                   ArrayRef<Value> operands) {
287   return TransferReadOpAdaptor(operands);
288 }
289 
290 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
291                                                    ArrayRef<Value> operands) {
292   return TransferWriteOpAdaptor(operands);
293 }
294 
295 namespace {
296 
297 /// Conversion pattern for a vector.matrix_multiply.
298 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
299 class VectorMatmulOpConversion : public ConvertToLLVMPattern {
300 public:
301   explicit VectorMatmulOpConversion(MLIRContext *context,
302                                     LLVMTypeConverter &typeConverter)
303       : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
304                              typeConverter) {}
305 
306   LogicalResult
307   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto matmulOp = cast<vector::MatmulOp>(op);
310     auto adaptor = vector::MatmulOpAdaptor(operands);
311     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
312         op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
313         adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
314         matmulOp.rhs_columns());
315     return success();
316   }
317 };
318 
319 /// Conversion pattern for a vector.flat_transpose.
320 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
321 class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
322 public:
323   explicit VectorFlatTransposeOpConversion(MLIRContext *context,
324                                            LLVMTypeConverter &typeConverter)
325       : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(),
326                              context, typeConverter) {}
327 
328   LogicalResult
329   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
330                   ConversionPatternRewriter &rewriter) const override {
331     auto transOp = cast<vector::FlatTransposeOp>(op);
332     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
333     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
334         transOp, typeConverter.convertType(transOp.res().getType()),
335         adaptor.matrix(), transOp.rows(), transOp.columns());
336     return success();
337   }
338 };
339 
340 /// Conversion pattern for a vector.maskedload.
341 class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern {
342 public:
343   explicit VectorMaskedLoadOpConversion(MLIRContext *context,
344                                         LLVMTypeConverter &typeConverter)
345       : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context,
346                              typeConverter) {}
347 
348   LogicalResult
349   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
350                   ConversionPatternRewriter &rewriter) const override {
351     auto loc = op->getLoc();
352     auto load = cast<vector::MaskedLoadOp>(op);
353     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
354 
355     // Resolve alignment.
356     unsigned align;
357     if (failed(getMemRefAlignment(typeConverter, load, align)))
358       return failure();
359 
360     auto vtype = typeConverter.convertType(load.getResultVectorType());
361     Value ptr;
362     if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
363                           vtype, ptr)))
364       return failure();
365 
366     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
367         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
368         rewriter.getI32IntegerAttr(align));
369     return success();
370   }
371 };
372 
373 /// Conversion pattern for a vector.maskedstore.
374 class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern {
375 public:
376   explicit VectorMaskedStoreOpConversion(MLIRContext *context,
377                                          LLVMTypeConverter &typeConverter)
378       : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context,
379                              typeConverter) {}
380 
381   LogicalResult
382   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
383                   ConversionPatternRewriter &rewriter) const override {
384     auto loc = op->getLoc();
385     auto store = cast<vector::MaskedStoreOp>(op);
386     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
387 
388     // Resolve alignment.
389     unsigned align;
390     if (failed(getMemRefAlignment(typeConverter, store, align)))
391       return failure();
392 
393     auto vtype = typeConverter.convertType(store.getValueVectorType());
394     Value ptr;
395     if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
396                           vtype, ptr)))
397       return failure();
398 
399     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
400         store, adaptor.value(), ptr, adaptor.mask(),
401         rewriter.getI32IntegerAttr(align));
402     return success();
403   }
404 };
405 
406 /// Conversion pattern for a vector.gather.
407 class VectorGatherOpConversion : public ConvertToLLVMPattern {
408 public:
409   explicit VectorGatherOpConversion(MLIRContext *context,
410                                     LLVMTypeConverter &typeConverter)
411       : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context,
412                              typeConverter) {}
413 
414   LogicalResult
415   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
416                   ConversionPatternRewriter &rewriter) const override {
417     auto loc = op->getLoc();
418     auto gather = cast<vector::GatherOp>(op);
419     auto adaptor = vector::GatherOpAdaptor(operands);
420 
421     // Resolve alignment.
422     unsigned align;
423     if (failed(getMemRefAlignment(typeConverter, gather, align)))
424       return failure();
425 
426     // Get index ptrs.
427     VectorType vType = gather.getResultVectorType();
428     Type iType = gather.getIndicesVectorType().getElementType();
429     Value ptrs;
430     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
431                               gather.getMemRefType(), vType, iType, ptrs)))
432       return failure();
433 
434     // Replace with the gather intrinsic.
435     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
436         gather, typeConverter.convertType(vType), ptrs, adaptor.mask(),
437         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
438     return success();
439   }
440 };
441 
442 /// Conversion pattern for a vector.scatter.
443 class VectorScatterOpConversion : public ConvertToLLVMPattern {
444 public:
445   explicit VectorScatterOpConversion(MLIRContext *context,
446                                      LLVMTypeConverter &typeConverter)
447       : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context,
448                              typeConverter) {}
449 
450   LogicalResult
451   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
452                   ConversionPatternRewriter &rewriter) const override {
453     auto loc = op->getLoc();
454     auto scatter = cast<vector::ScatterOp>(op);
455     auto adaptor = vector::ScatterOpAdaptor(operands);
456 
457     // Resolve alignment.
458     unsigned align;
459     if (failed(getMemRefAlignment(typeConverter, scatter, align)))
460       return failure();
461 
462     // Get index ptrs.
463     VectorType vType = scatter.getValueVectorType();
464     Type iType = scatter.getIndicesVectorType().getElementType();
465     Value ptrs;
466     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
467                               scatter.getMemRefType(), vType, iType, ptrs)))
468       return failure();
469 
470     // Replace with the scatter intrinsic.
471     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
472         scatter, adaptor.value(), ptrs, adaptor.mask(),
473         rewriter.getI32IntegerAttr(align));
474     return success();
475   }
476 };
477 
478 /// Conversion pattern for a vector.expandload.
479 class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
480 public:
481   explicit VectorExpandLoadOpConversion(MLIRContext *context,
482                                         LLVMTypeConverter &typeConverter)
483       : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
484                              typeConverter) {}
485 
486   LogicalResult
487   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
488                   ConversionPatternRewriter &rewriter) const override {
489     auto loc = op->getLoc();
490     auto expand = cast<vector::ExpandLoadOp>(op);
491     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
492 
493     Value ptr;
494     if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
495                           ptr)))
496       return failure();
497 
498     auto vType = expand.getResultVectorType();
499     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
500         op, typeConverter.convertType(vType), ptr, adaptor.mask(),
501         adaptor.pass_thru());
502     return success();
503   }
504 };
505 
506 /// Conversion pattern for a vector.compressstore.
507 class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
508 public:
509   explicit VectorCompressStoreOpConversion(MLIRContext *context,
510                                            LLVMTypeConverter &typeConverter)
511       : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
512                              context, typeConverter) {}
513 
514   LogicalResult
515   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
516                   ConversionPatternRewriter &rewriter) const override {
517     auto loc = op->getLoc();
518     auto compress = cast<vector::CompressStoreOp>(op);
519     auto adaptor = vector::CompressStoreOpAdaptor(operands);
520 
521     Value ptr;
522     if (failed(getBasePtr(rewriter, loc, adaptor.base(),
523                           compress.getMemRefType(), ptr)))
524       return failure();
525 
526     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
527         op, adaptor.value(), ptr, adaptor.mask());
528     return success();
529   }
530 };
531 
532 /// Conversion pattern for all vector reductions.
533 class VectorReductionOpConversion : public ConvertToLLVMPattern {
534 public:
535   explicit VectorReductionOpConversion(MLIRContext *context,
536                                        LLVMTypeConverter &typeConverter,
537                                        bool reassociateFPRed)
538       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
539                              typeConverter),
540         reassociateFPReductions(reassociateFPRed) {}
541 
542   LogicalResult
543   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
544                   ConversionPatternRewriter &rewriter) const override {
545     auto reductionOp = cast<vector::ReductionOp>(op);
546     auto kind = reductionOp.kind();
547     Type eltType = reductionOp.dest().getType();
548     Type llvmType = typeConverter.convertType(eltType);
549     if (eltType.isIntOrIndex()) {
550       // Integer reductions: add/mul/min/max/and/or/xor.
551       if (kind == "add")
552         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
553             op, llvmType, operands[0]);
554       else if (kind == "mul")
555         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
556             op, llvmType, operands[0]);
557       else if (kind == "min" &&
558                (eltType.isIndex() || eltType.isUnsignedInteger()))
559         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
560             op, llvmType, operands[0]);
561       else if (kind == "min")
562         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
563             op, llvmType, operands[0]);
564       else if (kind == "max" &&
565                (eltType.isIndex() || eltType.isUnsignedInteger()))
566         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
567             op, llvmType, operands[0]);
568       else if (kind == "max")
569         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
570             op, llvmType, operands[0]);
571       else if (kind == "and")
572         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
573             op, llvmType, operands[0]);
574       else if (kind == "or")
575         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
576             op, llvmType, operands[0]);
577       else if (kind == "xor")
578         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
579             op, llvmType, operands[0]);
580       else
581         return failure();
582       return success();
583 
584     } else if (eltType.isa<FloatType>()) {
585       // Floating-point reductions: add/mul/min/max
586       if (kind == "add") {
587         // Optional accumulator (or zero).
588         Value acc = operands.size() > 1 ? operands[1]
589                                         : rewriter.create<LLVM::ConstantOp>(
590                                               op->getLoc(), llvmType,
591                                               rewriter.getZeroAttr(eltType));
592         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
593             op, llvmType, acc, operands[0],
594             rewriter.getBoolAttr(reassociateFPReductions));
595       } else if (kind == "mul") {
596         // Optional accumulator (or one).
597         Value acc = operands.size() > 1
598                         ? operands[1]
599                         : rewriter.create<LLVM::ConstantOp>(
600                               op->getLoc(), llvmType,
601                               rewriter.getFloatAttr(eltType, 1.0));
602         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
603             op, llvmType, acc, operands[0],
604             rewriter.getBoolAttr(reassociateFPReductions));
605       } else if (kind == "min")
606         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
607             op, llvmType, operands[0]);
608       else if (kind == "max")
609         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
610             op, llvmType, operands[0]);
611       else
612         return failure();
613       return success();
614     }
615     return failure();
616   }
617 
618 private:
619   const bool reassociateFPReductions;
620 };
621 
622 /// Conversion pattern for a vector.create_mask (1-D only).
623 class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
624 public:
625   explicit VectorCreateMaskOpConversion(MLIRContext *context,
626                                         LLVMTypeConverter &typeConverter,
627                                         bool enableIndexOpt)
628       : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
629                              typeConverter),
630         enableIndexOptimizations(enableIndexOpt) {}
631 
632   LogicalResult
633   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
634                   ConversionPatternRewriter &rewriter) const override {
635     auto dstType = op->getResult(0).getType().cast<VectorType>();
636     int64_t rank = dstType.getRank();
637     if (rank == 1) {
638       rewriter.replaceOp(
639           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
640                                     dstType.getDimSize(0), operands[0]));
641       return success();
642     }
643     return failure();
644   }
645 
646 private:
647   const bool enableIndexOptimizations;
648 };
649 
650 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
651 public:
652   explicit VectorShuffleOpConversion(MLIRContext *context,
653                                      LLVMTypeConverter &typeConverter)
654       : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
655                              typeConverter) {}
656 
657   LogicalResult
658   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
659                   ConversionPatternRewriter &rewriter) const override {
660     auto loc = op->getLoc();
661     auto adaptor = vector::ShuffleOpAdaptor(operands);
662     auto shuffleOp = cast<vector::ShuffleOp>(op);
663     auto v1Type = shuffleOp.getV1VectorType();
664     auto v2Type = shuffleOp.getV2VectorType();
665     auto vectorType = shuffleOp.getVectorType();
666     Type llvmType = typeConverter.convertType(vectorType);
667     auto maskArrayAttr = shuffleOp.mask();
668 
669     // Bail if result type cannot be lowered.
670     if (!llvmType)
671       return failure();
672 
673     // Get rank and dimension sizes.
674     int64_t rank = vectorType.getRank();
675     assert(v1Type.getRank() == rank);
676     assert(v2Type.getRank() == rank);
677     int64_t v1Dim = v1Type.getDimSize(0);
678 
679     // For rank 1, where both operands have *exactly* the same vector type,
680     // there is direct shuffle support in LLVM. Use it!
681     if (rank == 1 && v1Type == v2Type) {
682       Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
683           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
684       rewriter.replaceOp(op, shuffle);
685       return success();
686     }
687 
688     // For all other cases, insert the individual values individually.
689     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
690     int64_t insPos = 0;
691     for (auto en : llvm::enumerate(maskArrayAttr)) {
692       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
693       Value value = adaptor.v1();
694       if (extPos >= v1Dim) {
695         extPos -= v1Dim;
696         value = adaptor.v2();
697       }
698       Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
699                                  rank, extPos);
700       insert = insertOne(rewriter, typeConverter, loc, insert, extract,
701                          llvmType, rank, insPos++);
702     }
703     rewriter.replaceOp(op, insert);
704     return success();
705   }
706 };
707 
708 class VectorExtractElementOpConversion : public ConvertToLLVMPattern {
709 public:
710   explicit VectorExtractElementOpConversion(MLIRContext *context,
711                                             LLVMTypeConverter &typeConverter)
712       : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
713                              context, typeConverter) {}
714 
715   LogicalResult
716   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
717                   ConversionPatternRewriter &rewriter) const override {
718     auto adaptor = vector::ExtractElementOpAdaptor(operands);
719     auto extractEltOp = cast<vector::ExtractElementOp>(op);
720     auto vectorType = extractEltOp.getVectorType();
721     auto llvmType = typeConverter.convertType(vectorType.getElementType());
722 
723     // Bail if result type cannot be lowered.
724     if (!llvmType)
725       return failure();
726 
727     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
728         op, llvmType, adaptor.vector(), adaptor.position());
729     return success();
730   }
731 };
732 
733 class VectorExtractOpConversion : public ConvertToLLVMPattern {
734 public:
735   explicit VectorExtractOpConversion(MLIRContext *context,
736                                      LLVMTypeConverter &typeConverter)
737       : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
738                              typeConverter) {}
739 
740   LogicalResult
741   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
742                   ConversionPatternRewriter &rewriter) const override {
743     auto loc = op->getLoc();
744     auto adaptor = vector::ExtractOpAdaptor(operands);
745     auto extractOp = cast<vector::ExtractOp>(op);
746     auto vectorType = extractOp.getVectorType();
747     auto resultType = extractOp.getResult().getType();
748     auto llvmResultType = typeConverter.convertType(resultType);
749     auto positionArrayAttr = extractOp.position();
750 
751     // Bail if result type cannot be lowered.
752     if (!llvmResultType)
753       return failure();
754 
755     // One-shot extraction of vector from array (only requires extractvalue).
756     if (resultType.isa<VectorType>()) {
757       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
758           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
759       rewriter.replaceOp(op, extracted);
760       return success();
761     }
762 
763     // Potential extraction of 1-D vector from array.
764     auto *context = op->getContext();
765     Value extracted = adaptor.vector();
766     auto positionAttrs = positionArrayAttr.getValue();
767     if (positionAttrs.size() > 1) {
768       auto oneDVectorType = reducedVectorTypeBack(vectorType);
769       auto nMinusOnePositionAttrs =
770           ArrayAttr::get(positionAttrs.drop_back(), context);
771       extracted = rewriter.create<LLVM::ExtractValueOp>(
772           loc, typeConverter.convertType(oneDVectorType), extracted,
773           nMinusOnePositionAttrs);
774     }
775 
776     // Remaining extraction of element from 1-D LLVM vector
777     auto position = positionAttrs.back().cast<IntegerAttr>();
778     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
779     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
780     extracted =
781         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
782     rewriter.replaceOp(op, extracted);
783 
784     return success();
785   }
786 };
787 
788 /// Conversion pattern that turns a vector.fma on a 1-D vector
789 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
790 /// This does not match vectors of n >= 2 rank.
791 ///
792 /// Example:
793 /// ```
794 ///  vector.fma %a, %a, %a : vector<8xf32>
795 /// ```
796 /// is converted to:
797 /// ```
798 ///  llvm.intr.fmuladd %va, %va, %va:
799 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
800 ///    -> !llvm<"<8 x float>">
801 /// ```
802 class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
803 public:
804   explicit VectorFMAOp1DConversion(MLIRContext *context,
805                                    LLVMTypeConverter &typeConverter)
806       : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
807                              typeConverter) {}
808 
809   LogicalResult
810   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
811                   ConversionPatternRewriter &rewriter) const override {
812     auto adaptor = vector::FMAOpAdaptor(operands);
813     vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
814     VectorType vType = fmaOp.getVectorType();
815     if (vType.getRank() != 1)
816       return failure();
817     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
818                                                  adaptor.rhs(), adaptor.acc());
819     return success();
820   }
821 };
822 
823 class VectorInsertElementOpConversion : public ConvertToLLVMPattern {
824 public:
825   explicit VectorInsertElementOpConversion(MLIRContext *context,
826                                            LLVMTypeConverter &typeConverter)
827       : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
828                              context, typeConverter) {}
829 
830   LogicalResult
831   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
832                   ConversionPatternRewriter &rewriter) const override {
833     auto adaptor = vector::InsertElementOpAdaptor(operands);
834     auto insertEltOp = cast<vector::InsertElementOp>(op);
835     auto vectorType = insertEltOp.getDestVectorType();
836     auto llvmType = typeConverter.convertType(vectorType);
837 
838     // Bail if result type cannot be lowered.
839     if (!llvmType)
840       return failure();
841 
842     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
843         op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
844     return success();
845   }
846 };
847 
848 class VectorInsertOpConversion : public ConvertToLLVMPattern {
849 public:
850   explicit VectorInsertOpConversion(MLIRContext *context,
851                                     LLVMTypeConverter &typeConverter)
852       : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
853                              typeConverter) {}
854 
855   LogicalResult
856   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
857                   ConversionPatternRewriter &rewriter) const override {
858     auto loc = op->getLoc();
859     auto adaptor = vector::InsertOpAdaptor(operands);
860     auto insertOp = cast<vector::InsertOp>(op);
861     auto sourceType = insertOp.getSourceType();
862     auto destVectorType = insertOp.getDestVectorType();
863     auto llvmResultType = typeConverter.convertType(destVectorType);
864     auto positionArrayAttr = insertOp.position();
865 
866     // Bail if result type cannot be lowered.
867     if (!llvmResultType)
868       return failure();
869 
870     // One-shot insertion of a vector into an array (only requires insertvalue).
871     if (sourceType.isa<VectorType>()) {
872       Value inserted = rewriter.create<LLVM::InsertValueOp>(
873           loc, llvmResultType, adaptor.dest(), adaptor.source(),
874           positionArrayAttr);
875       rewriter.replaceOp(op, inserted);
876       return success();
877     }
878 
879     // Potential extraction of 1-D vector from array.
880     auto *context = op->getContext();
881     Value extracted = adaptor.dest();
882     auto positionAttrs = positionArrayAttr.getValue();
883     auto position = positionAttrs.back().cast<IntegerAttr>();
884     auto oneDVectorType = destVectorType;
885     if (positionAttrs.size() > 1) {
886       oneDVectorType = reducedVectorTypeBack(destVectorType);
887       auto nMinusOnePositionAttrs =
888           ArrayAttr::get(positionAttrs.drop_back(), context);
889       extracted = rewriter.create<LLVM::ExtractValueOp>(
890           loc, typeConverter.convertType(oneDVectorType), extracted,
891           nMinusOnePositionAttrs);
892     }
893 
894     // Insertion of an element into a 1-D LLVM vector.
895     auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
896     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
897     Value inserted = rewriter.create<LLVM::InsertElementOp>(
898         loc, typeConverter.convertType(oneDVectorType), extracted,
899         adaptor.source(), constant);
900 
901     // Potential insertion of resulting 1-D vector into array.
902     if (positionAttrs.size() > 1) {
903       auto nMinusOnePositionAttrs =
904           ArrayAttr::get(positionAttrs.drop_back(), context);
905       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
906                                                       adaptor.dest(), inserted,
907                                                       nMinusOnePositionAttrs);
908     }
909 
910     rewriter.replaceOp(op, inserted);
911     return success();
912   }
913 };
914 
915 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
916 ///
917 /// Example:
918 /// ```
919 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
920 /// ```
921 /// is rewritten into:
922 /// ```
923 ///  %r = splat %f0: vector<2x4xf32>
924 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
925 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
926 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
927 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
928 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
929 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
930 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
931 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
932 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
933 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
934 ///  // %r3 holds the final value.
935 /// ```
936 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
937 public:
938   using OpRewritePattern<FMAOp>::OpRewritePattern;
939 
940   LogicalResult matchAndRewrite(FMAOp op,
941                                 PatternRewriter &rewriter) const override {
942     auto vType = op.getVectorType();
943     if (vType.getRank() < 2)
944       return failure();
945 
946     auto loc = op.getLoc();
947     auto elemType = vType.getElementType();
948     Value zero = rewriter.create<ConstantOp>(loc, elemType,
949                                              rewriter.getZeroAttr(elemType));
950     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
951     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
952       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
953       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
954       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
955       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
956       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
957     }
958     rewriter.replaceOp(op, desc);
959     return success();
960   }
961 };
962 
963 // When ranks are different, InsertStridedSlice needs to extract a properly
964 // ranked vector from the destination vector into which to insert. This pattern
965 // only takes care of this part and forwards the rest of the conversion to
966 // another pattern that converts InsertStridedSlice for operands of the same
967 // rank.
968 //
969 // RewritePattern for InsertStridedSliceOp where source and destination vectors
970 // have different ranks. In this case:
971 //   1. the proper subvector is extracted from the destination vector
972 //   2. a new InsertStridedSlice op is created to insert the source in the
973 //   destination subvector
974 //   3. the destination subvector is inserted back in the proper place
975 //   4. the op is replaced by the result of step 3.
976 // The new InsertStridedSlice from step 2. will be picked up by a
977 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
978 class VectorInsertStridedSliceOpDifferentRankRewritePattern
979     : public OpRewritePattern<InsertStridedSliceOp> {
980 public:
981   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
982 
983   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
984                                 PatternRewriter &rewriter) const override {
985     auto srcType = op.getSourceVectorType();
986     auto dstType = op.getDestVectorType();
987 
988     if (op.offsets().getValue().empty())
989       return failure();
990 
991     auto loc = op.getLoc();
992     int64_t rankDiff = dstType.getRank() - srcType.getRank();
993     assert(rankDiff >= 0);
994     if (rankDiff == 0)
995       return failure();
996 
997     int64_t rankRest = dstType.getRank() - rankDiff;
998     // Extract / insert the subvector of matching rank and InsertStridedSlice
999     // on it.
1000     Value extracted =
1001         rewriter.create<ExtractOp>(loc, op.dest(),
1002                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
1003                                                   /*dropFront=*/rankRest));
1004     // A different pattern will kick in for InsertStridedSlice with matching
1005     // ranks.
1006     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
1007         loc, op.source(), extracted,
1008         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1009         getI64SubArray(op.strides(), /*dropFront=*/0));
1010     rewriter.replaceOpWithNewOp<InsertOp>(
1011         op, stridedSliceInnerOp.getResult(), op.dest(),
1012         getI64SubArray(op.offsets(), /*dropFront=*/0,
1013                        /*dropFront=*/rankRest));
1014     return success();
1015   }
1016 };
1017 
1018 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1019 // have the same rank. In this case, we reduce
1020 //   1. the proper subvector is extracted from the destination vector
1021 //   2. a new InsertStridedSlice op is created to insert the source in the
1022 //   destination subvector
1023 //   3. the destination subvector is inserted back in the proper place
1024 //   4. the op is replaced by the result of step 3.
1025 // The new InsertStridedSlice from step 2. will be picked up by a
1026 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1027 class VectorInsertStridedSliceOpSameRankRewritePattern
1028     : public OpRewritePattern<InsertStridedSliceOp> {
1029 public:
1030   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1031       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1032     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1033     // bounded as the rank is strictly decreasing.
1034     setHasBoundedRewriteRecursion();
1035   }
1036 
1037   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1038                                 PatternRewriter &rewriter) const override {
1039     auto srcType = op.getSourceVectorType();
1040     auto dstType = op.getDestVectorType();
1041 
1042     if (op.offsets().getValue().empty())
1043       return failure();
1044 
1045     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1046     assert(rankDiff >= 0);
1047     if (rankDiff != 0)
1048       return failure();
1049 
1050     if (srcType == dstType) {
1051       rewriter.replaceOp(op, op.source());
1052       return success();
1053     }
1054 
1055     int64_t offset =
1056         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1057     int64_t size = srcType.getShape().front();
1058     int64_t stride =
1059         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1060 
1061     auto loc = op.getLoc();
1062     Value res = op.dest();
1063     // For each slice of the source vector along the most major dimension.
1064     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1065          off += stride, ++idx) {
1066       // 1. extract the proper subvector (or element) from source
1067       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1068       if (extractedSource.getType().isa<VectorType>()) {
1069         // 2. If we have a vector, extract the proper subvector from destination
1070         // Otherwise we are at the element level and no need to recurse.
1071         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1072         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1073         // smaller rank.
1074         extractedSource = rewriter.create<InsertStridedSliceOp>(
1075             loc, extractedSource, extractedDest,
1076             getI64SubArray(op.offsets(), /* dropFront=*/1),
1077             getI64SubArray(op.strides(), /* dropFront=*/1));
1078       }
1079       // 4. Insert the extractedSource into the res vector.
1080       res = insertOne(rewriter, loc, extractedSource, res, off);
1081     }
1082 
1083     rewriter.replaceOp(op, res);
1084     return success();
1085   }
1086 };
1087 
1088 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1089 /// static layout.
1090 static llvm::Optional<SmallVector<int64_t, 4>>
1091 computeContiguousStrides(MemRefType memRefType) {
1092   int64_t offset;
1093   SmallVector<int64_t, 4> strides;
1094   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1095     return None;
1096   if (!strides.empty() && strides.back() != 1)
1097     return None;
1098   // If no layout or identity layout, this is contiguous by definition.
1099   if (memRefType.getAffineMaps().empty() ||
1100       memRefType.getAffineMaps().front().isIdentity())
1101     return strides;
1102 
1103   // Otherwise, we must determine contiguity form shapes. This can only ever
1104   // work in static cases because MemRefType is underspecified to represent
1105   // contiguous dynamic shapes in other ways than with just empty/identity
1106   // layout.
1107   auto sizes = memRefType.getShape();
1108   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1109     if (ShapedType::isDynamic(sizes[index + 1]) ||
1110         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1111         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1112       return None;
1113     if (strides[index] != strides[index + 1] * sizes[index + 1])
1114       return None;
1115   }
1116   return strides;
1117 }
1118 
1119 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
1120 public:
1121   explicit VectorTypeCastOpConversion(MLIRContext *context,
1122                                       LLVMTypeConverter &typeConverter)
1123       : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
1124                              typeConverter) {}
1125 
1126   LogicalResult
1127   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1128                   ConversionPatternRewriter &rewriter) const override {
1129     auto loc = op->getLoc();
1130     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
1131     MemRefType sourceMemRefType =
1132         castOp.getOperand().getType().cast<MemRefType>();
1133     MemRefType targetMemRefType =
1134         castOp.getResult().getType().cast<MemRefType>();
1135 
1136     // Only static shape casts supported atm.
1137     if (!sourceMemRefType.hasStaticShape() ||
1138         !targetMemRefType.hasStaticShape())
1139       return failure();
1140 
1141     auto llvmSourceDescriptorTy =
1142         operands[0].getType().dyn_cast<LLVM::LLVMType>();
1143     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
1144       return failure();
1145     MemRefDescriptor sourceMemRef(operands[0]);
1146 
1147     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
1148                                       .dyn_cast_or_null<LLVM::LLVMType>();
1149     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
1150       return failure();
1151 
1152     // Only contiguous source buffers supported atm.
1153     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1154     if (!sourceStrides)
1155       return failure();
1156     auto targetStrides = computeContiguousStrides(targetMemRefType);
1157     if (!targetStrides)
1158       return failure();
1159     // Only support static strides for now, regardless of contiguity.
1160     if (llvm::any_of(*targetStrides, [](int64_t stride) {
1161           return ShapedType::isDynamicStrideOrOffset(stride);
1162         }))
1163       return failure();
1164 
1165     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
1166 
1167     // Create descriptor.
1168     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1169     Type llvmTargetElementTy = desc.getElementPtrType();
1170     // Set allocated ptr.
1171     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1172     allocated =
1173         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1174     desc.setAllocatedPtr(rewriter, loc, allocated);
1175     // Set aligned ptr.
1176     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1177     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1178     desc.setAlignedPtr(rewriter, loc, ptr);
1179     // Fill offset 0.
1180     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1181     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1182     desc.setOffset(rewriter, loc, zero);
1183 
1184     // Fill size and stride descriptors in memref.
1185     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1186       int64_t index = indexedSize.index();
1187       auto sizeAttr =
1188           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1189       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1190       desc.setSize(rewriter, loc, index, size);
1191       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1192                                                 (*targetStrides)[index]);
1193       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1194       desc.setStride(rewriter, loc, index, stride);
1195     }
1196 
1197     rewriter.replaceOp(op, {desc});
1198     return success();
1199   }
1200 };
1201 
1202 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1203 /// sequence of:
1204 /// 1. Get the source/dst address as an LLVM vector pointer.
1205 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1206 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1207 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1208 /// 5. Rewrite op as a masked read or write.
1209 template <typename ConcreteOp>
1210 class VectorTransferConversion : public ConvertToLLVMPattern {
1211 public:
1212   explicit VectorTransferConversion(MLIRContext *context,
1213                                     LLVMTypeConverter &typeConv,
1214                                     bool enableIndexOpt)
1215       : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
1216         enableIndexOptimizations(enableIndexOpt) {}
1217 
1218   LogicalResult
1219   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1220                   ConversionPatternRewriter &rewriter) const override {
1221     auto xferOp = cast<ConcreteOp>(op);
1222     auto adaptor = getTransferOpAdapter(xferOp, operands);
1223 
1224     if (xferOp.getVectorType().getRank() > 1 ||
1225         llvm::size(xferOp.indices()) == 0)
1226       return failure();
1227     if (xferOp.permutation_map() !=
1228         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1229                                        xferOp.getVectorType().getRank(),
1230                                        op->getContext()))
1231       return failure();
1232     // Only contiguous source tensors supported atm.
1233     auto strides = computeContiguousStrides(xferOp.getMemRefType());
1234     if (!strides)
1235       return failure();
1236 
1237     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
1238 
1239     Location loc = op->getLoc();
1240     MemRefType memRefType = xferOp.getMemRefType();
1241 
1242     if (auto memrefVectorElementType =
1243             memRefType.getElementType().dyn_cast<VectorType>()) {
1244       // Memref has vector element type.
1245       if (memrefVectorElementType.getElementType() !=
1246           xferOp.getVectorType().getElementType())
1247         return failure();
1248 #ifndef NDEBUG
1249       // Check that memref vector type is a suffix of 'vectorType.
1250       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1251       unsigned resultVecRank = xferOp.getVectorType().getRank();
1252       assert(memrefVecEltRank <= resultVecRank);
1253       // TODO: Move this to isSuffix in Vector/Utils.h.
1254       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1255       auto memrefVecEltShape = memrefVectorElementType.getShape();
1256       auto resultVecShape = xferOp.getVectorType().getShape();
1257       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1258         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1259                "memref vector element shape should match suffix of vector "
1260                "result shape.");
1261 #endif // ifndef NDEBUG
1262     }
1263 
1264     // 1. Get the source/dst address as an LLVM vector pointer.
1265     //    The vector pointer would always be on address space 0, therefore
1266     //    addrspacecast shall be used when source/dst memrefs are not on
1267     //    address space 0.
1268     // TODO: support alignment when possible.
1269     Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
1270                                          adaptor.indices(), rewriter);
1271     auto vecTy =
1272         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
1273     Value vectorDataPtr;
1274     if (memRefType.getMemorySpace() == 0)
1275       vectorDataPtr =
1276           rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
1277     else
1278       vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1279           loc, vecTy.getPointerTo(), dataPtr);
1280 
1281     if (!xferOp.isMaskedDim(0))
1282       return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc,
1283                                               xferOp, operands, vectorDataPtr);
1284 
1285     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1286     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1287     // 4. Let dim the memref dimension, compute the vector comparison mask:
1288     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1289     //
1290     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1291     //       dimensions here.
1292     unsigned vecWidth = vecTy.getVectorNumElements();
1293     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1294     Value off = xferOp.indices()[lastIndex];
1295     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
1296     Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
1297                                        vecWidth, dim, &off);
1298 
1299     // 5. Rewrite as a masked read / write.
1300     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
1301                                        operands, vectorDataPtr, mask);
1302   }
1303 
1304 private:
1305   const bool enableIndexOptimizations;
1306 };
1307 
1308 class VectorPrintOpConversion : public ConvertToLLVMPattern {
1309 public:
1310   explicit VectorPrintOpConversion(MLIRContext *context,
1311                                    LLVMTypeConverter &typeConverter)
1312       : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context,
1313                              typeConverter) {}
1314 
1315   // Proof-of-concept lowering implementation that relies on a small
1316   // runtime support library, which only needs to provide a few
1317   // printing methods (single value for all data types, opening/closing
1318   // bracket, comma, newline). The lowering fully unrolls a vector
1319   // in terms of these elementary printing operations. The advantage
1320   // of this approach is that the library can remain unaware of all
1321   // low-level implementation details of vectors while still supporting
1322   // output of any shaped and dimensioned vector. Due to full unrolling,
1323   // this approach is less suited for very large vectors though.
1324   //
1325   // TODO: rely solely on libc in future? something else?
1326   //
1327   LogicalResult
1328   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
1329                   ConversionPatternRewriter &rewriter) const override {
1330     auto printOp = cast<vector::PrintOp>(op);
1331     auto adaptor = vector::PrintOpAdaptor(operands);
1332     Type printType = printOp.getPrintType();
1333 
1334     if (typeConverter.convertType(printType) == nullptr)
1335       return failure();
1336 
1337     // Make sure element type has runtime support.
1338     PrintConversion conversion = PrintConversion::None;
1339     VectorType vectorType = printType.dyn_cast<VectorType>();
1340     Type eltType = vectorType ? vectorType.getElementType() : printType;
1341     Operation *printer;
1342     if (eltType.isF32()) {
1343       printer = getPrintFloat(op);
1344     } else if (eltType.isF64()) {
1345       printer = getPrintDouble(op);
1346     } else if (eltType.isIndex()) {
1347       printer = getPrintU64(op);
1348     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1349       // Integers need a zero or sign extension on the operand
1350       // (depending on the source type) as well as a signed or
1351       // unsigned print method. Up to 64-bit is supported.
1352       unsigned width = intTy.getWidth();
1353       if (intTy.isUnsigned()) {
1354         if (width <= 64) {
1355           if (width < 64)
1356             conversion = PrintConversion::ZeroExt64;
1357           printer = getPrintU64(op);
1358         } else {
1359           return failure();
1360         }
1361       } else {
1362         assert(intTy.isSignless() || intTy.isSigned());
1363         if (width <= 64) {
1364           // Note that we *always* zero extend booleans (1-bit integers),
1365           // so that true/false is printed as 1/0 rather than -1/0.
1366           if (width == 1)
1367             conversion = PrintConversion::ZeroExt64;
1368           else if (width < 64)
1369             conversion = PrintConversion::SignExt64;
1370           printer = getPrintI64(op);
1371         } else {
1372           return failure();
1373         }
1374       }
1375     } else {
1376       return failure();
1377     }
1378 
1379     // Unroll vector into elementary print calls.
1380     int64_t rank = vectorType ? vectorType.getRank() : 0;
1381     emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
1382               conversion);
1383     emitCall(rewriter, op->getLoc(), getPrintNewline(op));
1384     rewriter.eraseOp(op);
1385     return success();
1386   }
1387 
1388 private:
1389   enum class PrintConversion {
1390     // clang-format off
1391     None,
1392     ZeroExt64,
1393     SignExt64
1394     // clang-format on
1395   };
1396 
1397   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1398                  Value value, VectorType vectorType, Operation *printer,
1399                  int64_t rank, PrintConversion conversion) const {
1400     Location loc = op->getLoc();
1401     if (rank == 0) {
1402       switch (conversion) {
1403       case PrintConversion::ZeroExt64:
1404         value = rewriter.create<ZeroExtendIOp>(
1405             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1406         break;
1407       case PrintConversion::SignExt64:
1408         value = rewriter.create<SignExtendIOp>(
1409             loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
1410         break;
1411       case PrintConversion::None:
1412         break;
1413       }
1414       emitCall(rewriter, loc, printer, value);
1415       return;
1416     }
1417 
1418     emitCall(rewriter, loc, getPrintOpen(op));
1419     Operation *printComma = getPrintComma(op);
1420     int64_t dim = vectorType.getDimSize(0);
1421     for (int64_t d = 0; d < dim; ++d) {
1422       auto reducedType =
1423           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1424       auto llvmType = typeConverter.convertType(
1425           rank > 1 ? reducedType : vectorType.getElementType());
1426       Value nestedVal =
1427           extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
1428       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1429                 conversion);
1430       if (d != dim - 1)
1431         emitCall(rewriter, loc, printComma);
1432     }
1433     emitCall(rewriter, loc, getPrintClose(op));
1434   }
1435 
1436   // Helper to emit a call.
1437   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1438                        Operation *ref, ValueRange params = ValueRange()) {
1439     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1440                                   rewriter.getSymbolRefAttr(ref), params);
1441   }
1442 
1443   // Helper for printer method declaration (first hit) and lookup.
1444   static Operation *getPrint(Operation *op, StringRef name,
1445                              ArrayRef<LLVM::LLVMType> params) {
1446     auto module = op->getParentOfType<ModuleOp>();
1447     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1448     if (func)
1449       return func;
1450     OpBuilder moduleBuilder(module.getBodyRegion());
1451     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1452         op->getLoc(), name,
1453         LLVM::LLVMType::getFunctionTy(
1454             LLVM::LLVMType::getVoidTy(op->getContext()), params,
1455             /*isVarArg=*/false));
1456   }
1457 
1458   // Helpers for method names.
1459   Operation *getPrintI64(Operation *op) const {
1460     return getPrint(op, "printI64",
1461                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1462   }
1463   Operation *getPrintU64(Operation *op) const {
1464     return getPrint(op, "printU64",
1465                     LLVM::LLVMType::getInt64Ty(op->getContext()));
1466   }
1467   Operation *getPrintFloat(Operation *op) const {
1468     return getPrint(op, "printF32",
1469                     LLVM::LLVMType::getFloatTy(op->getContext()));
1470   }
1471   Operation *getPrintDouble(Operation *op) const {
1472     return getPrint(op, "printF64",
1473                     LLVM::LLVMType::getDoubleTy(op->getContext()));
1474   }
1475   Operation *getPrintOpen(Operation *op) const {
1476     return getPrint(op, "printOpen", {});
1477   }
1478   Operation *getPrintClose(Operation *op) const {
1479     return getPrint(op, "printClose", {});
1480   }
1481   Operation *getPrintComma(Operation *op) const {
1482     return getPrint(op, "printComma", {});
1483   }
1484   Operation *getPrintNewline(Operation *op) const {
1485     return getPrint(op, "printNewline", {});
1486   }
1487 };
1488 
1489 /// Progressive lowering of ExtractStridedSliceOp to either:
1490 ///   1. express single offset extract as a direct shuffle.
1491 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1492 class VectorExtractStridedSliceOpConversion
1493     : public OpRewritePattern<ExtractStridedSliceOp> {
1494 public:
1495   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1496       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1497     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1498     // is bounded as the rank is strictly decreasing.
1499     setHasBoundedRewriteRecursion();
1500   }
1501 
1502   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1503                                 PatternRewriter &rewriter) const override {
1504     auto dstType = op.getResult().getType().cast<VectorType>();
1505 
1506     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1507 
1508     int64_t offset =
1509         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1510     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1511     int64_t stride =
1512         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1513 
1514     auto loc = op.getLoc();
1515     auto elemType = dstType.getElementType();
1516     assert(elemType.isSignlessIntOrIndexOrFloat());
1517 
1518     // Single offset can be more efficiently shuffled.
1519     if (op.offsets().getValue().size() == 1) {
1520       SmallVector<int64_t, 4> offsets;
1521       offsets.reserve(size);
1522       for (int64_t off = offset, e = offset + size * stride; off < e;
1523            off += stride)
1524         offsets.push_back(off);
1525       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1526                                              op.vector(),
1527                                              rewriter.getI64ArrayAttr(offsets));
1528       return success();
1529     }
1530 
1531     // Extract/insert on a lower ranked extract strided slice op.
1532     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1533                                              rewriter.getZeroAttr(elemType));
1534     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1535     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1536          off += stride, ++idx) {
1537       Value one = extractOne(rewriter, loc, op.vector(), off);
1538       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1539           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1540           getI64SubArray(op.sizes(), /* dropFront=*/1),
1541           getI64SubArray(op.strides(), /* dropFront=*/1));
1542       res = insertOne(rewriter, loc, extracted, res, idx);
1543     }
1544     rewriter.replaceOp(op, res);
1545     return success();
1546   }
1547 };
1548 
1549 } // namespace
1550 
1551 /// Populate the given list with patterns that convert from Vector to LLVM.
1552 void mlir::populateVectorToLLVMConversionPatterns(
1553     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1554     bool reassociateFPReductions, bool enableIndexOptimizations) {
1555   MLIRContext *ctx = converter.getDialect()->getContext();
1556   // clang-format off
1557   patterns.insert<VectorFMAOpNDRewritePattern,
1558                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1559                   VectorInsertStridedSliceOpSameRankRewritePattern,
1560                   VectorExtractStridedSliceOpConversion>(ctx);
1561   patterns.insert<VectorReductionOpConversion>(
1562       ctx, converter, reassociateFPReductions);
1563   patterns.insert<VectorCreateMaskOpConversion,
1564                   VectorTransferConversion<TransferReadOp>,
1565                   VectorTransferConversion<TransferWriteOp>>(
1566       ctx, converter, enableIndexOptimizations);
1567   patterns
1568       .insert<VectorShuffleOpConversion,
1569               VectorExtractElementOpConversion,
1570               VectorExtractOpConversion,
1571               VectorFMAOp1DConversion,
1572               VectorInsertElementOpConversion,
1573               VectorInsertOpConversion,
1574               VectorPrintOpConversion,
1575               VectorTypeCastOpConversion,
1576               VectorMaskedLoadOpConversion,
1577               VectorMaskedStoreOpConversion,
1578               VectorGatherOpConversion,
1579               VectorScatterOpConversion,
1580               VectorExpandLoadOpConversion,
1581               VectorCompressStoreOpConversion>(ctx, converter);
1582   // clang-format on
1583 }
1584 
1585 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1586     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1587   MLIRContext *ctx = converter.getDialect()->getContext();
1588   patterns.insert<VectorMatmulOpConversion>(ctx, converter);
1589   patterns.insert<VectorFlatTransposeOpConversion>(ctx, converter);
1590 }
1591