1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/AMX/Transforms.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 17 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 18 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 23 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 24 #include "mlir/Dialect/X86Vector/Transforms.h" 25 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 29 namespace mlir { 30 #define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS 31 #include "mlir/Conversion/Passes.h.inc" 32 } // namespace mlir 33 34 using namespace mlir; 35 using namespace mlir::vector; 36 37 namespace { 38 struct LowerVectorToLLVMPass 39 : public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> { 40 41 using Base::Base; 42 43 // Override explicitly to allow conditional dialect dependence. 44 void getDependentDialects(DialectRegistry ®istry) const override { 45 registry.insert<LLVM::LLVMDialect>(); 46 registry.insert<arith::ArithDialect>(); 47 registry.insert<memref::MemRefDialect>(); 48 if (armNeon) 49 registry.insert<arm_neon::ArmNeonDialect>(); 50 if (armSVE) 51 registry.insert<arm_sve::ArmSVEDialect>(); 52 if (amx) 53 registry.insert<amx::AMXDialect>(); 54 if (x86Vector) 55 registry.insert<x86vector::X86VectorDialect>(); 56 } 57 void runOnOperation() override; 58 }; 59 } // namespace 60 61 void LowerVectorToLLVMPass::runOnOperation() { 62 // Perform progressive lowering of operations on slices and 63 // all contraction operations. Also applies folding and DCE. 64 { 65 RewritePatternSet patterns(&getContext()); 66 populateVectorToVectorCanonicalizationPatterns(patterns); 67 populateVectorBroadcastLoweringPatterns(patterns); 68 populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions()); 69 populateVectorMaskOpLoweringPatterns(patterns); 70 populateVectorShapeCastLoweringPatterns(patterns); 71 populateVectorTransposeLoweringPatterns(patterns, 72 VectorTransformsOptions()); 73 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. 74 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 75 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 76 } 77 78 // Convert to the LLVM IR dialect. 79 LowerToLLVMOptions options(&getContext()); 80 LLVMTypeConverter converter(&getContext(), options); 81 RewritePatternSet patterns(&getContext()); 82 populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices); 83 populateVectorTransferLoweringPatterns(patterns); 84 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 85 populateVectorToLLVMConversionPatterns( 86 converter, patterns, reassociateFPReductions, force32BitVectorIndices); 87 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 88 89 // Architecture specific augmentations. 90 LLVMConversionTarget target(getContext()); 91 target.addLegalDialect<arith::ArithDialect>(); 92 target.addLegalDialect<memref::MemRefDialect>(); 93 target.addLegalOp<UnrealizedConversionCastOp>(); 94 95 if (armNeon) { 96 // TODO: we may or may not want to include in-dialect lowering to 97 // LLVM-compatible operations here. So far, all operations in the dialect 98 // can be translated to LLVM IR so there is no conversion necessary. 99 target.addLegalDialect<arm_neon::ArmNeonDialect>(); 100 } 101 if (armSVE) { 102 configureArmSVELegalizeForExportTarget(target); 103 populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); 104 } 105 if (amx) { 106 configureAMXLegalizeForExportTarget(target); 107 populateAMXLegalizeForLLVMExportPatterns(converter, patterns); 108 } 109 if (x86Vector) { 110 configureX86VectorLegalizeForExportTarget(target); 111 populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); 112 } 113 114 if (failed( 115 applyPartialConversion(getOperation(), target, std::move(patterns)))) 116 signalPassFailure(); 117 } 118