1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Dialect/AMX/AMXDialect.h" 16 #include "mlir/Dialect/AMX/Transforms.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 19 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 20 #include "mlir/Dialect/ArmSVE/Transforms.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 25 #include "mlir/Dialect/X86Vector/Transforms.h" 26 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 using namespace mlir; 30 using namespace mlir::vector; 31 32 namespace { 33 struct LowerVectorToLLVMPass 34 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 35 LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 36 this->reassociateFPReductions = options.reassociateFPReductions; 37 this->indexOptimizations = options.indexOptimizations; 38 this->armNeon = options.armNeon; 39 this->armSVE = options.armSVE; 40 this->amx = options.amx; 41 this->x86Vector = options.x86Vector; 42 } 43 // Override explicitly to allow conditional dialect dependence. 44 void getDependentDialects(DialectRegistry ®istry) const override { 45 registry.insert<LLVM::LLVMDialect>(); 46 registry.insert<arith::ArithmeticDialect>(); 47 registry.insert<memref::MemRefDialect>(); 48 if (armNeon) 49 registry.insert<arm_neon::ArmNeonDialect>(); 50 if (armSVE) 51 registry.insert<arm_sve::ArmSVEDialect>(); 52 if (amx) 53 registry.insert<amx::AMXDialect>(); 54 if (x86Vector) 55 registry.insert<x86vector::X86VectorDialect>(); 56 } 57 void runOnOperation() override; 58 }; 59 } // namespace 60 61 void LowerVectorToLLVMPass::runOnOperation() { 62 // Perform progressive lowering of operations on slices and 63 // all contraction operations. Also applies folding and DCE. 64 { 65 RewritePatternSet patterns(&getContext()); 66 populateVectorToVectorCanonicalizationPatterns(patterns); 67 populateVectorBroadcastLoweringPatterns(patterns); 68 populateVectorContractLoweringPatterns(patterns); 69 populateVectorMaskOpLoweringPatterns(patterns); 70 populateVectorShapeCastLoweringPatterns(patterns); 71 populateVectorTransposeLoweringPatterns(patterns); 72 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. 73 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 74 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 75 } 76 77 // Convert to the LLVM IR dialect. 78 LLVMTypeConverter converter(&getContext()); 79 RewritePatternSet patterns(&getContext()); 80 populateVectorMaskMaterializationPatterns(patterns, indexOptimizations); 81 populateVectorTransferLoweringPatterns(patterns); 82 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 83 populateVectorToLLVMConversionPatterns(converter, patterns, 84 reassociateFPReductions); 85 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 86 87 // Architecture specific augmentations. 88 LLVMConversionTarget target(getContext()); 89 target.addLegalDialect<arith::ArithmeticDialect>(); 90 target.addLegalDialect<memref::MemRefDialect>(); 91 target.addLegalDialect<StandardOpsDialect>(); 92 target.addLegalOp<UnrealizedConversionCastOp>(); 93 if (armNeon) { 94 // TODO: we may or may not want to include in-dialect lowering to 95 // LLVM-compatible operations here. So far, all operations in the dialect 96 // can be translated to LLVM IR so there is no conversion necessary. 97 target.addLegalDialect<arm_neon::ArmNeonDialect>(); 98 } 99 if (armSVE) { 100 configureArmSVELegalizeForExportTarget(target); 101 populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); 102 } 103 if (amx) { 104 configureAMXLegalizeForExportTarget(target); 105 populateAMXLegalizeForLLVMExportPatterns(converter, patterns); 106 } 107 if (x86Vector) { 108 configureX86VectorLegalizeForExportTarget(target); 109 populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); 110 } 111 112 if (failed( 113 applyPartialConversion(getOperation(), target, std::move(patterns)))) 114 signalPassFailure(); 115 } 116 117 std::unique_ptr<OperationPass<ModuleOp>> 118 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 119 return std::make_unique<LowerVectorToLLVMPass>(options); 120 } 121