1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 20 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/VectorOps/VectorOps.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/MLIRContext.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 #include "mlir/IR/Types.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Pass/PassManager.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "mlir/Transforms/Passes.h" 35 36 #include "llvm/IR/DerivedTypes.h" 37 #include "llvm/IR/Module.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/Support/Allocator.h" 40 #include "llvm/Support/ErrorHandling.h" 41 42 using namespace mlir; 43 44 template <typename T> 45 static LLVM::LLVMType getPtrToElementType(T containerType, 46 LLVMTypeConverter &lowering) { 47 return lowering.convertType(containerType.getElementType()) 48 .template cast<LLVM::LLVMType>() 49 .getPointerTo(); 50 } 51 52 // Helper to reduce vector type by one rank at front. 53 static VectorType reducedVectorTypeFront(VectorType tp) { 54 assert((tp.getRank() > 1) && "unlowerable vector type"); 55 return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); 56 } 57 58 // Helper to reduce vector type by *all* but one rank at back. 59 static VectorType reducedVectorTypeBack(VectorType tp) { 60 assert((tp.getRank() > 1) && "unlowerable vector type"); 61 return VectorType::get(tp.getShape().take_back(), tp.getElementType()); 62 } 63 64 class VectorBroadcastOpConversion : public LLVMOpLowering { 65 public: 66 explicit VectorBroadcastOpConversion(MLIRContext *context, 67 LLVMTypeConverter &typeConverter) 68 : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, 69 typeConverter) {} 70 71 PatternMatchResult 72 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 73 ConversionPatternRewriter &rewriter) const override { 74 auto broadcastOp = cast<vector::BroadcastOp>(op); 75 VectorType dstVectorType = broadcastOp.getVectorType(); 76 if (lowering.convertType(dstVectorType) == nullptr) 77 return matchFailure(); 78 // Rewrite when the full vector type can be lowered (which 79 // implies all 'reduced' types can be lowered too). 80 VectorType srcVectorType = 81 broadcastOp.getSourceType().dyn_cast<VectorType>(); 82 rewriter.replaceOp( 83 op, expandRanks(operands[0], // source value to be expanded 84 op->getLoc(), // location of original broadcast 85 srcVectorType, dstVectorType, rewriter)); 86 return matchSuccess(); 87 } 88 89 private: 90 // Expands the given source value over all the ranks, as defined 91 // by the source and destination type (a null source type denotes 92 // expansion from a scalar value into a vector). 93 // 94 // TODO(ajcbik): consider replacing this one-pattern lowering 95 // with a two-pattern lowering using other vector 96 // ops once all insert/extract/shuffle operations 97 // are available with lowering implemention. 98 // 99 Value *expandRanks(Value *value, Location loc, VectorType srcVectorType, 100 VectorType dstVectorType, 101 ConversionPatternRewriter &rewriter) const { 102 assert((dstVectorType != nullptr) && "invalid result type in broadcast"); 103 // Determine rank of source and destination. 104 int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; 105 int64_t dstRank = dstVectorType.getRank(); 106 int64_t curDim = dstVectorType.getDimSize(0); 107 if (srcRank < dstRank) 108 // Duplicate this rank. 109 return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 110 curDim, rewriter); 111 // If all trailing dimensions are the same, the broadcast consists of 112 // simply passing through the source value and we are done. Otherwise, 113 // any non-matching dimension forces a stretch along this rank. 114 assert((srcVectorType != nullptr) && (srcRank > 0) && 115 (srcRank == dstRank) && "invalid rank in broadcast"); 116 for (int64_t r = 0; r < dstRank; r++) { 117 if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { 118 return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, 119 curDim, rewriter); 120 } 121 } 122 return value; 123 } 124 125 // Picks the best way to duplicate a single rank. For the 1-D case, a 126 // single insert-elt/shuffle is the most efficient expansion. For higher 127 // dimensions, however, we need dim x insert-values on a new broadcast 128 // with one less leading dimension, which will be lowered "recursively" 129 // to matching LLVM IR. 130 // For example: 131 // v = broadcast s : f32 to vector<4x2xf32> 132 // becomes: 133 // x = broadcast s : f32 to vector<2xf32> 134 // v = [x,x,x,x] 135 // becomes: 136 // x = [s,s] 137 // v = [x,x,x,x] 138 Value *duplicateOneRank(Value *value, Location loc, VectorType srcVectorType, 139 VectorType dstVectorType, int64_t rank, int64_t dim, 140 ConversionPatternRewriter &rewriter) const { 141 Type llvmType = lowering.convertType(dstVectorType); 142 assert((llvmType != nullptr) && "unlowerable vector type"); 143 if (rank == 1) { 144 Value *undef = rewriter.create<LLVM::UndefOp>(loc, llvmType); 145 Value *expand = insertOne(undef, value, loc, llvmType, rank, 0, rewriter); 146 SmallVector<int32_t, 4> zeroValues(dim, 0); 147 return rewriter.create<LLVM::ShuffleVectorOp>( 148 loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); 149 } 150 Value *expand = 151 expandRanks(value, loc, srcVectorType, 152 reducedVectorTypeFront(dstVectorType), rewriter); 153 Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 154 for (int64_t d = 0; d < dim; ++d) { 155 result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); 156 } 157 return result; 158 } 159 160 // Picks the best way to stretch a single rank. For the 1-D case, a 161 // single insert-elt/shuffle is the most efficient expansion when at 162 // a stretch. Otherwise, every dimension needs to be expanded 163 // individually and individually inserted in the resulting vector. 164 // For example: 165 // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> 166 // becomes: 167 // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> 168 // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> 169 // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> 170 // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> 171 // v = [a,b,c,d] 172 // becomes: 173 // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> 174 // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> 175 // a = [x, y] 176 // etc. 177 Value *stretchOneRank(Value *value, Location loc, VectorType srcVectorType, 178 VectorType dstVectorType, int64_t rank, int64_t dim, 179 ConversionPatternRewriter &rewriter) const { 180 Type llvmType = lowering.convertType(dstVectorType); 181 assert((llvmType != nullptr) && "unlowerable vector type"); 182 Value *result = rewriter.create<LLVM::UndefOp>(loc, llvmType); 183 bool atStretch = dim != srcVectorType.getDimSize(0); 184 if (rank == 1) { 185 Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); 186 if (atStretch) { 187 Value *one = extractOne(value, loc, redLlvmType, rank, 0, rewriter); 188 Value *expand = 189 insertOne(result, one, loc, llvmType, rank, 0, rewriter); 190 SmallVector<int32_t, 4> zeroValues(dim, 0); 191 return rewriter.create<LLVM::ShuffleVectorOp>( 192 loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); 193 } 194 for (int64_t d = 0; d < dim; ++d) { 195 Value *one = extractOne(value, loc, redLlvmType, rank, d, rewriter); 196 result = insertOne(result, one, loc, llvmType, rank, d, rewriter); 197 } 198 } else { 199 VectorType redSrcType = reducedVectorTypeFront(srcVectorType); 200 VectorType redDstType = reducedVectorTypeFront(dstVectorType); 201 Type redLlvmType = lowering.convertType(redSrcType); 202 for (int64_t d = 0; d < dim; ++d) { 203 int64_t pos = atStretch ? 0 : d; 204 Value *one = extractOne(value, loc, redLlvmType, rank, pos, rewriter); 205 Value *expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); 206 result = insertOne(result, expand, loc, llvmType, rank, d, rewriter); 207 } 208 } 209 return result; 210 } 211 212 // Picks the proper sequence for inserting. 213 Value *insertOne(Value *val1, Value *val2, Location loc, Type llvmType, 214 int64_t rank, int64_t pos, 215 ConversionPatternRewriter &rewriter) const { 216 if (rank == 1) { 217 auto idxType = rewriter.getIndexType(); 218 auto constant = rewriter.create<LLVM::ConstantOp>( 219 loc, lowering.convertType(idxType), 220 rewriter.getIntegerAttr(idxType, pos)); 221 return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2, 222 constant); 223 } 224 return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2, 225 rewriter.getI64ArrayAttr(pos)); 226 } 227 228 // Picks the proper sequence for extracting. 229 Value *extractOne(Value *value, Location loc, Type llvmType, int64_t rank, 230 int64_t pos, ConversionPatternRewriter &rewriter) const { 231 if (rank == 1) { 232 auto idxType = rewriter.getIndexType(); 233 auto constant = rewriter.create<LLVM::ConstantOp>( 234 loc, lowering.convertType(idxType), 235 rewriter.getIntegerAttr(idxType, pos)); 236 return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, value, 237 constant); 238 } 239 return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, value, 240 rewriter.getI64ArrayAttr(pos)); 241 } 242 }; 243 244 class VectorExtractOpConversion : public LLVMOpLowering { 245 public: 246 explicit VectorExtractOpConversion(MLIRContext *context, 247 LLVMTypeConverter &typeConverter) 248 : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, 249 typeConverter) {} 250 251 PatternMatchResult 252 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 253 ConversionPatternRewriter &rewriter) const override { 254 auto loc = op->getLoc(); 255 auto adaptor = vector::ExtractOpOperandAdaptor(operands); 256 auto extractOp = cast<vector::ExtractOp>(op); 257 auto vectorType = extractOp.getVectorType(); 258 auto resultType = extractOp.getResult()->getType(); 259 auto llvmResultType = lowering.convertType(resultType); 260 auto positionArrayAttr = extractOp.position(); 261 262 // Bail if result type cannot be lowered. 263 if (!llvmResultType) 264 return matchFailure(); 265 266 // One-shot extraction of vector from array (only requires extractvalue). 267 if (resultType.isa<VectorType>()) { 268 Value *extracted = rewriter.create<LLVM::ExtractValueOp>( 269 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 270 rewriter.replaceOp(op, extracted); 271 return matchSuccess(); 272 } 273 274 // Potential extraction of 1-D vector from array. 275 auto *context = op->getContext(); 276 Value *extracted = adaptor.vector(); 277 auto positionAttrs = positionArrayAttr.getValue(); 278 if (positionAttrs.size() > 1) { 279 auto oneDVectorType = reducedVectorTypeBack(vectorType); 280 auto nMinusOnePositionAttrs = 281 ArrayAttr::get(positionAttrs.drop_back(), context); 282 extracted = rewriter.create<LLVM::ExtractValueOp>( 283 loc, lowering.convertType(oneDVectorType), extracted, 284 nMinusOnePositionAttrs); 285 } 286 287 // Remaining extraction of element from 1-D LLVM vector 288 auto position = positionAttrs.back().cast<IntegerAttr>(); 289 auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); 290 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position); 291 extracted = 292 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 293 rewriter.replaceOp(op, extracted); 294 295 return matchSuccess(); 296 } 297 }; 298 299 class VectorInsertOpConversion : public LLVMOpLowering { 300 public: 301 explicit VectorInsertOpConversion(MLIRContext *context, 302 LLVMTypeConverter &typeConverter) 303 : LLVMOpLowering(vector::InsertOp::getOperationName(), context, 304 typeConverter) {} 305 306 PatternMatchResult 307 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 308 ConversionPatternRewriter &rewriter) const override { 309 auto loc = op->getLoc(); 310 auto adaptor = vector::InsertOpOperandAdaptor(operands); 311 auto insertOp = cast<vector::InsertOp>(op); 312 auto sourceType = insertOp.getSourceType(); 313 auto destVectorType = insertOp.getDestVectorType(); 314 auto llvmResultType = lowering.convertType(destVectorType); 315 auto positionArrayAttr = insertOp.position(); 316 317 // Bail if result type cannot be lowered. 318 if (!llvmResultType) 319 return matchFailure(); 320 321 // One-shot insertion of a vector into an array (only requires insertvalue). 322 if (sourceType.isa<VectorType>()) { 323 Value *inserted = rewriter.create<LLVM::InsertValueOp>( 324 loc, llvmResultType, adaptor.dest(), adaptor.source(), 325 positionArrayAttr); 326 rewriter.replaceOp(op, inserted); 327 return matchSuccess(); 328 } 329 330 // Potential extraction of 1-D vector from array. 331 auto *context = op->getContext(); 332 Value *extracted = adaptor.dest(); 333 auto positionAttrs = positionArrayAttr.getValue(); 334 auto position = positionAttrs.back().cast<IntegerAttr>(); 335 auto oneDVectorType = destVectorType; 336 if (positionAttrs.size() > 1) { 337 oneDVectorType = reducedVectorTypeBack(destVectorType); 338 auto nMinusOnePositionAttrs = 339 ArrayAttr::get(positionAttrs.drop_back(), context); 340 extracted = rewriter.create<LLVM::ExtractValueOp>( 341 loc, lowering.convertType(oneDVectorType), extracted, 342 nMinusOnePositionAttrs); 343 } 344 345 // Insertion of an element into a 1-D LLVM vector. 346 auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); 347 auto constant = rewriter.create<LLVM::ConstantOp>(loc, i32Type, position); 348 Value *inserted = rewriter.create<LLVM::InsertElementOp>( 349 loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), 350 constant); 351 352 // Potential insertion of resulting 1-D vector into array. 353 if (positionAttrs.size() > 1) { 354 auto nMinusOnePositionAttrs = 355 ArrayAttr::get(positionAttrs.drop_back(), context); 356 inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType, 357 adaptor.dest(), inserted, 358 nMinusOnePositionAttrs); 359 } 360 361 rewriter.replaceOp(op, inserted); 362 return matchSuccess(); 363 } 364 }; 365 366 class VectorOuterProductOpConversion : public LLVMOpLowering { 367 public: 368 explicit VectorOuterProductOpConversion(MLIRContext *context, 369 LLVMTypeConverter &typeConverter) 370 : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, 371 typeConverter) {} 372 373 PatternMatchResult 374 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 375 ConversionPatternRewriter &rewriter) const override { 376 auto loc = op->getLoc(); 377 auto adaptor = vector::OuterProductOpOperandAdaptor(operands); 378 auto *ctx = op->getContext(); 379 auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); 380 auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); 381 auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); 382 auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); 383 auto llvmArrayOfVectType = lowering.convertType( 384 cast<vector::OuterProductOp>(op).getResult()->getType()); 385 Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); 386 Value *a = adaptor.lhs(), *b = adaptor.rhs(); 387 Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); 388 SmallVector<Value *, 8> lhs, accs; 389 lhs.reserve(rankLHS); 390 accs.reserve(rankLHS); 391 for (unsigned d = 0, e = rankLHS; d < e; ++d) { 392 // shufflevector explicitly requires i32. 393 auto attr = rewriter.getI32IntegerAttr(d); 394 SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); 395 auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); 396 Value *aD = nullptr, *accD = nullptr; 397 // 1. Broadcast the element a[d] into vector aD. 398 aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); 399 // 2. If acc is present, extract 1-d vector acc[d] into accD. 400 if (acc) 401 accD = rewriter.create<LLVM::ExtractValueOp>( 402 loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); 403 // 3. Compute aD outer b (plus accD, if relevant). 404 Value *aOuterbD = 405 accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD) 406 .getResult() 407 : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); 408 // 4. Insert as value `d` in the descriptor. 409 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType, 410 desc, aOuterbD, 411 rewriter.getI64ArrayAttr(d)); 412 } 413 rewriter.replaceOp(op, desc); 414 return matchSuccess(); 415 } 416 }; 417 418 class VectorTypeCastOpConversion : public LLVMOpLowering { 419 public: 420 explicit VectorTypeCastOpConversion(MLIRContext *context, 421 LLVMTypeConverter &typeConverter) 422 : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, 423 typeConverter) {} 424 425 PatternMatchResult 426 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 427 ConversionPatternRewriter &rewriter) const override { 428 auto loc = op->getLoc(); 429 vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op); 430 MemRefType sourceMemRefType = 431 castOp.getOperand()->getType().cast<MemRefType>(); 432 MemRefType targetMemRefType = 433 castOp.getResult()->getType().cast<MemRefType>(); 434 435 // Only static shape casts supported atm. 436 if (!sourceMemRefType.hasStaticShape() || 437 !targetMemRefType.hasStaticShape()) 438 return matchFailure(); 439 440 auto llvmSourceDescriptorTy = 441 operands[0]->getType().dyn_cast<LLVM::LLVMType>(); 442 if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) 443 return matchFailure(); 444 MemRefDescriptor sourceMemRef(operands[0]); 445 446 auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) 447 .dyn_cast_or_null<LLVM::LLVMType>(); 448 if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) 449 return matchFailure(); 450 451 int64_t offset; 452 SmallVector<int64_t, 4> strides; 453 auto successStrides = 454 getStridesAndOffset(sourceMemRefType, strides, offset); 455 bool isContiguous = (strides.back() == 1); 456 if (isContiguous) { 457 auto sizes = sourceMemRefType.getShape(); 458 for (int index = 0, e = strides.size() - 2; index < e; ++index) { 459 if (strides[index] != strides[index + 1] * sizes[index + 1]) { 460 isContiguous = false; 461 break; 462 } 463 } 464 } 465 // Only contiguous source tensors supported atm. 466 if (failed(successStrides) || !isContiguous) 467 return matchFailure(); 468 469 auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); 470 471 // Create descriptor. 472 auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); 473 Type llvmTargetElementTy = desc.getElementType(); 474 // Set allocated ptr. 475 Value *allocated = sourceMemRef.allocatedPtr(rewriter, loc); 476 allocated = 477 rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); 478 desc.setAllocatedPtr(rewriter, loc, allocated); 479 // Set aligned ptr. 480 Value *ptr = sourceMemRef.alignedPtr(rewriter, loc); 481 ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); 482 desc.setAlignedPtr(rewriter, loc, ptr); 483 // Fill offset 0. 484 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); 485 auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr); 486 desc.setOffset(rewriter, loc, zero); 487 488 // Fill size and stride descriptors in memref. 489 for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { 490 int64_t index = indexedSize.index(); 491 auto sizeAttr = 492 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); 493 auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr); 494 desc.setSize(rewriter, loc, index, size); 495 auto strideAttr = 496 rewriter.getIntegerAttr(rewriter.getIndexType(), strides[index]); 497 auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr); 498 desc.setStride(rewriter, loc, index, stride); 499 } 500 501 rewriter.replaceOp(op, {desc}); 502 return matchSuccess(); 503 } 504 }; 505 506 /// Populate the given list with patterns that convert from Vector to LLVM. 507 void mlir::populateVectorToLLVMConversionPatterns( 508 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 509 patterns.insert<VectorBroadcastOpConversion, VectorExtractOpConversion, 510 VectorInsertOpConversion, VectorOuterProductOpConversion, 511 VectorTypeCastOpConversion>( 512 converter.getDialect()->getContext(), converter); 513 } 514 515 namespace { 516 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 517 void runOnModule() override; 518 }; 519 } // namespace 520 521 void LowerVectorToLLVMPass::runOnModule() { 522 // Convert to the LLVM IR dialect using the converter defined above. 523 OwningRewritePatternList patterns; 524 LLVMTypeConverter converter(&getContext()); 525 populateVectorToLLVMConversionPatterns(converter, patterns); 526 populateStdToLLVMConversionPatterns(converter, patterns); 527 528 ConversionTarget target(getContext()); 529 target.addLegalDialect<LLVM::LLVMDialect>(); 530 target.addDynamicallyLegalOp<FuncOp>( 531 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 532 if (failed( 533 applyPartialConversion(getModule(), target, patterns, &converter))) { 534 signalPassFailure(); 535 } 536 } 537 538 OpPassBase<ModuleOp> *mlir::createLowerVectorToLLVMPass() { 539 return new LowerVectorToLLVMPass(); 540 } 541 542 static PassRegistration<LowerVectorToLLVMPass> 543 pass("convert-vector-to-llvm", 544 "Lower the operations from the vector dialect into the LLVM dialect"); 545