xref: /llvm-project/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (revision cc7f52432bca6938d748c6730943586f879f841f)
1 //===- VectorTransformOps.cpp - Implementation of Vector transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
10 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
11 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
13 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
14 #include "mlir/Dialect/Transform/IR/TransformOps.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
18 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
19 #include "mlir/Dialect/X86Vector/Transforms.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 using namespace mlir::transform;
25 
26 //===----------------------------------------------------------------------===//
27 // Apply...PatternsOp
28 //===----------------------------------------------------------------------===//
29 
30 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
31     RewritePatternSet &patterns) {
32   vector::populateVectorTransferDropUnitDimsPatterns(patterns);
33 }
34 
35 void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
36     RewritePatternSet &patterns) {
37   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
38 }
39 
40 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
41     RewritePatternSet &patterns) {
42   populateVectorBroadcastLoweringPatterns(patterns);
43 }
44 
45 void transform::ApplyLowerContractionPatternsOp::populatePatterns(
46     RewritePatternSet &patterns) {
47   vector::VectorTransformsOptions vectorTransformOptions;
48   vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
49   populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
50                                          /*benefit=*/1,
51                                          /*disableOuterProductLowering=*/true);
52 }
53 
54 void transform::ApplyLowerMasksPatternsOp::populatePatterns(
55     RewritePatternSet &patterns) {
56   populateVectorMaskOpLoweringPatterns(patterns);
57 }
58 
59 void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns(
60     RewritePatternSet &patterns) {
61   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
62 }
63 
64 void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
65     RewritePatternSet &patterns) {
66   populateVectorMaskMaterializationPatterns(patterns,
67                                             /*force32BitVectorIndices=*/false);
68 }
69 
70 void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns(
71     RewritePatternSet &patterns) {
72   vector::VectorTransformsOptions vectorTransformOptions;
73   vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
74   vector::populateVectorMultiReductionLoweringPatterns(
75       patterns, vectorTransformOptions.vectorMultiReductionLowering);
76 }
77 
78 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
79     RewritePatternSet &patterns) {
80   populateVectorOuterProductLoweringPatterns(patterns);
81 }
82 
83 void transform::ApplyLowerGatherPatternsOp::populatePatterns(
84     RewritePatternSet &patterns) {
85   vector::populateVectorGatherLoweringPatterns(patterns);
86 }
87 
88 void transform::ApplyLowerScanPatternsOp::populatePatterns(
89     RewritePatternSet &patterns) {
90   vector::populateVectorScanLoweringPatterns(patterns);
91 }
92 
93 void transform::ApplyLowerShapeCastPatternsOp::populatePatterns(
94     RewritePatternSet &patterns) {
95   vector::populateVectorShapeCastLoweringPatterns(patterns);
96 }
97 
98 void transform::ApplyLowerTransferPatternsOp::populatePatterns(
99     RewritePatternSet &patterns) {
100   vector::populateVectorTransferLoweringPatterns(patterns,
101                                                  getMaxTransferRank());
102 }
103 
104 void transform::ApplyLowerTransposePatternsOp::populatePatterns(
105     RewritePatternSet &patterns) {
106   vector::populateVectorTransposeLoweringPatterns(
107       patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
108                     getLoweringStrategy()));
109   if (getAvx2LoweringStrategy()) {
110     auto avx2LoweringOptions =
111         x86vector::avx2::LoweringOptions().setTransposeOptions(
112             x86vector::avx2::TransposeLoweringOptions()
113                 .lower4x8xf32(true)
114                 .lower8x8xf32(true));
115     x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
116         patterns, avx2LoweringOptions, /*benefit=*/10);
117   }
118 }
119 
120 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
121     RewritePatternSet &patterns) {
122   vector::VectorTransformsOptions vectorTransformOptions;
123   vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy());
124   populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
125 }
126 
127 void transform::ApplyTransferToScfPatternsOp::populatePatterns(
128     RewritePatternSet &patterns) {
129   VectorTransferToSCFOptions vectorTransferToSCFOptions =
130       VectorTransferToSCFOptions()
131           .enableFullUnroll(getFullUnroll())
132           .setTargetRank(getMaxTransferRank());
133   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Transform op registration
138 //===----------------------------------------------------------------------===//
139 
140 namespace {
141 /// Registers new ops and declares PDL as dependent dialect since the additional
142 /// ops are using PDL types for operands and results.
143 class VectorTransformDialectExtension
144     : public transform::TransformDialectExtension<
145           VectorTransformDialectExtension> {
146 public:
147   VectorTransformDialectExtension() {
148     declareGeneratedDialect<vector::VectorDialect>();
149     declareGeneratedDialect<LLVM::LLVMDialect>();
150     registerTransformOps<
151 #define GET_OP_LIST
152 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
153         >();
154   }
155 };
156 } // namespace
157 
158 #define GET_OP_CLASSES
159 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc"
160 
161 void mlir::vector::registerTransformDialectExtension(
162     DialectRegistry &registry) {
163   registry.addExtensions<VectorTransformDialectExtension>();
164 }
165