xref: /llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (revision d25a1f8887e59cb770749766af9b0c7caf88326e)
1c95acf05SAart Bik //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2c95acf05SAart Bik //
3c95acf05SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c95acf05SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5c95acf05SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c95acf05SAart Bik //
7c95acf05SAart Bik //===----------------------------------------------------------------------===//
8c95acf05SAart Bik 
9df852599SKrzysztof Drewniak #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
10c95acf05SAart Bik 
1175e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
136ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMXDialect.h"
146ad7b97eSAart Bik #include "mlir/Dialect/AMX/Transforms.h"
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
167310501fSNicolas Vasilache #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
177bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
187bbfd2aeSBenjamin Maxwell #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
1923aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
20c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
22a8f3d303SLongsheng Mou #include "mlir/Dialect/Tensor/IR/Tensor.h"
232bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
258508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/Transforms.h"
268508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
2767d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
28c95acf05SAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29c95acf05SAart Bik 
3067d0d7acSMichele Scuttari namespace mlir {
31cd4ca2d7SMarkus Böck #define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
3267d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3367d0d7acSMichele Scuttari } // namespace mlir
3467d0d7acSMichele Scuttari 
35c95acf05SAart Bik using namespace mlir;
36c95acf05SAart Bik using namespace mlir::vector;
37c95acf05SAart Bik 
38c95acf05SAart Bik namespace {
39cb9267f0SHugo Trachino struct ConvertVectorToLLVMPass
40cb9267f0SHugo Trachino     : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
41cd4ca2d7SMarkus Böck 
42cd4ca2d7SMarkus Böck   using Base::Base;
43cd4ca2d7SMarkus Böck 
447310501fSNicolas Vasilache   // Override explicitly to allow conditional dialect dependence.
457310501fSNicolas Vasilache   void getDependentDialects(DialectRegistry &registry) const override {
467310501fSNicolas Vasilache     registry.insert<LLVM::LLVMDialect>();
47abc362a1SJakub Kuderski     registry.insert<arith::ArithDialect>();
48e2310704SJulian Gross     registry.insert<memref::MemRefDialect>();
49a8f3d303SLongsheng Mou     registry.insert<tensor::TensorDialect>();
50cd392c0eSNicolas Vasilache     if (armNeon)
516410ee0dSAlex Zinenko       registry.insert<arm_neon::ArmNeonDialect>();
52cd392c0eSNicolas Vasilache     if (armSVE)
53b739badaSJavier Setoain       registry.insert<arm_sve::ArmSVEDialect>();
54cd392c0eSNicolas Vasilache     if (amx)
556ad7b97eSAart Bik       registry.insert<amx::AMXDialect>();
56cd392c0eSNicolas Vasilache     if (x86Vector)
578508a63bSEmilio Cota       registry.insert<x86vector::X86VectorDialect>();
587310501fSNicolas Vasilache   }
59c95acf05SAart Bik   void runOnOperation() override;
60c95acf05SAart Bik };
61c95acf05SAart Bik } // namespace
62c95acf05SAart Bik 
63cb9267f0SHugo Trachino void ConvertVectorToLLVMPass::runOnOperation() {
648cd8b507SMatthias Springer   // Perform progressive lowering of operations on slices and all contraction
650693b9e9SMatthias Springer   // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
660693b9e9SMatthias Springer   // applies folding and DCE.
67c95acf05SAart Bik   {
68dc4e913bSChris Lattner     RewritePatternSet patterns(&getContext());
693a506b31SChris Lattner     populateVectorToVectorCanonicalizationPatterns(patterns);
700ea1271eSHan-Chung Wang     populateVectorBitCastLoweringPatterns(patterns);
713964c1dbSLei Zhang     populateVectorBroadcastLoweringPatterns(patterns);
72*d25a1f88SDiego Caballero     populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
733964c1dbSLei Zhang     populateVectorMaskOpLoweringPatterns(patterns);
743964c1dbSLei Zhang     populateVectorShapeCastLoweringPatterns(patterns);
75a1a68603SBenjamin Maxwell     populateVectorInterleaveLoweringPatterns(patterns);
76*d25a1f88SDiego Caballero     populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
77d1a9e9a7SMatthias Springer     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
78d1a9e9a7SMatthias Springer     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
798cd8b507SMatthias Springer     populateVectorMaskMaterializationPatterns(patterns,
808cd8b507SMatthias Springer                                               force32BitVectorIndices);
810693b9e9SMatthias Springer     populateVectorInsertExtractStridedSliceTransforms(patterns);
820693b9e9SMatthias Springer     populateVectorStepLoweringPatterns(patterns);
830693b9e9SMatthias Springer     populateVectorRankReducingFMAPattern(patterns);
8409dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
85c95acf05SAart Bik   }
86c95acf05SAart Bik 
87c95acf05SAart Bik   // Convert to the LLVM IR dialect.
884a2d4588SMarkus Böck   LowerToLLVMOptions options(&getContext());
894a2d4588SMarkus Böck   LLVMTypeConverter converter(&getContext(), options);
90dc4e913bSChris Lattner   RewritePatternSet patterns(&getContext());
91d1a9e9a7SMatthias Springer   populateVectorTransferLoweringPatterns(patterns);
92c95acf05SAart Bik   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
93a75a46dbSJavier Setoain   populateVectorToLLVMConversionPatterns(
947bc8ad51SJavier Setoain       converter, patterns, reassociateFPReductions, force32BitVectorIndices);
95c95acf05SAart Bik   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
96c95acf05SAart Bik 
97c95acf05SAart Bik   // Architecture specific augmentations.
98c95acf05SAart Bik   LLVMConversionTarget target(getContext());
99abc362a1SJakub Kuderski   target.addLegalDialect<arith::ArithDialect>();
100e2310704SJulian Gross   target.addLegalDialect<memref::MemRefDialect>();
101ba87f991SAlex Zinenko   target.addLegalOp<UnrealizedConversionCastOp>();
1023fa5ee67SAndrzej Warzynski 
103cd392c0eSNicolas Vasilache   if (armNeon) {
1046410ee0dSAlex Zinenko     // TODO: we may or may not want to include in-dialect lowering to
1056410ee0dSAlex Zinenko     // LLVM-compatible operations here. So far, all operations in the dialect
1066410ee0dSAlex Zinenko     // can be translated to LLVM IR so there is no conversion necessary.
1076410ee0dSAlex Zinenko     target.addLegalDialect<arm_neon::ArmNeonDialect>();
1087310501fSNicolas Vasilache   }
109cd392c0eSNicolas Vasilache   if (armSVE) {
110b739badaSJavier Setoain     configureArmSVELegalizeForExportTarget(target);
111b739badaSJavier Setoain     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
112aece4e27SJavier Setoain   }
113cd392c0eSNicolas Vasilache   if (amx) {
1146ad7b97eSAart Bik     configureAMXLegalizeForExportTarget(target);
1156ad7b97eSAart Bik     populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
1166ad7b97eSAart Bik   }
117cd392c0eSNicolas Vasilache   if (x86Vector) {
1188508a63bSEmilio Cota     configureX86VectorLegalizeForExportTarget(target);
1198508a63bSEmilio Cota     populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
120c95acf05SAart Bik   }
121c95acf05SAart Bik 
122c95acf05SAart Bik   if (failed(
123c95acf05SAart Bik           applyPartialConversion(getOperation(), target, std::move(patterns))))
124c95acf05SAart Bik     signalPassFailure();
125c95acf05SAart Bik }
126