xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision 1f5e8263b920f591c517a5dc562cccad39dd6ec7)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
13 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Dialect/Transform/IR/TransformOps.h"
17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22 #include "mlir/Dialect/X86Vector/Transforms.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 using namespace mlir::transform;
28 
29 //===----------------------------------------------------------------------===//
30 // Apply...ConversionPatternsOp
31 //===----------------------------------------------------------------------===//
32 
33 void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
34     TypeConverter &typeConverter, RewritePatternSet &patterns) {
35   populateVectorToLLVMConversionPatterns(
36       static_cast<LLVMTypeConverter &>(typeConverter), patterns,
37       getReassociateFpReductions(), getForce_32bitVectorIndices());
38 }
39 
40 LogicalResult
41 transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
42     transform::TypeConverterBuilderOpInterface builder) {
43   if (builder.getTypeConverterType() != "LLVMTypeConverter")
44     return emitOpError("expected LLVMTypeConverter");
45   return success();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // Apply...PatternsOp
50 //===----------------------------------------------------------------------===//
51 
52 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
53     RewritePatternSet &patterns) {
54   vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
55 }
56 
57 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
58     RewritePatternSet &patterns) {
59   vector::populateFoldArithExtensionPatterns(patterns);
60 }
61 
62 void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
63     RewritePatternSet &patterns) {
64   vector::populateElementwiseToVectorOpsPatterns(patterns);
65 }
66 
67 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
68     RewritePatternSet &patterns) {
69   vector::populateVectorReductionToContractPatterns(patterns);
70   vector::populateSinkVectorOpsPatterns(patterns);
71 }
72 
73 void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
74     RewritePatternSet &patterns) {
75   vector::populateVectorMaskOpLoweringPatterns(patterns);
76 }
77 
78 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
79     RewritePatternSet &patterns) {
80   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
81 }
82 
83 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
84     RewritePatternSet &patterns) {
85   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
86 }
87 
88 void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
89     RewritePatternSet &patterns) {
90   vector::populateDropUnitDimWithShapeCastPatterns(patterns);
91 }
92 
93 void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
94     RewritePatternSet &patterns) {
95   vector::populateVectorBitCastLoweringPatterns(patterns);
96 }
97 
98 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
99     RewritePatternSet &patterns) {
100   populateVectorBroadcastLoweringPatterns(patterns);
101 }
102 
103 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
104     RewritePatternSet &patterns) {
105   vector::VectorTransformsOptions vectorTransformOptions;
106   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
107   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
108                                          /*benefit=*/1,
109                                          /*disableOuterProductLowering=*/true);
110 }
111 
112 void transform::ApplyLowerMasksPatternsOp::populatePatterns(
113     RewritePatternSet &patterns) {
114   populateVectorMaskOpLoweringPatterns(patterns);
115 }
116 
117 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
118     RewritePatternSet &patterns) {
119   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
120 }
121 
122 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
123     RewritePatternSet &patterns) {
124   populateVectorMaskMaterializationPatterns(patterns,
125                                             /*force32BitVectorIndices=*/false);
126 }
127 
128 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
129     RewritePatternSet &patterns) {
130   vector::VectorTransformsOptions vectorTransformOptions;
131   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
132   vector::populateVectorMultiReductionLoweringPatterns(
133       patterns, vectorTransformOptions.vectorMultiReductionLowering);
134 }
135 
136 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
137     RewritePatternSet &patterns) {
138   populateVectorOuterProductLoweringPatterns(patterns);
139 }
140 
141 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
142     RewritePatternSet &patterns) {
143   vector::populateVectorGatherLoweringPatterns(patterns);
144 }
145 
146 void transform::ApplyLowerScanPatternsOp::populatePatterns(
147     RewritePatternSet &patterns) {
148   vector::populateVectorScanLoweringPatterns(patterns);
149 }
150 
151 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
152     RewritePatternSet &patterns) {
153   vector::populateVectorShapeCastLoweringPatterns(patterns);
154 }
155 
156 void transform::ApplyLowerTransferPatternsOp::populatePatterns(
157     RewritePatternSet &patterns) {
158   vector::populateVectorTransferLoweringPatterns(patterns,
159                                                  getMaxTransferRank());
160 }
161 
162 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
163     RewritePatternSet &patterns) {
164   vector::populateVectorTransposeLoweringPatterns(
165       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
166                     getLoweringStrategy()));
167   if (getAvx2LoweringStrategy()) {
168     auto avx2LoweringOptions =
169         x86vector::avx2::LoweringOptions().setTransposeOptions(
170             x86vector::avx2::TransposeLoweringOptions()
171                 .lower4x8xf32(true)
172                 .lower8x8xf32(true));
173     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
174         patterns, avx2LoweringOptions, /*benefit=*/10);
175   }
176 }
177 
178 void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
179     RewritePatternSet &patterns) {
180   vector::populateVectorInterleaveLoweringPatterns(patterns);
181 }
182 
183 void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
184     RewritePatternSet &patterns) {
185   vector::populateVectorInterleaveToShufflePatterns(patterns);
186 }
187 
188 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
189     RewritePatternSet &patterns) {
190   populateVectorNarrowTypeRewritePatterns(patterns);
191   populateVectorTransposeNarrowTypeRewritePatterns(patterns);
192 }
193 
194 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
195     RewritePatternSet &patterns) {
196   vector::VectorTransformsOptions vectorTransformOptions;
197   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
198   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
199 }
200 
201 void transform::ApplyTransferToScfPatternsOp::populatePatterns(
202     RewritePatternSet &patterns) {
203   VectorTransferToSCFOptions vectorTransferToSCFOptions =
204       VectorTransferToSCFOptions()
205           .enableFullUnroll(getFullUnroll())
206           .setTargetRank(getMaxTransferRank());
207   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // Transform op registration
212 //===----------------------------------------------------------------------===//
213 
214 namespace {
215 /// Registers new ops and declares PDL as dependent dialect since the additional
216 /// ops are using PDL types for operands and results.
217 class VectorTransformDialectExtension
218     : public transform::TransformDialectExtension<
219           VectorTransformDialectExtension> {
220 public:
221   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
222 
223   VectorTransformDialectExtension() {
224     declareGeneratedDialect<vector::VectorDialect>();
225     declareGeneratedDialect<LLVM::LLVMDialect>();
226     registerTransformOps<
227 #define GET_OP_LIST
228 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
229         >();
230   }
231 };
232 } // namespace
233 
234 #define GET_OP_CLASSES
235 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
236 
237 void mlir::vector::registerTransformDialectExtension(
238     DialectRegistry &registry) {
239   registry.addExtensions<VectorTransformDialectExtension>();
240 }
241