13fef2d26SRiver Riddle //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle // 93fef2d26SRiver Riddle // This file implements logic for testing Linalg transformations. 103fef2d26SRiver Riddle // 113fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 123fef2d26SRiver Riddle 133fef2d26SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 15178f9bd6SNicolas Vasilache #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1636550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 17d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 19800694a6SNicolas Vasilache #include "mlir/Dialect/Linalg/Passes.h" 203fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 213fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 223fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Utils/Utils.h" 2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 24800694a6SNicolas Vasilache #include "mlir/Pass/PassManager.h" 253fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 263fef2d26SRiver Riddle 27c57c4f88SMatthias Springer #include "llvm/ADT/SmallVector.h" 283fef2d26SRiver Riddle 293fef2d26SRiver Riddle using namespace mlir; 303fef2d26SRiver Riddle using namespace mlir::linalg; 313fef2d26SRiver Riddle 323fef2d26SRiver Riddle namespace { 333fef2d26SRiver Riddle struct TestLinalgTransforms 3458ceae95SRiver Riddle : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> { 355e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms) 365e50dd04SRiver Riddle 373fef2d26SRiver Riddle TestLinalgTransforms() = default; 383bab9d4eSMehdi Amini TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 393fef2d26SRiver Riddle 403fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 413fef2d26SRiver Riddle // clang-format off 424c48f016SMatthias Springer registry.insert<affine::AffineDialect, 43178f9bd6SNicolas Vasilache bufferization::BufferizationDialect, 443fef2d26SRiver Riddle memref::MemRefDialect, 453fef2d26SRiver Riddle scf::SCFDialect, 46fd0c6f53SAlexander Belyaev linalg::LinalgDialect, 473fef2d26SRiver Riddle vector::VectorDialect, 483fef2d26SRiver Riddle gpu::GPUDialect>(); 493fef2d26SRiver Riddle // clang-format on 503fef2d26SRiver Riddle } 51b5e22e6dSMehdi Amini StringRef getArgument() const final { 52b5e22e6dSMehdi Amini return "test-linalg-transform-patterns"; 53b5e22e6dSMehdi Amini } 54b5e22e6dSMehdi Amini StringRef getDescription() const final { 55b5e22e6dSMehdi Amini return "Test Linalg transformation patterns by applying them greedily."; 56b5e22e6dSMehdi Amini } 573fef2d26SRiver Riddle 5841574554SRiver Riddle void runOnOperation() override; 593fef2d26SRiver Riddle 603fef2d26SRiver Riddle Option<bool> testPatterns{*this, "test-patterns", 613fef2d26SRiver Riddle llvm::cl::desc("Test a mixed set of patterns"), 623fef2d26SRiver Riddle llvm::cl::init(false)}; 633fef2d26SRiver Riddle Option<bool> testVectorTransferForwardingPatterns{ 643fef2d26SRiver Riddle *this, "test-vector-transfer-forwarding-patterns", 653fef2d26SRiver Riddle llvm::cl::desc( 66ebc81537SAlexander Belyaev "Test a fused pass that forwards memref.copy to vector.transfer"), 673fef2d26SRiver Riddle llvm::cl::init(false)}; 683fef2d26SRiver Riddle Option<bool> testGenericToVectorPattern{ 693fef2d26SRiver Riddle *this, "test-linalg-to-vector-patterns", 703fef2d26SRiver Riddle llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 713fef2d26SRiver Riddle "in vector.contract form"), 723fef2d26SRiver Riddle llvm::cl::init(false)}; 731b2c8f10SAndrzej Warzyński Option<bool> testDecomposePadTensor{ 741b2c8f10SAndrzej Warzyński *this, "test-decompose-pad-tensor", 7535df2f6fSYi Zhang llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 7635df2f6fSYi Zhang llvm::cl::init(false)}; 7707750882SAndrzej Warzyński Option<bool> testDecomposeTensorPackOp{ 7807750882SAndrzej Warzyński *this, "test-decompose-tensor-pack", 793ebc6beeSHanhan Wang llvm::cl::desc("Test transform that generalizes pack ops into a sequence " 803ebc6beeSHanhan Wang "of tensor and Linalg ops"), 813ebc6beeSHanhan Wang llvm::cl::init(false)}; 8207750882SAndrzej Warzyński Option<bool> testDecomposeTensorUnPackOp{ 8307750882SAndrzej Warzyński *this, "test-decompose-tensor-unpack", 843ebc6beeSHanhan Wang llvm::cl::desc( 853ebc6beeSHanhan Wang "Test transform that generalizes unpack ops into a sequence " 86644f0f83SHanhan Wang "of tensor and Linalg ops"), 87644f0f83SHanhan Wang llvm::cl::init(false)}; 8824199f53SMatthias Springer Option<bool> testSwapSubTensorPadTensor{ 8924199f53SMatthias Springer *this, "test-swap-subtensor-padtensor", 901ad9b266Slorenzo chelini llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " 911ad9b266Slorenzo chelini "tensor.pad(subtensor)"), 9224199f53SMatthias Springer llvm::cl::init(false)}; 938faf35c0SMatthias Springer ListOption<int64_t> peeledLoops{ 948faf35c0SMatthias Springer *this, "peeled-loops", 9562a4e6abSFangrui Song llvm::cl::desc("Loops to be peeled when test-tile-pattern")}; 968faf35c0SMatthias Springer ListOption<int64_t> tileSizes{ 978faf35c0SMatthias Springer *this, "tile-sizes", 9862a4e6abSFangrui Song llvm::cl::desc("Linalg tile sizes for test-tile-pattern")}; 99a4a654d3SMatthias Springer Option<bool> skipPartial{ 100a4a654d3SMatthias Springer *this, "skip-partial", 101a4a654d3SMatthias Springer llvm::cl::desc("Skip loops inside partial iterations during peeling"), 102a4a654d3SMatthias Springer llvm::cl::init(false)}; 1032190f8a8SMatthias Springer Option<std::string> loopType{ 1042190f8a8SMatthias Springer *this, "loop-type", 1052190f8a8SMatthias Springer llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 1062190f8a8SMatthias Springer "tiled_loop"), 1072190f8a8SMatthias Springer llvm::cl::init("for")}; 10865bdeddbSOkwan Kwon Option<bool> testBubbleUpExtractSliceOpPattern{ 10965bdeddbSOkwan Kwon *this, "test-bubble-up-extract-slice-op-pattern", 11065bdeddbSOkwan Kwon llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " 11165bdeddbSOkwan Kwon "extract_slice + linalgOp"), 11265bdeddbSOkwan Kwon llvm::cl::init(false)}; 113c325e978SLei Zhang Option<bool> testSwapExtractSliceWithFill{ 114c325e978SLei Zhang *this, "test-swap-extract-slice-with-fill-pattern", 115c325e978SLei Zhang llvm::cl::desc( 116c325e978SLei Zhang "Test patterns to swap tensor.extract_slice(linalg.fill())"), 117c325e978SLei Zhang llvm::cl::init(false)}; 118da8a8e92SMahesh Ravishankar Option<bool> testEraseUnusedOperandsAndResults{ 119da8a8e92SMahesh Ravishankar *this, "test-erase-unused-operands-and-results", 120da8a8e92SMahesh Ravishankar llvm::cl::desc("Test patterns to erase unused operands and results"), 121da8a8e92SMahesh Ravishankar llvm::cl::init(false)}; 122e7328a9eSMatthias Springer Option<bool> testEraseUnnecessaryInputs{ 123e7328a9eSMatthias Springer *this, "test-erase-unnecessary-inputs", 124e7328a9eSMatthias Springer llvm::cl::desc("Test patterns to erase unnecessary inputs"), 125e7328a9eSMatthias Springer llvm::cl::init(false)}; 1267d246e84SHsiangkai Wang Option<bool> testWinogradConv2D{ 1277d246e84SHsiangkai Wang *this, "test-winograd-conv2d", 1287d246e84SHsiangkai Wang llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), 1297d246e84SHsiangkai Wang llvm::cl::init(false)}; 13027ee33d1SHsiangkai Wang Option<bool> testDecomposeWinogradOps{ 13127ee33d1SHsiangkai Wang *this, "test-decompose-winograd-ops", 13227ee33d1SHsiangkai Wang llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)}; 1333fef2d26SRiver Riddle }; 134be0a7e9fSMehdi Amini } // namespace 1353fef2d26SRiver Riddle 13658ceae95SRiver Riddle static void applyPatterns(func::FuncOp funcOp) { 1373fef2d26SRiver Riddle MLIRContext *ctx = funcOp.getContext(); 1383fef2d26SRiver Riddle RewritePatternSet patterns(ctx); 1393fef2d26SRiver Riddle 1403fef2d26SRiver Riddle //===--------------------------------------------------------------------===// 1413fef2d26SRiver Riddle // Linalg distribution patterns. 1423fef2d26SRiver Riddle //===--------------------------------------------------------------------===// 1433fef2d26SRiver Riddle LinalgLoopDistributionOptions distributionOptions; 1443fef2d26SRiver Riddle 1453fef2d26SRiver Riddle //===--------------------------------------------------------------------===// 1463fef2d26SRiver Riddle // Linalg to vector contraction patterns. 1473fef2d26SRiver Riddle //===--------------------------------------------------------------------===// 148ebc81537SAlexander Belyaev patterns.add<CopyVectorizationPattern>(ctx); 1493fef2d26SRiver Riddle 150*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 1513fef2d26SRiver Riddle } 1523fef2d26SRiver Riddle 15358ceae95SRiver Riddle static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { 1543fef2d26SRiver Riddle RewritePatternSet forwardPattern(funcOp.getContext()); 1553fef2d26SRiver Riddle forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 1563fef2d26SRiver Riddle forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 157*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(forwardPattern)); 1583fef2d26SRiver Riddle } 1593fef2d26SRiver Riddle 16058ceae95SRiver Riddle static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { 1613fef2d26SRiver Riddle RewritePatternSet patterns(funcOp.getContext()); 162ebc81537SAlexander Belyaev auto *ctx = funcOp.getContext(); 163ebc81537SAlexander Belyaev patterns.add<CopyVectorizationPattern>(ctx); 164fd0c6f53SAlexander Belyaev populatePadOpVectorizationPatterns(patterns); 1656bb7d247SNicolas Vasilache populateConvolutionVectorizationPatterns(patterns); 166*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 1673fef2d26SRiver Riddle } 1683fef2d26SRiver Riddle 1691b2c8f10SAndrzej Warzyński static void applyDecomposePadPatterns(func::FuncOp funcOp) { 17035df2f6fSYi Zhang RewritePatternSet patterns(funcOp.getContext()); 1711b2c8f10SAndrzej Warzyński patterns.add<DecomposePadOpPattern>(funcOp.getContext()); 172*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 17335df2f6fSYi Zhang } 17435df2f6fSYi Zhang 17507750882SAndrzej Warzyński static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) { 176644f0f83SHanhan Wang RewritePatternSet patterns(funcOp.getContext()); 17707750882SAndrzej Warzyński patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext()); 178*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 179644f0f83SHanhan Wang } 180644f0f83SHanhan Wang 18107750882SAndrzej Warzyński static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) { 1823ebc6beeSHanhan Wang RewritePatternSet patterns(funcOp.getContext()); 18307750882SAndrzej Warzyński patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext()); 184*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 1853ebc6beeSHanhan Wang } 1863ebc6beeSHanhan Wang 18758ceae95SRiver Riddle static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 18824199f53SMatthias Springer RewritePatternSet patterns(funcOp.getContext()); 189060208b4SMatthias Springer patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 190*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 19124199f53SMatthias Springer } 19224199f53SMatthias Springer 19358ceae95SRiver Riddle static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { 19465bdeddbSOkwan Kwon RewritePatternSet patterns(funcOp.getContext()); 19565bdeddbSOkwan Kwon populateBubbleUpExtractSliceOpPatterns(patterns); 196*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 19765bdeddbSOkwan Kwon } 19865bdeddbSOkwan Kwon 199c325e978SLei Zhang static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) { 200c325e978SLei Zhang RewritePatternSet patterns(funcOp.getContext()); 201c325e978SLei Zhang populateSwapExtractSliceWithFillPatterns(patterns); 202*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 203c325e978SLei Zhang } 204c325e978SLei Zhang 205da8a8e92SMahesh Ravishankar static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) { 206da8a8e92SMahesh Ravishankar RewritePatternSet patterns(funcOp.getContext()); 207da8a8e92SMahesh Ravishankar populateEraseUnusedOperandsAndResultsPatterns(patterns); 208*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 209da8a8e92SMahesh Ravishankar } 210da8a8e92SMahesh Ravishankar 211e7328a9eSMatthias Springer static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { 212e7328a9eSMatthias Springer RewritePatternSet patterns(funcOp.getContext()); 213e7328a9eSMatthias Springer populateEraseUnnecessaryInputsPatterns(patterns); 214*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 215e7328a9eSMatthias Springer } 216e7328a9eSMatthias Springer 2177d246e84SHsiangkai Wang static void applyWinogradConv2D(func::FuncOp funcOp) { 2187d246e84SHsiangkai Wang RewritePatternSet patterns(funcOp.getContext()); 2197d246e84SHsiangkai Wang populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); 2207d246e84SHsiangkai Wang populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); 221*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 2227d246e84SHsiangkai Wang } 2237d246e84SHsiangkai Wang 22427ee33d1SHsiangkai Wang static void applyDecomposeWinogradOps(func::FuncOp funcOp) { 22527ee33d1SHsiangkai Wang RewritePatternSet patterns(funcOp.getContext()); 22627ee33d1SHsiangkai Wang populateDecomposeWinogradOpsPatterns(patterns); 227*09dfc571SJacques Pienaar (void)applyPatternsGreedily(funcOp, std::move(patterns)); 22827ee33d1SHsiangkai Wang } 22927ee33d1SHsiangkai Wang 2303fef2d26SRiver Riddle /// Apply transformations specified as patterns. 23141574554SRiver Riddle void TestLinalgTransforms::runOnOperation() { 2323fef2d26SRiver Riddle if (testPatterns) 23341574554SRiver Riddle return applyPatterns(getOperation()); 2343fef2d26SRiver Riddle if (testVectorTransferForwardingPatterns) 23541574554SRiver Riddle return applyVectorTransferForwardingPatterns(getOperation()); 2363fef2d26SRiver Riddle if (testGenericToVectorPattern) 23741574554SRiver Riddle return applyLinalgToVectorPatterns(getOperation()); 2381b2c8f10SAndrzej Warzyński if (testDecomposePadTensor) 2391b2c8f10SAndrzej Warzyński return applyDecomposePadPatterns(getOperation()); 24007750882SAndrzej Warzyński if (testDecomposeTensorPackOp) 24107750882SAndrzej Warzyński return applyDecomposeTensorPackPatterns(getOperation()); 24207750882SAndrzej Warzyński if (testDecomposeTensorUnPackOp) 24307750882SAndrzej Warzyński return applyDecomposeTensorUnPackPatterns(getOperation()); 24424199f53SMatthias Springer if (testSwapSubTensorPadTensor) 24541574554SRiver Riddle return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 24665bdeddbSOkwan Kwon if (testBubbleUpExtractSliceOpPattern) 24765bdeddbSOkwan Kwon return applyBubbleUpExtractSliceOpPattern(getOperation()); 248c325e978SLei Zhang if (testSwapExtractSliceWithFill) 249c325e978SLei Zhang return applySwapExtractSliceWithFillPattern(getOperation()); 250da8a8e92SMahesh Ravishankar if (testEraseUnusedOperandsAndResults) 251da8a8e92SMahesh Ravishankar return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); 252e7328a9eSMatthias Springer if (testEraseUnnecessaryInputs) 253e7328a9eSMatthias Springer return applyEraseUnnecessaryInputs(getOperation()); 2547d246e84SHsiangkai Wang if (testWinogradConv2D) 2557d246e84SHsiangkai Wang return applyWinogradConv2D(getOperation()); 25627ee33d1SHsiangkai Wang if (testDecomposeWinogradOps) 25727ee33d1SHsiangkai Wang return applyDecomposeWinogradOps(getOperation()); 2583fef2d26SRiver Riddle } 2593fef2d26SRiver Riddle 2603fef2d26SRiver Riddle namespace mlir { 2613fef2d26SRiver Riddle namespace test { 2623fef2d26SRiver Riddle void registerTestLinalgTransforms() { 263b5e22e6dSMehdi Amini PassRegistration<TestLinalgTransforms>(); 2643fef2d26SRiver Riddle } 2653fef2d26SRiver Riddle } // namespace test 2663fef2d26SRiver Riddle } // namespace mlir 267