xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
13fef2d26SRiver Riddle //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file implements logic for testing Linalg transformations.
103fef2d26SRiver Riddle //
113fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
123fef2d26SRiver Riddle 
133fef2d26SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
15178f9bd6SNicolas Vasilache #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1636550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
17d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
19800694a6SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h"
203fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
213fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
223fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Utils/Utils.h"
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
24800694a6SNicolas Vasilache #include "mlir/Pass/PassManager.h"
253fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
263fef2d26SRiver Riddle 
27c57c4f88SMatthias Springer #include "llvm/ADT/SmallVector.h"
283fef2d26SRiver Riddle 
293fef2d26SRiver Riddle using namespace mlir;
303fef2d26SRiver Riddle using namespace mlir::linalg;
313fef2d26SRiver Riddle 
323fef2d26SRiver Riddle namespace {
333fef2d26SRiver Riddle struct TestLinalgTransforms
3458ceae95SRiver Riddle     : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> {
355e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms)
365e50dd04SRiver Riddle 
373fef2d26SRiver Riddle   TestLinalgTransforms() = default;
383bab9d4eSMehdi Amini   TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {}
393fef2d26SRiver Riddle 
403fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
413fef2d26SRiver Riddle     // clang-format off
424c48f016SMatthias Springer     registry.insert<affine::AffineDialect,
43178f9bd6SNicolas Vasilache                     bufferization::BufferizationDialect,
443fef2d26SRiver Riddle                     memref::MemRefDialect,
453fef2d26SRiver Riddle                     scf::SCFDialect,
46fd0c6f53SAlexander Belyaev                     linalg::LinalgDialect,
473fef2d26SRiver Riddle                     vector::VectorDialect,
483fef2d26SRiver Riddle                     gpu::GPUDialect>();
493fef2d26SRiver Riddle     // clang-format on
503fef2d26SRiver Riddle   }
51b5e22e6dSMehdi Amini   StringRef getArgument() const final {
52b5e22e6dSMehdi Amini     return "test-linalg-transform-patterns";
53b5e22e6dSMehdi Amini   }
54b5e22e6dSMehdi Amini   StringRef getDescription() const final {
55b5e22e6dSMehdi Amini     return "Test Linalg transformation patterns by applying them greedily.";
56b5e22e6dSMehdi Amini   }
573fef2d26SRiver Riddle 
5841574554SRiver Riddle   void runOnOperation() override;
593fef2d26SRiver Riddle 
603fef2d26SRiver Riddle   Option<bool> testPatterns{*this, "test-patterns",
613fef2d26SRiver Riddle                             llvm::cl::desc("Test a mixed set of patterns"),
623fef2d26SRiver Riddle                             llvm::cl::init(false)};
633fef2d26SRiver Riddle   Option<bool> testVectorTransferForwardingPatterns{
643fef2d26SRiver Riddle       *this, "test-vector-transfer-forwarding-patterns",
653fef2d26SRiver Riddle       llvm::cl::desc(
66ebc81537SAlexander Belyaev           "Test a fused pass that forwards memref.copy to vector.transfer"),
673fef2d26SRiver Riddle       llvm::cl::init(false)};
683fef2d26SRiver Riddle   Option<bool> testGenericToVectorPattern{
693fef2d26SRiver Riddle       *this, "test-linalg-to-vector-patterns",
703fef2d26SRiver Riddle       llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
713fef2d26SRiver Riddle                      "in vector.contract form"),
723fef2d26SRiver Riddle       llvm::cl::init(false)};
731b2c8f10SAndrzej Warzyński   Option<bool> testDecomposePadTensor{
741b2c8f10SAndrzej Warzyński       *this, "test-decompose-pad-tensor",
7535df2f6fSYi Zhang       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
7635df2f6fSYi Zhang       llvm::cl::init(false)};
7707750882SAndrzej Warzyński   Option<bool> testDecomposeTensorPackOp{
7807750882SAndrzej Warzyński       *this, "test-decompose-tensor-pack",
793ebc6beeSHanhan Wang       llvm::cl::desc("Test transform that generalizes pack ops into a sequence "
803ebc6beeSHanhan Wang                      "of tensor and Linalg ops"),
813ebc6beeSHanhan Wang       llvm::cl::init(false)};
8207750882SAndrzej Warzyński   Option<bool> testDecomposeTensorUnPackOp{
8307750882SAndrzej Warzyński       *this, "test-decompose-tensor-unpack",
843ebc6beeSHanhan Wang       llvm::cl::desc(
853ebc6beeSHanhan Wang           "Test transform that generalizes unpack ops into a sequence "
86644f0f83SHanhan Wang           "of tensor and Linalg ops"),
87644f0f83SHanhan Wang       llvm::cl::init(false)};
8824199f53SMatthias Springer   Option<bool> testSwapSubTensorPadTensor{
8924199f53SMatthias Springer       *this, "test-swap-subtensor-padtensor",
901ad9b266Slorenzo chelini       llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
911ad9b266Slorenzo chelini                      "tensor.pad(subtensor)"),
9224199f53SMatthias Springer       llvm::cl::init(false)};
938faf35c0SMatthias Springer   ListOption<int64_t> peeledLoops{
948faf35c0SMatthias Springer       *this, "peeled-loops",
9562a4e6abSFangrui Song       llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
968faf35c0SMatthias Springer   ListOption<int64_t> tileSizes{
978faf35c0SMatthias Springer       *this, "tile-sizes",
9862a4e6abSFangrui Song       llvm::cl::desc("Linalg tile sizes for test-tile-pattern")};
99a4a654d3SMatthias Springer   Option<bool> skipPartial{
100a4a654d3SMatthias Springer       *this, "skip-partial",
101a4a654d3SMatthias Springer       llvm::cl::desc("Skip loops inside partial iterations during peeling"),
102a4a654d3SMatthias Springer       llvm::cl::init(false)};
1032190f8a8SMatthias Springer   Option<std::string> loopType{
1042190f8a8SMatthias Springer       *this, "loop-type",
1052190f8a8SMatthias Springer       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
1062190f8a8SMatthias Springer                      "tiled_loop"),
1072190f8a8SMatthias Springer       llvm::cl::init("for")};
10865bdeddbSOkwan Kwon   Option<bool> testBubbleUpExtractSliceOpPattern{
10965bdeddbSOkwan Kwon       *this, "test-bubble-up-extract-slice-op-pattern",
11065bdeddbSOkwan Kwon       llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
11165bdeddbSOkwan Kwon                      "extract_slice + linalgOp"),
11265bdeddbSOkwan Kwon       llvm::cl::init(false)};
113c325e978SLei Zhang   Option<bool> testSwapExtractSliceWithFill{
114c325e978SLei Zhang       *this, "test-swap-extract-slice-with-fill-pattern",
115c325e978SLei Zhang       llvm::cl::desc(
116c325e978SLei Zhang           "Test patterns to swap tensor.extract_slice(linalg.fill())"),
117c325e978SLei Zhang       llvm::cl::init(false)};
118da8a8e92SMahesh Ravishankar   Option<bool> testEraseUnusedOperandsAndResults{
119da8a8e92SMahesh Ravishankar       *this, "test-erase-unused-operands-and-results",
120da8a8e92SMahesh Ravishankar       llvm::cl::desc("Test patterns to erase unused operands and results"),
121da8a8e92SMahesh Ravishankar       llvm::cl::init(false)};
122e7328a9eSMatthias Springer   Option<bool> testEraseUnnecessaryInputs{
123e7328a9eSMatthias Springer       *this, "test-erase-unnecessary-inputs",
124e7328a9eSMatthias Springer       llvm::cl::desc("Test patterns to erase unnecessary inputs"),
125e7328a9eSMatthias Springer       llvm::cl::init(false)};
1267d246e84SHsiangkai Wang   Option<bool> testWinogradConv2D{
1277d246e84SHsiangkai Wang       *this, "test-winograd-conv2d",
1287d246e84SHsiangkai Wang       llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"),
1297d246e84SHsiangkai Wang       llvm::cl::init(false)};
13027ee33d1SHsiangkai Wang   Option<bool> testDecomposeWinogradOps{
13127ee33d1SHsiangkai Wang       *this, "test-decompose-winograd-ops",
13227ee33d1SHsiangkai Wang       llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)};
1333fef2d26SRiver Riddle };
134be0a7e9fSMehdi Amini } // namespace
1353fef2d26SRiver Riddle 
13658ceae95SRiver Riddle static void applyPatterns(func::FuncOp funcOp) {
1373fef2d26SRiver Riddle   MLIRContext *ctx = funcOp.getContext();
1383fef2d26SRiver Riddle   RewritePatternSet patterns(ctx);
1393fef2d26SRiver Riddle 
1403fef2d26SRiver Riddle   //===--------------------------------------------------------------------===//
1413fef2d26SRiver Riddle   // Linalg distribution patterns.
1423fef2d26SRiver Riddle   //===--------------------------------------------------------------------===//
1433fef2d26SRiver Riddle   LinalgLoopDistributionOptions distributionOptions;
1443fef2d26SRiver Riddle 
1453fef2d26SRiver Riddle   //===--------------------------------------------------------------------===//
1463fef2d26SRiver Riddle   // Linalg to vector contraction patterns.
1473fef2d26SRiver Riddle   //===--------------------------------------------------------------------===//
148ebc81537SAlexander Belyaev   patterns.add<CopyVectorizationPattern>(ctx);
1493fef2d26SRiver Riddle 
150*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
1513fef2d26SRiver Riddle }
1523fef2d26SRiver Riddle 
15358ceae95SRiver Riddle static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
1543fef2d26SRiver Riddle   RewritePatternSet forwardPattern(funcOp.getContext());
1553fef2d26SRiver Riddle   forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
1563fef2d26SRiver Riddle   forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
157*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(forwardPattern));
1583fef2d26SRiver Riddle }
1593fef2d26SRiver Riddle 
16058ceae95SRiver Riddle static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
1613fef2d26SRiver Riddle   RewritePatternSet patterns(funcOp.getContext());
162ebc81537SAlexander Belyaev   auto *ctx = funcOp.getContext();
163ebc81537SAlexander Belyaev   patterns.add<CopyVectorizationPattern>(ctx);
164fd0c6f53SAlexander Belyaev   populatePadOpVectorizationPatterns(patterns);
1656bb7d247SNicolas Vasilache   populateConvolutionVectorizationPatterns(patterns);
166*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
1673fef2d26SRiver Riddle }
1683fef2d26SRiver Riddle 
1691b2c8f10SAndrzej Warzyński static void applyDecomposePadPatterns(func::FuncOp funcOp) {
17035df2f6fSYi Zhang   RewritePatternSet patterns(funcOp.getContext());
1711b2c8f10SAndrzej Warzyński   patterns.add<DecomposePadOpPattern>(funcOp.getContext());
172*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
17335df2f6fSYi Zhang }
17435df2f6fSYi Zhang 
17507750882SAndrzej Warzyński static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
176644f0f83SHanhan Wang   RewritePatternSet patterns(funcOp.getContext());
17707750882SAndrzej Warzyński   patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext());
178*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
179644f0f83SHanhan Wang }
180644f0f83SHanhan Wang 
18107750882SAndrzej Warzyński static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
1823ebc6beeSHanhan Wang   RewritePatternSet patterns(funcOp.getContext());
18307750882SAndrzej Warzyński   patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
184*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
1853ebc6beeSHanhan Wang }
1863ebc6beeSHanhan Wang 
18758ceae95SRiver Riddle static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
18824199f53SMatthias Springer   RewritePatternSet patterns(funcOp.getContext());
189060208b4SMatthias Springer   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
190*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
19124199f53SMatthias Springer }
19224199f53SMatthias Springer 
19358ceae95SRiver Riddle static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
19465bdeddbSOkwan Kwon   RewritePatternSet patterns(funcOp.getContext());
19565bdeddbSOkwan Kwon   populateBubbleUpExtractSliceOpPatterns(patterns);
196*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
19765bdeddbSOkwan Kwon }
19865bdeddbSOkwan Kwon 
199c325e978SLei Zhang static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
200c325e978SLei Zhang   RewritePatternSet patterns(funcOp.getContext());
201c325e978SLei Zhang   populateSwapExtractSliceWithFillPatterns(patterns);
202*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
203c325e978SLei Zhang }
204c325e978SLei Zhang 
205da8a8e92SMahesh Ravishankar static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
206da8a8e92SMahesh Ravishankar   RewritePatternSet patterns(funcOp.getContext());
207da8a8e92SMahesh Ravishankar   populateEraseUnusedOperandsAndResultsPatterns(patterns);
208*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
209da8a8e92SMahesh Ravishankar }
210da8a8e92SMahesh Ravishankar 
211e7328a9eSMatthias Springer static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
212e7328a9eSMatthias Springer   RewritePatternSet patterns(funcOp.getContext());
213e7328a9eSMatthias Springer   populateEraseUnnecessaryInputsPatterns(patterns);
214*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
215e7328a9eSMatthias Springer }
216e7328a9eSMatthias Springer 
2177d246e84SHsiangkai Wang static void applyWinogradConv2D(func::FuncOp funcOp) {
2187d246e84SHsiangkai Wang   RewritePatternSet patterns(funcOp.getContext());
2197d246e84SHsiangkai Wang   populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
2207d246e84SHsiangkai Wang   populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
221*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
2227d246e84SHsiangkai Wang }
2237d246e84SHsiangkai Wang 
22427ee33d1SHsiangkai Wang static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
22527ee33d1SHsiangkai Wang   RewritePatternSet patterns(funcOp.getContext());
22627ee33d1SHsiangkai Wang   populateDecomposeWinogradOpsPatterns(patterns);
227*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(funcOp, std::move(patterns));
22827ee33d1SHsiangkai Wang }
22927ee33d1SHsiangkai Wang 
2303fef2d26SRiver Riddle /// Apply transformations specified as patterns.
23141574554SRiver Riddle void TestLinalgTransforms::runOnOperation() {
2323fef2d26SRiver Riddle   if (testPatterns)
23341574554SRiver Riddle     return applyPatterns(getOperation());
2343fef2d26SRiver Riddle   if (testVectorTransferForwardingPatterns)
23541574554SRiver Riddle     return applyVectorTransferForwardingPatterns(getOperation());
2363fef2d26SRiver Riddle   if (testGenericToVectorPattern)
23741574554SRiver Riddle     return applyLinalgToVectorPatterns(getOperation());
2381b2c8f10SAndrzej Warzyński   if (testDecomposePadTensor)
2391b2c8f10SAndrzej Warzyński     return applyDecomposePadPatterns(getOperation());
24007750882SAndrzej Warzyński   if (testDecomposeTensorPackOp)
24107750882SAndrzej Warzyński     return applyDecomposeTensorPackPatterns(getOperation());
24207750882SAndrzej Warzyński   if (testDecomposeTensorUnPackOp)
24307750882SAndrzej Warzyński     return applyDecomposeTensorUnPackPatterns(getOperation());
24424199f53SMatthias Springer   if (testSwapSubTensorPadTensor)
24541574554SRiver Riddle     return applyExtractSliceOfPadTensorSwapPattern(getOperation());
24665bdeddbSOkwan Kwon   if (testBubbleUpExtractSliceOpPattern)
24765bdeddbSOkwan Kwon     return applyBubbleUpExtractSliceOpPattern(getOperation());
248c325e978SLei Zhang   if (testSwapExtractSliceWithFill)
249c325e978SLei Zhang     return applySwapExtractSliceWithFillPattern(getOperation());
250da8a8e92SMahesh Ravishankar   if (testEraseUnusedOperandsAndResults)
251da8a8e92SMahesh Ravishankar     return applyEraseUnusedOperandsAndResultsPatterns(getOperation());
252e7328a9eSMatthias Springer   if (testEraseUnnecessaryInputs)
253e7328a9eSMatthias Springer     return applyEraseUnnecessaryInputs(getOperation());
2547d246e84SHsiangkai Wang   if (testWinogradConv2D)
2557d246e84SHsiangkai Wang     return applyWinogradConv2D(getOperation());
25627ee33d1SHsiangkai Wang   if (testDecomposeWinogradOps)
25727ee33d1SHsiangkai Wang     return applyDecomposeWinogradOps(getOperation());
2583fef2d26SRiver Riddle }
2593fef2d26SRiver Riddle 
2603fef2d26SRiver Riddle namespace mlir {
2613fef2d26SRiver Riddle namespace test {
2623fef2d26SRiver Riddle void registerTestLinalgTransforms() {
263b5e22e6dSMehdi Amini   PassRegistration<TestLinalgTransforms>();
2643fef2d26SRiver Riddle }
2653fef2d26SRiver Riddle } // namespace test
2663fef2d26SRiver Riddle } // namespace mlir
267