xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Passes.h"
20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/Linalg/Utils/Utils.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #include "llvm/ADT/SmallVector.h"
28 
29 using namespace mlir;
30 using namespace mlir::linalg;
31 
32 namespace {
33 struct TestLinalgTransforms
34     : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
35   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
36 
37   TestLinalgTransforms() = default;
38   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
39 
40   void getDependentDialects(DialectRegistry &registry) const override {
41     // clang-format off
42     registry.insert<affine::AffineDialect,
43                     bufferization::BufferizationDialect,
44                     memref::MemRefDialect,
45                     scf::SCFDialect,
46                     linalg::LinalgDialect,
47                     vector::VectorDialect,
48                     gpu::GPUDialect>();
49     // clang-format on
50   }
51   StringRef getArgument() const final {
52     return "test-linalg-transform-patterns";
53   }
54   StringRef getDescription() const final {
55     return "Test Linalg transformation patterns by applying them greedily.";
56   }
57 
58   void runOnOperation() override;
59 
60   Option<bool> testPatterns{*this, "test-patterns",
61                             llvm::cl::desc("Test a mixed set of patterns"),
62                             llvm::cl::init(false)};
63   Option<bool> testVectorTransferForwardingPatterns{
64       *this, "test-vector-transfer-forwarding-patterns",
65       llvm::cl::desc(
66           "Test a fused pass that forwards memref.copy to vector.transfer"),
67       llvm::cl::init(false)};
68   Option<bool> testGenericToVectorPattern{
69       *this, "test-linalg-to-vector-patterns",
70       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
71                      "in vector.contract form"),
72       llvm::cl::init(false)};
73   Option<bool> testDecomposePadTensor{
74       *this, "test-decompose-pad-tensor",
75       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
76       llvm::cl::init(false)};
77   Option<bool> testDecomposeTensorPackOp{
78       *this, "test-decompose-tensor-pack",
79       llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
80                      "of tensor and Linalg ops"),
81       llvm::cl::init(false)};
82   Option<bool> testDecomposeTensorUnPackOp{
83       *this, "test-decompose-tensor-unpack",
84       llvm::cl::desc(
85           "Test transform that generalizes unpack ops into a sequence "
86           "of tensor and Linalg ops"),
87       llvm::cl::init(false)};
88   Option<bool> testSwapSubTensorPadTensor{
89       *this, "test-swap-subtensor-padtensor",
90       llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
91                      "tensor.pad(subtensor)"),
92       llvm::cl::init(false)};
93   ListOption<int64_t> peeledLoops{
94       *this, "peeled-loops",
95       llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
96   ListOption<int64_t> tileSizes{
97       *this, "tile-sizes",
98       llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
99   Option<bool> skipPartial{
100       *this, "skip-partial",
101       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
102       llvm::cl::init(false)};
103   Option<std::string> loopType{
104       *this, "loop-type",
105       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
106                      "tiled_loop"),
107       llvm::cl::init("for")};
108   Option<bool> testBubbleUpExtractSliceOpPattern{
109       *this, "test-bubble-up-extract-slice-op-pattern",
110       llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
111                      "extract_slice + linalgOp"),
112       llvm::cl::init(false)};
113   Option<bool> testSwapExtractSliceWithFill{
114       *this, "test-swap-extract-slice-with-fill-pattern",
115       llvm::cl::desc(
116           "Test patterns to swap tensor.extract_slice(linalg.fill())"),
117       llvm::cl::init(false)};
118   Option<bool> testEraseUnusedOperandsAndResults{
119       *this, "test-erase-unused-operands-and-results",
120       llvm::cl::desc("Test patterns to erase unused operands and results"),
121       llvm::cl::init(false)};
122   Option<bool> testEraseUnnecessaryInputs{
123       *this, "test-erase-unnecessary-inputs",
124       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
125       llvm::cl::init(false)};
126   Option<bool> testWinogradConv2D{
127       *this, "test-winograd-conv2d",
128       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
129       llvm::cl::init(false)};
130   Option<bool> testDecomposeWinogradOps{
131       *this, "test-decompose-winograd-ops",
132       llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
133 };
134 } // namespace
135 
136 static void applyPatterns(func::FuncOp funcOp) {
137   MLIRContext *ctx = funcOp.getContext();
138   RewritePatternSet patterns(ctx);
139 
140   //===--------------------------------------------------------------------===//
141   // Linalg distribution patterns.
142   //===--------------------------------------------------------------------===//
143   LinalgLoopDistributionOptions distributionOptions;
144 
145   //===--------------------------------------------------------------------===//
146   // Linalg to vector contraction patterns.
147   //===--------------------------------------------------------------------===//
148   patterns.add<CopyVectorizationPattern>(ctx);
149 
150   (void)applyPatternsGreedily(funcOp, std::move(patterns));
151 }
152 
153 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
154   RewritePatternSet forwardPattern(funcOp.getContext());
155   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
156   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
157   (void)applyPatternsGreedily(funcOp, std::move(forwardPattern));
158 }
159 
160 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
161   RewritePatternSet patterns(funcOp.getContext());
162   auto *ctx = funcOp.getContext();
163   patterns.add<CopyVectorizationPattern>(ctx);
164   populatePadOpVectorizationPatterns(patterns);
165   populateConvolutionVectorizationPatterns(patterns);
166   (void)applyPatternsGreedily(funcOp, std::move(patterns));
167 }
168 
169 static void applyDecomposePadPatterns(func::FuncOp funcOp) {
170   RewritePatternSet patterns(funcOp.getContext());
171   patterns.add<DecomposePadOpPattern>(funcOp.getContext());
172   (void)applyPatternsGreedily(funcOp, std::move(patterns));
173 }
174 
175 static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
176   RewritePatternSet patterns(funcOp.getContext());
177   patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext());
178   (void)applyPatternsGreedily(funcOp, std::move(patterns));
179 }
180 
181 static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
182   RewritePatternSet patterns(funcOp.getContext());
183   patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
184   (void)applyPatternsGreedily(funcOp, std::move(patterns));
185 }
186 
187 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
188   RewritePatternSet patterns(funcOp.getContext());
189   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
190   (void)applyPatternsGreedily(funcOp, std::move(patterns));
191 }
192 
193 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
194   RewritePatternSet patterns(funcOp.getContext());
195   populateBubbleUpExtractSliceOpPatterns(patterns);
196   (void)applyPatternsGreedily(funcOp, std::move(patterns));
197 }
198 
199 static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
200   RewritePatternSet patterns(funcOp.getContext());
201   populateSwapExtractSliceWithFillPatterns(patterns);
202   (void)applyPatternsGreedily(funcOp, std::move(patterns));
203 }
204 
205 static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
206   RewritePatternSet patterns(funcOp.getContext());
207   populateEraseUnusedOperandsAndResultsPatterns(patterns);
208   (void)applyPatternsGreedily(funcOp, std::move(patterns));
209 }
210 
211 static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
212   RewritePatternSet patterns(funcOp.getContext());
213   populateEraseUnnecessaryInputsPatterns(patterns);
214   (void)applyPatternsGreedily(funcOp, std::move(patterns));
215 }
216 
217 static void applyWinogradConv2D(func::FuncOp funcOp) {
218   RewritePatternSet patterns(funcOp.getContext());
219   populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
220   populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
221   (void)applyPatternsGreedily(funcOp, std::move(patterns));
222 }
223 
224 static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
225   RewritePatternSet patterns(funcOp.getContext());
226   populateDecomposeWinogradOpsPatterns(patterns);
227   (void)applyPatternsGreedily(funcOp, std::move(patterns));
228 }
229 
230 /// Apply transformations specified as patterns.
231 void TestLinalgTransforms::runOnOperation() {
232   if (testPatterns)
233     return applyPatterns(getOperation());
234   if (testVectorTransferForwardingPatterns)
235     return applyVectorTransferForwardingPatterns(getOperation());
236   if (testGenericToVectorPattern)
237     return applyLinalgToVectorPatterns(getOperation());
238   if (testDecomposePadTensor)
239     return applyDecomposePadPatterns(getOperation());
240   if (testDecomposeTensorPackOp)
241     return applyDecomposeTensorPackPatterns(getOperation());
242   if (testDecomposeTensorUnPackOp)
243     return applyDecomposeTensorUnPackPatterns(getOperation());
244   if (testSwapSubTensorPadTensor)
245     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
246   if (testBubbleUpExtractSliceOpPattern)
247     return applyBubbleUpExtractSliceOpPattern(getOperation());
248   if (testSwapExtractSliceWithFill)
249     return applySwapExtractSliceWithFillPattern(getOperation());
250   if (testEraseUnusedOperandsAndResults)
251     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
252   if (testEraseUnnecessaryInputs)
253     return applyEraseUnnecessaryInputs(getOperation());
254   if (testWinogradConv2D)
255     return applyWinogradConv2D(getOperation());
256   if (testDecomposeWinogradOps)
257     return applyDecomposeWinogradOps(getOperation());
258 }
259 
260 namespace mlir {
261 namespace test {
262 void registerTestLinalgTransforms() {
263   PassRegistration<TestLinalgTransforms>();
264 }
265 } // namespace test
266 } // namespace mlir
267