13fef2d26SRiver Riddle //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle // 93fef2d26SRiver Riddle // This file implements logic for testing Linalg fusion patterns. 103fef2d26SRiver Riddle // 113fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 123fef2d26SRiver Riddle 13eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 1436550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 153fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 164a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h" 173fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 183fef2d26SRiver Riddle #include "mlir/Pass/PassManager.h" 193fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 203fef2d26SRiver Riddle #include "mlir/Transforms/Passes.h" 213fef2d26SRiver Riddle 223fef2d26SRiver Riddle using namespace mlir; 233fef2d26SRiver Riddle using namespace mlir::linalg; 243fef2d26SRiver Riddle 2558ceae95SRiver Riddle static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { 263fef2d26SRiver Riddle OpBuilder b(f); 273fef2d26SRiver Riddle 283fef2d26SRiver Riddle // Save original Linalg ops, we only want to make a pass over those. 293fef2d26SRiver Riddle SmallVector<LinalgOp, 8> linalgOps; 303fef2d26SRiver Riddle f.walk([&](LinalgOp op) { 313fef2d26SRiver Riddle // TODO: support multi-results. 323fef2d26SRiver Riddle if (op->getNumResults() <= 1) 333fef2d26SRiver Riddle linalgOps.push_back(op); 343fef2d26SRiver Riddle }); 353fef2d26SRiver Riddle 363fef2d26SRiver Riddle // Tile and Fuse for tensors inputs (TODO: all tensor operands). 373fef2d26SRiver Riddle bool changed = false; 383fef2d26SRiver Riddle for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 39a7cccb9cSAlexander Belyaev for (OpOperand &opOperand : linalgOp->getOpOperands()) { 405550c821STres Popp if (isa<MemRefType>(opOperand.get().getType())) 41489fec27SNicolas Vasilache continue; 425550c821STres Popp if (isa<RankedTensorType>(opOperand.get().getType())) { 433fef2d26SRiver Riddle // Tile and Fuse tensor input. 44b4db15a9SAlexander Belyaev if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs()) 453fef2d26SRiver Riddle continue; 46a7cccb9cSAlexander Belyaev auto info = fuseProducerOfTensor(b, opOperand); 47489fec27SNicolas Vasilache if (failed(info)) 48489fec27SNicolas Vasilache continue; 493fef2d26SRiver Riddle auto *originalOp = info->originalProducer.getOperation(); 503fef2d26SRiver Riddle auto *originalOpInLinalgOpsVector = 513fef2d26SRiver Riddle std::find(linalgOps.begin(), linalgOps.end(), originalOp); 526089d612SRahul Kayaith *originalOpInLinalgOpsVector = info->fusedProducer; 533fef2d26SRiver Riddle // Don't mark for erasure in the tensor case, let DCE handle this. 543fef2d26SRiver Riddle changed = true; 553fef2d26SRiver Riddle } 563fef2d26SRiver Riddle } 573fef2d26SRiver Riddle } 583fef2d26SRiver Riddle 593fef2d26SRiver Riddle return changed ? success() : failure(); 603fef2d26SRiver Riddle } 613fef2d26SRiver Riddle 623fef2d26SRiver Riddle namespace { 633fef2d26SRiver Riddle struct TestLinalgGreedyFusion 6458ceae95SRiver Riddle : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> { 655e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion) 665e50dd04SRiver Riddle 673fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 684c48f016SMatthias Springer registry.insert<affine::AffineDialect, linalg::LinalgDialect, 694c48f016SMatthias Springer memref::MemRefDialect, scf::SCFDialect>(); 703fef2d26SRiver Riddle } 71b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 72b5e22e6dSMehdi Amini StringRef getDescription() const final { 73b5e22e6dSMehdi Amini return "Test Linalg fusion by applying a greedy test transformation."; 74b5e22e6dSMehdi Amini } 7541574554SRiver Riddle void runOnOperation() override { 763fef2d26SRiver Riddle MLIRContext *context = &getContext(); 773fef2d26SRiver Riddle RewritePatternSet patterns = 783fef2d26SRiver Riddle linalg::getLinalgTilingCanonicalizationPatterns(context); 792de2dbefSMatthias Springer patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 80d18ffd61SMatthias Springer scf::populateSCFForLoopCanonicalizationPatterns(patterns); 813fef2d26SRiver Riddle FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 8258ceae95SRiver Riddle OpPassManager pm(func::FuncOp::getOperationName()); 833fef2d26SRiver Riddle pm.addPass(createLoopInvariantCodeMotionPass()); 843fef2d26SRiver Riddle pm.addPass(createCanonicalizerPass()); 853fef2d26SRiver Riddle pm.addPass(createCSEPass()); 860f304ef0SRiver Riddle do { 87*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), frozenPatterns); 880f304ef0SRiver Riddle if (failed(runPipeline(pm, getOperation()))) 893fef2d26SRiver Riddle this->signalPassFailure(); 9041574554SRiver Riddle } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); 913fef2d26SRiver Riddle } 923fef2d26SRiver Riddle }; 933fef2d26SRiver Riddle } // namespace 943fef2d26SRiver Riddle 953fef2d26SRiver Riddle namespace mlir { 963fef2d26SRiver Riddle namespace test { 973fef2d26SRiver Riddle void registerTestLinalgGreedyFusion() { 98b5e22e6dSMehdi Amini PassRegistration<TestLinalgGreedyFusion>(); 993fef2d26SRiver Riddle } 1003fef2d26SRiver Riddle 1013fef2d26SRiver Riddle } // namespace test 1023fef2d26SRiver Riddle } // namespace mlir 103