xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision ee66e43a96e138cc0ed5c37897576d05fa897c27)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/Target/LLVMIR/TypeTranslation.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 // Helper to reduce vector type by one rank at front.
25 static VectorType reducedVectorTypeFront(VectorType tp) {
26   assert((tp.getRank() > 1) && "unlowerable vector type");
27   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
28 }
29 
30 // Helper to reduce vector type by *all* but one rank at back.
31 static VectorType reducedVectorTypeBack(VectorType tp) {
32   assert((tp.getRank() > 1) && "unlowerable vector type");
33   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
34 }
35 
36 // Helper that picks the proper sequence for inserting.
37 static Value insertOne(ConversionPatternRewriter &rewriter,
38                        LLVMTypeConverter &typeConverter, Location loc,
39                        Value val1, Value val2, Type llvmType, int64_t rank,
40                        int64_t pos) {
41   if (rank == 1) {
42     auto idxType = rewriter.getIndexType();
43     auto constant = rewriter.create<LLVM::ConstantOp>(
44         loc, typeConverter.convertType(idxType),
45         rewriter.getIntegerAttr(idxType, pos));
46     return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
47                                                   constant);
48   }
49   return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
50                                               rewriter.getI64ArrayAttr(pos));
51 }
52 
53 // Helper that picks the proper sequence for inserting.
54 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
55                        Value into, int64_t offset) {
56   auto vectorType = into.getType().cast<VectorType>();
57   if (vectorType.getRank() > 1)
58     return rewriter.create<InsertOp>(loc, from, into, offset);
59   return rewriter.create<vector::InsertElementOp>(
60       loc, vectorType, from, into,
61       rewriter.create<ConstantIndexOp>(loc, offset));
62 }
63 
64 // Helper that picks the proper sequence for extracting.
65 static Value extractOne(ConversionPatternRewriter &rewriter,
66                         LLVMTypeConverter &typeConverter, Location loc,
67                         Value val, Type llvmType, int64_t rank, int64_t pos) {
68   if (rank == 1) {
69     auto idxType = rewriter.getIndexType();
70     auto constant = rewriter.create<LLVM::ConstantOp>(
71         loc, typeConverter.convertType(idxType),
72         rewriter.getIntegerAttr(idxType, pos));
73     return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74                                                    constant);
75   }
76   return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
77                                                rewriter.getI64ArrayAttr(pos));
78 }
79 
80 // Helper that picks the proper sequence for extracting.
81 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
82                         int64_t offset) {
83   auto vectorType = vector.getType().cast<VectorType>();
84   if (vectorType.getRank() > 1)
85     return rewriter.create<ExtractOp>(loc, vector, offset);
86   return rewriter.create<vector::ExtractElementOp>(
87       loc, vectorType.getElementType(), vector,
88       rewriter.create<ConstantIndexOp>(loc, offset));
89 }
90 
91 // Helper that returns a subset of `arrayAttr` as a vector of int64_t.
92 // TODO: Better support for attribute subtype forwarding + slicing.
93 static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
94                                               unsigned dropFront = 0,
95                                               unsigned dropBack = 0) {
96   assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
97   auto range = arrayAttr.getAsRange<IntegerAttr>();
98   SmallVector<int64_t, 4> res;
99   res.reserve(arrayAttr.size() - dropFront - dropBack);
100   for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
101        it != eit; ++it)
102     res.push_back((*it).getValue().getSExtValue());
103   return res;
104 }
105 
106 static Value createCastToIndexLike(ConversionPatternRewriter &rewriter,
107                                    Location loc, Type targetType, Value value) {
108   if (targetType == value.getType())
109     return value;
110 
111   bool targetIsIndex = targetType.isIndex();
112   bool valueIsIndex = value.getType().isIndex();
113   if (targetIsIndex ^ valueIsIndex)
114     return rewriter.create<IndexCastOp>(loc, targetType, value);
115 
116   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
117   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
118   assert(targetIntegerType && valueIntegerType &&
119          "unexpected cast between types other than integers and index");
120   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
121 
122   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
123     return rewriter.create<SignExtendIOp>(loc, targetIntegerType, value);
124   return rewriter.create<TruncateIOp>(loc, targetIntegerType, value);
125 }
126 
127 // Helper that returns a vector comparison that constructs a mask:
128 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
129 //
130 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
131 //       much more compact, IR for this operation, but LLVM eventually
132 //       generates more elaborate instructions for this intrinsic since it
133 //       is very conservative on the boundary conditions.
134 static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
135                                    Operation *op, bool enableIndexOptimizations,
136                                    int64_t dim, Value b, Value *off = nullptr) {
137   auto loc = op->getLoc();
138   // If we can assume all indices fit in 32-bit, we perform the vector
139   // comparison in 32-bit to get a higher degree of SIMD parallelism.
140   // Otherwise we perform the vector comparison using 64-bit indices.
141   Value indices;
142   Type idxType;
143   if (enableIndexOptimizations) {
144     indices = rewriter.create<ConstantOp>(
145         loc, rewriter.getI32VectorAttr(
146                  llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
147     idxType = rewriter.getI32Type();
148   } else {
149     indices = rewriter.create<ConstantOp>(
150         loc, rewriter.getI64VectorAttr(
151                  llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
152     idxType = rewriter.getI64Type();
153   }
154   // Add in an offset if requested.
155   if (off) {
156     Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
157     Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
158     indices = rewriter.create<AddIOp>(loc, ov, indices);
159   }
160   // Construct the vector comparison.
161   Value bound = createCastToIndexLike(rewriter, loc, idxType, b);
162   Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
163   return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
164 }
165 
166 // Helper that returns data layout alignment of a memref.
167 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
168                                  MemRefType memrefType, unsigned &align) {
169   Type elementTy = typeConverter.convertType(memrefType.getElementType());
170   if (!elementTy)
171     return failure();
172 
173   // TODO: this should use the MLIR data layout when it becomes available and
174   // stop depending on translation.
175   llvm::LLVMContext llvmContext;
176   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
177               .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
178   return success();
179 }
180 
181 // Helper that returns the base address of a memref.
182 static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
183                              Value memref, MemRefType memRefType, Value &base) {
184   // Inspect stride and offset structure.
185   //
186   // TODO: flat memory only for now, generalize
187   //
188   int64_t offset;
189   SmallVector<int64_t, 4> strides;
190   auto successStrides = getStridesAndOffset(memRefType, strides, offset);
191   if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
192       offset != 0 || memRefType.getMemorySpace() != 0)
193     return failure();
194   base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
195   return success();
196 }
197 
198 // Helper that returns vector of pointers given a memref base with index vector.
199 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
200                                     Location loc, Value memref, Value indices,
201                                     MemRefType memRefType, VectorType vType,
202                                     Type iType, Value &ptrs) {
203   Value base;
204   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
205     return failure();
206   auto pType = MemRefDescriptor(memref).getElementPtrType();
207   auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
208   ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
209   return success();
210 }
211 
212 // Casts a strided element pointer to a vector pointer. The vector pointer
213 // would always be on address space 0, therefore addrspacecast shall be
214 // used when source/dst memrefs are not on address space 0.
215 static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
216                          Value ptr, MemRefType memRefType, Type vt) {
217   auto pType = LLVM::LLVMPointerType::get(vt);
218   if (memRefType.getMemorySpace() == 0)
219     return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
220   return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
221 }
222 
223 static LogicalResult
224 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
225                                  LLVMTypeConverter &typeConverter, Location loc,
226                                  TransferReadOp xferOp,
227                                  ArrayRef<Value> operands, Value dataPtr) {
228   unsigned align;
229   if (failed(getMemRefAlignment(
230           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
231     return failure();
232   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
233   return success();
234 }
235 
236 static LogicalResult
237 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
238                             LLVMTypeConverter &typeConverter, Location loc,
239                             TransferReadOp xferOp, ArrayRef<Value> operands,
240                             Value dataPtr, Value mask) {
241   VectorType fillType = xferOp.getVectorType();
242   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
243 
244   Type vecTy = typeConverter.convertType(xferOp.getVectorType());
245   if (!vecTy)
246     return failure();
247 
248   unsigned align;
249   if (failed(getMemRefAlignment(
250           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
251     return failure();
252 
253   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
254       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
255       rewriter.getI32IntegerAttr(align));
256   return success();
257 }
258 
259 static LogicalResult
260 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
261                                  LLVMTypeConverter &typeConverter, Location loc,
262                                  TransferWriteOp xferOp,
263                                  ArrayRef<Value> operands, Value dataPtr) {
264   unsigned align;
265   if (failed(getMemRefAlignment(
266           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
267     return failure();
268   auto adaptor = TransferWriteOpAdaptor(operands);
269   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
270                                              align);
271   return success();
272 }
273 
274 static LogicalResult
275 replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
276                             LLVMTypeConverter &typeConverter, Location loc,
277                             TransferWriteOp xferOp, ArrayRef<Value> operands,
278                             Value dataPtr, Value mask) {
279   unsigned align;
280   if (failed(getMemRefAlignment(
281           typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
282     return failure();
283 
284   auto adaptor = TransferWriteOpAdaptor(operands);
285   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
286       xferOp, adaptor.vector(), dataPtr, mask,
287       rewriter.getI32IntegerAttr(align));
288   return success();
289 }
290 
291 static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp,
292                                                   ArrayRef<Value> operands) {
293   return TransferReadOpAdaptor(operands);
294 }
295 
296 static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp,
297                                                    ArrayRef<Value> operands) {
298   return TransferWriteOpAdaptor(operands);
299 }
300 
301 namespace {
302 
303 /// Conversion pattern for a vector.bitcast.
304 class VectorBitCastOpConversion
305     : public ConvertOpToLLVMPattern<vector::BitCastOp> {
306 public:
307   using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
308 
309   LogicalResult
310   matchAndRewrite(vector::BitCastOp bitCastOp, ArrayRef<Value> operands,
311                   ConversionPatternRewriter &rewriter) const override {
312     // Only 1-D vectors can be lowered to LLVM.
313     VectorType resultTy = bitCastOp.getType();
314     if (resultTy.getRank() != 1)
315       return failure();
316     Type newResultTy = typeConverter->convertType(resultTy);
317     rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
318                                                  operands[0]);
319     return success();
320   }
321 };
322 
323 /// Conversion pattern for a vector.matrix_multiply.
324 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
325 class VectorMatmulOpConversion
326     : public ConvertOpToLLVMPattern<vector::MatmulOp> {
327 public:
328   using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
329 
330   LogicalResult
331   matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef<Value> operands,
332                   ConversionPatternRewriter &rewriter) const override {
333     auto adaptor = vector::MatmulOpAdaptor(operands);
334     rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
335         matmulOp, typeConverter->convertType(matmulOp.res().getType()),
336         adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(),
337         matmulOp.lhs_columns(), matmulOp.rhs_columns());
338     return success();
339   }
340 };
341 
342 /// Conversion pattern for a vector.flat_transpose.
343 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
344 class VectorFlatTransposeOpConversion
345     : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
346 public:
347   using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
348 
349   LogicalResult
350   matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef<Value> operands,
351                   ConversionPatternRewriter &rewriter) const override {
352     auto adaptor = vector::FlatTransposeOpAdaptor(operands);
353     rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
354         transOp, typeConverter->convertType(transOp.res().getType()),
355         adaptor.matrix(), transOp.rows(), transOp.columns());
356     return success();
357   }
358 };
359 
360 /// Overloaded utility that replaces a vector.load, vector.store,
361 /// vector.maskedload and vector.maskedstore with their respective LLVM
362 /// couterparts.
363 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
364                                  vector::LoadOpAdaptor adaptor,
365                                  VectorType vectorTy, Value ptr, unsigned align,
366                                  ConversionPatternRewriter &rewriter) {
367   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
368 }
369 
370 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
371                                  vector::MaskedLoadOpAdaptor adaptor,
372                                  VectorType vectorTy, Value ptr, unsigned align,
373                                  ConversionPatternRewriter &rewriter) {
374   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
375       loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align);
376 }
377 
378 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
379                                  vector::StoreOpAdaptor adaptor,
380                                  VectorType vectorTy, Value ptr, unsigned align,
381                                  ConversionPatternRewriter &rewriter) {
382   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.valueToStore(),
383                                              ptr, align);
384 }
385 
386 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
387                                  vector::MaskedStoreOpAdaptor adaptor,
388                                  VectorType vectorTy, Value ptr, unsigned align,
389                                  ConversionPatternRewriter &rewriter) {
390   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
391       storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align);
392 }
393 
394 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
395 /// vector.maskedstore.
396 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
397 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
398 public:
399   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
400 
401   LogicalResult
402   matchAndRewrite(LoadOrStoreOp loadOrStoreOp, ArrayRef<Value> operands,
403                   ConversionPatternRewriter &rewriter) const override {
404     // Only 1-D vectors can be lowered to LLVM.
405     VectorType vectorTy = loadOrStoreOp.getVectorType();
406     if (vectorTy.getRank() > 1)
407       return failure();
408 
409     auto loc = loadOrStoreOp->getLoc();
410     auto adaptor = LoadOrStoreOpAdaptor(operands);
411     MemRefType memRefTy = loadOrStoreOp.getMemRefType();
412 
413     // Resolve alignment.
414     unsigned align;
415     if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
416       return failure();
417 
418     // Resolve address.
419     auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
420                      .template cast<VectorType>();
421     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(),
422                                                adaptor.indices(), rewriter);
423     Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
424 
425     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
426     return success();
427   }
428 };
429 
430 /// Conversion pattern for a vector.gather.
431 class VectorGatherOpConversion
432     : public ConvertOpToLLVMPattern<vector::GatherOp> {
433 public:
434   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
435 
436   LogicalResult
437   matchAndRewrite(vector::GatherOp gather, ArrayRef<Value> operands,
438                   ConversionPatternRewriter &rewriter) const override {
439     auto loc = gather->getLoc();
440     auto adaptor = vector::GatherOpAdaptor(operands);
441 
442     // Resolve alignment.
443     unsigned align;
444     if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
445                                   align)))
446       return failure();
447 
448     // Get index ptrs.
449     VectorType vType = gather.getResultVectorType();
450     Type iType = gather.getIndicesVectorType().getElementType();
451     Value ptrs;
452     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
453                               gather.getMemRefType(), vType, iType, ptrs)))
454       return failure();
455 
456     // Replace with the gather intrinsic.
457     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
458         gather, typeConverter->convertType(vType), ptrs, adaptor.mask(),
459         adaptor.pass_thru(), rewriter.getI32IntegerAttr(align));
460     return success();
461   }
462 };
463 
464 /// Conversion pattern for a vector.scatter.
465 class VectorScatterOpConversion
466     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
467 public:
468   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
469 
470   LogicalResult
471   matchAndRewrite(vector::ScatterOp scatter, ArrayRef<Value> operands,
472                   ConversionPatternRewriter &rewriter) const override {
473     auto loc = scatter->getLoc();
474     auto adaptor = vector::ScatterOpAdaptor(operands);
475 
476     // Resolve alignment.
477     unsigned align;
478     if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
479                                   align)))
480       return failure();
481 
482     // Get index ptrs.
483     VectorType vType = scatter.getValueVectorType();
484     Type iType = scatter.getIndicesVectorType().getElementType();
485     Value ptrs;
486     if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
487                               scatter.getMemRefType(), vType, iType, ptrs)))
488       return failure();
489 
490     // Replace with the scatter intrinsic.
491     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
492         scatter, adaptor.value(), ptrs, adaptor.mask(),
493         rewriter.getI32IntegerAttr(align));
494     return success();
495   }
496 };
497 
498 /// Conversion pattern for a vector.expandload.
499 class VectorExpandLoadOpConversion
500     : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
501 public:
502   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
503 
504   LogicalResult
505   matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef<Value> operands,
506                   ConversionPatternRewriter &rewriter) const override {
507     auto loc = expand->getLoc();
508     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
509     MemRefType memRefType = expand.getMemRefType();
510 
511     // Resolve address.
512     auto vtype = typeConverter->convertType(expand.getResultVectorType());
513     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
514                                            adaptor.indices(), rewriter);
515 
516     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
517         expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
518     return success();
519   }
520 };
521 
522 /// Conversion pattern for a vector.compressstore.
523 class VectorCompressStoreOpConversion
524     : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
525 public:
526   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
527 
528   LogicalResult
529   matchAndRewrite(vector::CompressStoreOp compress, ArrayRef<Value> operands,
530                   ConversionPatternRewriter &rewriter) const override {
531     auto loc = compress->getLoc();
532     auto adaptor = vector::CompressStoreOpAdaptor(operands);
533     MemRefType memRefType = compress.getMemRefType();
534 
535     // Resolve address.
536     Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
537                                            adaptor.indices(), rewriter);
538 
539     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
540         compress, adaptor.value(), ptr, adaptor.mask());
541     return success();
542   }
543 };
544 
545 /// Conversion pattern for all vector reductions.
546 class VectorReductionOpConversion
547     : public ConvertOpToLLVMPattern<vector::ReductionOp> {
548 public:
549   explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
550                                        bool reassociateFPRed)
551       : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
552         reassociateFPReductions(reassociateFPRed) {}
553 
554   LogicalResult
555   matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef<Value> operands,
556                   ConversionPatternRewriter &rewriter) const override {
557     auto kind = reductionOp.kind();
558     Type eltType = reductionOp.dest().getType();
559     Type llvmType = typeConverter->convertType(eltType);
560     if (eltType.isIntOrIndex()) {
561       // Integer reductions: add/mul/min/max/and/or/xor.
562       if (kind == "add")
563         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(
564             reductionOp, llvmType, operands[0]);
565       else if (kind == "mul")
566         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(
567             reductionOp, llvmType, operands[0]);
568       else if (kind == "min" &&
569                (eltType.isIndex() || eltType.isUnsignedInteger()))
570         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
571             reductionOp, llvmType, operands[0]);
572       else if (kind == "min")
573         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
574             reductionOp, llvmType, operands[0]);
575       else if (kind == "max" &&
576                (eltType.isIndex() || eltType.isUnsignedInteger()))
577         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
578             reductionOp, llvmType, operands[0]);
579       else if (kind == "max")
580         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
581             reductionOp, llvmType, operands[0]);
582       else if (kind == "and")
583         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(
584             reductionOp, llvmType, operands[0]);
585       else if (kind == "or")
586         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(
587             reductionOp, llvmType, operands[0]);
588       else if (kind == "xor")
589         rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(
590             reductionOp, llvmType, operands[0]);
591       else
592         return failure();
593       return success();
594     }
595 
596     if (!eltType.isa<FloatType>())
597       return failure();
598 
599     // Floating-point reductions: add/mul/min/max
600     if (kind == "add") {
601       // Optional accumulator (or zero).
602       Value acc = operands.size() > 1 ? operands[1]
603                                       : rewriter.create<LLVM::ConstantOp>(
604                                             reductionOp->getLoc(), llvmType,
605                                             rewriter.getZeroAttr(eltType));
606       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
607           reductionOp, llvmType, acc, operands[0],
608           rewriter.getBoolAttr(reassociateFPReductions));
609     } else if (kind == "mul") {
610       // Optional accumulator (or one).
611       Value acc = operands.size() > 1
612                       ? operands[1]
613                       : rewriter.create<LLVM::ConstantOp>(
614                             reductionOp->getLoc(), llvmType,
615                             rewriter.getFloatAttr(eltType, 1.0));
616       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
617           reductionOp, llvmType, acc, operands[0],
618           rewriter.getBoolAttr(reassociateFPReductions));
619     } else if (kind == "min")
620       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(
621           reductionOp, llvmType, operands[0]);
622     else if (kind == "max")
623       rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(
624           reductionOp, llvmType, operands[0]);
625     else
626       return failure();
627     return success();
628   }
629 
630 private:
631   const bool reassociateFPReductions;
632 };
633 
634 /// Conversion pattern for a vector.create_mask (1-D only).
635 class VectorCreateMaskOpConversion
636     : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
637 public:
638   explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv,
639                                         bool enableIndexOpt)
640       : ConvertOpToLLVMPattern<vector::CreateMaskOp>(typeConv),
641         enableIndexOptimizations(enableIndexOpt) {}
642 
643   LogicalResult
644   matchAndRewrite(vector::CreateMaskOp op, ArrayRef<Value> operands,
645                   ConversionPatternRewriter &rewriter) const override {
646     auto dstType = op.getType();
647     int64_t rank = dstType.getRank();
648     if (rank == 1) {
649       rewriter.replaceOp(
650           op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
651                                     dstType.getDimSize(0), operands[0]));
652       return success();
653     }
654     return failure();
655   }
656 
657 private:
658   const bool enableIndexOptimizations;
659 };
660 
661 class VectorShuffleOpConversion
662     : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
663 public:
664   using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
665 
666   LogicalResult
667   matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef<Value> operands,
668                   ConversionPatternRewriter &rewriter) const override {
669     auto loc = shuffleOp->getLoc();
670     auto adaptor = vector::ShuffleOpAdaptor(operands);
671     auto v1Type = shuffleOp.getV1VectorType();
672     auto v2Type = shuffleOp.getV2VectorType();
673     auto vectorType = shuffleOp.getVectorType();
674     Type llvmType = typeConverter->convertType(vectorType);
675     auto maskArrayAttr = shuffleOp.mask();
676 
677     // Bail if result type cannot be lowered.
678     if (!llvmType)
679       return failure();
680 
681     // Get rank and dimension sizes.
682     int64_t rank = vectorType.getRank();
683     assert(v1Type.getRank() == rank);
684     assert(v2Type.getRank() == rank);
685     int64_t v1Dim = v1Type.getDimSize(0);
686 
687     // For rank 1, where both operands have *exactly* the same vector type,
688     // there is direct shuffle support in LLVM. Use it!
689     if (rank == 1 && v1Type == v2Type) {
690       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
691           loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
692       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
693       return success();
694     }
695 
696     // For all other cases, insert the individual values individually.
697     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
698     int64_t insPos = 0;
699     for (auto en : llvm::enumerate(maskArrayAttr)) {
700       int64_t extPos = en.value().cast<IntegerAttr>().getInt();
701       Value value = adaptor.v1();
702       if (extPos >= v1Dim) {
703         extPos -= v1Dim;
704         value = adaptor.v2();
705       }
706       Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
707                                  llvmType, rank, extPos);
708       insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
709                          llvmType, rank, insPos++);
710     }
711     rewriter.replaceOp(shuffleOp, insert);
712     return success();
713   }
714 };
715 
716 class VectorExtractElementOpConversion
717     : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
718 public:
719   using ConvertOpToLLVMPattern<
720       vector::ExtractElementOp>::ConvertOpToLLVMPattern;
721 
722   LogicalResult
723   matchAndRewrite(vector::ExtractElementOp extractEltOp,
724                   ArrayRef<Value> operands,
725                   ConversionPatternRewriter &rewriter) const override {
726     auto adaptor = vector::ExtractElementOpAdaptor(operands);
727     auto vectorType = extractEltOp.getVectorType();
728     auto llvmType = typeConverter->convertType(vectorType.getElementType());
729 
730     // Bail if result type cannot be lowered.
731     if (!llvmType)
732       return failure();
733 
734     rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
735         extractEltOp, llvmType, adaptor.vector(), adaptor.position());
736     return success();
737   }
738 };
739 
740 class VectorExtractOpConversion
741     : public ConvertOpToLLVMPattern<vector::ExtractOp> {
742 public:
743   using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
744 
745   LogicalResult
746   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
747                   ConversionPatternRewriter &rewriter) const override {
748     auto loc = extractOp->getLoc();
749     auto adaptor = vector::ExtractOpAdaptor(operands);
750     auto vectorType = extractOp.getVectorType();
751     auto resultType = extractOp.getResult().getType();
752     auto llvmResultType = typeConverter->convertType(resultType);
753     auto positionArrayAttr = extractOp.position();
754 
755     // Bail if result type cannot be lowered.
756     if (!llvmResultType)
757       return failure();
758 
759     // One-shot extraction of vector from array (only requires extractvalue).
760     if (resultType.isa<VectorType>()) {
761       Value extracted = rewriter.create<LLVM::ExtractValueOp>(
762           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
763       rewriter.replaceOp(extractOp, extracted);
764       return success();
765     }
766 
767     // Potential extraction of 1-D vector from array.
768     auto *context = extractOp->getContext();
769     Value extracted = adaptor.vector();
770     auto positionAttrs = positionArrayAttr.getValue();
771     if (positionAttrs.size() > 1) {
772       auto oneDVectorType = reducedVectorTypeBack(vectorType);
773       auto nMinusOnePositionAttrs =
774           ArrayAttr::get(context, positionAttrs.drop_back());
775       extracted = rewriter.create<LLVM::ExtractValueOp>(
776           loc, typeConverter->convertType(oneDVectorType), extracted,
777           nMinusOnePositionAttrs);
778     }
779 
780     // Remaining extraction of element from 1-D LLVM vector
781     auto position = positionAttrs.back().cast<IntegerAttr>();
782     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
783     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
784     extracted =
785         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
786     rewriter.replaceOp(extractOp, extracted);
787 
788     return success();
789   }
790 };
791 
792 /// Conversion pattern that turns a vector.fma on a 1-D vector
793 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
794 /// This does not match vectors of n >= 2 rank.
795 ///
796 /// Example:
797 /// ```
798 ///  vector.fma %a, %a, %a : vector<8xf32>
799 /// ```
800 /// is converted to:
801 /// ```
802 ///  llvm.intr.fmuladd %va, %va, %va:
803 ///    (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
804 ///    -> !llvm."<8 x f32>">
805 /// ```
806 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
807 public:
808   using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
809 
810   LogicalResult
811   matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
812                   ConversionPatternRewriter &rewriter) const override {
813     auto adaptor = vector::FMAOpAdaptor(operands);
814     VectorType vType = fmaOp.getVectorType();
815     if (vType.getRank() != 1)
816       return failure();
817     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(fmaOp, adaptor.lhs(),
818                                                  adaptor.rhs(), adaptor.acc());
819     return success();
820   }
821 };
822 
823 class VectorInsertElementOpConversion
824     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
825 public:
826   using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
827 
828   LogicalResult
829   matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef<Value> operands,
830                   ConversionPatternRewriter &rewriter) const override {
831     auto adaptor = vector::InsertElementOpAdaptor(operands);
832     auto vectorType = insertEltOp.getDestVectorType();
833     auto llvmType = typeConverter->convertType(vectorType);
834 
835     // Bail if result type cannot be lowered.
836     if (!llvmType)
837       return failure();
838 
839     rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
840         insertEltOp, llvmType, adaptor.dest(), adaptor.source(),
841         adaptor.position());
842     return success();
843   }
844 };
845 
846 class VectorInsertOpConversion
847     : public ConvertOpToLLVMPattern<vector::InsertOp> {
848 public:
849   using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
850 
851   LogicalResult
852   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
853                   ConversionPatternRewriter &rewriter) const override {
854     auto loc = insertOp->getLoc();
855     auto adaptor = vector::InsertOpAdaptor(operands);
856     auto sourceType = insertOp.getSourceType();
857     auto destVectorType = insertOp.getDestVectorType();
858     auto llvmResultType = typeConverter->convertType(destVectorType);
859     auto positionArrayAttr = insertOp.position();
860 
861     // Bail if result type cannot be lowered.
862     if (!llvmResultType)
863       return failure();
864 
865     // One-shot insertion of a vector into an array (only requires insertvalue).
866     if (sourceType.isa<VectorType>()) {
867       Value inserted = rewriter.create<LLVM::InsertValueOp>(
868           loc, llvmResultType, adaptor.dest(), adaptor.source(),
869           positionArrayAttr);
870       rewriter.replaceOp(insertOp, inserted);
871       return success();
872     }
873 
874     // Potential extraction of 1-D vector from array.
875     auto *context = insertOp->getContext();
876     Value extracted = adaptor.dest();
877     auto positionAttrs = positionArrayAttr.getValue();
878     auto position = positionAttrs.back().cast<IntegerAttr>();
879     auto oneDVectorType = destVectorType;
880     if (positionAttrs.size() > 1) {
881       oneDVectorType = reducedVectorTypeBack(destVectorType);
882       auto nMinusOnePositionAttrs =
883           ArrayAttr::get(context, positionAttrs.drop_back());
884       extracted = rewriter.create<LLVM::ExtractValueOp>(
885           loc, typeConverter->convertType(oneDVectorType), extracted,
886           nMinusOnePositionAttrs);
887     }
888 
889     // Insertion of an element into a 1-D LLVM vector.
890     auto i64Type = IntegerType::get(rewriter.getContext(), 64);
891     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
892     Value inserted = rewriter.create<LLVM::InsertElementOp>(
893         loc, typeConverter->convertType(oneDVectorType), extracted,
894         adaptor.source(), constant);
895 
896     // Potential insertion of resulting 1-D vector into array.
897     if (positionAttrs.size() > 1) {
898       auto nMinusOnePositionAttrs =
899           ArrayAttr::get(context, positionAttrs.drop_back());
900       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
901                                                       adaptor.dest(), inserted,
902                                                       nMinusOnePositionAttrs);
903     }
904 
905     rewriter.replaceOp(insertOp, inserted);
906     return success();
907   }
908 };
909 
910 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
911 ///
912 /// Example:
913 /// ```
914 ///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
915 /// ```
916 /// is rewritten into:
917 /// ```
918 ///  %r = splat %f0: vector<2x4xf32>
919 ///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
920 ///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
921 ///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
922 ///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
923 ///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
924 ///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
925 ///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
926 ///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
927 ///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
928 ///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
929 ///  // %r3 holds the final value.
930 /// ```
931 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
932 public:
933   using OpRewritePattern<FMAOp>::OpRewritePattern;
934 
935   LogicalResult matchAndRewrite(FMAOp op,
936                                 PatternRewriter &rewriter) const override {
937     auto vType = op.getVectorType();
938     if (vType.getRank() < 2)
939       return failure();
940 
941     auto loc = op.getLoc();
942     auto elemType = vType.getElementType();
943     Value zero = rewriter.create<ConstantOp>(loc, elemType,
944                                              rewriter.getZeroAttr(elemType));
945     Value desc = rewriter.create<SplatOp>(loc, vType, zero);
946     for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
947       Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
948       Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
949       Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
950       Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
951       desc = rewriter.create<InsertOp>(loc, fma, desc, i);
952     }
953     rewriter.replaceOp(op, desc);
954     return success();
955   }
956 };
957 
958 // When ranks are different, InsertStridedSlice needs to extract a properly
959 // ranked vector from the destination vector into which to insert. This pattern
960 // only takes care of this part and forwards the rest of the conversion to
961 // another pattern that converts InsertStridedSlice for operands of the same
962 // rank.
963 //
964 // RewritePattern for InsertStridedSliceOp where source and destination vectors
965 // have different ranks. In this case:
966 //   1. the proper subvector is extracted from the destination vector
967 //   2. a new InsertStridedSlice op is created to insert the source in the
968 //   destination subvector
969 //   3. the destination subvector is inserted back in the proper place
970 //   4. the op is replaced by the result of step 3.
971 // The new InsertStridedSlice from step 2. will be picked up by a
972 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
973 class VectorInsertStridedSliceOpDifferentRankRewritePattern
974     : public OpRewritePattern<InsertStridedSliceOp> {
975 public:
976   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
977 
978   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
979                                 PatternRewriter &rewriter) const override {
980     auto srcType = op.getSourceVectorType();
981     auto dstType = op.getDestVectorType();
982 
983     if (op.offsets().getValue().empty())
984       return failure();
985 
986     auto loc = op.getLoc();
987     int64_t rankDiff = dstType.getRank() - srcType.getRank();
988     assert(rankDiff >= 0);
989     if (rankDiff == 0)
990       return failure();
991 
992     int64_t rankRest = dstType.getRank() - rankDiff;
993     // Extract / insert the subvector of matching rank and InsertStridedSlice
994     // on it.
995     Value extracted =
996         rewriter.create<ExtractOp>(loc, op.dest(),
997                                    getI64SubArray(op.offsets(), /*dropFront=*/0,
998                                                   /*dropBack=*/rankRest));
999     // A different pattern will kick in for InsertStridedSlice with matching
1000     // ranks.
1001     auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
1002         loc, op.source(), extracted,
1003         getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
1004         getI64SubArray(op.strides(), /*dropFront=*/0));
1005     rewriter.replaceOpWithNewOp<InsertOp>(
1006         op, stridedSliceInnerOp.getResult(), op.dest(),
1007         getI64SubArray(op.offsets(), /*dropFront=*/0,
1008                        /*dropBack=*/rankRest));
1009     return success();
1010   }
1011 };
1012 
1013 // RewritePattern for InsertStridedSliceOp where source and destination vectors
1014 // have the same rank. In this case, we reduce
1015 //   1. the proper subvector is extracted from the destination vector
1016 //   2. a new InsertStridedSlice op is created to insert the source in the
1017 //   destination subvector
1018 //   3. the destination subvector is inserted back in the proper place
1019 //   4. the op is replaced by the result of step 3.
1020 // The new InsertStridedSlice from step 2. will be picked up by a
1021 // `VectorInsertStridedSliceOpSameRankRewritePattern`.
1022 class VectorInsertStridedSliceOpSameRankRewritePattern
1023     : public OpRewritePattern<InsertStridedSliceOp> {
1024 public:
1025   VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
1026       : OpRewritePattern<InsertStridedSliceOp>(ctx) {
1027     // This pattern creates recursive InsertStridedSliceOp, but the recursion is
1028     // bounded as the rank is strictly decreasing.
1029     setHasBoundedRewriteRecursion();
1030   }
1031 
1032   LogicalResult matchAndRewrite(InsertStridedSliceOp op,
1033                                 PatternRewriter &rewriter) const override {
1034     auto srcType = op.getSourceVectorType();
1035     auto dstType = op.getDestVectorType();
1036 
1037     if (op.offsets().getValue().empty())
1038       return failure();
1039 
1040     int64_t rankDiff = dstType.getRank() - srcType.getRank();
1041     assert(rankDiff >= 0);
1042     if (rankDiff != 0)
1043       return failure();
1044 
1045     if (srcType == dstType) {
1046       rewriter.replaceOp(op, op.source());
1047       return success();
1048     }
1049 
1050     int64_t offset =
1051         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1052     int64_t size = srcType.getShape().front();
1053     int64_t stride =
1054         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1055 
1056     auto loc = op.getLoc();
1057     Value res = op.dest();
1058     // For each slice of the source vector along the most major dimension.
1059     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1060          off += stride, ++idx) {
1061       // 1. extract the proper subvector (or element) from source
1062       Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
1063       if (extractedSource.getType().isa<VectorType>()) {
1064         // 2. If we have a vector, extract the proper subvector from destination
1065         // Otherwise we are at the element level and no need to recurse.
1066         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
1067         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
1068         // smaller rank.
1069         extractedSource = rewriter.create<InsertStridedSliceOp>(
1070             loc, extractedSource, extractedDest,
1071             getI64SubArray(op.offsets(), /* dropFront=*/1),
1072             getI64SubArray(op.strides(), /* dropFront=*/1));
1073       }
1074       // 4. Insert the extractedSource into the res vector.
1075       res = insertOne(rewriter, loc, extractedSource, res, off);
1076     }
1077 
1078     rewriter.replaceOp(op, res);
1079     return success();
1080   }
1081 };
1082 
1083 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1084 /// static layout.
1085 static llvm::Optional<SmallVector<int64_t, 4>>
1086 computeContiguousStrides(MemRefType memRefType) {
1087   int64_t offset;
1088   SmallVector<int64_t, 4> strides;
1089   if (failed(getStridesAndOffset(memRefType, strides, offset)))
1090     return None;
1091   if (!strides.empty() && strides.back() != 1)
1092     return None;
1093   // If no layout or identity layout, this is contiguous by definition.
1094   if (memRefType.getAffineMaps().empty() ||
1095       memRefType.getAffineMaps().front().isIdentity())
1096     return strides;
1097 
1098   // Otherwise, we must determine contiguity form shapes. This can only ever
1099   // work in static cases because MemRefType is underspecified to represent
1100   // contiguous dynamic shapes in other ways than with just empty/identity
1101   // layout.
1102   auto sizes = memRefType.getShape();
1103   for (int index = 0, e = strides.size() - 2; index < e; ++index) {
1104     if (ShapedType::isDynamic(sizes[index + 1]) ||
1105         ShapedType::isDynamicStrideOrOffset(strides[index]) ||
1106         ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
1107       return None;
1108     if (strides[index] != strides[index + 1] * sizes[index + 1])
1109       return None;
1110   }
1111   return strides;
1112 }
1113 
1114 class VectorTypeCastOpConversion
1115     : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1116 public:
1117   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1118 
1119   LogicalResult
1120   matchAndRewrite(vector::TypeCastOp castOp, ArrayRef<Value> operands,
1121                   ConversionPatternRewriter &rewriter) const override {
1122     auto loc = castOp->getLoc();
1123     MemRefType sourceMemRefType =
1124         castOp.getOperand().getType().cast<MemRefType>();
1125     MemRefType targetMemRefType = castOp.getType();
1126 
1127     // Only static shape casts supported atm.
1128     if (!sourceMemRefType.hasStaticShape() ||
1129         !targetMemRefType.hasStaticShape())
1130       return failure();
1131 
1132     auto llvmSourceDescriptorTy =
1133         operands[0].getType().dyn_cast<LLVM::LLVMStructType>();
1134     if (!llvmSourceDescriptorTy)
1135       return failure();
1136     MemRefDescriptor sourceMemRef(operands[0]);
1137 
1138     auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
1139                                       .dyn_cast_or_null<LLVM::LLVMStructType>();
1140     if (!llvmTargetDescriptorTy)
1141       return failure();
1142 
1143     // Only contiguous source buffers supported atm.
1144     auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1145     if (!sourceStrides)
1146       return failure();
1147     auto targetStrides = computeContiguousStrides(targetMemRefType);
1148     if (!targetStrides)
1149       return failure();
1150     // Only support static strides for now, regardless of contiguity.
1151     if (llvm::any_of(*targetStrides, [](int64_t stride) {
1152           return ShapedType::isDynamicStrideOrOffset(stride);
1153         }))
1154       return failure();
1155 
1156     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1157 
1158     // Create descriptor.
1159     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1160     Type llvmTargetElementTy = desc.getElementPtrType();
1161     // Set allocated ptr.
1162     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1163     allocated =
1164         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
1165     desc.setAllocatedPtr(rewriter, loc, allocated);
1166     // Set aligned ptr.
1167     Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1168     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
1169     desc.setAlignedPtr(rewriter, loc, ptr);
1170     // Fill offset 0.
1171     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1172     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1173     desc.setOffset(rewriter, loc, zero);
1174 
1175     // Fill size and stride descriptors in memref.
1176     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
1177       int64_t index = indexedSize.index();
1178       auto sizeAttr =
1179           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1180       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1181       desc.setSize(rewriter, loc, index, size);
1182       auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1183                                                 (*targetStrides)[index]);
1184       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1185       desc.setStride(rewriter, loc, index, stride);
1186     }
1187 
1188     rewriter.replaceOp(castOp, {desc});
1189     return success();
1190   }
1191 };
1192 
1193 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
1194 /// sequence of:
1195 /// 1. Get the source/dst address as an LLVM vector pointer.
1196 /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1197 /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1198 /// 4. Create a mask where offsetVector is compared against memref upper bound.
1199 /// 5. Rewrite op as a masked read or write.
1200 template <typename ConcreteOp>
1201 class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
1202 public:
1203   explicit VectorTransferConversion(LLVMTypeConverter &typeConv,
1204                                     bool enableIndexOpt)
1205       : ConvertOpToLLVMPattern<ConcreteOp>(typeConv),
1206         enableIndexOptimizations(enableIndexOpt) {}
1207 
1208   LogicalResult
1209   matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
1210                   ConversionPatternRewriter &rewriter) const override {
1211     auto adaptor = getTransferOpAdapter(xferOp, operands);
1212 
1213     if (xferOp.getVectorType().getRank() > 1 ||
1214         llvm::size(xferOp.indices()) == 0)
1215       return failure();
1216     if (xferOp.permutation_map() !=
1217         AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
1218                                        xferOp.getVectorType().getRank(),
1219                                        xferOp->getContext()))
1220       return failure();
1221     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1222     if (!memRefType)
1223       return failure();
1224     // Only contiguous source tensors supported atm.
1225     auto strides = computeContiguousStrides(memRefType);
1226     if (!strides)
1227       return failure();
1228 
1229     auto toLLVMTy = [&](Type t) {
1230       return this->getTypeConverter()->convertType(t);
1231     };
1232 
1233     Location loc = xferOp->getLoc();
1234 
1235     if (auto memrefVectorElementType =
1236             memRefType.getElementType().template dyn_cast<VectorType>()) {
1237       // Memref has vector element type.
1238       if (memrefVectorElementType.getElementType() !=
1239           xferOp.getVectorType().getElementType())
1240         return failure();
1241 #ifndef NDEBUG
1242       // Check that memref vector type is a suffix of 'vectorType.
1243       unsigned memrefVecEltRank = memrefVectorElementType.getRank();
1244       unsigned resultVecRank = xferOp.getVectorType().getRank();
1245       assert(memrefVecEltRank <= resultVecRank);
1246       // TODO: Move this to isSuffix in Vector/Utils.h.
1247       unsigned rankOffset = resultVecRank - memrefVecEltRank;
1248       auto memrefVecEltShape = memrefVectorElementType.getShape();
1249       auto resultVecShape = xferOp.getVectorType().getShape();
1250       for (unsigned i = 0; i < memrefVecEltRank; ++i)
1251         assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
1252                "memref vector element shape should match suffix of vector "
1253                "result shape.");
1254 #endif // ifndef NDEBUG
1255     }
1256 
1257     // 1. Get the source/dst address as an LLVM vector pointer.
1258     VectorType vtp = xferOp.getVectorType();
1259     Value dataPtr = this->getStridedElementPtr(
1260         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
1261     Value vectorDataPtr =
1262         castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
1263 
1264     if (!xferOp.isMaskedDim(0))
1265       return replaceTransferOpWithLoadOrStore(rewriter,
1266                                               *this->getTypeConverter(), loc,
1267                                               xferOp, operands, vectorDataPtr);
1268 
1269     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
1270     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
1271     // 4. Let dim the memref dimension, compute the vector comparison mask:
1272     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1273     //
1274     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1275     //       dimensions here.
1276     unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
1277     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
1278     Value off = xferOp.indices()[lastIndex];
1279     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
1280     Value mask = buildVectorComparison(
1281         rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off);
1282 
1283     // 5. Rewrite as a masked read / write.
1284     return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc,
1285                                        xferOp, operands, vectorDataPtr, mask);
1286   }
1287 
1288 private:
1289   const bool enableIndexOptimizations;
1290 };
1291 
1292 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1293 public:
1294   using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
1295 
1296   // Proof-of-concept lowering implementation that relies on a small
1297   // runtime support library, which only needs to provide a few
1298   // printing methods (single value for all data types, opening/closing
1299   // bracket, comma, newline). The lowering fully unrolls a vector
1300   // in terms of these elementary printing operations. The advantage
1301   // of this approach is that the library can remain unaware of all
1302   // low-level implementation details of vectors while still supporting
1303   // output of any shaped and dimensioned vector. Due to full unrolling,
1304   // this approach is less suited for very large vectors though.
1305   //
1306   // TODO: rely solely on libc in future? something else?
1307   //
1308   LogicalResult
1309   matchAndRewrite(vector::PrintOp printOp, ArrayRef<Value> operands,
1310                   ConversionPatternRewriter &rewriter) const override {
1311     auto adaptor = vector::PrintOpAdaptor(operands);
1312     Type printType = printOp.getPrintType();
1313 
1314     if (typeConverter->convertType(printType) == nullptr)
1315       return failure();
1316 
1317     // Make sure element type has runtime support.
1318     PrintConversion conversion = PrintConversion::None;
1319     VectorType vectorType = printType.dyn_cast<VectorType>();
1320     Type eltType = vectorType ? vectorType.getElementType() : printType;
1321     Operation *printer;
1322     if (eltType.isF32()) {
1323       printer =
1324           LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1325     } else if (eltType.isF64()) {
1326       printer =
1327           LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1328     } else if (eltType.isIndex()) {
1329       printer =
1330           LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1331     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1332       // Integers need a zero or sign extension on the operand
1333       // (depending on the source type) as well as a signed or
1334       // unsigned print method. Up to 64-bit is supported.
1335       unsigned width = intTy.getWidth();
1336       if (intTy.isUnsigned()) {
1337         if (width <= 64) {
1338           if (width < 64)
1339             conversion = PrintConversion::ZeroExt64;
1340           printer = LLVM::lookupOrCreatePrintU64Fn(
1341               printOp->getParentOfType<ModuleOp>());
1342         } else {
1343           return failure();
1344         }
1345       } else {
1346         assert(intTy.isSignless() || intTy.isSigned());
1347         if (width <= 64) {
1348           // Note that we *always* zero extend booleans (1-bit integers),
1349           // so that true/false is printed as 1/0 rather than -1/0.
1350           if (width == 1)
1351             conversion = PrintConversion::ZeroExt64;
1352           else if (width < 64)
1353             conversion = PrintConversion::SignExt64;
1354           printer = LLVM::lookupOrCreatePrintI64Fn(
1355               printOp->getParentOfType<ModuleOp>());
1356         } else {
1357           return failure();
1358         }
1359       }
1360     } else {
1361       return failure();
1362     }
1363 
1364     // Unroll vector into elementary print calls.
1365     int64_t rank = vectorType ? vectorType.getRank() : 0;
1366     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
1367               conversion);
1368     emitCall(rewriter, printOp->getLoc(),
1369              LLVM::lookupOrCreatePrintNewlineFn(
1370                  printOp->getParentOfType<ModuleOp>()));
1371     rewriter.eraseOp(printOp);
1372     return success();
1373   }
1374 
1375 private:
1376   enum class PrintConversion {
1377     // clang-format off
1378     None,
1379     ZeroExt64,
1380     SignExt64
1381     // clang-format on
1382   };
1383 
1384   void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1385                  Value value, VectorType vectorType, Operation *printer,
1386                  int64_t rank, PrintConversion conversion) const {
1387     Location loc = op->getLoc();
1388     if (rank == 0) {
1389       switch (conversion) {
1390       case PrintConversion::ZeroExt64:
1391         value = rewriter.create<ZeroExtendIOp>(
1392             loc, value, IntegerType::get(rewriter.getContext(), 64));
1393         break;
1394       case PrintConversion::SignExt64:
1395         value = rewriter.create<SignExtendIOp>(
1396             loc, value, IntegerType::get(rewriter.getContext(), 64));
1397         break;
1398       case PrintConversion::None:
1399         break;
1400       }
1401       emitCall(rewriter, loc, printer, value);
1402       return;
1403     }
1404 
1405     emitCall(rewriter, loc,
1406              LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
1407     Operation *printComma =
1408         LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
1409     int64_t dim = vectorType.getDimSize(0);
1410     for (int64_t d = 0; d < dim; ++d) {
1411       auto reducedType =
1412           rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
1413       auto llvmType = typeConverter->convertType(
1414           rank > 1 ? reducedType : vectorType.getElementType());
1415       Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1416                                    llvmType, rank, d);
1417       emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1418                 conversion);
1419       if (d != dim - 1)
1420         emitCall(rewriter, loc, printComma);
1421     }
1422     emitCall(rewriter, loc,
1423              LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
1424   }
1425 
1426   // Helper to emit a call.
1427   static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1428                        Operation *ref, ValueRange params = ValueRange()) {
1429     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
1430                                   rewriter.getSymbolRefAttr(ref), params);
1431   }
1432 };
1433 
1434 /// Progressive lowering of ExtractStridedSliceOp to either:
1435 ///   1. express single offset extract as a direct shuffle.
1436 ///   2. extract + lower rank strided_slice + insert for the n-D case.
1437 class VectorExtractStridedSliceOpConversion
1438     : public OpRewritePattern<ExtractStridedSliceOp> {
1439 public:
1440   VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
1441       : OpRewritePattern<ExtractStridedSliceOp>(ctx) {
1442     // This pattern creates recursive ExtractStridedSliceOp, but the recursion
1443     // is bounded as the rank is strictly decreasing.
1444     setHasBoundedRewriteRecursion();
1445   }
1446 
1447   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
1448                                 PatternRewriter &rewriter) const override {
1449     auto dstType = op.getType();
1450 
1451     assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
1452 
1453     int64_t offset =
1454         op.offsets().getValue().front().cast<IntegerAttr>().getInt();
1455     int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
1456     int64_t stride =
1457         op.strides().getValue().front().cast<IntegerAttr>().getInt();
1458 
1459     auto loc = op.getLoc();
1460     auto elemType = dstType.getElementType();
1461     assert(elemType.isSignlessIntOrIndexOrFloat());
1462 
1463     // Single offset can be more efficiently shuffled.
1464     if (op.offsets().getValue().size() == 1) {
1465       SmallVector<int64_t, 4> offsets;
1466       offsets.reserve(size);
1467       for (int64_t off = offset, e = offset + size * stride; off < e;
1468            off += stride)
1469         offsets.push_back(off);
1470       rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
1471                                              op.vector(),
1472                                              rewriter.getI64ArrayAttr(offsets));
1473       return success();
1474     }
1475 
1476     // Extract/insert on a lower ranked extract strided slice op.
1477     Value zero = rewriter.create<ConstantOp>(loc, elemType,
1478                                              rewriter.getZeroAttr(elemType));
1479     Value res = rewriter.create<SplatOp>(loc, dstType, zero);
1480     for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
1481          off += stride, ++idx) {
1482       Value one = extractOne(rewriter, loc, op.vector(), off);
1483       Value extracted = rewriter.create<ExtractStridedSliceOp>(
1484           loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
1485           getI64SubArray(op.sizes(), /* dropFront=*/1),
1486           getI64SubArray(op.strides(), /* dropFront=*/1));
1487       res = insertOne(rewriter, loc, extracted, res, idx);
1488     }
1489     rewriter.replaceOp(op, res);
1490     return success();
1491   }
1492 };
1493 
1494 } // namespace
1495 
1496 /// Populate the given list with patterns that convert from Vector to LLVM.
1497 void mlir::populateVectorToLLVMConversionPatterns(
1498     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
1499     bool reassociateFPReductions, bool enableIndexOptimizations) {
1500   MLIRContext *ctx = converter.getDialect()->getContext();
1501   // clang-format off
1502   patterns.insert<VectorFMAOpNDRewritePattern,
1503                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
1504                   VectorInsertStridedSliceOpSameRankRewritePattern,
1505                   VectorExtractStridedSliceOpConversion>(ctx);
1506   patterns.insert<VectorReductionOpConversion>(
1507       converter, reassociateFPReductions);
1508   patterns.insert<VectorCreateMaskOpConversion,
1509                   VectorTransferConversion<TransferReadOp>,
1510                   VectorTransferConversion<TransferWriteOp>>(
1511       converter, enableIndexOptimizations);
1512   patterns
1513       .insert<VectorBitCastOpConversion,
1514               VectorShuffleOpConversion,
1515               VectorExtractElementOpConversion,
1516               VectorExtractOpConversion,
1517               VectorFMAOp1DConversion,
1518               VectorInsertElementOpConversion,
1519               VectorInsertOpConversion,
1520               VectorPrintOpConversion,
1521               VectorTypeCastOpConversion,
1522               VectorLoadStoreConversion<vector::LoadOp,
1523                                         vector::LoadOpAdaptor>,
1524               VectorLoadStoreConversion<vector::MaskedLoadOp,
1525                                         vector::MaskedLoadOpAdaptor>,
1526               VectorLoadStoreConversion<vector::StoreOp,
1527                                         vector::StoreOpAdaptor>,
1528               VectorLoadStoreConversion<vector::MaskedStoreOp,
1529                                         vector::MaskedStoreOpAdaptor>,
1530               VectorGatherOpConversion,
1531               VectorScatterOpConversion,
1532               VectorExpandLoadOpConversion,
1533               VectorCompressStoreOpConversion>(converter);
1534   // clang-format on
1535 }
1536 
1537 void mlir::populateVectorToLLVMMatrixConversionPatterns(
1538     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
1539   patterns.insert<VectorMatmulOpConversion>(converter);
1540   patterns.insert<VectorFlatTransposeOpConversion>(converter);
1541 }
1542