xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision 6b5ce2cffe96d532a498b21a55f5b4c3a0570609)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14 #include "mlir/Dialect/Transform/IR/TransformOps.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
19 #include "mlir/Dialect/X86Vector/Transforms.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 using namespace mlir::transform;
25 
26 //===----------------------------------------------------------------------===//
27 // Apply...PatternsOp
28 //===----------------------------------------------------------------------===//
29 
30 void transform::ApplyCastAwayVectorLeadingOneDimPatternsOp::populatePatterns(
31     RewritePatternSet &patterns) {
32   vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
33 }
34 
35 void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
36     RewritePatternSet &patterns) {
37   vector::populateFoldArithExtensionPatterns(patterns);
38 }
39 
40 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
41     RewritePatternSet &patterns) {
42   vector::populateVectorReductionToContractPatterns(patterns);
43 }
44 
45 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
46     RewritePatternSet &patterns) {
47   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
48 }
49 
50 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
51     RewritePatternSet &patterns) {
52   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
53 }
54 
55 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
56     RewritePatternSet &patterns) {
57   populateVectorBroadcastLoweringPatterns(patterns);
58 }
59 
60 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
61     RewritePatternSet &patterns) {
62   vector::VectorTransformsOptions vectorTransformOptions;
63   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
64   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
65                                          /*benefit=*/1,
66                                          /*disableOuterProductLowering=*/true);
67 }
68 
69 void transform::ApplyLowerMasksPatternsOp::populatePatterns(
70     RewritePatternSet &patterns) {
71   populateVectorMaskOpLoweringPatterns(patterns);
72 }
73 
74 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
75     RewritePatternSet &patterns) {
76   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
77 }
78 
79 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
80     RewritePatternSet &patterns) {
81   populateVectorMaskMaterializationPatterns(patterns,
82                                             /*force32BitVectorIndices=*/false);
83 }
84 
85 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
86     RewritePatternSet &patterns) {
87   vector::VectorTransformsOptions vectorTransformOptions;
88   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
89   vector::populateVectorMultiReductionLoweringPatterns(
90       patterns, vectorTransformOptions.vectorMultiReductionLowering);
91 }
92 
93 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
94     RewritePatternSet &patterns) {
95   populateVectorOuterProductLoweringPatterns(patterns);
96 }
97 
98 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
99     RewritePatternSet &patterns) {
100   vector::populateVectorGatherLoweringPatterns(patterns);
101 }
102 
103 void transform::ApplyLowerScanPatternsOp::populatePatterns(
104     RewritePatternSet &patterns) {
105   vector::populateVectorScanLoweringPatterns(patterns);
106 }
107 
108 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
109     RewritePatternSet &patterns) {
110   vector::populateVectorShapeCastLoweringPatterns(patterns);
111 }
112 
113 void transform::ApplyLowerTransferPatternsOp::populatePatterns(
114     RewritePatternSet &patterns) {
115   vector::populateVectorTransferLoweringPatterns(patterns,
116                                                  getMaxTransferRank());
117 }
118 
119 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
120     RewritePatternSet &patterns) {
121   vector::populateVectorTransposeLoweringPatterns(
122       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
123                     getLoweringStrategy()));
124   if (getAvx2LoweringStrategy()) {
125     auto avx2LoweringOptions =
126         x86vector::avx2::LoweringOptions().setTransposeOptions(
127             x86vector::avx2::TransposeLoweringOptions()
128                 .lower4x8xf32(true)
129                 .lower8x8xf32(true));
130     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
131         patterns, avx2LoweringOptions, /*benefit=*/10);
132   }
133 }
134 
135 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
136     RewritePatternSet &patterns) {
137   vector::VectorTransformsOptions vectorTransformOptions;
138   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
139   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
140 }
141 
142 void transform::ApplyTransferToScfPatternsOp::populatePatterns(
143     RewritePatternSet &patterns) {
144   VectorTransferToSCFOptions vectorTransferToSCFOptions =
145       VectorTransferToSCFOptions()
146           .enableFullUnroll(getFullUnroll())
147           .setTargetRank(getMaxTransferRank());
148   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Transform op registration
153 //===----------------------------------------------------------------------===//
154 
155 namespace {
156 /// Registers new ops and declares PDL as dependent dialect since the additional
157 /// ops are using PDL types for operands and results.
158 class VectorTransformDialectExtension
159     : public transform::TransformDialectExtension<
160           VectorTransformDialectExtension> {
161 public:
162   VectorTransformDialectExtension() {
163     declareGeneratedDialect<vector::VectorDialect>();
164     declareGeneratedDialect<LLVM::LLVMDialect>();
165     registerTransformOps<
166 #define GET_OP_LIST
167 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
168         >();
169   }
170 };
171 } // namespace
172 
173 #define GET_OP_CLASSES
174 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
175 
176 void mlir::vector::registerTransformDialectExtension(
177     DialectRegistry &registry) {
178   registry.addExtensions<VectorTransformDialectExtension>();
179 }
180