xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision e332c22cdf54d16974f445ba4345cd2a76d7fc6a)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/Target/LLVMIR/TypeTranslation.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 // Helper to reduce vector type by one rank at front.
25 static VectorType reducedVectorTypeFront(VectorType tp) {
26   assert((tp.getRank() > 1) && "unlowerable vector type");
27   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
28 }
29 
30 // Helper to reduce vector type by *all* but one rank at back.
31 static VectorType reducedVectorTypeBack(VectorType tp) {
32   assert((tp.getRank() > 1) && "unlowerable vector type");
33   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
34 }
35 
36 // Helper that picks the proper sequence for inserting.
37 static Value insertOne(ConversionPatternRewriter &rewriter,
38                        LLVMTypeConverter &typeConverter, Location loc,
39                        Value val1, Value val2, Type llvmType, int64_t rank,
40                        int64_t pos) {
41   if (rank == 1) {
42     auto idxType = rewriter.getIndexType();
43     auto constant = rewriter.create<LLVM::ConstantOp>(
44         loc, typeConverter.convertType(idxType),
45         rewriter.getIntegerAttr(idxType, pos));
46     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
47                                                   constant);
48   }
49   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
50                                               rewriter.getI64ArrayAttr(pos));
51 }
52 
53 // Helper that picks the proper sequence for inserting.
54 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
55                        Value into, int64_t offset) {
56   auto vectorType = into.getType().cast<VectorType>();
57   if (vectorType.getRank() > 1)
58     return rewriter.create<InsertOp>(loc, from, into, offset);
59   return rewriter.create<vector::InsertElementOp>(
60       loc, vectorType, from, into,
61       rewriter.create<ConstantIndexOp>(loc, offset));
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank == 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
77                                                rewriter.getI64ArrayAttr(pos));
78 }
79 
80 // Helper that picks the proper sequence for extracting.
81 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
82                         int64_t offset) {
83   auto vectorType = vector.getType().cast<VectorType>();
84   if (vectorType.getRank() > 1)
85     return rewriter.create<ExtractOp>(loc, vector, offset);
86   return rewriter.create<vector::ExtractElementOp>(
87       loc, vectorType.getElementType(), vector,
88       rewriter.create<ConstantIndexOp>(loc, offset));
89 }
90 
91 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
92 // TODO: Better support for attribute subtype forwarding + slicing.
93 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
94                                               unsigned dropFront = 0,
95                                               unsigned dropBack = 0) {
96   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
97   auto range = arrayAttr.getAsRange<IntegerAttr>();
98   SmallVector<int64_t, 4> res;
99   res.reserve(arrayAttr.size() - dropFront - dropBack);
100   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
101        it != eit; ++it)
102     res.push_back((*it).getValue().getSExtValue());
103   return res;
104 }
105 
106 static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
107                                    Location loc, Type targetType, Value value) {
108   if (targetType == value.getType())
109     return value;
110 
111   bool targetIsIndex = targetType.isIndex();
112   bool valueIsIndex = value.getType().isIndex();
113   if (targetIsIndex ^ valueIsIndex)
114     return rewriter.create<IndexCastOp>(loc, targetType, value);
115 
116   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
117   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
118   assert(targetIntegerType && valueIntegerType &&
119          "unexpected cast between types other than integers and index");
120   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
121 
122   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
123     return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
124   return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
125 }
126 
127 // Helper that returns a vector comparison that constructs a mask:
128 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
129 //
130 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
131 //       much more compact, IR for this operation, but LLVM eventually
132 //       generates more elaborate instructions for this intrinsic since it
133 //       is very conservative on the boundary conditions.
134 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
135                                    Operation *op, bool enableIndexOptimizations,
136                                    int64_t dim, Value b, Value *off = nullptr) {
137   auto loc = op->getLoc();
138   // If we can assume all indices fit in 32-bit, we perform the vector
139   // comparison in 32-bit to get a higher degree of SIMD parallelism.
140   // Otherwise we perform the vector comparison using 64-bit indices.
141   Value indices;
142   Type idxType;
143   if (enableIndexOptimizations) {
144     indices = rewriter.create<ConstantOp>(
145         loc, rewriter.getI32VectorAttr(
146                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
147     idxType = rewriter.getI32Type();
148   } else {
149     indices = rewriter.create<ConstantOp>(
150         loc, rewriter.getI64VectorAttr(
151                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
152     idxType = rewriter.getI64Type();
153   }
154   // Add in an offset if requested.
155   if (off) {
156     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
157     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
158     indices = rewriter.create<AddIOp>(loc, ov, indices);
159   }
160   // Construct the vector comparison.
161   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
162   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
163   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
164 }
165 
166 // Helper that returns data layout alignment of a memref.
167 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
168                                  MemRefType memrefType, unsigned &align) {
169   Type elementTy = typeConverter.convertType(memrefType.getElementType());
170   if (!elementTy)
171     return failure();
172 
173   // TODO: this should use the MLIR data layout when it becomes available and
174   // stop depending on translation.
175   llvm::LLVMContext llvmContext;
176   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
177               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
178   return success();
179 }
180 
181 // Helper that returns the base address of a memref.
182 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
183                              Value memref, MemRefType memRefType, Value &base) {
184   // Inspect stride and offset structure.
185   //
186   // TODO: flat memory only for now, generalize
187   //
188   int64_t offset;
189   SmallVector<int64_t, 4> strides;
190   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
191   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
192       offset != 0 || memRefType.getMemorySpace() != 0)
193     return failure();
194   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
195   return success();
196 }
197 
198 // Helper that returns vector of pointers given a memref base with index vector.
199 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
200                                     Location loc, Value memref, Value indices,
201                                     MemRefType memRefType, VectorType vType,
202                                     Type iType, Value &ptrs) {
203   Value base;
204   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
205     return failure();
206   auto pType = MemRefDescriptor(memref).getElementPtrType();
207   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
208   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
209   return success();
210 }
211 
212 // Casts a strided element pointer to a vector pointer. The vector pointer
213 // would always be on address space 0, therefore addrspacecast shall be
214 // used when source/dst memrefs are not on address space 0.
215 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
216                          Value ptr, MemRefType memRefType, Type vt) {
217   auto pType = LLVM::LLVMPointerType::get(vt);
218   if (memRefType.getMemorySpace() == 0)
219     return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
220   return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
221 }
222 
223 static LogicalResult
224 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
225                                  LLVMTypeConverter &typeConverter, Location loc,
226                                  TransferReadOp xferOp,
227                                  ArrayRef<Value> operands, Value dataPtr) {
228   unsigned align;
229   if (failed(getMemRefAlignment(
230           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
231     return failure();
232   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
233   return success();
234 }
235 
236 static LogicalResult
237 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
238                             LLVMTypeConverter &typeConverter, Location loc,
239                             TransferReadOp xferOp, ArrayRef<Value> operands,
240                             Value dataPtr, Value mask) {
241   VectorType fillType = xferOp.getVectorType();
242   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
243 
244   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
245   if (!vecTy)
246     return failure();
247 
248   unsigned align;
249   if (failed(getMemRefAlignment(
250           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
251     return failure();
252 
253   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
254       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
255       rewriter.getI32IntegerAttr(align));
256   return success();
257 }
258 
259 static LogicalResult
260 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
261                                  LLVMTypeConverter &typeConverter, Location loc,
262                                  TransferWriteOp xferOp,
263                                  ArrayRef<Value> operands, Value dataPtr) {
264   unsigned align;
265   if (failed(getMemRefAlignment(
266           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
267     return failure();
268   auto adaptor = TransferWriteOpAdaptor(operands);
269   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
270                                              align);
271   return success();
272 }
273 
274 static LogicalResult
275 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
276                             LLVMTypeConverter &typeConverter, Location loc,
277                             TransferWriteOp xferOp, ArrayRef<Value> operands,
278                             Value dataPtr, Value mask) {
279   unsigned align;
280   if (failed(getMemRefAlignment(
281           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
282     return failure();
283 
284   auto adaptor = TransferWriteOpAdaptor(operands);
285   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
286       xferOp, adaptor.vector(), dataPtr, mask,
287       rewriter.getI32IntegerAttr(align));
288   return success();
289 }
290 
291 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
292                                                   ArrayRef<Value> operands) {
293   return TransferReadOpAdaptor(operands);
294 }
295 
296 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
297                                                    ArrayRef<Value> operands) {
298   return TransferWriteOpAdaptor(operands);
299 }
300 
301 namespace {
302 
303 /// Conversion pattern for a vector.bitcast.
304 class VectorBitCastOpConversion
305     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
306 public:
307   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
308 
309   LogicalResult
310   matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
311                   ConversionPatternRewriter &rewriter) const override {
312     // Only 1-D vectors can be lowered to LLVM.
313     VectorType resultTy = bitCastOp.getType();
314     if (resultTy.getRank() != 1)
315       return failure();
316     Type newResultTy = typeConverter->convertType(resultTy);
317     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
318                                                  operands[0]);
319     return success();
320   }
321 };
322 
323 /// Conversion pattern for a vector.matrix_multiply.
324 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
325 class VectorMatmulOpConversion
326     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
327 public:
328   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
329 
330   LogicalResult
331   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
332                   ConversionPatternRewriter &rewriter) const override {
333     auto adaptor = vector::MatmulOpAdaptor(operands);
334     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
335         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
336         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
337         matmulOp.lhs_columns(), matmulOp.rhs_columns());
338     return success();
339   }
340 };
341 
342 /// Conversion pattern for a vector.flat_transpose.
343 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
344 class VectorFlatTransposeOpConversion
345     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
346 public:
347   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
348 
349   LogicalResult
350   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
351                   ConversionPatternRewriter &rewriter) const override {
352     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
353     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
354         transOp, typeConverter->convertType(transOp.res().getType()),
355         adaptor.matrix(), transOp.rows(), transOp.columns());
356     return success();
357   }
358 };
359 
360 /// Conversion pattern for a vector.maskedload.
361 class VectorMaskedLoadOpConversion
362     : public ConvertOpToLLVMPattern<vector::MaskedLoadOp> {
363 public:
364   using ConvertOpToLLVMPattern<vector::MaskedLoadOp>::ConvertOpToLLVMPattern;
365 
366   LogicalResult
367   matchAndRewrite(vector::MaskedLoadOp load, ArrayRef<Value> operands,
368                   ConversionPatternRewriter &rewriter) const override {
369     auto loc = load->getLoc();
370     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
371     MemRefType memRefType = load.getMemRefType();
372 
373     // Resolve alignment.
374     unsigned align;
375     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
376       return failure();
377 
378     // Resolve address.
379     auto vtype = typeConverter->convertType(load.getResultVectorType());
380     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
381                                                adaptor.indices(), rewriter);
382     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
383 
384     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
385         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
386         rewriter.getI32IntegerAttr(align));
387     return success();
388   }
389 };
390 
391 /// Conversion pattern for a vector.maskedstore.
392 class VectorMaskedStoreOpConversion
393     : public ConvertOpToLLVMPattern<vector::MaskedStoreOp> {
394 public:
395   using ConvertOpToLLVMPattern<vector::MaskedStoreOp>::ConvertOpToLLVMPattern;
396 
397   LogicalResult
398   matchAndRewrite(vector::MaskedStoreOp store, ArrayRef<Value> operands,
399                   ConversionPatternRewriter &rewriter) const override {
400     auto loc = store->getLoc();
401     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
402     MemRefType memRefType = store.getMemRefType();
403 
404     // Resolve alignment.
405     unsigned align;
406     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
407       return failure();
408 
409     // Resolve address.
410     auto vtype = typeConverter->convertType(store.getValueVectorType());
411     Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
412                                                adaptor.indices(), rewriter);
413     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
414 
415     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
416         store, adaptor.value(), ptr, adaptor.mask(),
417         rewriter.getI32IntegerAttr(align));
418     return success();
419   }
420 };
421 
422 /// Conversion pattern for a vector.gather.
423 class VectorGatherOpConversion
424     : public ConvertOpToLLVMPattern<vector::GatherOp> {
425 public:
426   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
427 
428   LogicalResult
429   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
430                   ConversionPatternRewriter &rewriter) const override {
431     auto loc = gather->getLoc();
432     auto adaptor = vector::GatherOpAdaptor(operands);
433 
434     // Resolve alignment.
435     unsigned align;
436     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
437                                   align)))
438       return failure();
439 
440     // Get index ptrs.
441     VectorType vType = gather.getResultVectorType();
442     Type iType = gather.getIndicesVectorType().getElementType();
443     Value ptrs;
444     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
445                               gather.getMemRefType(), vType, iType, ptrs)))
446       return failure();
447 
448     // Replace with the gather intrinsic.
449     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
450         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
451         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
452     return success();
453   }
454 };
455 
456 /// Conversion pattern for a vector.scatter.
457 class VectorScatterOpConversion
458     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
459 public:
460   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
461 
462   LogicalResult
463   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
464                   ConversionPatternRewriter &rewriter) const override {
465     auto loc = scatter->getLoc();
466     auto adaptor = vector::ScatterOpAdaptor(operands);
467 
468     // Resolve alignment.
469     unsigned align;
470     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
471                                   align)))
472       return failure();
473 
474     // Get index ptrs.
475     VectorType vType = scatter.getValueVectorType();
476     Type iType = scatter.getIndicesVectorType().getElementType();
477     Value ptrs;
478     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
479                               scatter.getMemRefType(), vType, iType, ptrs)))
480       return failure();
481 
482     // Replace with the scatter intrinsic.
483     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
484         scatter, adaptor.value(), ptrs, adaptor.mask(),
485         rewriter.getI32IntegerAttr(align));
486     return success();
487   }
488 };
489 
490 /// Conversion pattern for a vector.expandload.
491 class VectorExpandLoadOpConversion
492     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
493 public:
494   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
495 
496   LogicalResult
497   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
498                   ConversionPatternRewriter &rewriter) const override {
499     auto loc = expand->getLoc();
500     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
501     MemRefType memRefType = expand.getMemRefType();
502 
503     // Resolve address.
504     auto vtype = typeConverter->convertType(expand.getResultVectorType());
505     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
506                                            adaptor.indices(), rewriter);
507 
508     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
509         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
510     return success();
511   }
512 };
513 
514 /// Conversion pattern for a vector.compressstore.
515 class VectorCompressStoreOpConversion
516     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
517 public:
518   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
519 
520   LogicalResult
521   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
522                   ConversionPatternRewriter &rewriter) const override {
523     auto loc = compress->getLoc();
524     auto adaptor = vector::CompressStoreOpAdaptor(operands);
525     MemRefType memRefType = compress.getMemRefType();
526 
527     // Resolve address.
528     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
529                                            adaptor.indices(), rewriter);
530 
531     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
532         compress, adaptor.value(), ptr, adaptor.mask());
533     return success();
534   }
535 };
536 
537 /// Conversion pattern for all vector reductions.
538 class VectorReductionOpConversion
539     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
540 public:
541   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
542                                        bool reassociateFPRed)
543       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
544         reassociateFPReductions(reassociateFPRed) {}
545 
546   LogicalResult
547   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
548                   ConversionPatternRewriter &rewriter) const override {
549     auto kind = reductionOp.kind();
550     Type eltType = reductionOp.dest().getType();
551     Type llvmType = typeConverter->convertType(eltType);
552     if (eltType.isIntOrIndex()) {
553       // Integer reductions: add/mul/min/max/and/or/xor.
554       if (kind == "add")
555         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
556             reductionOp, llvmType, operands[0]);
557       else if (kind == "mul")
558         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
559             reductionOp, llvmType, operands[0]);
560       else if (kind == "min" &&
561                (eltType.isIndex() || eltType.isUnsignedInteger()))
562         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
563             reductionOp, llvmType, operands[0]);
564       else if (kind == "min")
565         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
566             reductionOp, llvmType, operands[0]);
567       else if (kind == "max" &&
568                (eltType.isIndex() || eltType.isUnsignedInteger()))
569         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
570             reductionOp, llvmType, operands[0]);
571       else if (kind == "max")
572         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
573             reductionOp, llvmType, operands[0]);
574       else if (kind == "and")
575         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
576             reductionOp, llvmType, operands[0]);
577       else if (kind == "or")
578         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
579             reductionOp, llvmType, operands[0]);
580       else if (kind == "xor")
581         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
582             reductionOp, llvmType, operands[0]);
583       else
584         return failure();
585       return success();
586     }
587 
588     if (!eltType.isa<FloatType>())
589       return failure();
590 
591     // Floating-point reductions: add/mul/min/max
592     if (kind == "add") {
593       // Optional accumulator (or zero).
594       Value acc = operands.size() > 1 ? operands[1]
595                                       : rewriter.create<LLVM::ConstantOp>(
596                                             reductionOp->getLoc(), llvmType,
597                                             rewriter.getZeroAttr(eltType));
598       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
599           reductionOp, llvmType, acc, operands[0],
600           rewriter.getBoolAttr(reassociateFPReductions));
601     } else if (kind == "mul") {
602       // Optional accumulator (or one).
603       Value acc = operands.size() > 1
604                       ? operands[1]
605                       : rewriter.create<LLVM::ConstantOp>(
606                             reductionOp->getLoc(), llvmType,
607                             rewriter.getFloatAttr(eltType, 1.0));
608       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
609           reductionOp, llvmType, acc, operands[0],
610           rewriter.getBoolAttr(reassociateFPReductions));
611     } else if (kind == "min")
612       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
613           reductionOp, llvmType, operands[0]);
614     else if (kind == "max")
615       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
616           reductionOp, llvmType, operands[0]);
617     else
618       return failure();
619     return success();
620   }
621 
622 private:
623   const bool reassociateFPReductions;
624 };
625 
626 /// Conversion pattern for a vector.create_mask (1-D only).
627 class VectorCreateMaskOpConversion
628     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
629 public:
630   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
631                                         bool enableIndexOpt)
632       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
633         enableIndexOptimizations(enableIndexOpt) {}
634 
635   LogicalResult
636   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
637                   ConversionPatternRewriter &rewriter) const override {
638     auto dstType = op.getType();
639     int64_t rank = dstType.getRank();
640     if (rank == 1) {
641       rewriter.replaceOp(
642           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
643                                     dstType.getDimSize(0), operands[0]));
644       return success();
645     }
646     return failure();
647   }
648 
649 private:
650   const bool enableIndexOptimizations;
651 };
652 
653 class VectorShuffleOpConversion
654     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
655 public:
656   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
657 
658   LogicalResult
659   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
660                   ConversionPatternRewriter &rewriter) const override {
661     auto loc = shuffleOp->getLoc();
662     auto adaptor = vector::ShuffleOpAdaptor(operands);
663     auto v1Type = shuffleOp.getV1VectorType();
664     auto v2Type = shuffleOp.getV2VectorType();
665     auto vectorType = shuffleOp.getVectorType();
666     Type llvmType = typeConverter->convertType(vectorType);
667     auto maskArrayAttr = shuffleOp.mask();
668 
669     // Bail if result type cannot be lowered.
670     if (!llvmType)
671       return failure();
672 
673     // Get rank and dimension sizes.
674     int64_t rank = vectorType.getRank();
675     assert(v1Type.getRank() == rank);
676     assert(v2Type.getRank() == rank);
677     int64_t v1Dim = v1Type.getDimSize(0);
678 
679     // For rank 1, where both operands have *exactly* the same vector type,
680     // there is direct shuffle support in LLVM. Use it!
681     if (rank == 1 && v1Type == v2Type) {
682       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
683           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
684       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
685       return success();
686     }
687 
688     // For all other cases, insert the individual values individually.
689     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
690     int64_t insPos = 0;
691     for (auto en : llvm::enumerate(maskArrayAttr)) {
692       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
693       Value value = adaptor.v1();
694       if (extPos >= v1Dim) {
695         extPos -= v1Dim;
696         value = adaptor.v2();
697       }
698       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
699                                  llvmType, rank, extPos);
700       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
701                          llvmType, rank, insPos++);
702     }
703     rewriter.replaceOp(shuffleOp, insert);
704     return success();
705   }
706 };
707 
708 class VectorExtractElementOpConversion
709     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
710 public:
711   using ConvertOpToLLVMPattern<
712       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
713 
714   LogicalResult
715   matchAndRewrite(vector::ExtractElementOp extractEltOp,
716                   ArrayRef<Value> operands,
717                   ConversionPatternRewriter &rewriter) const override {
718     auto adaptor = vector::ExtractElementOpAdaptor(operands);
719     auto vectorType = extractEltOp.getVectorType();
720     auto llvmType = typeConverter->convertType(vectorType.getElementType());
721 
722     // Bail if result type cannot be lowered.
723     if (!llvmType)
724       return failure();
725 
726     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
727         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
728     return success();
729   }
730 };
731 
732 class VectorExtractOpConversion
733     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
734 public:
735   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
736 
737   LogicalResult
738   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
739                   ConversionPatternRewriter &rewriter) const override {
740     auto loc = extractOp->getLoc();
741     auto adaptor = vector::ExtractOpAdaptor(operands);
742     auto vectorType = extractOp.getVectorType();
743     auto resultType = extractOp.getResult().getType();
744     auto llvmResultType = typeConverter->convertType(resultType);
745     auto positionArrayAttr = extractOp.position();
746 
747     // Bail if result type cannot be lowered.
748     if (!llvmResultType)
749       return failure();
750 
751     // One-shot extraction of vector from array (only requires extractvalue).
752     if (resultType.isa<VectorType>()) {
753       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
754           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
755       rewriter.replaceOp(extractOp, extracted);
756       return success();
757     }
758 
759     // Potential extraction of 1-D vector from array.
760     auto *context = extractOp->getContext();
761     Value extracted = adaptor.vector();
762     auto positionAttrs = positionArrayAttr.getValue();
763     if (positionAttrs.size() > 1) {
764       auto oneDVectorType = reducedVectorTypeBack(vectorType);
765       auto nMinusOnePositionAttrs =
766           ArrayAttr::get(context, positionAttrs.drop_back());
767       extracted = rewriter.create<LLVM::ExtractValueOp>(
768           loc, typeConverter->convertType(oneDVectorType), extracted,
769           nMinusOnePositionAttrs);
770     }
771 
772     // Remaining extraction of element from 1-D LLVM vector
773     auto position = positionAttrs.back().cast<IntegerAttr>();
774     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
775     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
776     extracted =
777         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
778     rewriter.replaceOp(extractOp, extracted);
779 
780     return success();
781   }
782 };
783 
784 /// Conversion pattern that turns a vector.fma on a 1-D vector
785 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
786 /// This does not match vectors of n >= 2 rank.
787 ///
788 /// Example:
789 /// ```
790 ///  vector.fma %a, %a, %a : vector<8xf32>
791 /// ```
792 /// is converted to:
793 /// ```
794 ///  llvm.intr.fmuladd %va, %va, %va:
795 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
796 ///    -> !llvm."<8 x f32>">
797 /// ```
798 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
799 public:
800   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
801 
802   LogicalResult
803   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
804                   ConversionPatternRewriter &rewriter) const override {
805     auto adaptor = vector::FMAOpAdaptor(operands);
806     VectorType vType = fmaOp.getVectorType();
807     if (vType.getRank() != 1)
808       return failure();
809     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
810                                                  adaptor.rhs(), adaptor.acc());
811     return success();
812   }
813 };
814 
815 class VectorInsertElementOpConversion
816     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
817 public:
818   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
819 
820   LogicalResult
821   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
822                   ConversionPatternRewriter &rewriter) const override {
823     auto adaptor = vector::InsertElementOpAdaptor(operands);
824     auto vectorType = insertEltOp.getDestVectorType();
825     auto llvmType = typeConverter->convertType(vectorType);
826 
827     // Bail if result type cannot be lowered.
828     if (!llvmType)
829       return failure();
830 
831     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
832         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
833         adaptor.position());
834     return success();
835   }
836 };
837 
838 class VectorInsertOpConversion
839     : public ConvertOpToLLVMPattern<vector::InsertOp> {
840 public:
841   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
842 
843   LogicalResult
844   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
845                   ConversionPatternRewriter &rewriter) const override {
846     auto loc = insertOp->getLoc();
847     auto adaptor = vector::InsertOpAdaptor(operands);
848     auto sourceType = insertOp.getSourceType();
849     auto destVectorType = insertOp.getDestVectorType();
850     auto llvmResultType = typeConverter->convertType(destVectorType);
851     auto positionArrayAttr = insertOp.position();
852 
853     // Bail if result type cannot be lowered.
854     if (!llvmResultType)
855       return failure();
856 
857     // One-shot insertion of a vector into an array (only requires insertvalue).
858     if (sourceType.isa<VectorType>()) {
859       Value inserted = rewriter.create<LLVM::InsertValueOp>(
860           loc, llvmResultType, adaptor.dest(), adaptor.source(),
861           positionArrayAttr);
862       rewriter.replaceOp(insertOp, inserted);
863       return success();
864     }
865 
866     // Potential extraction of 1-D vector from array.
867     auto *context = insertOp->getContext();
868     Value extracted = adaptor.dest();
869     auto positionAttrs = positionArrayAttr.getValue();
870     auto position = positionAttrs.back().cast<IntegerAttr>();
871     auto oneDVectorType = destVectorType;
872     if (positionAttrs.size() > 1) {
873       oneDVectorType = reducedVectorTypeBack(destVectorType);
874       auto nMinusOnePositionAttrs =
875           ArrayAttr::get(context, positionAttrs.drop_back());
876       extracted = rewriter.create<LLVM::ExtractValueOp>(
877           loc, typeConverter->convertType(oneDVectorType), extracted,
878           nMinusOnePositionAttrs);
879     }
880 
881     // Insertion of an element into a 1-D LLVM vector.
882     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
883     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
884     Value inserted = rewriter.create<LLVM::InsertElementOp>(
885         loc, typeConverter->convertType(oneDVectorType), extracted,
886         adaptor.source(), constant);
887 
888     // Potential insertion of resulting 1-D vector into array.
889     if (positionAttrs.size() > 1) {
890       auto nMinusOnePositionAttrs =
891           ArrayAttr::get(context, positionAttrs.drop_back());
892       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
893                                                       adaptor.dest(), inserted,
894                                                       nMinusOnePositionAttrs);
895     }
896 
897     rewriter.replaceOp(insertOp, inserted);
898     return success();
899   }
900 };
901 
902 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
903 ///
904 /// Example:
905 /// ```
906 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
907 /// ```
908 /// is rewritten into:
909 /// ```
910 ///  %r = splat %f0: vector<2x4xf32>
911 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
912 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
913 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
914 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
915 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
916 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
917 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
918 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
919 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
920 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
921 ///  // %r3 holds the final value.
922 /// ```
923 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
924 public:
925   using OpRewritePattern<FMAOp>::OpRewritePattern;
926 
927   LogicalResult matchAndRewrite(FMAOp op,
928                                 PatternRewriter &rewriter) const override {
929     auto vType = op.getVectorType();
930     if (vType.getRank() < 2)
931       return failure();
932 
933     auto loc = op.getLoc();
934     auto elemType = vType.getElementType();
935     Value zero = rewriter.create<ConstantOp>(loc, elemType,
936                                              rewriter.getZeroAttr(elemType));
937     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
938     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
939       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
940       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
941       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
942       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
943       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
944     }
945     rewriter.replaceOp(op, desc);
946     return success();
947   }
948 };
949 
950 // When ranks are different, InsertStridedSlice needs to extract a properly
951 // ranked vector from the destination vector into which to insert. This pattern
952 // only takes care of this part and forwards the rest of the conversion to
953 // another pattern that converts InsertStridedSlice for operands of the same
954 // rank.
955 //
956 // RewritePattern for InsertStridedSliceOp where source and destination vectors
957 // have different ranks. In this case:
958 //   1. the proper subvector is extracted from the destination vector
959 //   2. a new InsertStridedSlice op is created to insert the source in the
960 //   destination subvector
961 //   3. the destination subvector is inserted back in the proper place
962 //   4. the op is replaced by the result of step 3.
963 // The new InsertStridedSlice from step 2. will be picked up by a
964 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
965 class VectorInsertStridedSliceOpDifferentRankRewritePattern
966     : public OpRewritePattern<InsertStridedSliceOp> {
967 public:
968   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
969 
970   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
971                                 PatternRewriter &rewriter) const override {
972     auto srcType = op.getSourceVectorType();
973     auto dstType = op.getDestVectorType();
974 
975     if (op.offsets().getValue().empty())
976       return failure();
977 
978     auto loc = op.getLoc();
979     int64_t rankDiff = dstType.getRank() - srcType.getRank();
980     assert(rankDiff >= 0);
981     if (rankDiff == 0)
982       return failure();
983 
984     int64_t rankRest = dstType.getRank() - rankDiff;
985     // Extract / insert the subvector of matching rank and InsertStridedSlice
986     // on it.
987     Value extracted =
988         rewriter.create<ExtractOp>(loc, op.dest(),
989                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
990                                                   /*dropBack=*/rankRest));
991     // A different pattern will kick in for InsertStridedSlice with matching
992     // ranks.
993     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
994         loc, op.source(), extracted,
995         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
996         getI64SubArray(op.strides(), /*dropFront=*/0));
997     rewriter.replaceOpWithNewOp<InsertOp>(
998         op, stridedSliceInnerOp.getResult(), op.dest(),
999         getI64SubArray(op.offsets(), /*dropFront=*/0,
1000                        /*dropBack=*/rankRest));
1001     return success();
1002   }
1003 };
1004 
1005 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1006 // have the same rank. In this case, we reduce
1007 //   1. the proper subvector is extracted from the destination vector
1008 //   2. a new InsertStridedSlice op is created to insert the source in the
1009 //   destination subvector
1010 //   3. the destination subvector is inserted back in the proper place
1011 //   4. the op is replaced by the result of step 3.
1012 // The new InsertStridedSlice from step 2. will be picked up by a
1013 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1014 class VectorInsertStridedSliceOpSameRankRewritePattern
1015     : public OpRewritePattern<InsertStridedSliceOp> {
1016 public:
1017   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1018       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1019     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1020     // bounded as the rank is strictly decreasing.
1021     setHasBoundedRewriteRecursion();
1022   }
1023 
1024   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1025                                 PatternRewriter &rewriter) const override {
1026     auto srcType = op.getSourceVectorType();
1027     auto dstType = op.getDestVectorType();
1028 
1029     if (op.offsets().getValue().empty())
1030       return failure();
1031 
1032     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1033     assert(rankDiff >= 0);
1034     if (rankDiff != 0)
1035       return failure();
1036 
1037     if (srcType == dstType) {
1038       rewriter.replaceOp(op, op.source());
1039       return success();
1040     }
1041 
1042     int64_t offset =
1043         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1044     int64_t size = srcType.getShape().front();
1045     int64_t stride =
1046         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1047 
1048     auto loc = op.getLoc();
1049     Value res = op.dest();
1050     // For each slice of the source vector along the most major dimension.
1051     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1052          off += stride, ++idx) {
1053       // 1. extract the proper subvector (or element) from source
1054       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1055       if (extractedSource.getType().isa<VectorType>()) {
1056         // 2. If we have a vector, extract the proper subvector from destination
1057         // Otherwise we are at the element level and no need to recurse.
1058         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1059         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1060         // smaller rank.
1061         extractedSource = rewriter.create<InsertStridedSliceOp>(
1062             loc, extractedSource, extractedDest,
1063             getI64SubArray(op.offsets(), /* dropFront=*/1),
1064             getI64SubArray(op.strides(), /* dropFront=*/1));
1065       }
1066       // 4. Insert the extractedSource into the res vector.
1067       res = insertOne(rewriter, loc, extractedSource, res, off);
1068     }
1069 
1070     rewriter.replaceOp(op, res);
1071     return success();
1072   }
1073 };
1074 
1075 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1076 /// static layout.
1077 static llvm::Optional<SmallVector<int64_t, 4>>
1078 computeContiguousStrides(MemRefType memRefType) {
1079   int64_t offset;
1080   SmallVector<int64_t, 4> strides;
1081   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1082     return None;
1083   if (!strides.empty() && strides.back() != 1)
1084     return None;
1085   // If no layout or identity layout, this is contiguous by definition.
1086   if (memRefType.getAffineMaps().empty() ||
1087       memRefType.getAffineMaps().front().isIdentity())
1088     return strides;
1089 
1090   // Otherwise, we must determine contiguity form shapes. This can only ever
1091   // work in static cases because MemRefType is underspecified to represent
1092   // contiguous dynamic shapes in other ways than with just empty/identity
1093   // layout.
1094   auto sizes = memRefType.getShape();
1095   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1096     if (ShapedType::isDynamic(sizes[index + 1]) ||
1097         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1098         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1099       return None;
1100     if (strides[index] != strides[index + 1] * sizes[index + 1])
1101       return None;
1102   }
1103   return strides;
1104 }
1105 
1106 class VectorTypeCastOpConversion
1107     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1108 public:
1109   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1110 
1111   LogicalResult
1112   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
1113                   ConversionPatternRewriter &rewriter) const override {
1114     auto loc = castOp->getLoc();
1115     MemRefType sourceMemRefType =
1116         castOp.getOperand().getType().cast<MemRefType>();
1117     MemRefType targetMemRefType = castOp.getType();
1118 
1119     // Only static shape casts supported atm.
1120     if (!sourceMemRefType.hasStaticShape() ||
1121         !targetMemRefType.hasStaticShape())
1122       return failure();
1123 
1124     auto llvmSourceDescriptorTy =
1125         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
1126     if (!llvmSourceDescriptorTy)
1127       return failure();
1128     MemRefDescriptor sourceMemRef(operands[0]);
1129 
1130     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1131                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1132     if (!llvmTargetDescriptorTy)
1133       return failure();
1134 
1135     // Only contiguous source buffers supported atm.
1136     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1137     if (!sourceStrides)
1138       return failure();
1139     auto targetStrides = computeContiguousStrides(targetMemRefType);
1140     if (!targetStrides)
1141       return failure();
1142     // Only support static strides for now, regardless of contiguity.
1143     if (llvm::any_of(*targetStrides, [](int64_t stride) {
1144           return ShapedType::isDynamicStrideOrOffset(stride);
1145         }))
1146       return failure();
1147 
1148     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1149 
1150     // Create descriptor.
1151     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1152     Type llvmTargetElementTy = desc.getElementPtrType();
1153     // Set allocated ptr.
1154     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1155     allocated =
1156         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1157     desc.setAllocatedPtr(rewriter, loc, allocated);
1158     // Set aligned ptr.
1159     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1160     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1161     desc.setAlignedPtr(rewriter, loc, ptr);
1162     // Fill offset 0.
1163     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1164     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1165     desc.setOffset(rewriter, loc, zero);
1166 
1167     // Fill size and stride descriptors in memref.
1168     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1169       int64_t index = indexedSize.index();
1170       auto sizeAttr =
1171           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1172       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1173       desc.setSize(rewriter, loc, index, size);
1174       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1175                                                 (*targetStrides)[index]);
1176       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1177       desc.setStride(rewriter, loc, index, stride);
1178     }
1179 
1180     rewriter.replaceOp(castOp, {desc});
1181     return success();
1182   }
1183 };
1184 
1185 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1186 /// sequence of:
1187 /// 1. Get the source/dst address as an LLVM vector pointer.
1188 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1189 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1190 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1191 /// 5. Rewrite op as a masked read or write.
1192 template <typename ConcreteOp>
1193 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
1194 public:
1195   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1196                                     bool enableIndexOpt)
1197       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1198         enableIndexOptimizations(enableIndexOpt) {}
1199 
1200   LogicalResult
1201   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
1202                   ConversionPatternRewriter &rewriter) const override {
1203     auto adaptor = getTransferOpAdapter(xferOp, operands);
1204 
1205     if (xferOp.getVectorType().getRank() > 1 ||
1206         llvm::size(xferOp.indices()) == 0)
1207       return failure();
1208     if (xferOp.permutation_map() !=
1209         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1210                                        xferOp.getVectorType().getRank(),
1211                                        xferOp->getContext()))
1212       return failure();
1213     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1214     if (!memRefType)
1215       return failure();
1216     // Only contiguous source tensors supported atm.
1217     auto strides = computeContiguousStrides(memRefType);
1218     if (!strides)
1219       return failure();
1220 
1221     auto toLLVMTy = [&](Type t) {
1222       return this->getTypeConverter()->convertType(t);
1223     };
1224 
1225     Location loc = xferOp->getLoc();
1226 
1227     if (auto memrefVectorElementType =
1228             memRefType.getElementType().template dyn_cast<VectorType>()) {
1229       // Memref has vector element type.
1230       if (memrefVectorElementType.getElementType() !=
1231           xferOp.getVectorType().getElementType())
1232         return failure();
1233 #ifndef NDEBUG
1234       // Check that memref vector type is a suffix of 'vectorType.
1235       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1236       unsigned resultVecRank = xferOp.getVectorType().getRank();
1237       assert(memrefVecEltRank <= resultVecRank);
1238       // TODO: Move this to isSuffix in Vector/Utils.h.
1239       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1240       auto memrefVecEltShape = memrefVectorElementType.getShape();
1241       auto resultVecShape = xferOp.getVectorType().getShape();
1242       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1243         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1244                "memref vector element shape should match suffix of vector "
1245                "result shape.");
1246 #endif // ifndef NDEBUG
1247     }
1248 
1249     // 1. Get the source/dst address as an LLVM vector pointer.
1250     VectorType vtp = xferOp.getVectorType();
1251     Value dataPtr = this->getStridedElementPtr(
1252         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1253     Value vectorDataPtr =
1254         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
1255 
1256     if (!xferOp.isMaskedDim(0))
1257       return replaceTransferOpWithLoadOrStore(rewriter,
1258                                               *this->getTypeConverter(), loc,
1259                                               xferOp, operands, vectorDataPtr);
1260 
1261     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1262     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1263     // 4. Let dim the memref dimension, compute the vector comparison mask:
1264     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1265     //
1266     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1267     //       dimensions here.
1268     unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
1269     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1270     Value off = xferOp.indices()[lastIndex];
1271     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1272     Value mask = buildVectorComparison(
1273         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
1274 
1275     // 5. Rewrite as a masked read / write.
1276     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1277                                        xferOp, operands, vectorDataPtr, mask);
1278   }
1279 
1280 private:
1281   const bool enableIndexOptimizations;
1282 };
1283 
1284 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1285 public:
1286   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1287 
1288   // Proof-of-concept lowering implementation that relies on a small
1289   // runtime support library, which only needs to provide a few
1290   // printing methods (single value for all data types, opening/closing
1291   // bracket, comma, newline). The lowering fully unrolls a vector
1292   // in terms of these elementary printing operations. The advantage
1293   // of this approach is that the library can remain unaware of all
1294   // low-level implementation details of vectors while still supporting
1295   // output of any shaped and dimensioned vector. Due to full unrolling,
1296   // this approach is less suited for very large vectors though.
1297   //
1298   // TODO: rely solely on libc in future? something else?
1299   //
1300   LogicalResult
1301   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1302                   ConversionPatternRewriter &rewriter) const override {
1303     auto adaptor = vector::PrintOpAdaptor(operands);
1304     Type printType = printOp.getPrintType();
1305 
1306     if (typeConverter->convertType(printType) == nullptr)
1307       return failure();
1308 
1309     // Make sure element type has runtime support.
1310     PrintConversion conversion = PrintConversion::None;
1311     VectorType vectorType = printType.dyn_cast<VectorType>();
1312     Type eltType = vectorType ? vectorType.getElementType() : printType;
1313     Operation *printer;
1314     if (eltType.isF32()) {
1315       printer =
1316           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1317     } else if (eltType.isF64()) {
1318       printer =
1319           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1320     } else if (eltType.isIndex()) {
1321       printer =
1322           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1323     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1324       // Integers need a zero or sign extension on the operand
1325       // (depending on the source type) as well as a signed or
1326       // unsigned print method. Up to 64-bit is supported.
1327       unsigned width = intTy.getWidth();
1328       if (intTy.isUnsigned()) {
1329         if (width <= 64) {
1330           if (width < 64)
1331             conversion = PrintConversion::ZeroExt64;
1332           printer = LLVM::lookupOrCreatePrintU64Fn(
1333               printOp->getParentOfType<ModuleOp>());
1334         } else {
1335           return failure();
1336         }
1337       } else {
1338         assert(intTy.isSignless() || intTy.isSigned());
1339         if (width <= 64) {
1340           // Note that we *always* zero extend booleans (1-bit integers),
1341           // so that true/false is printed as 1/0 rather than -1/0.
1342           if (width == 1)
1343             conversion = PrintConversion::ZeroExt64;
1344           else if (width < 64)
1345             conversion = PrintConversion::SignExt64;
1346           printer = LLVM::lookupOrCreatePrintI64Fn(
1347               printOp->getParentOfType<ModuleOp>());
1348         } else {
1349           return failure();
1350         }
1351       }
1352     } else {
1353       return failure();
1354     }
1355 
1356     // Unroll vector into elementary print calls.
1357     int64_t rank = vectorType ? vectorType.getRank() : 0;
1358     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1359               conversion);
1360     emitCall(rewriter, printOp->getLoc(),
1361              LLVM::lookupOrCreatePrintNewlineFn(
1362                  printOp->getParentOfType<ModuleOp>()));
1363     rewriter.eraseOp(printOp);
1364     return success();
1365   }
1366 
1367 private:
1368   enum class PrintConversion {
1369     // clang-format off
1370     None,
1371     ZeroExt64,
1372     SignExt64
1373     // clang-format on
1374   };
1375 
1376   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1377                  Value value, VectorType vectorType, Operation *printer,
1378                  int64_t rank, PrintConversion conversion) const {
1379     Location loc = op->getLoc();
1380     if (rank == 0) {
1381       switch (conversion) {
1382       case PrintConversion::ZeroExt64:
1383         value = rewriter.create<ZeroExtendIOp>(
1384             loc, value, IntegerType::get(rewriter.getContext(), 64));
1385         break;
1386       case PrintConversion::SignExt64:
1387         value = rewriter.create<SignExtendIOp>(
1388             loc, value, IntegerType::get(rewriter.getContext(), 64));
1389         break;
1390       case PrintConversion::None:
1391         break;
1392       }
1393       emitCall(rewriter, loc, printer, value);
1394       return;
1395     }
1396 
1397     emitCall(rewriter, loc,
1398              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1399     Operation *printComma =
1400         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1401     int64_t dim = vectorType.getDimSize(0);
1402     for (int64_t d = 0; d < dim; ++d) {
1403       auto reducedType =
1404           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1405       auto llvmType = typeConverter->convertType(
1406           rank > 1 ? reducedType : vectorType.getElementType());
1407       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1408                                    llvmType, rank, d);
1409       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1410                 conversion);
1411       if (d != dim - 1)
1412         emitCall(rewriter, loc, printComma);
1413     }
1414     emitCall(rewriter, loc,
1415              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1416   }
1417 
1418   // Helper to emit a call.
1419   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1420                        Operation *ref, ValueRange params = ValueRange()) {
1421     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1422                                   rewriter.getSymbolRefAttr(ref), params);
1423   }
1424 };
1425 
1426 /// Progressive lowering of ExtractStridedSliceOp to either:
1427 ///   1. express single offset extract as a direct shuffle.
1428 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1429 class VectorExtractStridedSliceOpConversion
1430     : public OpRewritePattern<ExtractStridedSliceOp> {
1431 public:
1432   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1433       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1434     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1435     // is bounded as the rank is strictly decreasing.
1436     setHasBoundedRewriteRecursion();
1437   }
1438 
1439   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1440                                 PatternRewriter &rewriter) const override {
1441     auto dstType = op.getType();
1442 
1443     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1444 
1445     int64_t offset =
1446         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1447     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1448     int64_t stride =
1449         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1450 
1451     auto loc = op.getLoc();
1452     auto elemType = dstType.getElementType();
1453     assert(elemType.isSignlessIntOrIndexOrFloat());
1454 
1455     // Single offset can be more efficiently shuffled.
1456     if (op.offsets().getValue().size() == 1) {
1457       SmallVector<int64_t, 4> offsets;
1458       offsets.reserve(size);
1459       for (int64_t off = offset, e = offset + size * stride; off < e;
1460            off += stride)
1461         offsets.push_back(off);
1462       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1463                                              op.vector(),
1464                                              rewriter.getI64ArrayAttr(offsets));
1465       return success();
1466     }
1467 
1468     // Extract/insert on a lower ranked extract strided slice op.
1469     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1470                                              rewriter.getZeroAttr(elemType));
1471     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1472     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1473          off += stride, ++idx) {
1474       Value one = extractOne(rewriter, loc, op.vector(), off);
1475       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1476           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1477           getI64SubArray(op.sizes(), /* dropFront=*/1),
1478           getI64SubArray(op.strides(), /* dropFront=*/1));
1479       res = insertOne(rewriter, loc, extracted, res, idx);
1480     }
1481     rewriter.replaceOp(op, res);
1482     return success();
1483   }
1484 };
1485 
1486 } // namespace
1487 
1488 /// Populate the given list with patterns that convert from Vector to LLVM.
1489 void mlir::populateVectorToLLVMConversionPatterns(
1490     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1491     bool reassociateFPReductions, bool enableIndexOptimizations) {
1492   MLIRContext *ctx = converter.getDialect()->getContext();
1493   // clang-format off
1494   patterns.insert<VectorFMAOpNDRewritePattern,
1495                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1496                   VectorInsertStridedSliceOpSameRankRewritePattern,
1497                   VectorExtractStridedSliceOpConversion>(ctx);
1498   patterns.insert<VectorReductionOpConversion>(
1499       converter, reassociateFPReductions);
1500   patterns.insert<VectorCreateMaskOpConversion,
1501                   VectorTransferConversion<TransferReadOp>,
1502                   VectorTransferConversion<TransferWriteOp>>(
1503       converter, enableIndexOptimizations);
1504   patterns
1505       .insert<VectorBitCastOpConversion,
1506               VectorShuffleOpConversion,
1507               VectorExtractElementOpConversion,
1508               VectorExtractOpConversion,
1509               VectorFMAOp1DConversion,
1510               VectorInsertElementOpConversion,
1511               VectorInsertOpConversion,
1512               VectorPrintOpConversion,
1513               VectorTypeCastOpConversion,
1514               VectorMaskedLoadOpConversion,
1515               VectorMaskedStoreOpConversion,
1516               VectorGatherOpConversion,
1517               VectorScatterOpConversion,
1518               VectorExpandLoadOpConversion,
1519               VectorCompressStoreOpConversion>(converter);
1520   // clang-format on
1521 }
1522 
1523 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1524     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1525   patterns.insert<VectorMatmulOpConversion>(converter);
1526   patterns.insert<VectorFlatTransposeOpConversion>(converter);
1527 }
1528