1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10 11 #include "../PassDetail.h" 12 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 15 #include "mlir/Dialect/AMX/AMXDialect.h" 16 #include "mlir/Dialect/AMX/Transforms.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 19 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 20 #include "mlir/Dialect/ArmSVE/Transforms.h" 21 #include "mlir/Dialect/Func/IR/FuncOps.h" 22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 25 #include "mlir/Dialect/X86Vector/Transforms.h" 26 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 using namespace mlir; 30 using namespace mlir::vector; 31 32 namespace { 33 struct LowerVectorToLLVMPass 34 : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 35 LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 36 this->reassociateFPReductions = options.reassociateFPReductions; 37 this->force32BitVectorIndices = options.force32BitVectorIndices; 38 this->armNeon = options.armNeon; 39 this->armSVE = options.armSVE; 40 this->amx = options.amx; 41 this->x86Vector = options.x86Vector; 42 } 43 // Override explicitly to allow conditional dialect dependence. 44 void getDependentDialects(DialectRegistry ®istry) const override { 45 registry.insert<LLVM::LLVMDialect>(); 46 registry.insert<arith::ArithmeticDialect>(); 47 registry.insert<memref::MemRefDialect>(); 48 if (armNeon) 49 registry.insert<arm_neon::ArmNeonDialect>(); 50 if (armSVE) 51 registry.insert<arm_sve::ArmSVEDialect>(); 52 if (amx) 53 registry.insert<amx::AMXDialect>(); 54 if (x86Vector) 55 registry.insert<x86vector::X86VectorDialect>(); 56 } 57 void runOnOperation() override; 58 }; 59 } // namespace 60 61 void LowerVectorToLLVMPass::runOnOperation() { 62 // Perform progressive lowering of operations on slices and 63 // all contraction operations. Also applies folding and DCE. 64 { 65 RewritePatternSet patterns(&getContext()); 66 populateVectorToVectorCanonicalizationPatterns(patterns); 67 populateVectorBroadcastLoweringPatterns(patterns); 68 populateVectorContractLoweringPatterns(patterns); 69 populateVectorMaskOpLoweringPatterns(patterns); 70 populateVectorShapeCastLoweringPatterns(patterns); 71 populateVectorTransposeLoweringPatterns(patterns); 72 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. 73 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 74 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 75 } 76 77 // Convert to the LLVM IR dialect. 78 LLVMTypeConverter converter(&getContext()); 79 RewritePatternSet patterns(&getContext()); 80 populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); 81 populateVectorTransferLoweringPatterns(patterns); 82 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 83 populateVectorToLLVMConversionPatterns( 84 converter, patterns, reassociateFPReductions, force32BitVectorIndices); 85 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 86 87 // Architecture specific augmentations. 88 LLVMConversionTarget target(getContext()); 89 target.addLegalDialect<arith::ArithmeticDialect>(); 90 target.addLegalDialect<memref::MemRefDialect>(); 91 target.addLegalOp<UnrealizedConversionCastOp>(); 92 if (armNeon) { 93 // TODO: we may or may not want to include in-dialect lowering to 94 // LLVM-compatible operations here. So far, all operations in the dialect 95 // can be translated to LLVM IR so there is no conversion necessary. 96 target.addLegalDialect<arm_neon::ArmNeonDialect>(); 97 } 98 if (armSVE) { 99 configureArmSVELegalizeForExportTarget(target); 100 populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); 101 } 102 if (amx) { 103 configureAMXLegalizeForExportTarget(target); 104 populateAMXLegalizeForLLVMExportPatterns(converter, patterns); 105 } 106 if (x86Vector) { 107 configureX86VectorLegalizeForExportTarget(target); 108 populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); 109 } 110 111 if (failed( 112 applyPartialConversion(getOperation(), target, std::move(patterns)))) 113 signalPassFailure(); 114 } 115 116 std::unique_ptr<OperationPass<ModuleOp>> 117 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 118 return std::make_unique<LowerVectorToLLVMPass>(options); 119 } 120