xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (revision 9826fe5c9fb65da8f1d53b21348f013c58c09791)
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
20 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/VectorOps/VectorOps.h"
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/Module.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/StandardTypes.h"
30 #include "mlir/IR/Types.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Pass/PassManager.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "mlir/Transforms/Passes.h"
35 
36 #include "llvm/IR/DerivedTypes.h"
37 #include "llvm/IR/Module.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/Support/Allocator.h"
40 #include "llvm/Support/ErrorHandling.h"
41 
42 using namespace mlir;
43 
44 template <typename T>
45 static LLVM::LLVMType getPtrToElementType(T containerType,
46                                           LLVMTypeConverter &lowering) {
47   return lowering.convertType(containerType.getElementType())
48       .template cast<LLVM::LLVMType>()
49       .getPointerTo();
50 }
51 
52 // Helper to reduce vector type by one rank at front.
53 static VectorType reducedVectorTypeFront(VectorType tp) {
54   assert((tp.getRank() > 1) && "unlowerable vector type");
55   return VectorType::get(tp.getShape().drop_front(), tp.getElementType());
56 }
57 
58 // Helper to reduce vector type by *all* but one rank at back.
59 static VectorType reducedVectorTypeBack(VectorType tp) {
60   assert((tp.getRank() > 1) && "unlowerable vector type");
61   return VectorType::get(tp.getShape().take_back(), tp.getElementType());
62 }
63 
64 class VectorBroadcastOpConversion : public LLVMOpLowering {
65 public:
66   explicit VectorBroadcastOpConversion(MLIRContext *context,
67                                        LLVMTypeConverter &typeConverter)
68       : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context,
69                        typeConverter) {}
70 
71   PatternMatchResult
72   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
73                   ConversionPatternRewriter &rewriter) const override {
74     auto broadcastOp = cast<vector::BroadcastOp>(op);
75     VectorType dstVectorType = broadcastOp.getVectorType();
76     if (lowering.convertType(dstVectorType) == nullptr)
77       return matchFailure();
78     // Rewrite when the full vector type can be lowered (which
79     // implies all 'reduced' types can be lowered too).
80     VectorType srcVectorType =
81         broadcastOp.getSourceType().dyn_cast<VectorType>();
82     rewriter.replaceOp(
83         op, expandRanks(operands[0],  // source value to be expanded
84                         op->getLoc(), // location of original broadcast
85                         srcVectorType, dstVectorType, rewriter));
86     return matchSuccess();
87   }
88 
89 private:
90   // Expands the given source value over all the ranks, as defined
91   // by the source and destination type (a null source type denotes
92   // expansion from a scalar value into a vector).
93   //
94   // TODO(ajcbik): consider replacing this one-pattern lowering
95   //               with a two-pattern lowering using other vector
96   //               ops once all insert/extract/shuffle operations
97   //               are available with lowering implemention.
98   //
99   Value *expandRanks(Value *value, Location loc, VectorType srcVectorType,
100                      VectorType dstVectorType,
101                      ConversionPatternRewriter &rewriter) const {
102     assert((dstVectorType != nullptr) && "invalid result type in broadcast");
103     // Determine rank of source and destination.
104     int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
105     int64_t dstRank = dstVectorType.getRank();
106     int64_t curDim = dstVectorType.getDimSize(0);
107     if (srcRank < dstRank)
108       // Duplicate this rank.
109       return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
110                               curDim, rewriter);
111     // If all trailing dimensions are the same, the broadcast consists of
112     // simply passing through the source value and we are done. Otherwise,
113     // any non-matching dimension forces a stretch along this rank.
114     assert((srcVectorType != nullptr) && (srcRank > 0) &&
115            (srcRank == dstRank) && "invalid rank in broadcast");
116     for (int64_t r = 0; r < dstRank; r++) {
117       if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
118         return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
119                               curDim, rewriter);
120       }
121     }
122     return value;
123   }
124 
125   // Picks the best way to duplicate a single rank. For the 1-D case, a
126   // single insert-elt/shuffle is the most efficient expansion. For higher
127   // dimensions, however, we need dim x insert-values on a new broadcast
128   // with one less leading dimension, which will be lowered "recursively"
129   // to matching LLVM IR.
130   // For example:
131   //   v = broadcast s : f32 to vector<4x2xf32>
132   // becomes:
133   //   x = broadcast s : f32 to vector<2xf32>
134   //   v = [x,x,x,x]
135   // becomes:
136   //   x = [s,s]
137   //   v = [x,x,x,x]
138   Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType,
139                           VectorType dstVectorType, int64_t rank, int64_t dim,
140                           ConversionPatternRewriter &rewriter) const {
141     Type llvmType = lowering.convertType(dstVectorType);
142     assert((llvmType != nullptr) && "unlowerable vector type");
143     if (rank == 1) {
144       Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
145       Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter);
146       SmallVector<int32_t, 4> zeroValues(dim, 0);
147       return rewriter.create<LLVM::ShuffleVectorOp>(
148           loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
149     }
150     Value *expand =
151         expandRanks(value, loc, srcVectorType,
152                     reducedVectorTypeFront(dstVectorType), rewriter);
153     Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
154     for (int64_t d = 0; d < dim; ++d) {
155       result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
156     }
157     return result;
158   }
159 
160   // Picks the best way to stretch a single rank. For the 1-D case, a
161   // single insert-elt/shuffle is the most efficient expansion when at
162   // a stretch. Otherwise, every dimension needs to be expanded
163   // individually and individually inserted in the resulting vector.
164   // For example:
165   //   v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
166   // becomes:
167   //   a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
168   //   b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
169   //   c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
170   //   d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
171   //   v = [a,b,c,d]
172   // becomes:
173   //   x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
174   //   y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
175   //   a = [x, y]
176   //   etc.
177   Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType,
178                         VectorType dstVectorType, int64_t rank, int64_t dim,
179                         ConversionPatternRewriter &rewriter) const {
180     Type llvmType = lowering.convertType(dstVectorType);
181     assert((llvmType != nullptr) && "unlowerable vector type");
182     Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
183     bool atStretch = dim != srcVectorType.getDimSize(0);
184     if (rank == 1) {
185       Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
186       if (atStretch) {
187         Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter);
188         Value *expand =
189             insertOne(result, one, loc, llvmType, rank, 0, rewriter);
190         SmallVector<int32_t, 4> zeroValues(dim, 0);
191         return rewriter.create<LLVM::ShuffleVectorOp>(
192             loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
193       }
194       for (int64_t d = 0; d < dim; ++d) {
195         Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter);
196         result = insertOne(result, one, loc, llvmType, rank, d, rewriter);
197       }
198     } else {
199       VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
200       VectorType redDstType = reducedVectorTypeFront(dstVectorType);
201       Type redLlvmType = lowering.convertType(redSrcType);
202       for (int64_t d = 0; d < dim; ++d) {
203         int64_t pos = atStretch ? 0 : d;
204         Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter);
205         Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
206         result = insertOne(result, expand, loc, llvmType, rank, d, rewriter);
207       }
208     }
209     return result;
210   }
211 
212   // Picks the proper sequence for inserting.
213   Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType,
214                    int64_t rank, int64_t pos,
215                    ConversionPatternRewriter &rewriter) const {
216     if (rank == 1) {
217       auto idxType = rewriter.getIndexType();
218       auto constant = rewriter.create<LLVM::ConstantOp>(
219           loc, lowering.convertType(idxType),
220           rewriter.getIntegerAttr(idxType, pos));
221       return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
222                                                     constant);
223     }
224     return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
225                                                 rewriter.getI64ArrayAttr(pos));
226   }
227 
228   // Picks the proper sequence for extracting.
229   Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank,
230                     int64_t pos, ConversionPatternRewriter &rewriter) const {
231     if (rank == 1) {
232       auto idxType = rewriter.getIndexType();
233       auto constant = rewriter.create<LLVM::ConstantOp>(
234           loc, lowering.convertType(idxType),
235           rewriter.getIntegerAttr(idxType, pos));
236       return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, value,
237                                                      constant);
238     }
239     return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value,
240                                                  rewriter.getI64ArrayAttr(pos));
241   }
242 };
243 
244 class VectorExtractOpConversion : public LLVMOpLowering {
245 public:
246   explicit VectorExtractOpConversion(MLIRContext *context,
247                                      LLVMTypeConverter &typeConverter)
248       : LLVMOpLowering(vector::ExtractOp::getOperationName(), context,
249                        typeConverter) {}
250 
251   PatternMatchResult
252   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
253                   ConversionPatternRewriter &rewriter) const override {
254     auto loc = op->getLoc();
255     auto adaptor = vector::ExtractOpOperandAdaptor(operands);
256     auto extractOp = cast<vector::ExtractOp>(op);
257     auto vectorType = extractOp.getVectorType();
258     auto resultType = extractOp.getResult()->getType();
259     auto llvmResultType = lowering.convertType(resultType);
260     auto positionArrayAttr = extractOp.position();
261 
262     // Bail if result type cannot be lowered.
263     if (!llvmResultType)
264       return matchFailure();
265 
266     // One-shot extraction of vector from array (only requires extractvalue).
267     if (resultType.isa<VectorType>()) {
268       Value *extracted = rewriter.create<LLVM::ExtractValueOp>(
269           loc, llvmResultType, adaptor.vector(), positionArrayAttr);
270       rewriter.replaceOp(op, extracted);
271       return matchSuccess();
272     }
273 
274     // Potential extraction of 1-D vector from array.
275     auto *context = op->getContext();
276     Value *extracted = adaptor.vector();
277     auto positionAttrs = positionArrayAttr.getValue();
278     if (positionAttrs.size() > 1) {
279       auto oneDVectorType = reducedVectorTypeBack(vectorType);
280       auto nMinusOnePositionAttrs =
281           ArrayAttr::get(positionAttrs.drop_back(), context);
282       extracted = rewriter.create<LLVM::ExtractValueOp>(
283           loc, lowering.convertType(oneDVectorType), extracted,
284           nMinusOnePositionAttrs);
285     }
286 
287     // Remaining extraction of element from 1-D LLVM vector
288     auto position = positionAttrs.back().cast<IntegerAttr>();
289     auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
290     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
291     extracted =
292         rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
293     rewriter.replaceOp(op, extracted);
294 
295     return matchSuccess();
296   }
297 };
298 
299 class VectorInsertOpConversion : public LLVMOpLowering {
300 public:
301   explicit VectorInsertOpConversion(MLIRContext *context,
302                                     LLVMTypeConverter &typeConverter)
303       : LLVMOpLowering(vector::InsertOp::getOperationName(), context,
304                        typeConverter) {}
305 
306   PatternMatchResult
307   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto loc = op->getLoc();
310     auto adaptor = vector::InsertOpOperandAdaptor(operands);
311     auto insertOp = cast<vector::InsertOp>(op);
312     auto sourceType = insertOp.getSourceType();
313     auto destVectorType = insertOp.getDestVectorType();
314     auto llvmResultType = lowering.convertType(destVectorType);
315     auto positionArrayAttr = insertOp.position();
316 
317     // Bail if result type cannot be lowered.
318     if (!llvmResultType)
319       return matchFailure();
320 
321     // One-shot insertion of a vector into an array (only requires insertvalue).
322     if (sourceType.isa<VectorType>()) {
323       Value *inserted = rewriter.create<LLVM::InsertValueOp>(
324           loc, llvmResultType, adaptor.dest(), adaptor.source(),
325           positionArrayAttr);
326       rewriter.replaceOp(op, inserted);
327       return matchSuccess();
328     }
329 
330     // Potential extraction of 1-D vector from array.
331     auto *context = op->getContext();
332     Value *extracted = adaptor.dest();
333     auto positionAttrs = positionArrayAttr.getValue();
334     auto position = positionAttrs.back().cast<IntegerAttr>();
335     auto oneDVectorType = destVectorType;
336     if (positionAttrs.size() > 1) {
337       oneDVectorType = reducedVectorTypeBack(destVectorType);
338       auto nMinusOnePositionAttrs =
339           ArrayAttr::get(positionAttrs.drop_back(), context);
340       extracted = rewriter.create<LLVM::ExtractValueOp>(
341           loc, lowering.convertType(oneDVectorType), extracted,
342           nMinusOnePositionAttrs);
343     }
344 
345     // Insertion of an element into a 1-D LLVM vector.
346     auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());
347     auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position);
348     Value *inserted = rewriter.create<LLVM::InsertElementOp>(
349         loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
350         constant);
351 
352     // Potential insertion of resulting 1-D vector into array.
353     if (positionAttrs.size() > 1) {
354       auto nMinusOnePositionAttrs =
355           ArrayAttr::get(positionAttrs.drop_back(), context);
356       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
357                                                       adaptor.dest(), inserted,
358                                                       nMinusOnePositionAttrs);
359     }
360 
361     rewriter.replaceOp(op, inserted);
362     return matchSuccess();
363   }
364 };
365 
366 class VectorOuterProductOpConversion : public LLVMOpLowering {
367 public:
368   explicit VectorOuterProductOpConversion(MLIRContext *context,
369                                           LLVMTypeConverter &typeConverter)
370       : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context,
371                        typeConverter) {}
372 
373   PatternMatchResult
374   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
375                   ConversionPatternRewriter &rewriter) const override {
376     auto loc = op->getLoc();
377     auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
378     auto *ctx = op->getContext();
379     auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
380     auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
381     auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
382     auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
383     auto llvmArrayOfVectType = lowering.convertType(
384         cast<vector::OuterProductOp>(op).getResult()->getType());
385     Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
386     Value *a = adaptor.lhs(), *b = adaptor.rhs();
387     Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
388     SmallVector<Value *, 8> lhs, accs;
389     lhs.reserve(rankLHS);
390     accs.reserve(rankLHS);
391     for (unsigned d = 0, e = rankLHS; d < e; ++d) {
392       // shufflevector explicitly requires i32.
393       auto attr = rewriter.getI32IntegerAttr(d);
394       SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
395       auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
396       Value *aD = nullptr, *accD = nullptr;
397       // 1. Broadcast the element a[d] into vector aD.
398       aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
399       // 2. If acc is present, extract 1-d vector acc[d] into accD.
400       if (acc)
401         accD = rewriter.create<LLVM::ExtractValueOp>(
402             loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
403       // 3. Compute aD outer b (plus accD, if relevant).
404       Value *aOuterbD =
405           accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
406                      .getResult()
407                : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
408       // 4. Insert as value `d` in the descriptor.
409       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
410                                                   desc, aOuterbD,
411                                                   rewriter.getI64ArrayAttr(d));
412     }
413     rewriter.replaceOp(op, desc);
414     return matchSuccess();
415   }
416 };
417 
418 class VectorTypeCastOpConversion : public LLVMOpLowering {
419 public:
420   explicit VectorTypeCastOpConversion(MLIRContext *context,
421                                       LLVMTypeConverter &typeConverter)
422       : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context,
423                        typeConverter) {}
424 
425   PatternMatchResult
426   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
427                   ConversionPatternRewriter &rewriter) const override {
428     auto loc = op->getLoc();
429     vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
430     MemRefType sourceMemRefType =
431         castOp.getOperand()->getType().cast<MemRefType>();
432     MemRefType targetMemRefType =
433         castOp.getResult()->getType().cast<MemRefType>();
434 
435     // Only static shape casts supported atm.
436     if (!sourceMemRefType.hasStaticShape() ||
437         !targetMemRefType.hasStaticShape())
438       return matchFailure();
439 
440     auto llvmSourceDescriptorTy =
441         operands[0]->getType().dyn_cast<LLVM::LLVMType>();
442     if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
443       return matchFailure();
444     MemRefDescriptor sourceMemRef(operands[0]);
445 
446     auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
447                                       .dyn_cast_or_null<LLVM::LLVMType>();
448     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
449       return matchFailure();
450 
451     int64_t offset;
452     SmallVector<int64_t, 4> strides;
453     auto successStrides =
454         getStridesAndOffset(sourceMemRefType, strides, offset);
455     bool isContiguous = (strides.back() == 1);
456     if (isContiguous) {
457       auto sizes = sourceMemRefType.getShape();
458       for (int index = 0, e = strides.size() - 2; index < e; ++index) {
459         if (strides[index] != strides[index + 1] * sizes[index + 1]) {
460           isContiguous = false;
461           break;
462         }
463       }
464     }
465     // Only contiguous source tensors supported atm.
466     if (failed(successStrides) || !isContiguous)
467       return matchFailure();
468 
469     auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
470 
471     // Create descriptor.
472     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
473     Type llvmTargetElementTy = desc.getElementType();
474     // Set allocated ptr.
475     Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc);
476     allocated =
477         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
478     desc.setAllocatedPtr(rewriter, loc, allocated);
479     // Set aligned ptr.
480     Value *ptr = sourceMemRef.alignedPtr(rewriter, loc);
481     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
482     desc.setAlignedPtr(rewriter, loc, ptr);
483     // Fill offset 0.
484     auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
485     auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
486     desc.setOffset(rewriter, loc, zero);
487 
488     // Fill size and stride descriptors in memref.
489     for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) {
490       int64_t index = indexedSize.index();
491       auto sizeAttr =
492           rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
493       auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
494       desc.setSize(rewriter, loc, index, size);
495       auto strideAttr =
496           rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]);
497       auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
498       desc.setStride(rewriter, loc, index, stride);
499     }
500 
501     rewriter.replaceOp(op, {desc});
502     return matchSuccess();
503   }
504 };
505 
506 /// Populate the given list with patterns that convert from Vector to LLVM.
507 void mlir::populateVectorToLLVMConversionPatterns(
508     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
509   patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion,
510                   VectorInsertOpConversion, VectorOuterProductOpConversion,
511                   VectorTypeCastOpConversion>(
512       converter.getDialect()->getContext(), converter);
513 }
514 
515 namespace {
516 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
517   void runOnModule() override;
518 };
519 } // namespace
520 
521 void LowerVectorToLLVMPass::runOnModule() {
522   // Convert to the LLVM IR dialect using the converter defined above.
523   OwningRewritePatternList patterns;
524   LLVMTypeConverter converter(&getContext());
525   populateVectorToLLVMConversionPatterns(converter, patterns);
526   populateStdToLLVMConversionPatterns(converter, patterns);
527 
528   ConversionTarget target(getContext());
529   target.addLegalDialect<LLVM::LLVMDialect>();
530   target.addDynamicallyLegalOp<FuncOp>(
531       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
532   if (failed(
533           applyPartialConversion(getModule(), target, patterns, &converter))) {
534     signalPassFailure();
535   }
536 }
537 
538 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() {
539   return new LowerVectorToLLVMPass();
540 }
541 
542 static PassRegistration<LowerVectorToLLVMPass>
543     pass("convert-vector-to-llvm",
544          "Lower the operations from the vector dialect into the LLVM dialect");
545