1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/AMX/Transforms.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" 17 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 18 #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 24 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 25 #include "mlir/Dialect/X86Vector/Transforms.h" 26 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 30 namespace mlir { 31 #define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS 32 #include "mlir/Conversion/Passes.h.inc" 33 } // namespace mlir 34 35 using namespace mlir; 36 using namespace mlir::vector; 37 38 namespace { 39 struct ConvertVectorToLLVMPass 40 : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> { 41 42 using Base::Base; 43 44 // Override explicitly to allow conditional dialect dependence. 45 void getDependentDialects(DialectRegistry ®istry) const override { 46 registry.insert<LLVM::LLVMDialect>(); 47 registry.insert<arith::ArithDialect>(); 48 registry.insert<memref::MemRefDialect>(); 49 registry.insert<tensor::TensorDialect>(); 50 if (armNeon) 51 registry.insert<arm_neon::ArmNeonDialect>(); 52 if (armSVE) 53 registry.insert<arm_sve::ArmSVEDialect>(); 54 if (amx) 55 registry.insert<amx::AMXDialect>(); 56 if (x86Vector) 57 registry.insert<x86vector::X86VectorDialect>(); 58 } 59 void runOnOperation() override; 60 }; 61 } // namespace 62 63 void ConvertVectorToLLVMPass::runOnOperation() { 64 // Perform progressive lowering of operations on slices and all contraction 65 // operations. Also materializes masks, lowers vector.step, rank-reduces FMA, 66 // applies folding and DCE. 67 { 68 RewritePatternSet patterns(&getContext()); 69 populateVectorToVectorCanonicalizationPatterns(patterns); 70 populateVectorBitCastLoweringPatterns(patterns); 71 populateVectorBroadcastLoweringPatterns(patterns); 72 populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions); 73 populateVectorMaskOpLoweringPatterns(patterns); 74 populateVectorShapeCastLoweringPatterns(patterns); 75 populateVectorInterleaveLoweringPatterns(patterns); 76 populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions); 77 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. 78 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); 79 populateVectorMaskMaterializationPatterns(patterns, 80 force32BitVectorIndices); 81 populateVectorInsertExtractStridedSliceTransforms(patterns); 82 populateVectorStepLoweringPatterns(patterns); 83 populateVectorRankReducingFMAPattern(patterns); 84 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 85 } 86 87 // Convert to the LLVM IR dialect. 88 LowerToLLVMOptions options(&getContext()); 89 LLVMTypeConverter converter(&getContext(), options); 90 RewritePatternSet patterns(&getContext()); 91 populateVectorTransferLoweringPatterns(patterns); 92 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 93 populateVectorToLLVMConversionPatterns( 94 converter, patterns, reassociateFPReductions, force32BitVectorIndices); 95 populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 96 97 // Architecture specific augmentations. 98 LLVMConversionTarget target(getContext()); 99 target.addLegalDialect<arith::ArithDialect>(); 100 target.addLegalDialect<memref::MemRefDialect>(); 101 target.addLegalOp<UnrealizedConversionCastOp>(); 102 103 if (armNeon) { 104 // TODO: we may or may not want to include in-dialect lowering to 105 // LLVM-compatible operations here. So far, all operations in the dialect 106 // can be translated to LLVM IR so there is no conversion necessary. 107 target.addLegalDialect<arm_neon::ArmNeonDialect>(); 108 } 109 if (armSVE) { 110 configureArmSVELegalizeForExportTarget(target); 111 populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); 112 } 113 if (amx) { 114 configureAMXLegalizeForExportTarget(target); 115 populateAMXLegalizeForLLVMExportPatterns(converter, patterns); 116 } 117 if (x86Vector) { 118 configureX86VectorLegalizeForExportTarget(target); 119 populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); 120 } 121 122 if (failed( 123 applyPartialConversion(getOperation(), target, std::move(patterns)))) 124 signalPassFailure(); 125 } 126