1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/Linalg/Utils/Utils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, OperationPass<func::FuncOp>> { 35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgTransforms) 36 37 TestLinalgTransforms() = default; 38 TestLinalgTransforms(const TestLinalgTransforms &pass) : PassWrapper(pass) {} 39 40 void getDependentDialects(DialectRegistry ®istry) const override { 41 // clang-format off 42 registry.insert<affine::AffineDialect, 43 bufferization::BufferizationDialect, 44 memref::MemRefDialect, 45 scf::SCFDialect, 46 linalg::LinalgDialect, 47 vector::VectorDialect, 48 gpu::GPUDialect>(); 49 // clang-format on 50 } 51 StringRef getArgument() const final { 52 return "test-linalg-transform-patterns"; 53 } 54 StringRef getDescription() const final { 55 return "Test Linalg transformation patterns by applying them greedily."; 56 } 57 58 void runOnOperation() override; 59 60 Option<bool> testPatterns{*this, "test-patterns", 61 llvm::cl::desc("Test a mixed set of patterns"), 62 llvm::cl::init(false)}; 63 Option<bool> testVectorTransferForwardingPatterns{ 64 *this, "test-vector-transfer-forwarding-patterns", 65 llvm::cl::desc( 66 "Test a fused pass that forwards memref.copy to vector.transfer"), 67 llvm::cl::init(false)}; 68 Option<bool> testGenericToVectorPattern{ 69 *this, "test-linalg-to-vector-patterns", 70 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 71 "in vector.contract form"), 72 llvm::cl::init(false)}; 73 Option<bool> testDecomposePadTensor{ 74 *this, "test-decompose-pad-tensor", 75 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 76 llvm::cl::init(false)}; 77 Option<bool> testDecomposeTensorPackOp{ 78 *this, "test-decompose-tensor-pack", 79 llvm::cl::desc("Test transform that generalizes pack ops into a sequence " 80 "of tensor and Linalg ops"), 81 llvm::cl::init(false)}; 82 Option<bool> testDecomposeTensorUnPackOp{ 83 *this, "test-decompose-tensor-unpack", 84 llvm::cl::desc( 85 "Test transform that generalizes unpack ops into a sequence " 86 "of tensor and Linalg ops"), 87 llvm::cl::init(false)}; 88 Option<bool> testSwapSubTensorPadTensor{ 89 *this, "test-swap-subtensor-padtensor", 90 llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into " 91 "tensor.pad(subtensor)"), 92 llvm::cl::init(false)}; 93 ListOption<int64_t> peeledLoops{ 94 *this, "peeled-loops", 95 llvm::cl::desc("Loops to be peeled when test-tile-pattern")}; 96 ListOption<int64_t> tileSizes{ 97 *this, "tile-sizes", 98 llvm::cl::desc("Linalg tile sizes for test-tile-pattern")}; 99 Option<bool> skipPartial{ 100 *this, "skip-partial", 101 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 102 llvm::cl::init(false)}; 103 Option<std::string> loopType{ 104 *this, "loop-type", 105 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 106 "tiled_loop"), 107 llvm::cl::init("for")}; 108 Option<bool> testBubbleUpExtractSliceOpPattern{ 109 *this, "test-bubble-up-extract-slice-op-pattern", 110 llvm::cl::desc("Test rewrite of linalgOp + extract_slice into " 111 "extract_slice + linalgOp"), 112 llvm::cl::init(false)}; 113 Option<bool> testSwapExtractSliceWithFill{ 114 *this, "test-swap-extract-slice-with-fill-pattern", 115 llvm::cl::desc( 116 "Test patterns to swap tensor.extract_slice(linalg.fill())"), 117 llvm::cl::init(false)}; 118 Option<bool> testEraseUnusedOperandsAndResults{ 119 *this, "test-erase-unused-operands-and-results", 120 llvm::cl::desc("Test patterns to erase unused operands and results"), 121 llvm::cl::init(false)}; 122 Option<bool> testEraseUnnecessaryInputs{ 123 *this, "test-erase-unnecessary-inputs", 124 llvm::cl::desc("Test patterns to erase unnecessary inputs"), 125 llvm::cl::init(false)}; 126 Option<bool> testWinogradConv2D{ 127 *this, "test-winograd-conv2d", 128 llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), 129 llvm::cl::init(false)}; 130 Option<bool> testDecomposeWinogradOps{ 131 *this, "test-decompose-winograd-ops", 132 llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)}; 133 }; 134 } // namespace 135 136 static void applyPatterns(func::FuncOp funcOp) { 137 MLIRContext *ctx = funcOp.getContext(); 138 RewritePatternSet patterns(ctx); 139 140 //===--------------------------------------------------------------------===// 141 // Linalg distribution patterns. 142 //===--------------------------------------------------------------------===// 143 LinalgLoopDistributionOptions distributionOptions; 144 145 //===--------------------------------------------------------------------===// 146 // Linalg to vector contraction patterns. 147 //===--------------------------------------------------------------------===// 148 patterns.add<CopyVectorizationPattern>(ctx); 149 150 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 151 } 152 153 static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { 154 RewritePatternSet forwardPattern(funcOp.getContext()); 155 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 156 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 157 (void)applyPatternsGreedily(funcOp, std::move(forwardPattern)); 158 } 159 160 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { 161 RewritePatternSet patterns(funcOp.getContext()); 162 auto *ctx = funcOp.getContext(); 163 patterns.add<CopyVectorizationPattern>(ctx); 164 populatePadOpVectorizationPatterns(patterns); 165 populateConvolutionVectorizationPatterns(patterns); 166 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 167 } 168 169 static void applyDecomposePadPatterns(func::FuncOp funcOp) { 170 RewritePatternSet patterns(funcOp.getContext()); 171 patterns.add<DecomposePadOpPattern>(funcOp.getContext()); 172 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 173 } 174 175 static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) { 176 RewritePatternSet patterns(funcOp.getContext()); 177 patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext()); 178 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 179 } 180 181 static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) { 182 RewritePatternSet patterns(funcOp.getContext()); 183 patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext()); 184 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 185 } 186 187 static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { 188 RewritePatternSet patterns(funcOp.getContext()); 189 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 190 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 191 } 192 193 static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { 194 RewritePatternSet patterns(funcOp.getContext()); 195 populateBubbleUpExtractSliceOpPatterns(patterns); 196 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 197 } 198 199 static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) { 200 RewritePatternSet patterns(funcOp.getContext()); 201 populateSwapExtractSliceWithFillPatterns(patterns); 202 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 203 } 204 205 static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) { 206 RewritePatternSet patterns(funcOp.getContext()); 207 populateEraseUnusedOperandsAndResultsPatterns(patterns); 208 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 209 } 210 211 static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { 212 RewritePatternSet patterns(funcOp.getContext()); 213 populateEraseUnnecessaryInputsPatterns(patterns); 214 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 215 } 216 217 static void applyWinogradConv2D(func::FuncOp funcOp) { 218 RewritePatternSet patterns(funcOp.getContext()); 219 populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); 220 populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); 221 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 222 } 223 224 static void applyDecomposeWinogradOps(func::FuncOp funcOp) { 225 RewritePatternSet patterns(funcOp.getContext()); 226 populateDecomposeWinogradOpsPatterns(patterns); 227 (void)applyPatternsGreedily(funcOp, std::move(patterns)); 228 } 229 230 /// Apply transformations specified as patterns. 231 void TestLinalgTransforms::runOnOperation() { 232 if (testPatterns) 233 return applyPatterns(getOperation()); 234 if (testVectorTransferForwardingPatterns) 235 return applyVectorTransferForwardingPatterns(getOperation()); 236 if (testGenericToVectorPattern) 237 return applyLinalgToVectorPatterns(getOperation()); 238 if (testDecomposePadTensor) 239 return applyDecomposePadPatterns(getOperation()); 240 if (testDecomposeTensorPackOp) 241 return applyDecomposeTensorPackPatterns(getOperation()); 242 if (testDecomposeTensorUnPackOp) 243 return applyDecomposeTensorUnPackPatterns(getOperation()); 244 if (testSwapSubTensorPadTensor) 245 return applyExtractSliceOfPadTensorSwapPattern(getOperation()); 246 if (testBubbleUpExtractSliceOpPattern) 247 return applyBubbleUpExtractSliceOpPattern(getOperation()); 248 if (testSwapExtractSliceWithFill) 249 return applySwapExtractSliceWithFillPattern(getOperation()); 250 if (testEraseUnusedOperandsAndResults) 251 return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); 252 if (testEraseUnnecessaryInputs) 253 return applyEraseUnnecessaryInputs(getOperation()); 254 if (testWinogradConv2D) 255 return applyWinogradConv2D(getOperation()); 256 if (testDecomposeWinogradOps) 257 return applyDecomposeWinogradOps(getOperation()); 258 } 259 260 namespace mlir { 261 namespace test { 262 void registerTestLinalgTransforms() { 263 PassRegistration<TestLinalgTransforms>(); 264 } 265 } // namespace test 266 } // namespace mlir 267