xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision cf40c93b5be5cd0011ebbf3a9eead224f7b7079a)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 
11 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
13 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Dialect/Transform/IR/TransformOps.h"
17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
20 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22 #include "mlir/Dialect/X86Vector/Transforms.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 using namespace mlir::transform;
28 
29 //===----------------------------------------------------------------------===//
30 // Apply...ConversionPatternsOp
31 //===----------------------------------------------------------------------===//
32 
33 void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
34     TypeConverter &typeConverter, RewritePatternSet &patterns) {
35   populateVectorToLLVMConversionPatterns(
36       static_cast<LLVMTypeConverter &>(typeConverter), patterns,
37       getReassociateFpReductions(), getForce_32bitVectorIndices());
38 }
39 
40 LogicalResult
41 transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter(
42     transform::TypeConverterBuilderOpInterface builder) {
43   if (builder.getTypeConverterType() != "LLVMTypeConverter")
44     return emitOpError("expected LLVMTypeConverter");
45   return success();
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // Apply...PatternsOp
50 //===----------------------------------------------------------------------===//
51 
52 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
53     RewritePatternSet &patterns) {
54   vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
55 }
56 
57 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
58     RewritePatternSet &patterns) {
59   vector::populateFoldArithExtensionPatterns(patterns);
60 }
61 
62 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
63     RewritePatternSet &patterns) {
64   vector::populateVectorReductionToContractPatterns(patterns);
65 }
66 
67 void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
68     RewritePatternSet &patterns) {
69   vector::populateVectorMaskOpLoweringPatterns(patterns);
70 }
71 
72 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
73     RewritePatternSet &patterns) {
74   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
75 }
76 
77 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
78     RewritePatternSet &patterns) {
79   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
80 }
81 
82 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
83     RewritePatternSet &patterns) {
84   populateVectorBroadcastLoweringPatterns(patterns);
85 }
86 
87 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
88     RewritePatternSet &patterns) {
89   vector::VectorTransformsOptions vectorTransformOptions;
90   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
91   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
92                                          /*benefit=*/1,
93                                          /*disableOuterProductLowering=*/true);
94 }
95 
96 void transform::ApplyLowerMasksPatternsOp::populatePatterns(
97     RewritePatternSet &patterns) {
98   populateVectorMaskOpLoweringPatterns(patterns);
99 }
100 
101 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
102     RewritePatternSet &patterns) {
103   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
104 }
105 
106 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
107     RewritePatternSet &patterns) {
108   populateVectorMaskMaterializationPatterns(patterns,
109                                             /*force32BitVectorIndices=*/false);
110 }
111 
112 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
113     RewritePatternSet &patterns) {
114   vector::VectorTransformsOptions vectorTransformOptions;
115   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
116   vector::populateVectorMultiReductionLoweringPatterns(
117       patterns, vectorTransformOptions.vectorMultiReductionLowering);
118 }
119 
120 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
121     RewritePatternSet &patterns) {
122   populateVectorOuterProductLoweringPatterns(patterns);
123 }
124 
125 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
126     RewritePatternSet &patterns) {
127   vector::populateVectorGatherLoweringPatterns(patterns);
128 }
129 
130 void transform::ApplyLowerScanPatternsOp::populatePatterns(
131     RewritePatternSet &patterns) {
132   vector::populateVectorScanLoweringPatterns(patterns);
133 }
134 
135 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
136     RewritePatternSet &patterns) {
137   vector::populateVectorShapeCastLoweringPatterns(patterns);
138 }
139 
140 void transform::ApplyLowerTransferPatternsOp::populatePatterns(
141     RewritePatternSet &patterns) {
142   vector::populateVectorTransferLoweringPatterns(patterns,
143                                                  getMaxTransferRank());
144 }
145 
146 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
147     RewritePatternSet &patterns) {
148   vector::populateVectorTransposeLoweringPatterns(
149       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
150                     getLoweringStrategy()));
151   if (getAvx2LoweringStrategy()) {
152     auto avx2LoweringOptions =
153         x86vector::avx2::LoweringOptions().setTransposeOptions(
154             x86vector::avx2::TransposeLoweringOptions()
155                 .lower4x8xf32(true)
156                 .lower8x8xf32(true));
157     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
158         patterns, avx2LoweringOptions, /*benefit=*/10);
159   }
160 }
161 
162 void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
163     RewritePatternSet &patterns) {
164   vector::populateVectorInterleaveLoweringPatterns(patterns);
165 }
166 
167 void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
168     RewritePatternSet &patterns) {
169   vector::populateVectorInterleaveToShufflePatterns(patterns);
170 }
171 
172 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
173     RewritePatternSet &patterns) {
174   populateVectorNarrowTypeRewritePatterns(patterns);
175   populateVectorTransposeNarrowTypeRewritePatterns(patterns);
176 }
177 
178 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
179     RewritePatternSet &patterns) {
180   vector::VectorTransformsOptions vectorTransformOptions;
181   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
182   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
183 }
184 
185 void transform::ApplyTransferToScfPatternsOp::populatePatterns(
186     RewritePatternSet &patterns) {
187   VectorTransferToSCFOptions vectorTransferToSCFOptions =
188       VectorTransferToSCFOptions()
189           .enableFullUnroll(getFullUnroll())
190           .setTargetRank(getMaxTransferRank());
191   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // Transform op registration
196 //===----------------------------------------------------------------------===//
197 
198 namespace {
199 /// Registers new ops and declares PDL as dependent dialect since the additional
200 /// ops are using PDL types for operands and results.
201 class VectorTransformDialectExtension
202     : public transform::TransformDialectExtension<
203           VectorTransformDialectExtension> {
204 public:
205   VectorTransformDialectExtension() {
206     declareGeneratedDialect<vector::VectorDialect>();
207     declareGeneratedDialect<LLVM::LLVMDialect>();
208     registerTransformOps<
209 #define GET_OP_LIST
210 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
211         >();
212   }
213 };
214 } // namespace
215 
216 #define GET_OP_CLASSES
217 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
218 
219 void mlir::vector::registerTransformDialectExtension(
220     DialectRegistry &registry) {
221   registry.addExtensions<VectorTransformDialectExtension>();
222 }
223