1c4ce8a40SNicolas Vasilache //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===// 2c4ce8a40SNicolas Vasilache // 3c4ce8a40SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c4ce8a40SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 5c4ce8a40SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c4ce8a40SNicolas Vasilache // 7c4ce8a40SNicolas Vasilache //===----------------------------------------------------------------------===// 8c4ce8a40SNicolas Vasilache 9c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" 107ec88f06SMatthias Springer 117ec88f06SMatthias Springer #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 127ec88f06SMatthias Springer #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 136ff12333SNicolas Vasilache #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 148b513407SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h" 164b48063bSMatthias Springer #include "mlir/Dialect/Transform/IR/TransformOps.h" 175a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 18c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h" 192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 202bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 21c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 226ff12333SNicolas Vasilache #include "mlir/Dialect/X86Vector/Transforms.h" 23c4ce8a40SNicolas Vasilache #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24c4ce8a40SNicolas Vasilache 25c4ce8a40SNicolas Vasilache using namespace mlir; 26c4ce8a40SNicolas Vasilache using namespace mlir::vector; 27c4ce8a40SNicolas Vasilache using namespace mlir::transform; 28c4ce8a40SNicolas Vasilache 29c4ce8a40SNicolas Vasilache //===----------------------------------------------------------------------===// 307ec88f06SMatthias Springer // Apply...ConversionPatternsOp 317ec88f06SMatthias Springer //===----------------------------------------------------------------------===// 327ec88f06SMatthias Springer 337ec88f06SMatthias Springer void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns( 347ec88f06SMatthias Springer TypeConverter &typeConverter, RewritePatternSet &patterns) { 357ec88f06SMatthias Springer populateVectorToLLVMConversionPatterns( 367ec88f06SMatthias Springer static_cast<LLVMTypeConverter &>(typeConverter), patterns, 377ec88f06SMatthias Springer getReassociateFpReductions(), getForce_32bitVectorIndices()); 387ec88f06SMatthias Springer } 397ec88f06SMatthias Springer 407ec88f06SMatthias Springer LogicalResult 417ec88f06SMatthias Springer transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter( 427ec88f06SMatthias Springer transform::TypeConverterBuilderOpInterface builder) { 437ec88f06SMatthias Springer if (builder.getTypeConverterType() != "LLVMTypeConverter") 447ec88f06SMatthias Springer return emitOpError("expected LLVMTypeConverter"); 457ec88f06SMatthias Springer return success(); 467ec88f06SMatthias Springer } 477ec88f06SMatthias Springer 487ec88f06SMatthias Springer //===----------------------------------------------------------------------===// 49cc7f5243SMatthias Springer // Apply...PatternsOp 506ff12333SNicolas Vasilache //===----------------------------------------------------------------------===// 516ff12333SNicolas Vasilache 521d153ea9SMatthias Springer void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns( 531d153ea9SMatthias Springer RewritePatternSet &patterns) { 541d153ea9SMatthias Springer vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); 551d153ea9SMatthias Springer } 561d153ea9SMatthias Springer 570ff10484SGroverkss void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns( 580ff10484SGroverkss RewritePatternSet &patterns) { 590ff10484SGroverkss vector::populateFoldArithExtensionPatterns(patterns); 600ff10484SGroverkss } 610ff10484SGroverkss 629f0aa05bSHugo Trachino void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns( 639f0aa05bSHugo Trachino RewritePatternSet &patterns) { 649f0aa05bSHugo Trachino vector::populateElementwiseToVectorOpsPatterns(patterns); 659f0aa05bSHugo Trachino } 669f0aa05bSHugo Trachino 676b5ce2cfSNicolas Vasilache void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns( 686b5ce2cfSNicolas Vasilache RewritePatternSet &patterns) { 696b5ce2cfSNicolas Vasilache vector::populateVectorReductionToContractPatterns(patterns); 7074a512dfSBenjamin Maxwell vector::populateSinkVectorOpsPatterns(patterns); 716b5ce2cfSNicolas Vasilache } 726b5ce2cfSNicolas Vasilache 73ccef726dSBenjamin Maxwell void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns( 74ccef726dSBenjamin Maxwell RewritePatternSet &patterns) { 75ccef726dSBenjamin Maxwell vector::populateVectorMaskOpLoweringPatterns(patterns); 76ccef726dSBenjamin Maxwell } 77ccef726dSBenjamin Maxwell 788b513407SNicolas Vasilache void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( 798b513407SNicolas Vasilache RewritePatternSet &patterns) { 808b513407SNicolas Vasilache vector::populateVectorTransferDropUnitDimsPatterns(patterns); 816ff12333SNicolas Vasilache } 826ff12333SNicolas Vasilache 838b513407SNicolas Vasilache void transform::ApplyTransferPermutationPatternsOp::populatePatterns( 848b513407SNicolas Vasilache RewritePatternSet &patterns) { 858b513407SNicolas Vasilache vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); 866ff12333SNicolas Vasilache } 876ff12333SNicolas Vasilache 88*1f5e8263SAndrzej Warzyński void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns( 89*1f5e8263SAndrzej Warzyński RewritePatternSet &patterns) { 90*1f5e8263SAndrzej Warzyński vector::populateDropUnitDimWithShapeCastPatterns(patterns); 91*1f5e8263SAndrzej Warzyński } 92*1f5e8263SAndrzej Warzyński 930ea1271eSHan-Chung Wang void transform::ApplyLowerBitCastPatternsOp::populatePatterns( 940ea1271eSHan-Chung Wang RewritePatternSet &patterns) { 950ea1271eSHan-Chung Wang vector::populateVectorBitCastLoweringPatterns(patterns); 960ea1271eSHan-Chung Wang } 970ea1271eSHan-Chung Wang 98cc7f5243SMatthias Springer void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( 998b513407SNicolas Vasilache RewritePatternSet &patterns) { 1008b513407SNicolas Vasilache populateVectorBroadcastLoweringPatterns(patterns); 1018b513407SNicolas Vasilache } 1028b513407SNicolas Vasilache 103cc7f5243SMatthias Springer void transform::ApplyLowerContractionPatternsOp::populatePatterns( 1048b513407SNicolas Vasilache RewritePatternSet &patterns) { 1056ff12333SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 1068b513407SNicolas Vasilache vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); 1078b513407SNicolas Vasilache populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, 1088b513407SNicolas Vasilache /*benefit=*/1, 1098b513407SNicolas Vasilache /*disableOuterProductLowering=*/true); 1108b513407SNicolas Vasilache } 1116ff12333SNicolas Vasilache 112cc7f5243SMatthias Springer void transform::ApplyLowerMasksPatternsOp::populatePatterns( 113cc7f5243SMatthias Springer RewritePatternSet &patterns) { 1148b513407SNicolas Vasilache populateVectorMaskOpLoweringPatterns(patterns); 115e4e0bf63SNicolas Vasilache } 116e4e0bf63SNicolas Vasilache 117cc7f5243SMatthias Springer void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( 118e4e0bf63SNicolas Vasilache RewritePatternSet &patterns) { 1198b513407SNicolas Vasilache populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); 1208b513407SNicolas Vasilache } 1216ff12333SNicolas Vasilache 122cc7f5243SMatthias Springer void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( 123cc7f5243SMatthias Springer RewritePatternSet &patterns) { 124d1dffa40SNicolas Vasilache populateVectorMaskMaterializationPatterns(patterns, 125d1dffa40SNicolas Vasilache /*force32BitVectorIndices=*/false); 126d1dffa40SNicolas Vasilache } 127d1dffa40SNicolas Vasilache 128cc7f5243SMatthias Springer void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( 1298b513407SNicolas Vasilache RewritePatternSet &patterns) { 1308b513407SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 1318b513407SNicolas Vasilache vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); 1328b513407SNicolas Vasilache vector::populateVectorMultiReductionLoweringPatterns( 1338b513407SNicolas Vasilache patterns, vectorTransformOptions.vectorMultiReductionLowering); 1348b513407SNicolas Vasilache } 1358b513407SNicolas Vasilache 136cc7f5243SMatthias Springer void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( 1378b513407SNicolas Vasilache RewritePatternSet &patterns) { 1388b513407SNicolas Vasilache populateVectorOuterProductLoweringPatterns(patterns); 1398b513407SNicolas Vasilache } 1408b513407SNicolas Vasilache 141cc7f5243SMatthias Springer void transform::ApplyLowerGatherPatternsOp::populatePatterns( 142cc7f5243SMatthias Springer RewritePatternSet &patterns) { 143cc7f5243SMatthias Springer vector::populateVectorGatherLoweringPatterns(patterns); 144cc7f5243SMatthias Springer } 1458b513407SNicolas Vasilache 146cc7f5243SMatthias Springer void transform::ApplyLowerScanPatternsOp::populatePatterns( 147cc7f5243SMatthias Springer RewritePatternSet &patterns) { 148cc7f5243SMatthias Springer vector::populateVectorScanLoweringPatterns(patterns); 149cc7f5243SMatthias Springer } 150cc7f5243SMatthias Springer 151cc7f5243SMatthias Springer void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( 1528b513407SNicolas Vasilache RewritePatternSet &patterns) { 1538b513407SNicolas Vasilache vector::populateVectorShapeCastLoweringPatterns(patterns); 1548b513407SNicolas Vasilache } 1558b513407SNicolas Vasilache 156cc7f5243SMatthias Springer void transform::ApplyLowerTransferPatternsOp::populatePatterns( 157cc7f5243SMatthias Springer RewritePatternSet &patterns) { 1588b513407SNicolas Vasilache vector::populateVectorTransferLoweringPatterns(patterns, 1598b513407SNicolas Vasilache getMaxTransferRank()); 1608b513407SNicolas Vasilache } 1618b513407SNicolas Vasilache 162cc7f5243SMatthias Springer void transform::ApplyLowerTransposePatternsOp::populatePatterns( 1638b513407SNicolas Vasilache RewritePatternSet &patterns) { 1648b513407SNicolas Vasilache vector::populateVectorTransposeLoweringPatterns( 1658b513407SNicolas Vasilache patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( 1668b513407SNicolas Vasilache getLoweringStrategy())); 1678b513407SNicolas Vasilache if (getAvx2LoweringStrategy()) { 1686ff12333SNicolas Vasilache auto avx2LoweringOptions = 1696ff12333SNicolas Vasilache x86vector::avx2::LoweringOptions().setTransposeOptions( 1706ff12333SNicolas Vasilache x86vector::avx2::TransposeLoweringOptions() 1718b513407SNicolas Vasilache .lower4x8xf32(true) 1728b513407SNicolas Vasilache .lower8x8xf32(true)); 1736ff12333SNicolas Vasilache x86vector::avx2::populateSpecializedTransposeLoweringPatterns( 1746ff12333SNicolas Vasilache patterns, avx2LoweringOptions, /*benefit=*/10); 1758b513407SNicolas Vasilache } 176015cd84dSNicolas Vasilache } 177015cd84dSNicolas Vasilache 178a1a68603SBenjamin Maxwell void transform::ApplyLowerInterleavePatternsOp::populatePatterns( 179a1a68603SBenjamin Maxwell RewritePatternSet &patterns) { 180a1a68603SBenjamin Maxwell vector::populateVectorInterleaveLoweringPatterns(patterns); 181a1a68603SBenjamin Maxwell } 182a1a68603SBenjamin Maxwell 183a1d43c14SBenoit Jacob void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns( 184a1d43c14SBenoit Jacob RewritePatternSet &patterns) { 185a1d43c14SBenoit Jacob vector::populateVectorInterleaveToShufflePatterns(patterns); 186a1d43c14SBenoit Jacob } 187a1d43c14SBenoit Jacob 188bf7c490aSNicolas Vasilache void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( 189bf7c490aSNicolas Vasilache RewritePatternSet &patterns) { 190bf7c490aSNicolas Vasilache populateVectorNarrowTypeRewritePatterns(patterns); 1918ba018d7SDiego Caballero populateVectorTransposeNarrowTypeRewritePatterns(patterns); 192bf7c490aSNicolas Vasilache } 193bf7c490aSNicolas Vasilache 194cc7f5243SMatthias Springer void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( 1958b513407SNicolas Vasilache RewritePatternSet &patterns) { 1968b513407SNicolas Vasilache vector::VectorTransformsOptions vectorTransformOptions; 1978b513407SNicolas Vasilache vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); 1988b513407SNicolas Vasilache populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); 1998b513407SNicolas Vasilache } 2008b513407SNicolas Vasilache 201cc7f5243SMatthias Springer void transform::ApplyTransferToScfPatternsOp::populatePatterns( 202cc7f5243SMatthias Springer RewritePatternSet &patterns) { 2038b513407SNicolas Vasilache VectorTransferToSCFOptions vectorTransferToSCFOptions = 2048b513407SNicolas Vasilache VectorTransferToSCFOptions() 2058b513407SNicolas Vasilache .enableFullUnroll(getFullUnroll()) 2068b513407SNicolas Vasilache .setTargetRank(getMaxTransferRank()); 2078b513407SNicolas Vasilache populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); 2086ff12333SNicolas Vasilache } 2096ff12333SNicolas Vasilache 2106ff12333SNicolas Vasilache //===----------------------------------------------------------------------===// 211c4ce8a40SNicolas Vasilache // Transform op registration 212c4ce8a40SNicolas Vasilache //===----------------------------------------------------------------------===// 213c4ce8a40SNicolas Vasilache 214c4ce8a40SNicolas Vasilache namespace { 215c4ce8a40SNicolas Vasilache /// Registers new ops and declares PDL as dependent dialect since the additional 216c4ce8a40SNicolas Vasilache /// ops are using PDL types for operands and results. 217c4ce8a40SNicolas Vasilache class VectorTransformDialectExtension 218c4ce8a40SNicolas Vasilache : public transform::TransformDialectExtension< 219c4ce8a40SNicolas Vasilache VectorTransformDialectExtension> { 220c4ce8a40SNicolas Vasilache public: 22184cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension) 22284cc1865SNikhil Kalra 223c4ce8a40SNicolas Vasilache VectorTransformDialectExtension() { 2248b513407SNicolas Vasilache declareGeneratedDialect<vector::VectorDialect>(); 2258b513407SNicolas Vasilache declareGeneratedDialect<LLVM::LLVMDialect>(); 226c4ce8a40SNicolas Vasilache registerTransformOps< 227c4ce8a40SNicolas Vasilache #define GET_OP_LIST 228c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 229c4ce8a40SNicolas Vasilache >(); 230c4ce8a40SNicolas Vasilache } 231c4ce8a40SNicolas Vasilache }; 232c4ce8a40SNicolas Vasilache } // namespace 233c4ce8a40SNicolas Vasilache 234c4ce8a40SNicolas Vasilache #define GET_OP_CLASSES 235c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" 236c4ce8a40SNicolas Vasilache 237c4ce8a40SNicolas Vasilache void mlir::vector::registerTransformDialectExtension( 238c4ce8a40SNicolas Vasilache DialectRegistry ®istry) { 239c4ce8a40SNicolas Vasilache registry.addExtensions<VectorTransformDialectExtension>(); 240c4ce8a40SNicolas Vasilache } 241