xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision 1f5e8263b920f591c517a5dc562cccad39dd6ec7)
1c4ce8a40SNicolas Vasilache //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2c4ce8a40SNicolas Vasilache //
3c4ce8a40SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c4ce8a40SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5c4ce8a40SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c4ce8a40SNicolas Vasilache //
7c4ce8a40SNicolas Vasilache //===----------------------------------------------------------------------===//
8c4ce8a40SNicolas Vasilache 
9c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
107ec88f06SMatthias Springer 
117ec88f06SMatthias Springer #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
127ec88f06SMatthias Springer #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
136ff12333SNicolas Vasilache #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
148b513407SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h"
164b48063bSMatthias Springer #include "mlir/Dialect/Transform/IR/TransformOps.h"
175a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
202bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
226ff12333SNicolas Vasilache #include "mlir/Dialect/X86Vector/Transforms.h"
23c4ce8a40SNicolas Vasilache #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24c4ce8a40SNicolas Vasilache 
25c4ce8a40SNicolas Vasilache using namespace mlir;
26c4ce8a40SNicolas Vasilache using namespace mlir::vector;
27c4ce8a40SNicolas Vasilache using namespace mlir::transform;
28c4ce8a40SNicolas Vasilache 
29c4ce8a40SNicolas Vasilache //===----------------------------------------------------------------------===//
307ec88f06SMatthias Springer // Apply...ConversionPatternsOp
317ec88f06SMatthias Springer //===----------------------------------------------------------------------===//
327ec88f06SMatthias Springer 
337ec88f06SMatthias Springer void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
347ec88f06SMatthias Springer     TypeConverter &typeConverter, RewritePatternSet &patterns) {
357ec88f06SMatthias Springer   populateVectorToLLVMConversionPatterns(
367ec88f06SMatthias Springer       static_cast<LLVMTypeConverter &>(typeConverter), patterns,
377ec88f06SMatthias Springer       getReassociateFpReductions(), getForce_32bitVectorIndices());
387ec88f06SMatthias Springer }
397ec88f06SMatthias Springer 
407ec88f06SMatthias Springer LogicalResult
417ec88f06SMatthias Springer transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
427ec88f06SMatthias Springer     transform::TypeConverterBuilderOpInterface builder) {
437ec88f06SMatthias Springer   if (builder.getTypeConverterType() != "LLVMTypeConverter")
447ec88f06SMatthias Springer     return emitOpError("expected LLVMTypeConverter");
457ec88f06SMatthias Springer   return success();
467ec88f06SMatthias Springer }
477ec88f06SMatthias Springer 
487ec88f06SMatthias Springer //===----------------------------------------------------------------------===//
49cc7f5243SMatthias Springer // Apply...PatternsOp
506ff12333SNicolas Vasilache //===----------------------------------------------------------------------===//
516ff12333SNicolas Vasilache 
521d153ea9SMatthias Springer void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
531d153ea9SMatthias Springer     RewritePatternSet &patterns) {
541d153ea9SMatthias Springer   vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
551d153ea9SMatthias Springer }
561d153ea9SMatthias Springer 
570ff10484SGroverkss void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
580ff10484SGroverkss     RewritePatternSet &patterns) {
590ff10484SGroverkss   vector::populateFoldArithExtensionPatterns(patterns);
600ff10484SGroverkss }
610ff10484SGroverkss 
629f0aa05bSHugo Trachino void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
639f0aa05bSHugo Trachino     RewritePatternSet &patterns) {
649f0aa05bSHugo Trachino   vector::populateElementwiseToVectorOpsPatterns(patterns);
659f0aa05bSHugo Trachino }
669f0aa05bSHugo Trachino 
676b5ce2cfSNicolas Vasilache void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
686b5ce2cfSNicolas Vasilache     RewritePatternSet &patterns) {
696b5ce2cfSNicolas Vasilache   vector::populateVectorReductionToContractPatterns(patterns);
7074a512dfSBenjamin Maxwell   vector::populateSinkVectorOpsPatterns(patterns);
716b5ce2cfSNicolas Vasilache }
726b5ce2cfSNicolas Vasilache 
73ccef726dSBenjamin Maxwell void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
74ccef726dSBenjamin Maxwell     RewritePatternSet &patterns) {
75ccef726dSBenjamin Maxwell   vector::populateVectorMaskOpLoweringPatterns(patterns);
76ccef726dSBenjamin Maxwell }
77ccef726dSBenjamin Maxwell 
788b513407SNicolas Vasilache void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
798b513407SNicolas Vasilache     RewritePatternSet &patterns) {
808b513407SNicolas Vasilache   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
816ff12333SNicolas Vasilache }
826ff12333SNicolas Vasilache 
838b513407SNicolas Vasilache void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
848b513407SNicolas Vasilache     RewritePatternSet &patterns) {
858b513407SNicolas Vasilache   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
866ff12333SNicolas Vasilache }
876ff12333SNicolas Vasilache 
88*1f5e8263SAndrzej Warzyński void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
89*1f5e8263SAndrzej Warzyński     RewritePatternSet &patterns) {
90*1f5e8263SAndrzej Warzyński   vector::populateDropUnitDimWithShapeCastPatterns(patterns);
91*1f5e8263SAndrzej Warzyński }
92*1f5e8263SAndrzej Warzyński 
930ea1271eSHan-Chung Wang void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
940ea1271eSHan-Chung Wang     RewritePatternSet &patterns) {
950ea1271eSHan-Chung Wang   vector::populateVectorBitCastLoweringPatterns(patterns);
960ea1271eSHan-Chung Wang }
970ea1271eSHan-Chung Wang 
98cc7f5243SMatthias Springer void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
998b513407SNicolas Vasilache     RewritePatternSet &patterns) {
1008b513407SNicolas Vasilache   populateVectorBroadcastLoweringPatterns(patterns);
1018b513407SNicolas Vasilache }
1028b513407SNicolas Vasilache 
103cc7f5243SMatthias Springer void transform::ApplyLowerContractionPatternsOp::populatePatterns(
1048b513407SNicolas Vasilache     RewritePatternSet &patterns) {
1056ff12333SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
1068b513407SNicolas Vasilache   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
1078b513407SNicolas Vasilache   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
1088b513407SNicolas Vasilache                                          /*benefit=*/1,
1098b513407SNicolas Vasilache                                          /*disableOuterProductLowering=*/true);
1108b513407SNicolas Vasilache }
1116ff12333SNicolas Vasilache 
112cc7f5243SMatthias Springer void transform::ApplyLowerMasksPatternsOp::populatePatterns(
113cc7f5243SMatthias Springer     RewritePatternSet &patterns) {
1148b513407SNicolas Vasilache   populateVectorMaskOpLoweringPatterns(patterns);
115e4e0bf63SNicolas Vasilache }
116e4e0bf63SNicolas Vasilache 
117cc7f5243SMatthias Springer void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
118e4e0bf63SNicolas Vasilache     RewritePatternSet &patterns) {
1198b513407SNicolas Vasilache   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
1208b513407SNicolas Vasilache }
1216ff12333SNicolas Vasilache 
122cc7f5243SMatthias Springer void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
123cc7f5243SMatthias Springer     RewritePatternSet &patterns) {
124d1dffa40SNicolas Vasilache   populateVectorMaskMaterializationPatterns(patterns,
125d1dffa40SNicolas Vasilache                                             /*force32BitVectorIndices=*/false);
126d1dffa40SNicolas Vasilache }
127d1dffa40SNicolas Vasilache 
128cc7f5243SMatthias Springer void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
1298b513407SNicolas Vasilache     RewritePatternSet &patterns) {
1308b513407SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
1318b513407SNicolas Vasilache   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
1328b513407SNicolas Vasilache   vector::populateVectorMultiReductionLoweringPatterns(
1338b513407SNicolas Vasilache       patterns, vectorTransformOptions.vectorMultiReductionLowering);
1348b513407SNicolas Vasilache }
1358b513407SNicolas Vasilache 
136cc7f5243SMatthias Springer void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
1378b513407SNicolas Vasilache     RewritePatternSet &patterns) {
1388b513407SNicolas Vasilache   populateVectorOuterProductLoweringPatterns(patterns);
1398b513407SNicolas Vasilache }
1408b513407SNicolas Vasilache 
141cc7f5243SMatthias Springer void transform::ApplyLowerGatherPatternsOp::populatePatterns(
142cc7f5243SMatthias Springer     RewritePatternSet &patterns) {
143cc7f5243SMatthias Springer   vector::populateVectorGatherLoweringPatterns(patterns);
144cc7f5243SMatthias Springer }
1458b513407SNicolas Vasilache 
146cc7f5243SMatthias Springer void transform::ApplyLowerScanPatternsOp::populatePatterns(
147cc7f5243SMatthias Springer     RewritePatternSet &patterns) {
148cc7f5243SMatthias Springer   vector::populateVectorScanLoweringPatterns(patterns);
149cc7f5243SMatthias Springer }
150cc7f5243SMatthias Springer 
151cc7f5243SMatthias Springer void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
1528b513407SNicolas Vasilache     RewritePatternSet &patterns) {
1538b513407SNicolas Vasilache   vector::populateVectorShapeCastLoweringPatterns(patterns);
1548b513407SNicolas Vasilache }
1558b513407SNicolas Vasilache 
156cc7f5243SMatthias Springer void transform::ApplyLowerTransferPatternsOp::populatePatterns(
157cc7f5243SMatthias Springer     RewritePatternSet &patterns) {
1588b513407SNicolas Vasilache   vector::populateVectorTransferLoweringPatterns(patterns,
1598b513407SNicolas Vasilache                                                  getMaxTransferRank());
1608b513407SNicolas Vasilache }
1618b513407SNicolas Vasilache 
162cc7f5243SMatthias Springer void transform::ApplyLowerTransposePatternsOp::populatePatterns(
1638b513407SNicolas Vasilache     RewritePatternSet &patterns) {
1648b513407SNicolas Vasilache   vector::populateVectorTransposeLoweringPatterns(
1658b513407SNicolas Vasilache       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
1668b513407SNicolas Vasilache                     getLoweringStrategy()));
1678b513407SNicolas Vasilache   if (getAvx2LoweringStrategy()) {
1686ff12333SNicolas Vasilache     auto avx2LoweringOptions =
1696ff12333SNicolas Vasilache         x86vector::avx2::LoweringOptions().setTransposeOptions(
1706ff12333SNicolas Vasilache             x86vector::avx2::TransposeLoweringOptions()
1718b513407SNicolas Vasilache                 .lower4x8xf32(true)
1728b513407SNicolas Vasilache                 .lower8x8xf32(true));
1736ff12333SNicolas Vasilache     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
1746ff12333SNicolas Vasilache         patterns, avx2LoweringOptions, /*benefit=*/10);
1758b513407SNicolas Vasilache   }
176015cd84dSNicolas Vasilache }
177015cd84dSNicolas Vasilache 
178a1a68603SBenjamin Maxwell void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
179a1a68603SBenjamin Maxwell     RewritePatternSet &patterns) {
180a1a68603SBenjamin Maxwell   vector::populateVectorInterleaveLoweringPatterns(patterns);
181a1a68603SBenjamin Maxwell }
182a1a68603SBenjamin Maxwell 
183a1d43c14SBenoit Jacob void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
184a1d43c14SBenoit Jacob     RewritePatternSet &patterns) {
185a1d43c14SBenoit Jacob   vector::populateVectorInterleaveToShufflePatterns(patterns);
186a1d43c14SBenoit Jacob }
187a1d43c14SBenoit Jacob 
188bf7c490aSNicolas Vasilache void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
189bf7c490aSNicolas Vasilache     RewritePatternSet &patterns) {
190bf7c490aSNicolas Vasilache   populateVectorNarrowTypeRewritePatterns(patterns);
1918ba018d7SDiego Caballero   populateVectorTransposeNarrowTypeRewritePatterns(patterns);
192bf7c490aSNicolas Vasilache }
193bf7c490aSNicolas Vasilache 
194cc7f5243SMatthias Springer void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
1958b513407SNicolas Vasilache     RewritePatternSet &patterns) {
1968b513407SNicolas Vasilache   vector::VectorTransformsOptions vectorTransformOptions;
1978b513407SNicolas Vasilache   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
1988b513407SNicolas Vasilache   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
1998b513407SNicolas Vasilache }
2008b513407SNicolas Vasilache 
201cc7f5243SMatthias Springer void transform::ApplyTransferToScfPatternsOp::populatePatterns(
202cc7f5243SMatthias Springer     RewritePatternSet &patterns) {
2038b513407SNicolas Vasilache   VectorTransferToSCFOptions vectorTransferToSCFOptions =
2048b513407SNicolas Vasilache       VectorTransferToSCFOptions()
2058b513407SNicolas Vasilache           .enableFullUnroll(getFullUnroll())
2068b513407SNicolas Vasilache           .setTargetRank(getMaxTransferRank());
2078b513407SNicolas Vasilache   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
2086ff12333SNicolas Vasilache }
2096ff12333SNicolas Vasilache 
2106ff12333SNicolas Vasilache //===----------------------------------------------------------------------===//
211c4ce8a40SNicolas Vasilache // Transform op registration
212c4ce8a40SNicolas Vasilache //===----------------------------------------------------------------------===//
213c4ce8a40SNicolas Vasilache 
214c4ce8a40SNicolas Vasilache namespace {
215c4ce8a40SNicolas Vasilache /// Registers new ops and declares PDL as dependent dialect since the additional
216c4ce8a40SNicolas Vasilache /// ops are using PDL types for operands and results.
217c4ce8a40SNicolas Vasilache class VectorTransformDialectExtension
218c4ce8a40SNicolas Vasilache     : public transform::TransformDialectExtension<
219c4ce8a40SNicolas Vasilache           VectorTransformDialectExtension> {
220c4ce8a40SNicolas Vasilache public:
22184cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
22284cc1865SNikhil Kalra 
223c4ce8a40SNicolas Vasilache   VectorTransformDialectExtension() {
2248b513407SNicolas Vasilache     declareGeneratedDialect<vector::VectorDialect>();
2258b513407SNicolas Vasilache     declareGeneratedDialect<LLVM::LLVMDialect>();
226c4ce8a40SNicolas Vasilache     registerTransformOps<
227c4ce8a40SNicolas Vasilache #define GET_OP_LIST
228c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
229c4ce8a40SNicolas Vasilache         >();
230c4ce8a40SNicolas Vasilache   }
231c4ce8a40SNicolas Vasilache };
232c4ce8a40SNicolas Vasilache } // namespace
233c4ce8a40SNicolas Vasilache 
234c4ce8a40SNicolas Vasilache #define GET_OP_CLASSES
235c4ce8a40SNicolas Vasilache #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
236c4ce8a40SNicolas Vasilache 
237c4ce8a40SNicolas Vasilache void mlir::vector::registerTransformDialectExtension(
238c4ce8a40SNicolas Vasilache     DialectRegistry &registry) {
239c4ce8a40SNicolas Vasilache   registry.addExtensions<VectorTransformDialectExtension>();
240c4ce8a40SNicolas Vasilache }
241