xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision 4b48063b521bfcc9835269c729de79459d93229e)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14 #include "mlir/Dialect/Transform/IR/TransformOps.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
19 #include "mlir/Dialect/X86Vector/Transforms.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 using namespace mlir::transform;
25 
26 //===----------------------------------------------------------------------===//
27 // ApplyRankReducingSubviewPatternsOp
28 //===----------------------------------------------------------------------===//
29 
30 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
31     RewritePatternSet &patterns) {
32   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // ApplyTransferPermutationPatternsOp
37 //===----------------------------------------------------------------------===//
38 
39 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
40     RewritePatternSet &patterns) {
41   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // LowerBroadcastOp
46 //===----------------------------------------------------------------------===//
47 
48 void transform::LowerBroadcastOp::populatePatterns(
49     RewritePatternSet &patterns) {
50   populateVectorBroadcastLoweringPatterns(patterns);
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // LowerContractionOp
55 //===----------------------------------------------------------------------===//
56 
57 void transform::LowerContractionOp::populatePatterns(
58     RewritePatternSet &patterns) {
59   vector::VectorTransformsOptions vectorTransformOptions;
60   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
61   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
62                                          /*benefit=*/1,
63                                          /*disableOuterProductLowering=*/true);
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // LowerMasksOp
68 //===----------------------------------------------------------------------===//
69 
70 void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) {
71   populateVectorMaskOpLoweringPatterns(patterns);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // LowerMaskedTransfersOp
76 //===----------------------------------------------------------------------===//
77 
78 void transform::LowerMaskedTransfersOp::populatePatterns(
79     RewritePatternSet &patterns) {
80   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // MaterializeMasksOp
85 //===----------------------------------------------------------------------===//
86 
87 void transform::MaterializeMasksOp::populatePatterns(RewritePatternSet &patterns) {
88   populateVectorMaskMaterializationPatterns(patterns,
89                                             /*force32BitVectorIndices=*/false);
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // LowerMultiReductionOp
94 //===----------------------------------------------------------------------===//
95 
96 void transform::LowerMultiReductionOp::populatePatterns(
97     RewritePatternSet &patterns) {
98   vector::VectorTransformsOptions vectorTransformOptions;
99   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
100   vector::populateVectorMultiReductionLoweringPatterns(
101       patterns, vectorTransformOptions.vectorMultiReductionLowering);
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // LowerOuterProductOp
106 //===----------------------------------------------------------------------===//
107 
108 void transform::LowerOuterProductOp::populatePatterns(
109     RewritePatternSet &patterns) {
110   populateVectorOuterProductLoweringPatterns(patterns);
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // LowerShapeCastOp
115 //===----------------------------------------------------------------------===//
116 
117 void transform::LowerShapeCastOp::populatePatterns(
118     RewritePatternSet &patterns) {
119   vector::populateVectorShapeCastLoweringPatterns(patterns);
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // LowerTransferOp
124 //===----------------------------------------------------------------------===//
125 
126 void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) {
127   vector::populateVectorTransferLoweringPatterns(patterns,
128                                                  getMaxTransferRank());
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // LowerTransposeOp
133 //===----------------------------------------------------------------------===//
134 
135 void transform::LowerTransposeOp::populatePatterns(
136     RewritePatternSet &patterns) {
137   vector::populateVectorTransposeLoweringPatterns(
138       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
139                     getLoweringStrategy()));
140   if (getAvx2LoweringStrategy()) {
141     auto avx2LoweringOptions =
142         x86vector::avx2::LoweringOptions().setTransposeOptions(
143             x86vector::avx2::TransposeLoweringOptions()
144                 .lower4x8xf32(true)
145                 .lower8x8xf32(true));
146     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
147         patterns, avx2LoweringOptions, /*benefit=*/10);
148   }
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // SplitTransferFullPartialOp
153 //===----------------------------------------------------------------------===//
154 
155 void transform::SplitTransferFullPartialOp::populatePatterns(
156     RewritePatternSet &patterns) {
157   vector::VectorTransformsOptions vectorTransformOptions;
158   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
159   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // TransferToScfOp
164 //===----------------------------------------------------------------------===//
165 
166 void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) {
167   VectorTransferToSCFOptions vectorTransferToSCFOptions =
168       VectorTransferToSCFOptions()
169           .enableFullUnroll(getFullUnroll())
170           .setTargetRank(getMaxTransferRank());
171   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Transform op registration
176 //===----------------------------------------------------------------------===//
177 
178 namespace {
179 /// Registers new ops and declares PDL as dependent dialect since the additional
180 /// ops are using PDL types for operands and results.
181 class VectorTransformDialectExtension
182     : public transform::TransformDialectExtension<
183           VectorTransformDialectExtension> {
184 public:
185   VectorTransformDialectExtension() {
186     declareGeneratedDialect<vector::VectorDialect>();
187     declareGeneratedDialect<LLVM::LLVMDialect>();
188     registerTransformOps<
189 #define GET_OP_LIST
190 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
191         >();
192 
193     addDialectDataInitializer<transform::PatternRegistry>(
194         [&](transform::PatternRegistry &registry) {
195           registry.registerPatterns("vector.outer_product_lowering",
196                                     populateVectorOuterProductLoweringPatterns);
197           registry.registerPatterns("vector.broadcast_lowering",
198                                     populateVectorBroadcastLoweringPatterns);
199           registry.registerPatterns("vector.mask_op_lowering",
200                                     populateVectorMaskOpLoweringPatterns);
201           registry.registerPatterns("vector.shape_cast_lowering",
202                                     populateVectorShapeCastLoweringPatterns);
203           registry.registerPatterns(
204               "vector.transfer_lowering",
205               [&](RewritePatternSet &set, PatternBenefit benefit) {
206                 return populateVectorTransferLoweringPatterns(
207                     set, /*maxTransferRank=*/std::nullopt, benefit);
208               });
209           registry.registerPatterns(
210               "vector.transfer_permutation_map_lowering",
211               populateVectorTransferPermutationMapLoweringPatterns);
212           registry.registerPatterns("vector.scan_lowering",
213                                     populateVectorScanLoweringPatterns);
214           registry.registerPatterns("vector.vector_gather_lowering",
215                                     populateVectorGatherLoweringPatterns);
216           registry.registerPatterns(
217               "vector.mask_lowering_for_side_effecting_ops",
218               populateVectorMaskLoweringPatternsForSideEffectingOps);
219         });
220   }
221 };
222 } // namespace
223 
224 #define GET_OP_CLASSES
225 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
226 
227 void mlir::vector::registerTransformDialectExtension(
228     DialectRegistry &registry) {
229   registry.addExtensions<VectorTransformDialectExtension>();
230 }
231