xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision 0ff1048409da484de1fc912ea1c4552256cb6e6c)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14 #include "mlir/Dialect/Transform/IR/TransformOps.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
19 #include "mlir/Dialect/X86Vector/Transforms.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 using namespace mlir::transform;
25 
26 //===----------------------------------------------------------------------===//
27 // Apply...PatternsOp
28 //===----------------------------------------------------------------------===//
29 
30 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
31     RewritePatternSet &patterns) {
32   vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
33 }
34 
35 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
36     RewritePatternSet &patterns) {
37   vector::populateFoldArithExtensionPatterns(patterns);
38 }
39 
40 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
41     RewritePatternSet &patterns) {
42   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
43 }
44 
45 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
46     RewritePatternSet &patterns) {
47   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
48 }
49 
50 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
51     RewritePatternSet &patterns) {
52   populateVectorBroadcastLoweringPatterns(patterns);
53 }
54 
55 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
56     RewritePatternSet &patterns) {
57   vector::VectorTransformsOptions vectorTransformOptions;
58   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
59   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
60                                          /*benefit=*/1,
61                                          /*disableOuterProductLowering=*/true);
62 }
63 
64 void transform::ApplyLowerMasksPatternsOp::populatePatterns(
65     RewritePatternSet &patterns) {
66   populateVectorMaskOpLoweringPatterns(patterns);
67 }
68 
69 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
70     RewritePatternSet &patterns) {
71   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
72 }
73 
74 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
75     RewritePatternSet &patterns) {
76   populateVectorMaskMaterializationPatterns(patterns,
77                                             /*force32BitVectorIndices=*/false);
78 }
79 
80 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
81     RewritePatternSet &patterns) {
82   vector::VectorTransformsOptions vectorTransformOptions;
83   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
84   vector::populateVectorMultiReductionLoweringPatterns(
85       patterns, vectorTransformOptions.vectorMultiReductionLowering);
86 }
87 
88 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
89     RewritePatternSet &patterns) {
90   populateVectorOuterProductLoweringPatterns(patterns);
91 }
92 
93 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
94     RewritePatternSet &patterns) {
95   vector::populateVectorGatherLoweringPatterns(patterns);
96 }
97 
98 void transform::ApplyLowerScanPatternsOp::populatePatterns(
99     RewritePatternSet &patterns) {
100   vector::populateVectorScanLoweringPatterns(patterns);
101 }
102 
103 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
104     RewritePatternSet &patterns) {
105   vector::populateVectorShapeCastLoweringPatterns(patterns);
106 }
107 
108 void transform::ApplyLowerTransferPatternsOp::populatePatterns(
109     RewritePatternSet &patterns) {
110   vector::populateVectorTransferLoweringPatterns(patterns,
111                                                  getMaxTransferRank());
112 }
113 
114 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
115     RewritePatternSet &patterns) {
116   vector::populateVectorTransposeLoweringPatterns(
117       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
118                     getLoweringStrategy()));
119   if (getAvx2LoweringStrategy()) {
120     auto avx2LoweringOptions =
121         x86vector::avx2::LoweringOptions().setTransposeOptions(
122             x86vector::avx2::TransposeLoweringOptions()
123                 .lower4x8xf32(true)
124                 .lower8x8xf32(true));
125     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
126         patterns, avx2LoweringOptions, /*benefit=*/10);
127   }
128 }
129 
130 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
131     RewritePatternSet &patterns) {
132   vector::VectorTransformsOptions vectorTransformOptions;
133   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
134   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
135 }
136 
137 void transform::ApplyTransferToScfPatternsOp::populatePatterns(
138     RewritePatternSet &patterns) {
139   VectorTransferToSCFOptions vectorTransferToSCFOptions =
140       VectorTransferToSCFOptions()
141           .enableFullUnroll(getFullUnroll())
142           .setTargetRank(getMaxTransferRank());
143   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // Transform op registration
148 //===----------------------------------------------------------------------===//
149 
150 namespace {
151 /// Registers new ops and declares PDL as dependent dialect since the additional
152 /// ops are using PDL types for operands and results.
153 class VectorTransformDialectExtension
154     : public transform::TransformDialectExtension<
155           VectorTransformDialectExtension> {
156 public:
157   VectorTransformDialectExtension() {
158     declareGeneratedDialect<vector::VectorDialect>();
159     declareGeneratedDialect<LLVM::LLVMDialect>();
160     registerTransformOps<
161 #define GET_OP_LIST
162 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
163         >();
164   }
165 };
166 } // namespace
167 
168 #define GET_OP_CLASSES
169 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
170 
171 void mlir::vector::registerTransformDialectExtension(
172     DialectRegistry &registry) {
173   registry.addExtensions<VectorTransformDialectExtension>();
174 }
175