xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision ba87f99168c93461b28a4aa2d05e238ff774d57a)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/Target/LLVMIR/TypeTranslation.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 // Helper to reduce vector type by one rank at front.
24 static VectorType reducedVectorTypeFront(VectorType tp) {
25   assert((tp.getRank() > 1) && "unlowerable vector type");
26   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
27 }
28 
29 // Helper to reduce vector type by *all* but one rank at back.
30 static VectorType reducedVectorTypeBack(VectorType tp) {
31   assert((tp.getRank() > 1) && "unlowerable vector type");
32   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
33 }
34 
35 // Helper that picks the proper sequence for inserting.
36 static Value insertOne(ConversionPatternRewriter &rewriter,
37                        LLVMTypeConverter &typeConverter, Location loc,
38                        Value val1, Value val2, Type llvmType, int64_t rank,
39                        int64_t pos) {
40   if (rank == 1) {
41     auto idxType = rewriter.getIndexType();
42     auto constant = rewriter.create<LLVM::ConstantOp>(
43         loc, typeConverter.convertType(idxType),
44         rewriter.getIntegerAttr(idxType, pos));
45     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
46                                                   constant);
47   }
48   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
49                                               rewriter.getI64ArrayAttr(pos));
50 }
51 
52 // Helper that picks the proper sequence for inserting.
53 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
54                        Value into, int64_t offset) {
55   auto vectorType = into.getType().cast<VectorType>();
56   if (vectorType.getRank() > 1)
57     return rewriter.create<InsertOp>(loc, from, into, offset);
58   return rewriter.create<vector::InsertElementOp>(
59       loc, vectorType, from, into,
60       rewriter.create<ConstantIndexOp>(loc, offset));
61 }
62 
63 // Helper that picks the proper sequence for extracting.
64 static Value extractOne(ConversionPatternRewriter &rewriter,
65                         LLVMTypeConverter &typeConverter, Location loc,
66                         Value val, Type llvmType, int64_t rank, int64_t pos) {
67   if (rank == 1) {
68     auto idxType = rewriter.getIndexType();
69     auto constant = rewriter.create<LLVM::ConstantOp>(
70         loc, typeConverter.convertType(idxType),
71         rewriter.getIntegerAttr(idxType, pos));
72     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
73                                                    constant);
74   }
75   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
76                                                rewriter.getI64ArrayAttr(pos));
77 }
78 
79 // Helper that picks the proper sequence for extracting.
80 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
81                         int64_t offset) {
82   auto vectorType = vector.getType().cast<VectorType>();
83   if (vectorType.getRank() > 1)
84     return rewriter.create<ExtractOp>(loc, vector, offset);
85   return rewriter.create<vector::ExtractElementOp>(
86       loc, vectorType.getElementType(), vector,
87       rewriter.create<ConstantIndexOp>(loc, offset));
88 }
89 
90 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
91 // TODO: Better support for attribute subtype forwarding + slicing.
92 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
93                                               unsigned dropFront = 0,
94                                               unsigned dropBack = 0) {
95   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
96   auto range = arrayAttr.getAsRange<IntegerAttr>();
97   SmallVector<int64_t, 4> res;
98   res.reserve(arrayAttr.size() - dropFront - dropBack);
99   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
100        it != eit; ++it)
101     res.push_back((*it).getValue().getSExtValue());
102   return res;
103 }
104 
105 static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
106                                    Location loc, Type targetType, Value value) {
107   if (targetType == value.getType())
108     return value;
109 
110   bool targetIsIndex = targetType.isIndex();
111   bool valueIsIndex = value.getType().isIndex();
112   if (targetIsIndex ^ valueIsIndex)
113     return rewriter.create<IndexCastOp>(loc, targetType, value);
114 
115   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
116   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
117   assert(targetIntegerType && valueIntegerType &&
118          "unexpected cast between types other than integers and index");
119   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
120 
121   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
122     return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
123   return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
124 }
125 
126 // Helper that returns a vector comparison that constructs a mask:
127 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
128 //
129 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
130 //       much more compact, IR for this operation, but LLVM eventually
131 //       generates more elaborate instructions for this intrinsic since it
132 //       is very conservative on the boundary conditions.
133 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
134                                    Operation *op, bool enableIndexOptimizations,
135                                    int64_t dim, Value b, Value *off = nullptr) {
136   auto loc = op->getLoc();
137   // If we can assume all indices fit in 32-bit, we perform the vector
138   // comparison in 32-bit to get a higher degree of SIMD parallelism.
139   // Otherwise we perform the vector comparison using 64-bit indices.
140   Value indices;
141   Type idxType;
142   if (enableIndexOptimizations) {
143     indices = rewriter.create<ConstantOp>(
144         loc, rewriter.getI32VectorAttr(
145                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
146     idxType = rewriter.getI32Type();
147   } else {
148     indices = rewriter.create<ConstantOp>(
149         loc, rewriter.getI64VectorAttr(
150                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
151     idxType = rewriter.getI64Type();
152   }
153   // Add in an offset if requested.
154   if (off) {
155     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
156     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
157     indices = rewriter.create<AddIOp>(loc, ov, indices);
158   }
159   // Construct the vector comparison.
160   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
161   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
162   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
163 }
164 
165 // Helper that returns data layout alignment of a memref.
166 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
167                                  MemRefType memrefType, unsigned &align) {
168   Type elementTy = typeConverter.convertType(memrefType.getElementType());
169   if (!elementTy)
170     return failure();
171 
172   // TODO: this should use the MLIR data layout when it becomes available and
173   // stop depending on translation.
174   llvm::LLVMContext llvmContext;
175   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
176               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
177   return success();
178 }
179 
180 // Helper that returns the base address of a memref.
181 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
182                              Value memref, MemRefType memRefType, Value &base) {
183   // Inspect stride and offset structure.
184   //
185   // TODO: flat memory only for now, generalize
186   //
187   int64_t offset;
188   SmallVector<int64_t, 4> strides;
189   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
190   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
191       offset != 0 || memRefType.getMemorySpace() != 0)
192     return failure();
193   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
194   return success();
195 }
196 
197 // Helper that returns vector of pointers given a memref base with index vector.
198 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
199                                     Location loc, Value memref, Value indices,
200                                     MemRefType memRefType, VectorType vType,
201                                     Type iType, Value &ptrs) {
202   Value base;
203   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
204     return failure();
205   auto pType = MemRefDescriptor(memref).getElementPtrType();
206   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
207   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
208   return success();
209 }
210 
211 // Casts a strided element pointer to a vector pointer. The vector pointer
212 // would always be on address space 0, therefore addrspacecast shall be
213 // used when source/dst memrefs are not on address space 0.
214 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
215                          Value ptr, MemRefType memRefType, Type vt) {
216   auto pType = LLVM::LLVMPointerType::get(vt);
217   if (memRefType.getMemorySpace() == 0)
218     return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
219   return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
220 }
221 
222 static LogicalResult
223 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
224                                  LLVMTypeConverter &typeConverter, Location loc,
225                                  TransferReadOp xferOp,
226                                  ArrayRef<Value> operands, Value dataPtr) {
227   unsigned align;
228   if (failed(getMemRefAlignment(
229           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
230     return failure();
231   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
232   return success();
233 }
234 
235 static LogicalResult
236 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
237                             LLVMTypeConverter &typeConverter, Location loc,
238                             TransferReadOp xferOp, ArrayRef<Value> operands,
239                             Value dataPtr, Value mask) {
240   VectorType fillType = xferOp.getVectorType();
241   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
242 
243   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
244   if (!vecTy)
245     return failure();
246 
247   unsigned align;
248   if (failed(getMemRefAlignment(
249           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
250     return failure();
251 
252   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
253       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
254       rewriter.getI32IntegerAttr(align));
255   return success();
256 }
257 
258 static LogicalResult
259 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
260                                  LLVMTypeConverter &typeConverter, Location loc,
261                                  TransferWriteOp xferOp,
262                                  ArrayRef<Value> operands, Value dataPtr) {
263   unsigned align;
264   if (failed(getMemRefAlignment(
265           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
266     return failure();
267   auto adaptor = TransferWriteOpAdaptor(operands);
268   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
269                                              align);
270   return success();
271 }
272 
273 static LogicalResult
274 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
275                             LLVMTypeConverter &typeConverter, Location loc,
276                             TransferWriteOp xferOp, ArrayRef<Value> operands,
277                             Value dataPtr, Value mask) {
278   unsigned align;
279   if (failed(getMemRefAlignment(
280           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
281     return failure();
282 
283   auto adaptor = TransferWriteOpAdaptor(operands);
284   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
285       xferOp, adaptor.vector(), dataPtr, mask,
286       rewriter.getI32IntegerAttr(align));
287   return success();
288 }
289 
290 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
291                                                   ArrayRef<Value> operands) {
292   return TransferReadOpAdaptor(operands);
293 }
294 
295 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
296                                                    ArrayRef<Value> operands) {
297   return TransferWriteOpAdaptor(operands);
298 }
299 
300 namespace {
301 
302 /// Conversion pattern for a vector.bitcast.
303 class VectorBitCastOpConversion
304     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
305 public:
306   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
307 
308   LogicalResult
309   matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
310                   ConversionPatternRewriter &rewriter) const override {
311     // Only 1-D vectors can be lowered to LLVM.
312     VectorType resultTy = bitCastOp.getType();
313     if (resultTy.getRank() != 1)
314       return failure();
315     Type newResultTy = typeConverter->convertType(resultTy);
316     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
317                                                  operands[0]);
318     return success();
319   }
320 };
321 
322 /// Conversion pattern for a vector.matrix_multiply.
323 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
324 class VectorMatmulOpConversion
325     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
326 public:
327   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
328 
329   LogicalResult
330   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
331                   ConversionPatternRewriter &rewriter) const override {
332     auto adaptor = vector::MatmulOpAdaptor(operands);
333     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
334         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
335         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
336         matmulOp.lhs_columns(), matmulOp.rhs_columns());
337     return success();
338   }
339 };
340 
341 /// Conversion pattern for a vector.flat_transpose.
342 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
343 class VectorFlatTransposeOpConversion
344     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
345 public:
346   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
347 
348   LogicalResult
349   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
350                   ConversionPatternRewriter &rewriter) const override {
351     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
352     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
353         transOp, typeConverter->convertType(transOp.res().getType()),
354         adaptor.matrix(), transOp.rows(), transOp.columns());
355     return success();
356   }
357 };
358 
359 /// Conversion pattern for a vector.maskedload.
360 class VectorMaskedLoadOpConversion
361     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
362 public:
363   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
364 
365   LogicalResult
366   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
367                   ConversionPatternRewriter &rewriter) const override {
368     auto loc = load->getLoc();
369     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
370     MemRefType memRefType = load.getMemRefType();
371 
372     // Resolve alignment.
373     unsigned align;
374     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
375       return failure();
376 
377     // Resolve address.
378     auto vtype = typeConverter->convertType(load.getResultVectorType());
379     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
380                                                adaptor.indices(), rewriter);
381     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
382 
383     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
384         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
385         rewriter.getI32IntegerAttr(align));
386     return success();
387   }
388 };
389 
390 /// Conversion pattern for a vector.maskedstore.
391 class VectorMaskedStoreOpConversion
392     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
393 public:
394   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
395 
396   LogicalResult
397   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
398                   ConversionPatternRewriter &rewriter) const override {
399     auto loc = store->getLoc();
400     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
401     MemRefType memRefType = store.getMemRefType();
402 
403     // Resolve alignment.
404     unsigned align;
405     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
406       return failure();
407 
408     // Resolve address.
409     auto vtype = typeConverter->convertType(store.getValueVectorType());
410     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
411                                                adaptor.indices(), rewriter);
412     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
413 
414     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
415         store, adaptor.value(), ptr, adaptor.mask(),
416         rewriter.getI32IntegerAttr(align));
417     return success();
418   }
419 };
420 
421 /// Conversion pattern for a vector.gather.
422 class VectorGatherOpConversion
423     : public ConvertOpToLLVMPattern<vector::GatherOp> {
424 public:
425   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
426 
427   LogicalResult
428   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
429                   ConversionPatternRewriter &rewriter) const override {
430     auto loc = gather->getLoc();
431     auto adaptor = vector::GatherOpAdaptor(operands);
432 
433     // Resolve alignment.
434     unsigned align;
435     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
436                                   align)))
437       return failure();
438 
439     // Get index ptrs.
440     VectorType vType = gather.getResultVectorType();
441     Type iType = gather.getIndicesVectorType().getElementType();
442     Value ptrs;
443     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
444                               gather.getMemRefType(), vType, iType, ptrs)))
445       return failure();
446 
447     // Replace with the gather intrinsic.
448     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
449         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
450         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
451     return success();
452   }
453 };
454 
455 /// Conversion pattern for a vector.scatter.
456 class VectorScatterOpConversion
457     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
458 public:
459   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
460 
461   LogicalResult
462   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
463                   ConversionPatternRewriter &rewriter) const override {
464     auto loc = scatter->getLoc();
465     auto adaptor = vector::ScatterOpAdaptor(operands);
466 
467     // Resolve alignment.
468     unsigned align;
469     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
470                                   align)))
471       return failure();
472 
473     // Get index ptrs.
474     VectorType vType = scatter.getValueVectorType();
475     Type iType = scatter.getIndicesVectorType().getElementType();
476     Value ptrs;
477     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
478                               scatter.getMemRefType(), vType, iType, ptrs)))
479       return failure();
480 
481     // Replace with the scatter intrinsic.
482     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
483         scatter, adaptor.value(), ptrs, adaptor.mask(),
484         rewriter.getI32IntegerAttr(align));
485     return success();
486   }
487 };
488 
489 /// Conversion pattern for a vector.expandload.
490 class VectorExpandLoadOpConversion
491     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
492 public:
493   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
494 
495   LogicalResult
496   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
497                   ConversionPatternRewriter &rewriter) const override {
498     auto loc = expand->getLoc();
499     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
500     MemRefType memRefType = expand.getMemRefType();
501 
502     // Resolve address.
503     auto vtype = typeConverter->convertType(expand.getResultVectorType());
504     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
505                                            adaptor.indices(), rewriter);
506 
507     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
508         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
509     return success();
510   }
511 };
512 
513 /// Conversion pattern for a vector.compressstore.
514 class VectorCompressStoreOpConversion
515     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
516 public:
517   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
518 
519   LogicalResult
520   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
521                   ConversionPatternRewriter &rewriter) const override {
522     auto loc = compress->getLoc();
523     auto adaptor = vector::CompressStoreOpAdaptor(operands);
524     MemRefType memRefType = compress.getMemRefType();
525 
526     // Resolve address.
527     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
528                                            adaptor.indices(), rewriter);
529 
530     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
531         compress, adaptor.value(), ptr, adaptor.mask());
532     return success();
533   }
534 };
535 
536 /// Conversion pattern for all vector reductions.
537 class VectorReductionOpConversion
538     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
539 public:
540   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
541                                        bool reassociateFPRed)
542       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
543         reassociateFPReductions(reassociateFPRed) {}
544 
545   LogicalResult
546   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
547                   ConversionPatternRewriter &rewriter) const override {
548     auto kind = reductionOp.kind();
549     Type eltType = reductionOp.dest().getType();
550     Type llvmType = typeConverter->convertType(eltType);
551     if (eltType.isIntOrIndex()) {
552       // Integer reductions: add/mul/min/max/and/or/xor.
553       if (kind == "add")
554         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
555             reductionOp, llvmType, operands[0]);
556       else if (kind == "mul")
557         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
558             reductionOp, llvmType, operands[0]);
559       else if (kind == "min" &&
560                (eltType.isIndex() || eltType.isUnsignedInteger()))
561         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
562             reductionOp, llvmType, operands[0]);
563       else if (kind == "min")
564         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
565             reductionOp, llvmType, operands[0]);
566       else if (kind == "max" &&
567                (eltType.isIndex() || eltType.isUnsignedInteger()))
568         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
569             reductionOp, llvmType, operands[0]);
570       else if (kind == "max")
571         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
572             reductionOp, llvmType, operands[0]);
573       else if (kind == "and")
574         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
575             reductionOp, llvmType, operands[0]);
576       else if (kind == "or")
577         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
578             reductionOp, llvmType, operands[0]);
579       else if (kind == "xor")
580         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
581             reductionOp, llvmType, operands[0]);
582       else
583         return failure();
584       return success();
585     }
586 
587     if (!eltType.isa<FloatType>())
588       return failure();
589 
590     // Floating-point reductions: add/mul/min/max
591     if (kind == "add") {
592       // Optional accumulator (or zero).
593       Value acc = operands.size() > 1 ? operands[1]
594                                       : rewriter.create<LLVM::ConstantOp>(
595                                             reductionOp->getLoc(), llvmType,
596                                             rewriter.getZeroAttr(eltType));
597       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
598           reductionOp, llvmType, acc, operands[0],
599           rewriter.getBoolAttr(reassociateFPReductions));
600     } else if (kind == "mul") {
601       // Optional accumulator (or one).
602       Value acc = operands.size() > 1
603                       ? operands[1]
604                       : rewriter.create<LLVM::ConstantOp>(
605                             reductionOp->getLoc(), llvmType,
606                             rewriter.getFloatAttr(eltType, 1.0));
607       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
608           reductionOp, llvmType, acc, operands[0],
609           rewriter.getBoolAttr(reassociateFPReductions));
610     } else if (kind == "min")
611       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
612           reductionOp, llvmType, operands[0]);
613     else if (kind == "max")
614       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
615           reductionOp, llvmType, operands[0]);
616     else
617       return failure();
618     return success();
619   }
620 
621 private:
622   const bool reassociateFPReductions;
623 };
624 
625 /// Conversion pattern for a vector.create_mask (1-D only).
626 class VectorCreateMaskOpConversion
627     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
628 public:
629   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
630                                         bool enableIndexOpt)
631       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
632         enableIndexOptimizations(enableIndexOpt) {}
633 
634   LogicalResult
635   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
636                   ConversionPatternRewriter &rewriter) const override {
637     auto dstType = op.getType();
638     int64_t rank = dstType.getRank();
639     if (rank == 1) {
640       rewriter.replaceOp(
641           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
642                                     dstType.getDimSize(0), operands[0]));
643       return success();
644     }
645     return failure();
646   }
647 
648 private:
649   const bool enableIndexOptimizations;
650 };
651 
652 class VectorShuffleOpConversion
653     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
654 public:
655   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
656 
657   LogicalResult
658   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
659                   ConversionPatternRewriter &rewriter) const override {
660     auto loc = shuffleOp->getLoc();
661     auto adaptor = vector::ShuffleOpAdaptor(operands);
662     auto v1Type = shuffleOp.getV1VectorType();
663     auto v2Type = shuffleOp.getV2VectorType();
664     auto vectorType = shuffleOp.getVectorType();
665     Type llvmType = typeConverter->convertType(vectorType);
666     auto maskArrayAttr = shuffleOp.mask();
667 
668     // Bail if result type cannot be lowered.
669     if (!llvmType)
670       return failure();
671 
672     // Get rank and dimension sizes.
673     int64_t rank = vectorType.getRank();
674     assert(v1Type.getRank() == rank);
675     assert(v2Type.getRank() == rank);
676     int64_t v1Dim = v1Type.getDimSize(0);
677 
678     // For rank 1, where both operands have *exactly* the same vector type,
679     // there is direct shuffle support in LLVM. Use it!
680     if (rank == 1 && v1Type == v2Type) {
681       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
682           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
683       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
684       return success();
685     }
686 
687     // For all other cases, insert the individual values individually.
688     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
689     int64_t insPos = 0;
690     for (auto en : llvm::enumerate(maskArrayAttr)) {
691       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
692       Value value = adaptor.v1();
693       if (extPos >= v1Dim) {
694         extPos -= v1Dim;
695         value = adaptor.v2();
696       }
697       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
698                                  llvmType, rank, extPos);
699       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
700                          llvmType, rank, insPos++);
701     }
702     rewriter.replaceOp(shuffleOp, insert);
703     return success();
704   }
705 };
706 
707 class VectorExtractElementOpConversion
708     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
709 public:
710   using ConvertOpToLLVMPattern<
711       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
712 
713   LogicalResult
714   matchAndRewrite(vector::ExtractElementOp extractEltOp,
715                   ArrayRef<Value> operands,
716                   ConversionPatternRewriter &rewriter) const override {
717     auto adaptor = vector::ExtractElementOpAdaptor(operands);
718     auto vectorType = extractEltOp.getVectorType();
719     auto llvmType = typeConverter->convertType(vectorType.getElementType());
720 
721     // Bail if result type cannot be lowered.
722     if (!llvmType)
723       return failure();
724 
725     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
726         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
727     return success();
728   }
729 };
730 
731 class VectorExtractOpConversion
732     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
733 public:
734   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
735 
736   LogicalResult
737   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
738                   ConversionPatternRewriter &rewriter) const override {
739     auto loc = extractOp->getLoc();
740     auto adaptor = vector::ExtractOpAdaptor(operands);
741     auto vectorType = extractOp.getVectorType();
742     auto resultType = extractOp.getResult().getType();
743     auto llvmResultType = typeConverter->convertType(resultType);
744     auto positionArrayAttr = extractOp.position();
745 
746     // Bail if result type cannot be lowered.
747     if (!llvmResultType)
748       return failure();
749 
750     // One-shot extraction of vector from array (only requires extractvalue).
751     if (resultType.isa<VectorType>()) {
752       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
753           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
754       rewriter.replaceOp(extractOp, extracted);
755       return success();
756     }
757 
758     // Potential extraction of 1-D vector from array.
759     auto *context = extractOp->getContext();
760     Value extracted = adaptor.vector();
761     auto positionAttrs = positionArrayAttr.getValue();
762     if (positionAttrs.size() > 1) {
763       auto oneDVectorType = reducedVectorTypeBack(vectorType);
764       auto nMinusOnePositionAttrs =
765           ArrayAttr::get(positionAttrs.drop_back(), context);
766       extracted = rewriter.create<LLVM::ExtractValueOp>(
767           loc, typeConverter->convertType(oneDVectorType), extracted,
768           nMinusOnePositionAttrs);
769     }
770 
771     // Remaining extraction of element from 1-D LLVM vector
772     auto position = positionAttrs.back().cast<IntegerAttr>();
773     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
774     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
775     extracted =
776         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
777     rewriter.replaceOp(extractOp, extracted);
778 
779     return success();
780   }
781 };
782 
783 /// Conversion pattern that turns a vector.fma on a 1-D vector
784 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
785 /// This does not match vectors of n >= 2 rank.
786 ///
787 /// Example:
788 /// ```
789 ///  vector.fma %a, %a, %a : vector<8xf32>
790 /// ```
791 /// is converted to:
792 /// ```
793 ///  llvm.intr.fmuladd %va, %va, %va:
794 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
795 ///    -> !llvm."<8 x f32>">
796 /// ```
797 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
798 public:
799   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
800 
801   LogicalResult
802   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
803                   ConversionPatternRewriter &rewriter) const override {
804     auto adaptor = vector::FMAOpAdaptor(operands);
805     VectorType vType = fmaOp.getVectorType();
806     if (vType.getRank() != 1)
807       return failure();
808     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
809                                                  adaptor.rhs(), adaptor.acc());
810     return success();
811   }
812 };
813 
814 class VectorInsertElementOpConversion
815     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
816 public:
817   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
818 
819   LogicalResult
820   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
821                   ConversionPatternRewriter &rewriter) const override {
822     auto adaptor = vector::InsertElementOpAdaptor(operands);
823     auto vectorType = insertEltOp.getDestVectorType();
824     auto llvmType = typeConverter->convertType(vectorType);
825 
826     // Bail if result type cannot be lowered.
827     if (!llvmType)
828       return failure();
829 
830     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
831         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
832         adaptor.position());
833     return success();
834   }
835 };
836 
837 class VectorInsertOpConversion
838     : public ConvertOpToLLVMPattern<vector::InsertOp> {
839 public:
840   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
841 
842   LogicalResult
843   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
844                   ConversionPatternRewriter &rewriter) const override {
845     auto loc = insertOp->getLoc();
846     auto adaptor = vector::InsertOpAdaptor(operands);
847     auto sourceType = insertOp.getSourceType();
848     auto destVectorType = insertOp.getDestVectorType();
849     auto llvmResultType = typeConverter->convertType(destVectorType);
850     auto positionArrayAttr = insertOp.position();
851 
852     // Bail if result type cannot be lowered.
853     if (!llvmResultType)
854       return failure();
855 
856     // One-shot insertion of a vector into an array (only requires insertvalue).
857     if (sourceType.isa<VectorType>()) {
858       Value inserted = rewriter.create<LLVM::InsertValueOp>(
859           loc, llvmResultType, adaptor.dest(), adaptor.source(),
860           positionArrayAttr);
861       rewriter.replaceOp(insertOp, inserted);
862       return success();
863     }
864 
865     // Potential extraction of 1-D vector from array.
866     auto *context = insertOp->getContext();
867     Value extracted = adaptor.dest();
868     auto positionAttrs = positionArrayAttr.getValue();
869     auto position = positionAttrs.back().cast<IntegerAttr>();
870     auto oneDVectorType = destVectorType;
871     if (positionAttrs.size() > 1) {
872       oneDVectorType = reducedVectorTypeBack(destVectorType);
873       auto nMinusOnePositionAttrs =
874           ArrayAttr::get(positionAttrs.drop_back(), context);
875       extracted = rewriter.create<LLVM::ExtractValueOp>(
876           loc, typeConverter->convertType(oneDVectorType), extracted,
877           nMinusOnePositionAttrs);
878     }
879 
880     // Insertion of an element into a 1-D LLVM vector.
881     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
882     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
883     Value inserted = rewriter.create<LLVM::InsertElementOp>(
884         loc, typeConverter->convertType(oneDVectorType), extracted,
885         adaptor.source(), constant);
886 
887     // Potential insertion of resulting 1-D vector into array.
888     if (positionAttrs.size() > 1) {
889       auto nMinusOnePositionAttrs =
890           ArrayAttr::get(positionAttrs.drop_back(), context);
891       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
892                                                       adaptor.dest(), inserted,
893                                                       nMinusOnePositionAttrs);
894     }
895 
896     rewriter.replaceOp(insertOp, inserted);
897     return success();
898   }
899 };
900 
901 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
902 ///
903 /// Example:
904 /// ```
905 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
906 /// ```
907 /// is rewritten into:
908 /// ```
909 ///  %r = splat %f0: vector<2x4xf32>
910 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
911 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
912 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
913 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
914 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
915 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
916 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
917 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
918 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
919 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
920 ///  // %r3 holds the final value.
921 /// ```
922 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
923 public:
924   using OpRewritePattern<FMAOp>::OpRewritePattern;
925 
926   LogicalResult matchAndRewrite(FMAOp op,
927                                 PatternRewriter &rewriter) const override {
928     auto vType = op.getVectorType();
929     if (vType.getRank() < 2)
930       return failure();
931 
932     auto loc = op.getLoc();
933     auto elemType = vType.getElementType();
934     Value zero = rewriter.create<ConstantOp>(loc, elemType,
935                                              rewriter.getZeroAttr(elemType));
936     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
937     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
938       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
939       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
940       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
941       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
942       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
943     }
944     rewriter.replaceOp(op, desc);
945     return success();
946   }
947 };
948 
949 // When ranks are different, InsertStridedSlice needs to extract a properly
950 // ranked vector from the destination vector into which to insert. This pattern
951 // only takes care of this part and forwards the rest of the conversion to
952 // another pattern that converts InsertStridedSlice for operands of the same
953 // rank.
954 //
955 // RewritePattern for InsertStridedSliceOp where source and destination vectors
956 // have different ranks. In this case:
957 //   1. the proper subvector is extracted from the destination vector
958 //   2. a new InsertStridedSlice op is created to insert the source in the
959 //   destination subvector
960 //   3. the destination subvector is inserted back in the proper place
961 //   4. the op is replaced by the result of step 3.
962 // The new InsertStridedSlice from step 2. will be picked up by a
963 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
964 class VectorInsertStridedSliceOpDifferentRankRewritePattern
965     : public OpRewritePattern<InsertStridedSliceOp> {
966 public:
967   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
968 
969   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
970                                 PatternRewriter &rewriter) const override {
971     auto srcType = op.getSourceVectorType();
972     auto dstType = op.getDestVectorType();
973 
974     if (op.offsets().getValue().empty())
975       return failure();
976 
977     auto loc = op.getLoc();
978     int64_t rankDiff = dstType.getRank() - srcType.getRank();
979     assert(rankDiff >= 0);
980     if (rankDiff == 0)
981       return failure();
982 
983     int64_t rankRest = dstType.getRank() - rankDiff;
984     // Extract / insert the subvector of matching rank and InsertStridedSlice
985     // on it.
986     Value extracted =
987         rewriter.create<ExtractOp>(loc, op.dest(),
988                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
989                                                   /*dropBack=*/rankRest));
990     // A different pattern will kick in for InsertStridedSlice with matching
991     // ranks.
992     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
993         loc, op.source(), extracted,
994         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
995         getI64SubArray(op.strides(), /*dropFront=*/0));
996     rewriter.replaceOpWithNewOp<InsertOp>(
997         op, stridedSliceInnerOp.getResult(), op.dest(),
998         getI64SubArray(op.offsets(), /*dropFront=*/0,
999                        /*dropBack=*/rankRest));
1000     return success();
1001   }
1002 };
1003 
1004 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1005 // have the same rank. In this case, we reduce
1006 //   1. the proper subvector is extracted from the destination vector
1007 //   2. a new InsertStridedSlice op is created to insert the source in the
1008 //   destination subvector
1009 //   3. the destination subvector is inserted back in the proper place
1010 //   4. the op is replaced by the result of step 3.
1011 // The new InsertStridedSlice from step 2. will be picked up by a
1012 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1013 class VectorInsertStridedSliceOpSameRankRewritePattern
1014     : public OpRewritePattern<InsertStridedSliceOp> {
1015 public:
1016   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1017       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1018     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1019     // bounded as the rank is strictly decreasing.
1020     setHasBoundedRewriteRecursion();
1021   }
1022 
1023   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1024                                 PatternRewriter &rewriter) const override {
1025     auto srcType = op.getSourceVectorType();
1026     auto dstType = op.getDestVectorType();
1027 
1028     if (op.offsets().getValue().empty())
1029       return failure();
1030 
1031     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1032     assert(rankDiff >= 0);
1033     if (rankDiff != 0)
1034       return failure();
1035 
1036     if (srcType == dstType) {
1037       rewriter.replaceOp(op, op.source());
1038       return success();
1039     }
1040 
1041     int64_t offset =
1042         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1043     int64_t size = srcType.getShape().front();
1044     int64_t stride =
1045         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1046 
1047     auto loc = op.getLoc();
1048     Value res = op.dest();
1049     // For each slice of the source vector along the most major dimension.
1050     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1051          off += stride, ++idx) {
1052       // 1. extract the proper subvector (or element) from source
1053       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1054       if (extractedSource.getType().isa<VectorType>()) {
1055         // 2. If we have a vector, extract the proper subvector from destination
1056         // Otherwise we are at the element level and no need to recurse.
1057         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1058         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1059         // smaller rank.
1060         extractedSource = rewriter.create<InsertStridedSliceOp>(
1061             loc, extractedSource, extractedDest,
1062             getI64SubArray(op.offsets(), /* dropFront=*/1),
1063             getI64SubArray(op.strides(), /* dropFront=*/1));
1064       }
1065       // 4. Insert the extractedSource into the res vector.
1066       res = insertOne(rewriter, loc, extractedSource, res, off);
1067     }
1068 
1069     rewriter.replaceOp(op, res);
1070     return success();
1071   }
1072 };
1073 
1074 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1075 /// static layout.
1076 static llvm::Optional<SmallVector<int64_t, 4>>
1077 computeContiguousStrides(MemRefType memRefType) {
1078   int64_t offset;
1079   SmallVector<int64_t, 4> strides;
1080   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1081     return None;
1082   if (!strides.empty() && strides.back() != 1)
1083     return None;
1084   // If no layout or identity layout, this is contiguous by definition.
1085   if (memRefType.getAffineMaps().empty() ||
1086       memRefType.getAffineMaps().front().isIdentity())
1087     return strides;
1088 
1089   // Otherwise, we must determine contiguity form shapes. This can only ever
1090   // work in static cases because MemRefType is underspecified to represent
1091   // contiguous dynamic shapes in other ways than with just empty/identity
1092   // layout.
1093   auto sizes = memRefType.getShape();
1094   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1095     if (ShapedType::isDynamic(sizes[index + 1]) ||
1096         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1097         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1098       return None;
1099     if (strides[index] != strides[index + 1] * sizes[index + 1])
1100       return None;
1101   }
1102   return strides;
1103 }
1104 
1105 class VectorTypeCastOpConversion
1106     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1107 public:
1108   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1109 
1110   LogicalResult
1111   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
1112                   ConversionPatternRewriter &rewriter) const override {
1113     auto loc = castOp->getLoc();
1114     MemRefType sourceMemRefType =
1115         castOp.getOperand().getType().cast<MemRefType>();
1116     MemRefType targetMemRefType = castOp.getType();
1117 
1118     // Only static shape casts supported atm.
1119     if (!sourceMemRefType.hasStaticShape() ||
1120         !targetMemRefType.hasStaticShape())
1121       return failure();
1122 
1123     auto llvmSourceDescriptorTy =
1124         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
1125     if (!llvmSourceDescriptorTy)
1126       return failure();
1127     MemRefDescriptor sourceMemRef(operands[0]);
1128 
1129     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1130                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1131     if (!llvmTargetDescriptorTy)
1132       return failure();
1133 
1134     // Only contiguous source buffers supported atm.
1135     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1136     if (!sourceStrides)
1137       return failure();
1138     auto targetStrides = computeContiguousStrides(targetMemRefType);
1139     if (!targetStrides)
1140       return failure();
1141     // Only support static strides for now, regardless of contiguity.
1142     if (llvm::any_of(*targetStrides, [](int64_t stride) {
1143           return ShapedType::isDynamicStrideOrOffset(stride);
1144         }))
1145       return failure();
1146 
1147     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1148 
1149     // Create descriptor.
1150     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1151     Type llvmTargetElementTy = desc.getElementPtrType();
1152     // Set allocated ptr.
1153     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1154     allocated =
1155         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1156     desc.setAllocatedPtr(rewriter, loc, allocated);
1157     // Set aligned ptr.
1158     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1159     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1160     desc.setAlignedPtr(rewriter, loc, ptr);
1161     // Fill offset 0.
1162     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1163     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1164     desc.setOffset(rewriter, loc, zero);
1165 
1166     // Fill size and stride descriptors in memref.
1167     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1168       int64_t index = indexedSize.index();
1169       auto sizeAttr =
1170           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1171       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1172       desc.setSize(rewriter, loc, index, size);
1173       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1174                                                 (*targetStrides)[index]);
1175       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1176       desc.setStride(rewriter, loc, index, stride);
1177     }
1178 
1179     rewriter.replaceOp(castOp, {desc});
1180     return success();
1181   }
1182 };
1183 
1184 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1185 /// sequence of:
1186 /// 1. Get the source/dst address as an LLVM vector pointer.
1187 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1188 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1189 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1190 /// 5. Rewrite op as a masked read or write.
1191 template <typename ConcreteOp>
1192 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
1193 public:
1194   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1195                                     bool enableIndexOpt)
1196       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1197         enableIndexOptimizations(enableIndexOpt) {}
1198 
1199   LogicalResult
1200   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
1201                   ConversionPatternRewriter &rewriter) const override {
1202     auto adaptor = getTransferOpAdapter(xferOp, operands);
1203 
1204     if (xferOp.getVectorType().getRank() > 1 ||
1205         llvm::size(xferOp.indices()) == 0)
1206       return failure();
1207     if (xferOp.permutation_map() !=
1208         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1209                                        xferOp.getVectorType().getRank(),
1210                                        xferOp->getContext()))
1211       return failure();
1212     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1213     if (!memRefType)
1214       return failure();
1215     // Only contiguous source tensors supported atm.
1216     auto strides = computeContiguousStrides(memRefType);
1217     if (!strides)
1218       return failure();
1219 
1220     auto toLLVMTy = [&](Type t) {
1221       return this->getTypeConverter()->convertType(t);
1222     };
1223 
1224     Location loc = xferOp->getLoc();
1225 
1226     if (auto memrefVectorElementType =
1227             memRefType.getElementType().template dyn_cast<VectorType>()) {
1228       // Memref has vector element type.
1229       if (memrefVectorElementType.getElementType() !=
1230           xferOp.getVectorType().getElementType())
1231         return failure();
1232 #ifndef NDEBUG
1233       // Check that memref vector type is a suffix of 'vectorType.
1234       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1235       unsigned resultVecRank = xferOp.getVectorType().getRank();
1236       assert(memrefVecEltRank <= resultVecRank);
1237       // TODO: Move this to isSuffix in Vector/Utils.h.
1238       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1239       auto memrefVecEltShape = memrefVectorElementType.getShape();
1240       auto resultVecShape = xferOp.getVectorType().getShape();
1241       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1242         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1243                "memref vector element shape should match suffix of vector "
1244                "result shape.");
1245 #endif // ifndef NDEBUG
1246     }
1247 
1248     // 1. Get the source/dst address as an LLVM vector pointer.
1249     VectorType vtp = xferOp.getVectorType();
1250     Value dataPtr = this->getStridedElementPtr(
1251         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1252     Value vectorDataPtr =
1253         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
1254 
1255     if (!xferOp.isMaskedDim(0))
1256       return replaceTransferOpWithLoadOrStore(rewriter,
1257                                               *this->getTypeConverter(), loc,
1258                                               xferOp, operands, vectorDataPtr);
1259 
1260     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1261     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1262     // 4. Let dim the memref dimension, compute the vector comparison mask:
1263     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1264     //
1265     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1266     //       dimensions here.
1267     unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
1268     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1269     Value off = xferOp.indices()[lastIndex];
1270     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1271     Value mask = buildVectorComparison(
1272         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
1273 
1274     // 5. Rewrite as a masked read / write.
1275     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1276                                        xferOp, operands, vectorDataPtr, mask);
1277   }
1278 
1279 private:
1280   const bool enableIndexOptimizations;
1281 };
1282 
1283 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1284 public:
1285   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1286 
1287   // Proof-of-concept lowering implementation that relies on a small
1288   // runtime support library, which only needs to provide a few
1289   // printing methods (single value for all data types, opening/closing
1290   // bracket, comma, newline). The lowering fully unrolls a vector
1291   // in terms of these elementary printing operations. The advantage
1292   // of this approach is that the library can remain unaware of all
1293   // low-level implementation details of vectors while still supporting
1294   // output of any shaped and dimensioned vector. Due to full unrolling,
1295   // this approach is less suited for very large vectors though.
1296   //
1297   // TODO: rely solely on libc in future? something else?
1298   //
1299   LogicalResult
1300   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1301                   ConversionPatternRewriter &rewriter) const override {
1302     auto adaptor = vector::PrintOpAdaptor(operands);
1303     Type printType = printOp.getPrintType();
1304 
1305     if (typeConverter->convertType(printType) == nullptr)
1306       return failure();
1307 
1308     // Make sure element type has runtime support.
1309     PrintConversion conversion = PrintConversion::None;
1310     VectorType vectorType = printType.dyn_cast<VectorType>();
1311     Type eltType = vectorType ? vectorType.getElementType() : printType;
1312     Operation *printer;
1313     if (eltType.isF32()) {
1314       printer = getPrintFloat(printOp);
1315     } else if (eltType.isF64()) {
1316       printer = getPrintDouble(printOp);
1317     } else if (eltType.isIndex()) {
1318       printer = getPrintU64(printOp);
1319     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1320       // Integers need a zero or sign extension on the operand
1321       // (depending on the source type) as well as a signed or
1322       // unsigned print method. Up to 64-bit is supported.
1323       unsigned width = intTy.getWidth();
1324       if (intTy.isUnsigned()) {
1325         if (width <= 64) {
1326           if (width < 64)
1327             conversion = PrintConversion::ZeroExt64;
1328           printer = getPrintU64(printOp);
1329         } else {
1330           return failure();
1331         }
1332       } else {
1333         assert(intTy.isSignless() || intTy.isSigned());
1334         if (width <= 64) {
1335           // Note that we *always* zero extend booleans (1-bit integers),
1336           // so that true/false is printed as 1/0 rather than -1/0.
1337           if (width == 1)
1338             conversion = PrintConversion::ZeroExt64;
1339           else if (width < 64)
1340             conversion = PrintConversion::SignExt64;
1341           printer = getPrintI64(printOp);
1342         } else {
1343           return failure();
1344         }
1345       }
1346     } else {
1347       return failure();
1348     }
1349 
1350     // Unroll vector into elementary print calls.
1351     int64_t rank = vectorType ? vectorType.getRank() : 0;
1352     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1353               conversion);
1354     emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
1355     rewriter.eraseOp(printOp);
1356     return success();
1357   }
1358 
1359 private:
1360   enum class PrintConversion {
1361     // clang-format off
1362     None,
1363     ZeroExt64,
1364     SignExt64
1365     // clang-format on
1366   };
1367 
1368   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1369                  Value value, VectorType vectorType, Operation *printer,
1370                  int64_t rank, PrintConversion conversion) const {
1371     Location loc = op->getLoc();
1372     if (rank == 0) {
1373       switch (conversion) {
1374       case PrintConversion::ZeroExt64:
1375         value = rewriter.create<ZeroExtendIOp>(
1376             loc, value, IntegerType::get(rewriter.getContext(), 64));
1377         break;
1378       case PrintConversion::SignExt64:
1379         value = rewriter.create<SignExtendIOp>(
1380             loc, value, IntegerType::get(rewriter.getContext(), 64));
1381         break;
1382       case PrintConversion::None:
1383         break;
1384       }
1385       emitCall(rewriter, loc, printer, value);
1386       return;
1387     }
1388 
1389     emitCall(rewriter, loc, getPrintOpen(op));
1390     Operation *printComma = getPrintComma(op);
1391     int64_t dim = vectorType.getDimSize(0);
1392     for (int64_t d = 0; d < dim; ++d) {
1393       auto reducedType =
1394           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1395       auto llvmType = typeConverter->convertType(
1396           rank > 1 ? reducedType : vectorType.getElementType());
1397       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1398                                    llvmType, rank, d);
1399       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1400                 conversion);
1401       if (d != dim - 1)
1402         emitCall(rewriter, loc, printComma);
1403     }
1404     emitCall(rewriter, loc, getPrintClose(op));
1405   }
1406 
1407   // Helper to emit a call.
1408   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1409                        Operation *ref, ValueRange params = ValueRange()) {
1410     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1411                                   rewriter.getSymbolRefAttr(ref), params);
1412   }
1413 
1414   // Helper for printer method declaration (first hit) and lookup.
1415   static Operation *getPrint(Operation *op, StringRef name,
1416                              ArrayRef<Type> params) {
1417     auto module = op->getParentOfType<ModuleOp>();
1418     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
1419     if (func)
1420       return func;
1421     OpBuilder moduleBuilder(module.getBodyRegion());
1422     return moduleBuilder.create<LLVM::LLVMFuncOp>(
1423         op->getLoc(), name,
1424         LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
1425                                     params));
1426   }
1427 
1428   // Helpers for method names.
1429   Operation *getPrintI64(Operation *op) const {
1430     return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
1431   }
1432   Operation *getPrintU64(Operation *op) const {
1433     return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
1434   }
1435   Operation *getPrintFloat(Operation *op) const {
1436     return getPrint(op, "printF32", Float32Type::get(op->getContext()));
1437   }
1438   Operation *getPrintDouble(Operation *op) const {
1439     return getPrint(op, "printF64", Float64Type::get(op->getContext()));
1440   }
1441   Operation *getPrintOpen(Operation *op) const {
1442     return getPrint(op, "printOpen", {});
1443   }
1444   Operation *getPrintClose(Operation *op) const {
1445     return getPrint(op, "printClose", {});
1446   }
1447   Operation *getPrintComma(Operation *op) const {
1448     return getPrint(op, "printComma", {});
1449   }
1450   Operation *getPrintNewline(Operation *op) const {
1451     return getPrint(op, "printNewline", {});
1452   }
1453 };
1454 
1455 /// Progressive lowering of ExtractStridedSliceOp to either:
1456 ///   1. express single offset extract as a direct shuffle.
1457 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1458 class VectorExtractStridedSliceOpConversion
1459     : public OpRewritePattern<ExtractStridedSliceOp> {
1460 public:
1461   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1462       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1463     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1464     // is bounded as the rank is strictly decreasing.
1465     setHasBoundedRewriteRecursion();
1466   }
1467 
1468   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1469                                 PatternRewriter &rewriter) const override {
1470     auto dstType = op.getType();
1471 
1472     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1473 
1474     int64_t offset =
1475         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1476     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1477     int64_t stride =
1478         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1479 
1480     auto loc = op.getLoc();
1481     auto elemType = dstType.getElementType();
1482     assert(elemType.isSignlessIntOrIndexOrFloat());
1483 
1484     // Single offset can be more efficiently shuffled.
1485     if (op.offsets().getValue().size() == 1) {
1486       SmallVector<int64_t, 4> offsets;
1487       offsets.reserve(size);
1488       for (int64_t off = offset, e = offset + size * stride; off < e;
1489            off += stride)
1490         offsets.push_back(off);
1491       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1492                                              op.vector(),
1493                                              rewriter.getI64ArrayAttr(offsets));
1494       return success();
1495     }
1496 
1497     // Extract/insert on a lower ranked extract strided slice op.
1498     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1499                                              rewriter.getZeroAttr(elemType));
1500     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1501     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1502          off += stride, ++idx) {
1503       Value one = extractOne(rewriter, loc, op.vector(), off);
1504       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1505           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1506           getI64SubArray(op.sizes(), /* dropFront=*/1),
1507           getI64SubArray(op.strides(), /* dropFront=*/1));
1508       res = insertOne(rewriter, loc, extracted, res, idx);
1509     }
1510     rewriter.replaceOp(op, res);
1511     return success();
1512   }
1513 };
1514 
1515 } // namespace
1516 
1517 /// Populate the given list with patterns that convert from Vector to LLVM.
1518 void mlir::populateVectorToLLVMConversionPatterns(
1519     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1520     bool reassociateFPReductions, bool enableIndexOptimizations) {
1521   MLIRContext *ctx = converter.getDialect()->getContext();
1522   // clang-format off
1523   patterns.insert<VectorFMAOpNDRewritePattern,
1524                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1525                   VectorInsertStridedSliceOpSameRankRewritePattern,
1526                   VectorExtractStridedSliceOpConversion>(ctx);
1527   patterns.insert<VectorReductionOpConversion>(
1528       converter, reassociateFPReductions);
1529   patterns.insert<VectorCreateMaskOpConversion,
1530                   VectorTransferConversion<TransferReadOp>,
1531                   VectorTransferConversion<TransferWriteOp>>(
1532       converter, enableIndexOptimizations);
1533   patterns
1534       .insert<VectorBitCastOpConversion,
1535               VectorShuffleOpConversion,
1536               VectorExtractElementOpConversion,
1537               VectorExtractOpConversion,
1538               VectorFMAOp1DConversion,
1539               VectorInsertElementOpConversion,
1540               VectorInsertOpConversion,
1541               VectorPrintOpConversion,
1542               VectorTypeCastOpConversion,
1543               VectorMaskedLoadOpConversion,
1544               VectorMaskedStoreOpConversion,
1545               VectorGatherOpConversion,
1546               VectorScatterOpConversion,
1547               VectorExpandLoadOpConversion,
1548               VectorCompressStoreOpConversion>(converter);
1549   // clang-format on
1550 }
1551 
1552 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1553     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1554   patterns.insert<VectorMatmulOpConversion>(converter);
1555   patterns.insert<VectorFlatTransposeOpConversion>(converter);
1556 }
1557