xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision df5ccf5a94ad04231e8426dd0d02689a0717453b)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/Target/LLVMIR/TypeTranslation.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 // Helper to reduce vector type by one rank at front.
25 static VectorType reducedVectorTypeFront(VectorType tp) {
26   assert((tp.getRank() > 1) && "unlowerable vector type");
27   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
28 }
29 
30 // Helper to reduce vector type by *all* but one rank at back.
31 static VectorType reducedVectorTypeBack(VectorType tp) {
32   assert((tp.getRank() > 1) && "unlowerable vector type");
33   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
34 }
35 
36 // Helper that picks the proper sequence for inserting.
37 static Value insertOne(ConversionPatternRewriter &rewriter,
38                        LLVMTypeConverter &typeConverter, Location loc,
39                        Value val1, Value val2, Type llvmType, int64_t rank,
40                        int64_t pos) {
41   if (rank == 1) {
42     auto idxType = rewriter.getIndexType();
43     auto constant = rewriter.create<LLVM::ConstantOp>(
44         loc, typeConverter.convertType(idxType),
45         rewriter.getIntegerAttr(idxType, pos));
46     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
47                                                   constant);
48   }
49   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
50                                               rewriter.getI64ArrayAttr(pos));
51 }
52 
53 // Helper that picks the proper sequence for inserting.
54 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
55                        Value into, int64_t offset) {
56   auto vectorType = into.getType().cast<VectorType>();
57   if (vectorType.getRank() > 1)
58     return rewriter.create<InsertOp>(loc, from, into, offset);
59   return rewriter.create<vector::InsertElementOp>(
60       loc, vectorType, from, into,
61       rewriter.create<ConstantIndexOp>(loc, offset));
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank == 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
77                                                rewriter.getI64ArrayAttr(pos));
78 }
79 
80 // Helper that picks the proper sequence for extracting.
81 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
82                         int64_t offset) {
83   auto vectorType = vector.getType().cast<VectorType>();
84   if (vectorType.getRank() > 1)
85     return rewriter.create<ExtractOp>(loc, vector, offset);
86   return rewriter.create<vector::ExtractElementOp>(
87       loc, vectorType.getElementType(), vector,
88       rewriter.create<ConstantIndexOp>(loc, offset));
89 }
90 
91 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
92 // TODO: Better support for attribute subtype forwarding + slicing.
93 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
94                                               unsigned dropFront = 0,
95                                               unsigned dropBack = 0) {
96   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
97   auto range = arrayAttr.getAsRange<IntegerAttr>();
98   SmallVector<int64_t, 4> res;
99   res.reserve(arrayAttr.size() - dropFront - dropBack);
100   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
101        it != eit; ++it)
102     res.push_back((*it).getValue().getSExtValue());
103   return res;
104 }
105 
106 static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
107                                    Location loc, Type targetType, Value value) {
108   if (targetType == value.getType())
109     return value;
110 
111   bool targetIsIndex = targetType.isIndex();
112   bool valueIsIndex = value.getType().isIndex();
113   if (targetIsIndex ^ valueIsIndex)
114     return rewriter.create<IndexCastOp>(loc, targetType, value);
115 
116   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
117   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
118   assert(targetIntegerType && valueIntegerType &&
119          "unexpected cast between types other than integers and index");
120   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
121 
122   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
123     return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
124   return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
125 }
126 
127 // Helper that returns a vector comparison that constructs a mask:
128 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
129 //
130 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
131 //       much more compact, IR for this operation, but LLVM eventually
132 //       generates more elaborate instructions for this intrinsic since it
133 //       is very conservative on the boundary conditions.
134 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
135                                    Operation *op, bool enableIndexOptimizations,
136                                    int64_t dim, Value b, Value *off = nullptr) {
137   auto loc = op->getLoc();
138   // If we can assume all indices fit in 32-bit, we perform the vector
139   // comparison in 32-bit to get a higher degree of SIMD parallelism.
140   // Otherwise we perform the vector comparison using 64-bit indices.
141   Value indices;
142   Type idxType;
143   if (enableIndexOptimizations) {
144     indices = rewriter.create<ConstantOp>(
145         loc, rewriter.getI32VectorAttr(
146                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
147     idxType = rewriter.getI32Type();
148   } else {
149     indices = rewriter.create<ConstantOp>(
150         loc, rewriter.getI64VectorAttr(
151                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
152     idxType = rewriter.getI64Type();
153   }
154   // Add in an offset if requested.
155   if (off) {
156     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
157     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
158     indices = rewriter.create<AddIOp>(loc, ov, indices);
159   }
160   // Construct the vector comparison.
161   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
162   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
163   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
164 }
165 
166 // Helper that returns data layout alignment of a memref.
167 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
168                                  MemRefType memrefType, unsigned &align) {
169   Type elementTy = typeConverter.convertType(memrefType.getElementType());
170   if (!elementTy)
171     return failure();
172 
173   // TODO: this should use the MLIR data layout when it becomes available and
174   // stop depending on translation.
175   llvm::LLVMContext llvmContext;
176   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
177               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
178   return success();
179 }
180 
181 // Add an index vector component to a base pointer. This almost always succeeds
182 // unless the last stride is non-unit or the memory space is not zero.
183 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
184                                     Location loc, Value memref, Value base,
185                                     Value index, MemRefType memRefType,
186                                     VectorType vType, Value &ptrs) {
187   int64_t offset;
188   SmallVector<int64_t, 4> strides;
189   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
190   if (failed(successStrides) || strides.back() != 1 ||
191       memRefType.getMemorySpace() != 0)
192     return failure();
193   auto pType = MemRefDescriptor(memref).getElementPtrType();
194   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
195   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
196   return success();
197 }
198 
199 // Casts a strided element pointer to a vector pointer.  The vector pointer
200 // will be in the same address space as the incoming memref type.
201 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
202                          Value ptr, MemRefType memRefType, Type vt) {
203   auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpace());
204   return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
205 }
206 
207 static LogicalResult
208 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
209                                  LLVMTypeConverter &typeConverter, Location loc,
210                                  TransferReadOp xferOp,
211                                  ArrayRef<Value> operands, Value dataPtr) {
212   unsigned align;
213   if (failed(getMemRefAlignment(
214           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
215     return failure();
216   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
217   return success();
218 }
219 
220 static LogicalResult
221 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
222                             LLVMTypeConverter &typeConverter, Location loc,
223                             TransferReadOp xferOp, ArrayRef<Value> operands,
224                             Value dataPtr, Value mask) {
225   VectorType fillType = xferOp.getVectorType();
226   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
227 
228   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
229   if (!vecTy)
230     return failure();
231 
232   unsigned align;
233   if (failed(getMemRefAlignment(
234           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
235     return failure();
236 
237   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
238       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
239       rewriter.getI32IntegerAttr(align));
240   return success();
241 }
242 
243 static LogicalResult
244 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
245                                  LLVMTypeConverter &typeConverter, Location loc,
246                                  TransferWriteOp xferOp,
247                                  ArrayRef<Value> operands, Value dataPtr) {
248   unsigned align;
249   if (failed(getMemRefAlignment(
250           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
251     return failure();
252   auto adaptor = TransferWriteOpAdaptor(operands);
253   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
254                                              align);
255   return success();
256 }
257 
258 static LogicalResult
259 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
260                             LLVMTypeConverter &typeConverter, Location loc,
261                             TransferWriteOp xferOp, ArrayRef<Value> operands,
262                             Value dataPtr, Value mask) {
263   unsigned align;
264   if (failed(getMemRefAlignment(
265           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
266     return failure();
267 
268   auto adaptor = TransferWriteOpAdaptor(operands);
269   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
270       xferOp, adaptor.vector(), dataPtr, mask,
271       rewriter.getI32IntegerAttr(align));
272   return success();
273 }
274 
275 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
276                                                   ArrayRef<Value> operands) {
277   return TransferReadOpAdaptor(operands);
278 }
279 
280 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
281                                                    ArrayRef<Value> operands) {
282   return TransferWriteOpAdaptor(operands);
283 }
284 
285 namespace {
286 
287 /// Conversion pattern for a vector.bitcast.
288 class VectorBitCastOpConversion
289     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
290 public:
291   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
292 
293   LogicalResult
294   matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
295                   ConversionPatternRewriter &rewriter) const override {
296     // Only 1-D vectors can be lowered to LLVM.
297     VectorType resultTy = bitCastOp.getType();
298     if (resultTy.getRank() != 1)
299       return failure();
300     Type newResultTy = typeConverter->convertType(resultTy);
301     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
302                                                  operands[0]);
303     return success();
304   }
305 };
306 
307 /// Conversion pattern for a vector.matrix_multiply.
308 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
309 class VectorMatmulOpConversion
310     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
311 public:
312   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
313 
314   LogicalResult
315   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
316                   ConversionPatternRewriter &rewriter) const override {
317     auto adaptor = vector::MatmulOpAdaptor(operands);
318     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
319         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
320         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
321         matmulOp.lhs_columns(), matmulOp.rhs_columns());
322     return success();
323   }
324 };
325 
326 /// Conversion pattern for a vector.flat_transpose.
327 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
328 class VectorFlatTransposeOpConversion
329     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
330 public:
331   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
332 
333   LogicalResult
334   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
335                   ConversionPatternRewriter &rewriter) const override {
336     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
337     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
338         transOp, typeConverter->convertType(transOp.res().getType()),
339         adaptor.matrix(), transOp.rows(), transOp.columns());
340     return success();
341   }
342 };
343 
344 /// Overloaded utility that replaces a vector.load, vector.store,
345 /// vector.maskedload and vector.maskedstore with their respective LLVM
346 /// couterparts.
347 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
348                                  vector::LoadOpAdaptor adaptor,
349                                  VectorType vectorTy, Value ptr, unsigned align,
350                                  ConversionPatternRewriter &rewriter) {
351   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
352 }
353 
354 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
355                                  vector::MaskedLoadOpAdaptor adaptor,
356                                  VectorType vectorTy, Value ptr, unsigned align,
357                                  ConversionPatternRewriter &rewriter) {
358   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
359       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
360 }
361 
362 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
363                                  vector::StoreOpAdaptor adaptor,
364                                  VectorType vectorTy, Value ptr, unsigned align,
365                                  ConversionPatternRewriter &rewriter) {
366   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
367                                              ptr, align);
368 }
369 
370 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
371                                  vector::MaskedStoreOpAdaptor adaptor,
372                                  VectorType vectorTy, Value ptr, unsigned align,
373                                  ConversionPatternRewriter &rewriter) {
374   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
375       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
376 }
377 
378 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
379 /// vector.maskedstore.
380 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
381 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
382 public:
383   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
384 
385   LogicalResult
386   matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
387                   ConversionPatternRewriter &rewriter) const override {
388     // Only 1-D vectors can be lowered to LLVM.
389     VectorType vectorTy = loadOrStoreOp.getVectorType();
390     if (vectorTy.getRank() > 1)
391       return failure();
392 
393     auto loc = loadOrStoreOp->getLoc();
394     auto adaptor = LoadOrStoreOpAdaptor(operands);
395     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
396 
397     // Resolve alignment.
398     unsigned align;
399     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
400       return failure();
401 
402     // Resolve address.
403     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
404                      .template cast<VectorType>();
405     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
406                                                adaptor.indices(), rewriter);
407     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
408 
409     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
410     return success();
411   }
412 };
413 
414 /// Conversion pattern for a vector.gather.
415 class VectorGatherOpConversion
416     : public ConvertOpToLLVMPattern<vector::GatherOp> {
417 public:
418   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
419 
420   LogicalResult
421   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
422                   ConversionPatternRewriter &rewriter) const override {
423     auto loc = gather->getLoc();
424     auto adaptor = vector::GatherOpAdaptor(operands);
425     MemRefType memRefType = gather.getMemRefType();
426 
427     // Resolve alignment.
428     unsigned align;
429     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
430       return failure();
431 
432     // Resolve address.
433     Value ptrs;
434     VectorType vType = gather.getVectorType();
435     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
436                                      adaptor.indices(), rewriter);
437     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
438                               adaptor.index_vec(), memRefType, vType, ptrs)))
439       return failure();
440 
441     // Replace with the gather intrinsic.
442     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
443         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
444         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
445     return success();
446   }
447 };
448 
449 /// Conversion pattern for a vector.scatter.
450 class VectorScatterOpConversion
451     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
452 public:
453   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
454 
455   LogicalResult
456   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
457                   ConversionPatternRewriter &rewriter) const override {
458     auto loc = scatter->getLoc();
459     auto adaptor = vector::ScatterOpAdaptor(operands);
460     MemRefType memRefType = scatter.getMemRefType();
461 
462     // Resolve alignment.
463     unsigned align;
464     if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
465       return failure();
466 
467     // Resolve address.
468     Value ptrs;
469     VectorType vType = scatter.getVectorType();
470     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
471                                      adaptor.indices(), rewriter);
472     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
473                               adaptor.index_vec(), memRefType, vType, ptrs)))
474       return failure();
475 
476     // Replace with the scatter intrinsic.
477     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
478         scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
479         rewriter.getI32IntegerAttr(align));
480     return success();
481   }
482 };
483 
484 /// Conversion pattern for a vector.expandload.
485 class VectorExpandLoadOpConversion
486     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
487 public:
488   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
489 
490   LogicalResult
491   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
492                   ConversionPatternRewriter &rewriter) const override {
493     auto loc = expand->getLoc();
494     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
495     MemRefType memRefType = expand.getMemRefType();
496 
497     // Resolve address.
498     auto vtype = typeConverter->convertType(expand.getVectorType());
499     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
500                                      adaptor.indices(), rewriter);
501 
502     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
503         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
504     return success();
505   }
506 };
507 
508 /// Conversion pattern for a vector.compressstore.
509 class VectorCompressStoreOpConversion
510     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
511 public:
512   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
513 
514   LogicalResult
515   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
516                   ConversionPatternRewriter &rewriter) const override {
517     auto loc = compress->getLoc();
518     auto adaptor = vector::CompressStoreOpAdaptor(operands);
519     MemRefType memRefType = compress.getMemRefType();
520 
521     // Resolve address.
522     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
523                                      adaptor.indices(), rewriter);
524 
525     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
526         compress, adaptor.valueToStore(), ptr, adaptor.mask());
527     return success();
528   }
529 };
530 
531 /// Conversion pattern for all vector reductions.
532 class VectorReductionOpConversion
533     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
534 public:
535   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
536                                        bool reassociateFPRed)
537       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
538         reassociateFPReductions(reassociateFPRed) {}
539 
540   LogicalResult
541   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
542                   ConversionPatternRewriter &rewriter) const override {
543     auto kind = reductionOp.kind();
544     Type eltType = reductionOp.dest().getType();
545     Type llvmType = typeConverter->convertType(eltType);
546     if (eltType.isIntOrIndex()) {
547       // Integer reductions: add/mul/min/max/and/or/xor.
548       if (kind == "add")
549         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
550             reductionOp, llvmType, operands[0]);
551       else if (kind == "mul")
552         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
553             reductionOp, llvmType, operands[0]);
554       else if (kind == "min" &&
555                (eltType.isIndex() || eltType.isUnsignedInteger()))
556         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
557             reductionOp, llvmType, operands[0]);
558       else if (kind == "min")
559         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
560             reductionOp, llvmType, operands[0]);
561       else if (kind == "max" &&
562                (eltType.isIndex() || eltType.isUnsignedInteger()))
563         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
564             reductionOp, llvmType, operands[0]);
565       else if (kind == "max")
566         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
567             reductionOp, llvmType, operands[0]);
568       else if (kind == "and")
569         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
570             reductionOp, llvmType, operands[0]);
571       else if (kind == "or")
572         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
573             reductionOp, llvmType, operands[0]);
574       else if (kind == "xor")
575         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
576             reductionOp, llvmType, operands[0]);
577       else
578         return failure();
579       return success();
580     }
581 
582     if (!eltType.isa<FloatType>())
583       return failure();
584 
585     // Floating-point reductions: add/mul/min/max
586     if (kind == "add") {
587       // Optional accumulator (or zero).
588       Value acc = operands.size() > 1 ? operands[1]
589                                       : rewriter.create<LLVM::ConstantOp>(
590                                             reductionOp->getLoc(), llvmType,
591                                             rewriter.getZeroAttr(eltType));
592       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
593           reductionOp, llvmType, acc, operands[0],
594           rewriter.getBoolAttr(reassociateFPReductions));
595     } else if (kind == "mul") {
596       // Optional accumulator (or one).
597       Value acc = operands.size() > 1
598                       ? operands[1]
599                       : rewriter.create<LLVM::ConstantOp>(
600                             reductionOp->getLoc(), llvmType,
601                             rewriter.getFloatAttr(eltType, 1.0));
602       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
603           reductionOp, llvmType, acc, operands[0],
604           rewriter.getBoolAttr(reassociateFPReductions));
605     } else if (kind == "min")
606       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
607           reductionOp, llvmType, operands[0]);
608     else if (kind == "max")
609       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
610           reductionOp, llvmType, operands[0]);
611     else
612       return failure();
613     return success();
614   }
615 
616 private:
617   const bool reassociateFPReductions;
618 };
619 
620 /// Conversion pattern for a vector.create_mask (1-D only).
621 class VectorCreateMaskOpConversion
622     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
623 public:
624   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
625                                         bool enableIndexOpt)
626       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
627         enableIndexOptimizations(enableIndexOpt) {}
628 
629   LogicalResult
630   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
631                   ConversionPatternRewriter &rewriter) const override {
632     auto dstType = op.getType();
633     int64_t rank = dstType.getRank();
634     if (rank == 1) {
635       rewriter.replaceOp(
636           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
637                                     dstType.getDimSize(0), operands[0]));
638       return success();
639     }
640     return failure();
641   }
642 
643 private:
644   const bool enableIndexOptimizations;
645 };
646 
647 class VectorShuffleOpConversion
648     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
649 public:
650   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
651 
652   LogicalResult
653   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
654                   ConversionPatternRewriter &rewriter) const override {
655     auto loc = shuffleOp->getLoc();
656     auto adaptor = vector::ShuffleOpAdaptor(operands);
657     auto v1Type = shuffleOp.getV1VectorType();
658     auto v2Type = shuffleOp.getV2VectorType();
659     auto vectorType = shuffleOp.getVectorType();
660     Type llvmType = typeConverter->convertType(vectorType);
661     auto maskArrayAttr = shuffleOp.mask();
662 
663     // Bail if result type cannot be lowered.
664     if (!llvmType)
665       return failure();
666 
667     // Get rank and dimension sizes.
668     int64_t rank = vectorType.getRank();
669     assert(v1Type.getRank() == rank);
670     assert(v2Type.getRank() == rank);
671     int64_t v1Dim = v1Type.getDimSize(0);
672 
673     // For rank 1, where both operands have *exactly* the same vector type,
674     // there is direct shuffle support in LLVM. Use it!
675     if (rank == 1 && v1Type == v2Type) {
676       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
677           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
678       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
679       return success();
680     }
681 
682     // For all other cases, insert the individual values individually.
683     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
684     int64_t insPos = 0;
685     for (auto en : llvm::enumerate(maskArrayAttr)) {
686       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
687       Value value = adaptor.v1();
688       if (extPos >= v1Dim) {
689         extPos -= v1Dim;
690         value = adaptor.v2();
691       }
692       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
693                                  llvmType, rank, extPos);
694       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
695                          llvmType, rank, insPos++);
696     }
697     rewriter.replaceOp(shuffleOp, insert);
698     return success();
699   }
700 };
701 
702 class VectorExtractElementOpConversion
703     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
704 public:
705   using ConvertOpToLLVMPattern<
706       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
707 
708   LogicalResult
709   matchAndRewrite(vector::ExtractElementOp extractEltOp,
710                   ArrayRef<Value> operands,
711                   ConversionPatternRewriter &rewriter) const override {
712     auto adaptor = vector::ExtractElementOpAdaptor(operands);
713     auto vectorType = extractEltOp.getVectorType();
714     auto llvmType = typeConverter->convertType(vectorType.getElementType());
715 
716     // Bail if result type cannot be lowered.
717     if (!llvmType)
718       return failure();
719 
720     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
721         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
722     return success();
723   }
724 };
725 
726 class VectorExtractOpConversion
727     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
728 public:
729   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
730 
731   LogicalResult
732   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
733                   ConversionPatternRewriter &rewriter) const override {
734     auto loc = extractOp->getLoc();
735     auto adaptor = vector::ExtractOpAdaptor(operands);
736     auto vectorType = extractOp.getVectorType();
737     auto resultType = extractOp.getResult().getType();
738     auto llvmResultType = typeConverter->convertType(resultType);
739     auto positionArrayAttr = extractOp.position();
740 
741     // Bail if result type cannot be lowered.
742     if (!llvmResultType)
743       return failure();
744 
745     // One-shot extraction of vector from array (only requires extractvalue).
746     if (resultType.isa<VectorType>()) {
747       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
748           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
749       rewriter.replaceOp(extractOp, extracted);
750       return success();
751     }
752 
753     // Potential extraction of 1-D vector from array.
754     auto *context = extractOp->getContext();
755     Value extracted = adaptor.vector();
756     auto positionAttrs = positionArrayAttr.getValue();
757     if (positionAttrs.size() > 1) {
758       auto oneDVectorType = reducedVectorTypeBack(vectorType);
759       auto nMinusOnePositionAttrs =
760           ArrayAttr::get(context, positionAttrs.drop_back());
761       extracted = rewriter.create<LLVM::ExtractValueOp>(
762           loc, typeConverter->convertType(oneDVectorType), extracted,
763           nMinusOnePositionAttrs);
764     }
765 
766     // Remaining extraction of element from 1-D LLVM vector
767     auto position = positionAttrs.back().cast<IntegerAttr>();
768     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
769     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
770     extracted =
771         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
772     rewriter.replaceOp(extractOp, extracted);
773 
774     return success();
775   }
776 };
777 
778 /// Conversion pattern that turns a vector.fma on a 1-D vector
779 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
780 /// This does not match vectors of n >= 2 rank.
781 ///
782 /// Example:
783 /// ```
784 ///  vector.fma %a, %a, %a : vector<8xf32>
785 /// ```
786 /// is converted to:
787 /// ```
788 ///  llvm.intr.fmuladd %va, %va, %va:
789 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
790 ///    -> !llvm."<8 x f32>">
791 /// ```
792 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
793 public:
794   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
795 
796   LogicalResult
797   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
798                   ConversionPatternRewriter &rewriter) const override {
799     auto adaptor = vector::FMAOpAdaptor(operands);
800     VectorType vType = fmaOp.getVectorType();
801     if (vType.getRank() != 1)
802       return failure();
803     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
804                                                  adaptor.rhs(), adaptor.acc());
805     return success();
806   }
807 };
808 
809 class VectorInsertElementOpConversion
810     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
811 public:
812   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
813 
814   LogicalResult
815   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
816                   ConversionPatternRewriter &rewriter) const override {
817     auto adaptor = vector::InsertElementOpAdaptor(operands);
818     auto vectorType = insertEltOp.getDestVectorType();
819     auto llvmType = typeConverter->convertType(vectorType);
820 
821     // Bail if result type cannot be lowered.
822     if (!llvmType)
823       return failure();
824 
825     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
826         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
827         adaptor.position());
828     return success();
829   }
830 };
831 
832 class VectorInsertOpConversion
833     : public ConvertOpToLLVMPattern<vector::InsertOp> {
834 public:
835   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
836 
837   LogicalResult
838   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
839                   ConversionPatternRewriter &rewriter) const override {
840     auto loc = insertOp->getLoc();
841     auto adaptor = vector::InsertOpAdaptor(operands);
842     auto sourceType = insertOp.getSourceType();
843     auto destVectorType = insertOp.getDestVectorType();
844     auto llvmResultType = typeConverter->convertType(destVectorType);
845     auto positionArrayAttr = insertOp.position();
846 
847     // Bail if result type cannot be lowered.
848     if (!llvmResultType)
849       return failure();
850 
851     // One-shot insertion of a vector into an array (only requires insertvalue).
852     if (sourceType.isa<VectorType>()) {
853       Value inserted = rewriter.create<LLVM::InsertValueOp>(
854           loc, llvmResultType, adaptor.dest(), adaptor.source(),
855           positionArrayAttr);
856       rewriter.replaceOp(insertOp, inserted);
857       return success();
858     }
859 
860     // Potential extraction of 1-D vector from array.
861     auto *context = insertOp->getContext();
862     Value extracted = adaptor.dest();
863     auto positionAttrs = positionArrayAttr.getValue();
864     auto position = positionAttrs.back().cast<IntegerAttr>();
865     auto oneDVectorType = destVectorType;
866     if (positionAttrs.size() > 1) {
867       oneDVectorType = reducedVectorTypeBack(destVectorType);
868       auto nMinusOnePositionAttrs =
869           ArrayAttr::get(context, positionAttrs.drop_back());
870       extracted = rewriter.create<LLVM::ExtractValueOp>(
871           loc, typeConverter->convertType(oneDVectorType), extracted,
872           nMinusOnePositionAttrs);
873     }
874 
875     // Insertion of an element into a 1-D LLVM vector.
876     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
877     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
878     Value inserted = rewriter.create<LLVM::InsertElementOp>(
879         loc, typeConverter->convertType(oneDVectorType), extracted,
880         adaptor.source(), constant);
881 
882     // Potential insertion of resulting 1-D vector into array.
883     if (positionAttrs.size() > 1) {
884       auto nMinusOnePositionAttrs =
885           ArrayAttr::get(context, positionAttrs.drop_back());
886       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
887                                                       adaptor.dest(), inserted,
888                                                       nMinusOnePositionAttrs);
889     }
890 
891     rewriter.replaceOp(insertOp, inserted);
892     return success();
893   }
894 };
895 
896 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
897 ///
898 /// Example:
899 /// ```
900 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
901 /// ```
902 /// is rewritten into:
903 /// ```
904 ///  %r = splat %f0: vector<2x4xf32>
905 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
906 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
907 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
908 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
909 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
910 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
911 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
912 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
913 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
914 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
915 ///  // %r3 holds the final value.
916 /// ```
917 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
918 public:
919   using OpRewritePattern<FMAOp>::OpRewritePattern;
920 
921   LogicalResult matchAndRewrite(FMAOp op,
922                                 PatternRewriter &rewriter) const override {
923     auto vType = op.getVectorType();
924     if (vType.getRank() < 2)
925       return failure();
926 
927     auto loc = op.getLoc();
928     auto elemType = vType.getElementType();
929     Value zero = rewriter.create<ConstantOp>(loc, elemType,
930                                              rewriter.getZeroAttr(elemType));
931     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
932     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
933       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
934       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
935       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
936       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
937       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
938     }
939     rewriter.replaceOp(op, desc);
940     return success();
941   }
942 };
943 
944 // When ranks are different, InsertStridedSlice needs to extract a properly
945 // ranked vector from the destination vector into which to insert. This pattern
946 // only takes care of this part and forwards the rest of the conversion to
947 // another pattern that converts InsertStridedSlice for operands of the same
948 // rank.
949 //
950 // RewritePattern for InsertStridedSliceOp where source and destination vectors
951 // have different ranks. In this case:
952 //   1. the proper subvector is extracted from the destination vector
953 //   2. a new InsertStridedSlice op is created to insert the source in the
954 //   destination subvector
955 //   3. the destination subvector is inserted back in the proper place
956 //   4. the op is replaced by the result of step 3.
957 // The new InsertStridedSlice from step 2. will be picked up by a
958 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
959 class VectorInsertStridedSliceOpDifferentRankRewritePattern
960     : public OpRewritePattern<InsertStridedSliceOp> {
961 public:
962   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
963 
964   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
965                                 PatternRewriter &rewriter) const override {
966     auto srcType = op.getSourceVectorType();
967     auto dstType = op.getDestVectorType();
968 
969     if (op.offsets().getValue().empty())
970       return failure();
971 
972     auto loc = op.getLoc();
973     int64_t rankDiff = dstType.getRank() - srcType.getRank();
974     assert(rankDiff >= 0);
975     if (rankDiff == 0)
976       return failure();
977 
978     int64_t rankRest = dstType.getRank() - rankDiff;
979     // Extract / insert the subvector of matching rank and InsertStridedSlice
980     // on it.
981     Value extracted =
982         rewriter.create<ExtractOp>(loc, op.dest(),
983                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
984                                                   /*dropBack=*/rankRest));
985     // A different pattern will kick in for InsertStridedSlice with matching
986     // ranks.
987     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
988         loc, op.source(), extracted,
989         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
990         getI64SubArray(op.strides(), /*dropFront=*/0));
991     rewriter.replaceOpWithNewOp<InsertOp>(
992         op, stridedSliceInnerOp.getResult(), op.dest(),
993         getI64SubArray(op.offsets(), /*dropFront=*/0,
994                        /*dropBack=*/rankRest));
995     return success();
996   }
997 };
998 
999 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1000 // have the same rank. In this case, we reduce
1001 //   1. the proper subvector is extracted from the destination vector
1002 //   2. a new InsertStridedSlice op is created to insert the source in the
1003 //   destination subvector
1004 //   3. the destination subvector is inserted back in the proper place
1005 //   4. the op is replaced by the result of step 3.
1006 // The new InsertStridedSlice from step 2. will be picked up by a
1007 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1008 class VectorInsertStridedSliceOpSameRankRewritePattern
1009     : public OpRewritePattern<InsertStridedSliceOp> {
1010 public:
1011   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1012       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1013     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1014     // bounded as the rank is strictly decreasing.
1015     setHasBoundedRewriteRecursion();
1016   }
1017 
1018   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1019                                 PatternRewriter &rewriter) const override {
1020     auto srcType = op.getSourceVectorType();
1021     auto dstType = op.getDestVectorType();
1022 
1023     if (op.offsets().getValue().empty())
1024       return failure();
1025 
1026     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1027     assert(rankDiff >= 0);
1028     if (rankDiff != 0)
1029       return failure();
1030 
1031     if (srcType == dstType) {
1032       rewriter.replaceOp(op, op.source());
1033       return success();
1034     }
1035 
1036     int64_t offset =
1037         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1038     int64_t size = srcType.getShape().front();
1039     int64_t stride =
1040         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1041 
1042     auto loc = op.getLoc();
1043     Value res = op.dest();
1044     // For each slice of the source vector along the most major dimension.
1045     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1046          off += stride, ++idx) {
1047       // 1. extract the proper subvector (or element) from source
1048       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1049       if (extractedSource.getType().isa<VectorType>()) {
1050         // 2. If we have a vector, extract the proper subvector from destination
1051         // Otherwise we are at the element level and no need to recurse.
1052         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1053         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1054         // smaller rank.
1055         extractedSource = rewriter.create<InsertStridedSliceOp>(
1056             loc, extractedSource, extractedDest,
1057             getI64SubArray(op.offsets(), /* dropFront=*/1),
1058             getI64SubArray(op.strides(), /* dropFront=*/1));
1059       }
1060       // 4. Insert the extractedSource into the res vector.
1061       res = insertOne(rewriter, loc, extractedSource, res, off);
1062     }
1063 
1064     rewriter.replaceOp(op, res);
1065     return success();
1066   }
1067 };
1068 
1069 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1070 /// static layout.
1071 static llvm::Optional<SmallVector<int64_t, 4>>
1072 computeContiguousStrides(MemRefType memRefType) {
1073   int64_t offset;
1074   SmallVector<int64_t, 4> strides;
1075   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1076     return None;
1077   if (!strides.empty() && strides.back() != 1)
1078     return None;
1079   // If no layout or identity layout, this is contiguous by definition.
1080   if (memRefType.getAffineMaps().empty() ||
1081       memRefType.getAffineMaps().front().isIdentity())
1082     return strides;
1083 
1084   // Otherwise, we must determine contiguity form shapes. This can only ever
1085   // work in static cases because MemRefType is underspecified to represent
1086   // contiguous dynamic shapes in other ways than with just empty/identity
1087   // layout.
1088   auto sizes = memRefType.getShape();
1089   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1090     if (ShapedType::isDynamic(sizes[index + 1]) ||
1091         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1092         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1093       return None;
1094     if (strides[index] != strides[index + 1] * sizes[index + 1])
1095       return None;
1096   }
1097   return strides;
1098 }
1099 
1100 class VectorTypeCastOpConversion
1101     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1102 public:
1103   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1104 
1105   LogicalResult
1106   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
1107                   ConversionPatternRewriter &rewriter) const override {
1108     auto loc = castOp->getLoc();
1109     MemRefType sourceMemRefType =
1110         castOp.getOperand().getType().cast<MemRefType>();
1111     MemRefType targetMemRefType = castOp.getType();
1112 
1113     // Only static shape casts supported atm.
1114     if (!sourceMemRefType.hasStaticShape() ||
1115         !targetMemRefType.hasStaticShape())
1116       return failure();
1117 
1118     auto llvmSourceDescriptorTy =
1119         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
1120     if (!llvmSourceDescriptorTy)
1121       return failure();
1122     MemRefDescriptor sourceMemRef(operands[0]);
1123 
1124     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1125                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1126     if (!llvmTargetDescriptorTy)
1127       return failure();
1128 
1129     // Only contiguous source buffers supported atm.
1130     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1131     if (!sourceStrides)
1132       return failure();
1133     auto targetStrides = computeContiguousStrides(targetMemRefType);
1134     if (!targetStrides)
1135       return failure();
1136     // Only support static strides for now, regardless of contiguity.
1137     if (llvm::any_of(*targetStrides, [](int64_t stride) {
1138           return ShapedType::isDynamicStrideOrOffset(stride);
1139         }))
1140       return failure();
1141 
1142     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1143 
1144     // Create descriptor.
1145     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1146     Type llvmTargetElementTy = desc.getElementPtrType();
1147     // Set allocated ptr.
1148     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1149     allocated =
1150         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1151     desc.setAllocatedPtr(rewriter, loc, allocated);
1152     // Set aligned ptr.
1153     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1154     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1155     desc.setAlignedPtr(rewriter, loc, ptr);
1156     // Fill offset 0.
1157     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1158     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1159     desc.setOffset(rewriter, loc, zero);
1160 
1161     // Fill size and stride descriptors in memref.
1162     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1163       int64_t index = indexedSize.index();
1164       auto sizeAttr =
1165           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1166       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1167       desc.setSize(rewriter, loc, index, size);
1168       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1169                                                 (*targetStrides)[index]);
1170       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1171       desc.setStride(rewriter, loc, index, stride);
1172     }
1173 
1174     rewriter.replaceOp(castOp, {desc});
1175     return success();
1176   }
1177 };
1178 
1179 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1180 /// sequence of:
1181 /// 1. Get the source/dst address as an LLVM vector pointer.
1182 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1183 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1184 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1185 /// 5. Rewrite op as a masked read or write.
1186 template <typename ConcreteOp>
1187 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
1188 public:
1189   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1190                                     bool enableIndexOpt)
1191       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1192         enableIndexOptimizations(enableIndexOpt) {}
1193 
1194   LogicalResult
1195   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
1196                   ConversionPatternRewriter &rewriter) const override {
1197     auto adaptor = getTransferOpAdapter(xferOp, operands);
1198 
1199     if (xferOp.getVectorType().getRank() > 1 ||
1200         llvm::size(xferOp.indices()) == 0)
1201       return failure();
1202     if (xferOp.permutation_map() !=
1203         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1204                                        xferOp.getVectorType().getRank(),
1205                                        xferOp->getContext()))
1206       return failure();
1207     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1208     if (!memRefType)
1209       return failure();
1210     // Only contiguous source tensors supported atm.
1211     auto strides = computeContiguousStrides(memRefType);
1212     if (!strides)
1213       return failure();
1214 
1215     auto toLLVMTy = [&](Type t) {
1216       return this->getTypeConverter()->convertType(t);
1217     };
1218 
1219     Location loc = xferOp->getLoc();
1220 
1221     if (auto memrefVectorElementType =
1222             memRefType.getElementType().template dyn_cast<VectorType>()) {
1223       // Memref has vector element type.
1224       if (memrefVectorElementType.getElementType() !=
1225           xferOp.getVectorType().getElementType())
1226         return failure();
1227 #ifndef NDEBUG
1228       // Check that memref vector type is a suffix of 'vectorType.
1229       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1230       unsigned resultVecRank = xferOp.getVectorType().getRank();
1231       assert(memrefVecEltRank <= resultVecRank);
1232       // TODO: Move this to isSuffix in Vector/Utils.h.
1233       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1234       auto memrefVecEltShape = memrefVectorElementType.getShape();
1235       auto resultVecShape = xferOp.getVectorType().getShape();
1236       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1237         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1238                "memref vector element shape should match suffix of vector "
1239                "result shape.");
1240 #endif // ifndef NDEBUG
1241     }
1242 
1243     // 1. Get the source/dst address as an LLVM vector pointer.
1244     VectorType vtp = xferOp.getVectorType();
1245     Value dataPtr = this->getStridedElementPtr(
1246         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1247     Value vectorDataPtr =
1248         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
1249 
1250     if (!xferOp.isMaskedDim(0))
1251       return replaceTransferOpWithLoadOrStore(rewriter,
1252                                               *this->getTypeConverter(), loc,
1253                                               xferOp, operands, vectorDataPtr);
1254 
1255     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1256     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1257     // 4. Let dim the memref dimension, compute the vector comparison mask:
1258     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1259     //
1260     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1261     //       dimensions here.
1262     unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
1263     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1264     Value off = xferOp.indices()[lastIndex];
1265     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1266     Value mask = buildVectorComparison(
1267         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
1268 
1269     // 5. Rewrite as a masked read / write.
1270     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1271                                        xferOp, operands, vectorDataPtr, mask);
1272   }
1273 
1274 private:
1275   const bool enableIndexOptimizations;
1276 };
1277 
1278 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1279 public:
1280   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1281 
1282   // Proof-of-concept lowering implementation that relies on a small
1283   // runtime support library, which only needs to provide a few
1284   // printing methods (single value for all data types, opening/closing
1285   // bracket, comma, newline). The lowering fully unrolls a vector
1286   // in terms of these elementary printing operations. The advantage
1287   // of this approach is that the library can remain unaware of all
1288   // low-level implementation details of vectors while still supporting
1289   // output of any shaped and dimensioned vector. Due to full unrolling,
1290   // this approach is less suited for very large vectors though.
1291   //
1292   // TODO: rely solely on libc in future? something else?
1293   //
1294   LogicalResult
1295   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1296                   ConversionPatternRewriter &rewriter) const override {
1297     auto adaptor = vector::PrintOpAdaptor(operands);
1298     Type printType = printOp.getPrintType();
1299 
1300     if (typeConverter->convertType(printType) == nullptr)
1301       return failure();
1302 
1303     // Make sure element type has runtime support.
1304     PrintConversion conversion = PrintConversion::None;
1305     VectorType vectorType = printType.dyn_cast<VectorType>();
1306     Type eltType = vectorType ? vectorType.getElementType() : printType;
1307     Operation *printer;
1308     if (eltType.isF32()) {
1309       printer =
1310           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1311     } else if (eltType.isF64()) {
1312       printer =
1313           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1314     } else if (eltType.isIndex()) {
1315       printer =
1316           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1317     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1318       // Integers need a zero or sign extension on the operand
1319       // (depending on the source type) as well as a signed or
1320       // unsigned print method. Up to 64-bit is supported.
1321       unsigned width = intTy.getWidth();
1322       if (intTy.isUnsigned()) {
1323         if (width <= 64) {
1324           if (width < 64)
1325             conversion = PrintConversion::ZeroExt64;
1326           printer = LLVM::lookupOrCreatePrintU64Fn(
1327               printOp->getParentOfType<ModuleOp>());
1328         } else {
1329           return failure();
1330         }
1331       } else {
1332         assert(intTy.isSignless() || intTy.isSigned());
1333         if (width <= 64) {
1334           // Note that we *always* zero extend booleans (1-bit integers),
1335           // so that true/false is printed as 1/0 rather than -1/0.
1336           if (width == 1)
1337             conversion = PrintConversion::ZeroExt64;
1338           else if (width < 64)
1339             conversion = PrintConversion::SignExt64;
1340           printer = LLVM::lookupOrCreatePrintI64Fn(
1341               printOp->getParentOfType<ModuleOp>());
1342         } else {
1343           return failure();
1344         }
1345       }
1346     } else {
1347       return failure();
1348     }
1349 
1350     // Unroll vector into elementary print calls.
1351     int64_t rank = vectorType ? vectorType.getRank() : 0;
1352     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1353               conversion);
1354     emitCall(rewriter, printOp->getLoc(),
1355              LLVM::lookupOrCreatePrintNewlineFn(
1356                  printOp->getParentOfType<ModuleOp>()));
1357     rewriter.eraseOp(printOp);
1358     return success();
1359   }
1360 
1361 private:
1362   enum class PrintConversion {
1363     // clang-format off
1364     None,
1365     ZeroExt64,
1366     SignExt64
1367     // clang-format on
1368   };
1369 
1370   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1371                  Value value, VectorType vectorType, Operation *printer,
1372                  int64_t rank, PrintConversion conversion) const {
1373     Location loc = op->getLoc();
1374     if (rank == 0) {
1375       switch (conversion) {
1376       case PrintConversion::ZeroExt64:
1377         value = rewriter.create<ZeroExtendIOp>(
1378             loc, value, IntegerType::get(rewriter.getContext(), 64));
1379         break;
1380       case PrintConversion::SignExt64:
1381         value = rewriter.create<SignExtendIOp>(
1382             loc, value, IntegerType::get(rewriter.getContext(), 64));
1383         break;
1384       case PrintConversion::None:
1385         break;
1386       }
1387       emitCall(rewriter, loc, printer, value);
1388       return;
1389     }
1390 
1391     emitCall(rewriter, loc,
1392              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1393     Operation *printComma =
1394         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1395     int64_t dim = vectorType.getDimSize(0);
1396     for (int64_t d = 0; d < dim; ++d) {
1397       auto reducedType =
1398           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1399       auto llvmType = typeConverter->convertType(
1400           rank > 1 ? reducedType : vectorType.getElementType());
1401       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1402                                    llvmType, rank, d);
1403       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1404                 conversion);
1405       if (d != dim - 1)
1406         emitCall(rewriter, loc, printComma);
1407     }
1408     emitCall(rewriter, loc,
1409              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1410   }
1411 
1412   // Helper to emit a call.
1413   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1414                        Operation *ref, ValueRange params = ValueRange()) {
1415     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1416                                   rewriter.getSymbolRefAttr(ref), params);
1417   }
1418 };
1419 
1420 /// Progressive lowering of ExtractStridedSliceOp to either:
1421 ///   1. express single offset extract as a direct shuffle.
1422 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1423 class VectorExtractStridedSliceOpConversion
1424     : public OpRewritePattern<ExtractStridedSliceOp> {
1425 public:
1426   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1427       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1428     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1429     // is bounded as the rank is strictly decreasing.
1430     setHasBoundedRewriteRecursion();
1431   }
1432 
1433   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1434                                 PatternRewriter &rewriter) const override {
1435     auto dstType = op.getType();
1436 
1437     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1438 
1439     int64_t offset =
1440         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1441     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1442     int64_t stride =
1443         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1444 
1445     auto loc = op.getLoc();
1446     auto elemType = dstType.getElementType();
1447     assert(elemType.isSignlessIntOrIndexOrFloat());
1448 
1449     // Single offset can be more efficiently shuffled.
1450     if (op.offsets().getValue().size() == 1) {
1451       SmallVector<int64_t, 4> offsets;
1452       offsets.reserve(size);
1453       for (int64_t off = offset, e = offset + size * stride; off < e;
1454            off += stride)
1455         offsets.push_back(off);
1456       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1457                                              op.vector(),
1458                                              rewriter.getI64ArrayAttr(offsets));
1459       return success();
1460     }
1461 
1462     // Extract/insert on a lower ranked extract strided slice op.
1463     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1464                                              rewriter.getZeroAttr(elemType));
1465     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1466     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1467          off += stride, ++idx) {
1468       Value one = extractOne(rewriter, loc, op.vector(), off);
1469       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1470           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1471           getI64SubArray(op.sizes(), /* dropFront=*/1),
1472           getI64SubArray(op.strides(), /* dropFront=*/1));
1473       res = insertOne(rewriter, loc, extracted, res, idx);
1474     }
1475     rewriter.replaceOp(op, res);
1476     return success();
1477   }
1478 };
1479 
1480 } // namespace
1481 
1482 /// Populate the given list with patterns that convert from Vector to LLVM.
1483 void mlir::populateVectorToLLVMConversionPatterns(
1484     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1485     bool reassociateFPReductions, bool enableIndexOptimizations) {
1486   MLIRContext *ctx = converter.getDialect()->getContext();
1487   // clang-format off
1488   patterns.insert<VectorFMAOpNDRewritePattern,
1489                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1490                   VectorInsertStridedSliceOpSameRankRewritePattern,
1491                   VectorExtractStridedSliceOpConversion>(ctx);
1492   patterns.insert<VectorReductionOpConversion>(
1493       converter, reassociateFPReductions);
1494   patterns.insert<VectorCreateMaskOpConversion,
1495                   VectorTransferConversion<TransferReadOp>,
1496                   VectorTransferConversion<TransferWriteOp>>(
1497       converter, enableIndexOptimizations);
1498   patterns
1499       .insert<VectorBitCastOpConversion,
1500               VectorShuffleOpConversion,
1501               VectorExtractElementOpConversion,
1502               VectorExtractOpConversion,
1503               VectorFMAOp1DConversion,
1504               VectorInsertElementOpConversion,
1505               VectorInsertOpConversion,
1506               VectorPrintOpConversion,
1507               VectorTypeCastOpConversion,
1508               VectorLoadStoreConversion<vector::LoadOp,
1509                                         vector::LoadOpAdaptor>,
1510               VectorLoadStoreConversion<vector::MaskedLoadOp,
1511                                         vector::MaskedLoadOpAdaptor>,
1512               VectorLoadStoreConversion<vector::StoreOp,
1513                                         vector::StoreOpAdaptor>,
1514               VectorLoadStoreConversion<vector::MaskedStoreOp,
1515                                         vector::MaskedStoreOpAdaptor>,
1516               VectorGatherOpConversion,
1517               VectorScatterOpConversion,
1518               VectorExpandLoadOpConversion,
1519               VectorCompressStoreOpConversion>(converter);
1520   // clang-format on
1521 }
1522 
1523 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1524     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1525   patterns.insert<VectorMatmulOpConversion>(converter);
1526   patterns.insert<VectorFlatTransposeOpConversion>(converter);
1527 }
1528