xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision e4e0bf63d0b3615b9a2481be6769a3c876763ec6)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Dialect/X86Vector/Transforms.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 using namespace mlir::transform;
24 
25 //===----------------------------------------------------------------------===//
26 // ApplyRankReducingSubviewPatternsOp
27 //===----------------------------------------------------------------------===//
28 
29 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
30     RewritePatternSet &patterns) {
31   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // ApplyTransferPermutationPatternsOp
36 //===----------------------------------------------------------------------===//
37 
38 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
39     RewritePatternSet &patterns) {
40   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // LowerBroadcastOp
45 //===----------------------------------------------------------------------===//
46 
47 void transform::LowerBroadcastOp::populatePatterns(
48     RewritePatternSet &patterns) {
49   populateVectorBroadcastLoweringPatterns(patterns);
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // LowerContractionOp
54 //===----------------------------------------------------------------------===//
55 
56 void transform::LowerContractionOp::populatePatterns(
57     RewritePatternSet &patterns) {
58   vector::VectorTransformsOptions vectorTransformOptions;
59   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
60   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
61                                          /*benefit=*/1,
62                                          /*disableOuterProductLowering=*/true);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // LowerMasksOp
67 //===----------------------------------------------------------------------===//
68 
69 void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) {
70   populateVectorMaskOpLoweringPatterns(patterns);
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // LowerMaskedTransfersOp
75 //===----------------------------------------------------------------------===//
76 
77 void transform::LowerMaskedTransfersOp::populatePatterns(
78     RewritePatternSet &patterns) {
79   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // LowerMultiReductionOp
84 //===----------------------------------------------------------------------===//
85 
86 void transform::LowerMultiReductionOp::populatePatterns(
87     RewritePatternSet &patterns) {
88   vector::VectorTransformsOptions vectorTransformOptions;
89   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
90   vector::populateVectorMultiReductionLoweringPatterns(
91       patterns, vectorTransformOptions.vectorMultiReductionLowering);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // LowerOuterProductOp
96 //===----------------------------------------------------------------------===//
97 
98 void transform::LowerOuterProductOp::populatePatterns(
99     RewritePatternSet &patterns) {
100   populateVectorOuterProductLoweringPatterns(patterns);
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // LowerShapeCastOp
105 //===----------------------------------------------------------------------===//
106 
107 void transform::LowerShapeCastOp::populatePatterns(
108     RewritePatternSet &patterns) {
109   vector::populateVectorShapeCastLoweringPatterns(patterns);
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // LowerTransferOp
114 //===----------------------------------------------------------------------===//
115 
116 void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) {
117   vector::populateVectorTransferLoweringPatterns(patterns,
118                                                  getMaxTransferRank());
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // LowerTransposeOp
123 //===----------------------------------------------------------------------===//
124 
125 void transform::LowerTransposeOp::populatePatterns(
126     RewritePatternSet &patterns) {
127   vector::populateVectorTransposeLoweringPatterns(
128       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
129                     getLoweringStrategy()));
130   if (getAvx2LoweringStrategy()) {
131     auto avx2LoweringOptions =
132         x86vector::avx2::LoweringOptions().setTransposeOptions(
133             x86vector::avx2::TransposeLoweringOptions()
134                 .lower4x8xf32(true)
135                 .lower8x8xf32(true));
136     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
137         patterns, avx2LoweringOptions, /*benefit=*/10);
138   }
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // SplitTransferFullPartialOp
143 //===----------------------------------------------------------------------===//
144 
145 void transform::SplitTransferFullPartialOp::populatePatterns(
146     RewritePatternSet &patterns) {
147   vector::VectorTransformsOptions vectorTransformOptions;
148   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
149   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // TransferToScfOp
154 //===----------------------------------------------------------------------===//
155 
156 void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) {
157   VectorTransferToSCFOptions vectorTransferToSCFOptions =
158       VectorTransferToSCFOptions()
159           .enableFullUnroll(getFullUnroll())
160           .setTargetRank(getMaxTransferRank());
161   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
162 }
163 
164 //===----------------------------------------------------------------------===//
165 // Transform op registration
166 //===----------------------------------------------------------------------===//
167 
168 namespace {
169 /// Registers new ops and declares PDL as dependent dialect since the additional
170 /// ops are using PDL types for operands and results.
171 class VectorTransformDialectExtension
172     : public transform::TransformDialectExtension<
173           VectorTransformDialectExtension> {
174 public:
175   VectorTransformDialectExtension() {
176     declareGeneratedDialect<vector::VectorDialect>();
177     declareGeneratedDialect<LLVM::LLVMDialect>();
178     registerTransformOps<
179 #define GET_OP_LIST
180 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
181         >();
182   }
183 };
184 } // namespace
185 
186 #define GET_OP_CLASSES
187 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
188 
189 void mlir::vector::registerTransformDialectExtension(
190     DialectRegistry &registry) {
191   registry.addExtensions<VectorTransformDialectExtension>();
192 }
193