xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision d1dffa402384af63496979905536dd180c51eedc)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18 #include "mlir/Dialect/X86Vector/Transforms.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 using namespace mlir::transform;
24 
25 //===----------------------------------------------------------------------===//
26 // ApplyRankReducingSubviewPatternsOp
27 //===----------------------------------------------------------------------===//
28 
29 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
30     RewritePatternSet &patterns) {
31   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // ApplyTransferPermutationPatternsOp
36 //===----------------------------------------------------------------------===//
37 
38 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
39     RewritePatternSet &patterns) {
40   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // LowerBroadcastOp
45 //===----------------------------------------------------------------------===//
46 
47 void transform::LowerBroadcastOp::populatePatterns(
48     RewritePatternSet &patterns) {
49   populateVectorBroadcastLoweringPatterns(patterns);
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // LowerContractionOp
54 //===----------------------------------------------------------------------===//
55 
56 void transform::LowerContractionOp::populatePatterns(
57     RewritePatternSet &patterns) {
58   vector::VectorTransformsOptions vectorTransformOptions;
59   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
60   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
61                                          /*benefit=*/1,
62                                          /*disableOuterProductLowering=*/true);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // LowerMasksOp
67 //===----------------------------------------------------------------------===//
68 
69 void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) {
70   populateVectorMaskOpLoweringPatterns(patterns);
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // LowerMaskedTransfersOp
75 //===----------------------------------------------------------------------===//
76 
77 void transform::LowerMaskedTransfersOp::populatePatterns(
78     RewritePatternSet &patterns) {
79   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // MaterializeMasksOp
84 //===----------------------------------------------------------------------===//
85 
86 void transform::MaterializeMasksOp::populatePatterns(RewritePatternSet &patterns) {
87   populateVectorMaskMaterializationPatterns(patterns,
88                                             /*force32BitVectorIndices=*/false);
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // LowerMultiReductionOp
93 //===----------------------------------------------------------------------===//
94 
95 void transform::LowerMultiReductionOp::populatePatterns(
96     RewritePatternSet &patterns) {
97   vector::VectorTransformsOptions vectorTransformOptions;
98   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
99   vector::populateVectorMultiReductionLoweringPatterns(
100       patterns, vectorTransformOptions.vectorMultiReductionLowering);
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // LowerOuterProductOp
105 //===----------------------------------------------------------------------===//
106 
107 void transform::LowerOuterProductOp::populatePatterns(
108     RewritePatternSet &patterns) {
109   populateVectorOuterProductLoweringPatterns(patterns);
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // LowerShapeCastOp
114 //===----------------------------------------------------------------------===//
115 
116 void transform::LowerShapeCastOp::populatePatterns(
117     RewritePatternSet &patterns) {
118   vector::populateVectorShapeCastLoweringPatterns(patterns);
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // LowerTransferOp
123 //===----------------------------------------------------------------------===//
124 
125 void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) {
126   vector::populateVectorTransferLoweringPatterns(patterns,
127                                                  getMaxTransferRank());
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // LowerTransposeOp
132 //===----------------------------------------------------------------------===//
133 
134 void transform::LowerTransposeOp::populatePatterns(
135     RewritePatternSet &patterns) {
136   vector::populateVectorTransposeLoweringPatterns(
137       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
138                     getLoweringStrategy()));
139   if (getAvx2LoweringStrategy()) {
140     auto avx2LoweringOptions =
141         x86vector::avx2::LoweringOptions().setTransposeOptions(
142             x86vector::avx2::TransposeLoweringOptions()
143                 .lower4x8xf32(true)
144                 .lower8x8xf32(true));
145     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
146         patterns, avx2LoweringOptions, /*benefit=*/10);
147   }
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // SplitTransferFullPartialOp
152 //===----------------------------------------------------------------------===//
153 
154 void transform::SplitTransferFullPartialOp::populatePatterns(
155     RewritePatternSet &patterns) {
156   vector::VectorTransformsOptions vectorTransformOptions;
157   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
158   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // TransferToScfOp
163 //===----------------------------------------------------------------------===//
164 
165 void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) {
166   VectorTransferToSCFOptions vectorTransferToSCFOptions =
167       VectorTransferToSCFOptions()
168           .enableFullUnroll(getFullUnroll())
169           .setTargetRank(getMaxTransferRank());
170   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Transform op registration
175 //===----------------------------------------------------------------------===//
176 
177 namespace {
178 /// Registers new ops and declares PDL as dependent dialect since the additional
179 /// ops are using PDL types for operands and results.
180 class VectorTransformDialectExtension
181     : public transform::TransformDialectExtension<
182           VectorTransformDialectExtension> {
183 public:
184   VectorTransformDialectExtension() {
185     declareGeneratedDialect<vector::VectorDialect>();
186     declareGeneratedDialect<LLVM::LLVMDialect>();
187     registerTransformOps<
188 #define GET_OP_LIST
189 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
190         >();
191   }
192 };
193 } // namespace
194 
195 #define GET_OP_CLASSES
196 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
197 
198 void mlir::vector::registerTransformDialectExtension(
199     DialectRegistry &registry) {
200   registry.addExtensions<VectorTransformDialectExtension>();
201 }
202