1c95acf05SAart Bik //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2c95acf05SAart Bik // 3c95acf05SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c95acf05SAart Bik // See https://llvm.org/LICENSE.txt for license information. 5c95acf05SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c95acf05SAart Bik // 7c95acf05SAart Bik //===----------------------------------------------------------------------===// 8c95acf05SAart Bik 9df852599SKrzysztof Drewniak #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" 10c95acf05SAart Bik 1175e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 136ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMXDialect.h" 146ad7b97eSAart Bik #include "mlir/Dialect/AMX/Transforms.h" 15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 167310501fSNicolas Vasilache #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 177bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 187bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" 1923aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 20c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h" 22a8f3d303SLongsheng Mou #include "mlir/Dialect/Tensor/IR/Tensor.h" 232bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 258508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/Transforms.h" 268508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 2767d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 28c95acf05SAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29c95acf05SAart Bik 3067d0d7acSMichele Scuttari namespace mlir { 31cd4ca2d7SMarkus Böck #define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS 3267d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3367d0d7acSMichele Scuttari } // namespace mlir 3467d0d7acSMichele Scuttari 35c95acf05SAart Bik using namespace mlir; 36c95acf05SAart Bik using namespace mlir::vector; 37c95acf05SAart Bik 38c95acf05SAart Bik namespace { 39cb9267f0SHugo Trachino struct ConvertVectorToLLVMPass 40cb9267f0SHugo Trachino : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> { 41cd4ca2d7SMarkus Böck 42cd4ca2d7SMarkus Böck using Base::Base; 43cd4ca2d7SMarkus Böck 447310501fSNicolas Vasilache // Override explicitly to allow conditional dialect dependence. 457310501fSNicolas Vasilache void getDependentDialects(DialectRegistry ®istry) const override { 467310501fSNicolas Vasilache registry.insert<LLVM::LLVMDialect>(); 47abc362a1SJakub Kuderski registry.insert<arith::ArithDialect>(); 48e2310704SJulian Gross registry.insert<memref::MemRefDialect>(); 49a8f3d303SLongsheng Mou registry.insert<tensor::TensorDialect>(); 50cd392c0eSNicolas Vasilache if (armNeon) 516410ee0dSAlex Zinenko registry.insert<arm_neon::ArmNeonDialect>(); 52cd392c0eSNicolas Vasilache if (armSVE) 53b739badaSJavier Setoain registry.insert<arm_sve::ArmSVEDialect>(); 54cd392c0eSNicolas Vasilache if (amx) 556ad7b97eSAart Bik registry.insert<amx::AMXDialect>(); 56cd392c0eSNicolas Vasilache if (x86Vector) 578508a63bSEmilio Cota registry.insert<x86vector::X86VectorDialect>(); 587310501fSNicolas Vasilache } 59c95acf05SAart Bik void runOnOperation() override; 60c95acf05SAart Bik }; 61c95acf05SAart Bik } // namespace 62c95acf05SAart Bik 63cb9267f0SHugo Trachino void ConvertVectorToLLVMPass::runOnOperation() { 648cd8b507SMatthias Springer // Perform progressive lowering of operations on slices and all contraction 650693b9e9SMatthias Springer // operations. Also materializes masks, lowers vector.step, rank-reduces FMA, 660693b9e9SMatthias Springer // applies folding and DCE. 67c95acf05SAart Bik { 68dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 693a506b31SChris Lattner populateVectorToVectorCanonicalizationPatterns(patterns); 700ea1271eSHan-Chung Wang populateVectorBitCastLoweringPatterns(patterns); 713964c1dbSLei Zhang populateVectorBroadcastLoweringPatterns(patterns); 72*d25a1f88SDiego Caballero populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions); 733964c1dbSLei Zhang populateVectorMaskOpLoweringPatterns(patterns); 743964c1dbSLei Zhang populateVectorShapeCastLoweringPatterns(patterns); 75a1a68603SBenjamin Maxwell populateVectorInterleaveLoweringPatterns(patterns); 76*d25a1f88SDiego Caballero populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions); 77d1a9e9a7SMatthias Springer // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. 78d1a9e9a7SMatthias Springer populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 798cd8b507SMatthias Springer populateVectorMaskMaterializationPatterns(patterns, 808cd8b507SMatthias Springer force32BitVectorIndices); 810693b9e9SMatthias Springer populateVectorInsertExtractStridedSliceTransforms(patterns); 820693b9e9SMatthias Springer populateVectorStepLoweringPatterns(patterns); 830693b9e9SMatthias Springer populateVectorRankReducingFMAPattern(patterns); 8409dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 85c95acf05SAart Bik } 86c95acf05SAart Bik 87c95acf05SAart Bik // Convert to the LLVM IR dialect. 884a2d4588SMarkus Böck LowerToLLVMOptions options(&getContext()); 894a2d4588SMarkus Böck LLVMTypeConverter converter(&getContext(), options); 90dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 91d1a9e9a7SMatthias Springer populateVectorTransferLoweringPatterns(patterns); 92c95acf05SAart Bik populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 93a75a46dbSJavier Setoain populateVectorToLLVMConversionPatterns( 947bc8ad51SJavier Setoain converter, patterns, reassociateFPReductions, force32BitVectorIndices); 95c95acf05SAart Bik populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 96c95acf05SAart Bik 97c95acf05SAart Bik // Architecture specific augmentations. 98c95acf05SAart Bik LLVMConversionTarget target(getContext()); 99abc362a1SJakub Kuderski target.addLegalDialect<arith::ArithDialect>(); 100e2310704SJulian Gross target.addLegalDialect<memref::MemRefDialect>(); 101ba87f991SAlex Zinenko target.addLegalOp<UnrealizedConversionCastOp>(); 1023fa5ee67SAndrzej Warzynski 103cd392c0eSNicolas Vasilache if (armNeon) { 1046410ee0dSAlex Zinenko // TODO: we may or may not want to include in-dialect lowering to 1056410ee0dSAlex Zinenko // LLVM-compatible operations here. So far, all operations in the dialect 1066410ee0dSAlex Zinenko // can be translated to LLVM IR so there is no conversion necessary. 1076410ee0dSAlex Zinenko target.addLegalDialect<arm_neon::ArmNeonDialect>(); 1087310501fSNicolas Vasilache } 109cd392c0eSNicolas Vasilache if (armSVE) { 110b739badaSJavier Setoain configureArmSVELegalizeForExportTarget(target); 111b739badaSJavier Setoain populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); 112aece4e27SJavier Setoain } 113cd392c0eSNicolas Vasilache if (amx) { 1146ad7b97eSAart Bik configureAMXLegalizeForExportTarget(target); 1156ad7b97eSAart Bik populateAMXLegalizeForLLVMExportPatterns(converter, patterns); 1166ad7b97eSAart Bik } 117cd392c0eSNicolas Vasilache if (x86Vector) { 1188508a63bSEmilio Cota configureX86VectorLegalizeForExportTarget(target); 1198508a63bSEmilio Cota populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); 120c95acf05SAart Bik } 121c95acf05SAart Bik 122c95acf05SAart Bik if (failed( 123c95acf05SAart Bik applyPartialConversion(getOperation(), target, std::move(patterns)))) 124c95acf05SAart Bik signalPassFailure(); 125c95acf05SAart Bik } 126