xref: /llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
13fef2d26SRiver Riddle //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file implements logic for testing Linalg fusion patterns.
103fef2d26SRiver Riddle //
113fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
123fef2d26SRiver Riddle 
13eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
1436550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
153fef2d26SRiver Riddle #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
164a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
173fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
183fef2d26SRiver Riddle #include "mlir/Pass/PassManager.h"
193fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
203fef2d26SRiver Riddle #include "mlir/Transforms/Passes.h"
213fef2d26SRiver Riddle 
223fef2d26SRiver Riddle using namespace mlir;
233fef2d26SRiver Riddle using namespace mlir::linalg;
243fef2d26SRiver Riddle 
2558ceae95SRiver Riddle static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
263fef2d26SRiver Riddle   OpBuilder b(f);
273fef2d26SRiver Riddle 
283fef2d26SRiver Riddle   // Save original Linalg ops, we only want to make a pass over those.
293fef2d26SRiver Riddle   SmallVector<LinalgOp, 8> linalgOps;
303fef2d26SRiver Riddle   f.walk([&](LinalgOp op) {
313fef2d26SRiver Riddle     // TODO: support multi-results.
323fef2d26SRiver Riddle     if (op->getNumResults() <= 1)
333fef2d26SRiver Riddle       linalgOps.push_back(op);
343fef2d26SRiver Riddle   });
353fef2d26SRiver Riddle 
363fef2d26SRiver Riddle   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
373fef2d26SRiver Riddle   bool changed = false;
383fef2d26SRiver Riddle   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
39a7cccb9cSAlexander Belyaev     for (OpOperand &opOperand : linalgOp->getOpOperands()) {
405550c821STres Popp       if (isa<MemRefType>(opOperand.get().getType()))
41489fec27SNicolas Vasilache         continue;
425550c821STres Popp       if (isa<RankedTensorType>(opOperand.get().getType())) {
433fef2d26SRiver Riddle         // Tile and Fuse tensor input.
44b4db15a9SAlexander Belyaev         if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs())
453fef2d26SRiver Riddle           continue;
46a7cccb9cSAlexander Belyaev         auto info = fuseProducerOfTensor(b, opOperand);
47489fec27SNicolas Vasilache         if (failed(info))
48489fec27SNicolas Vasilache           continue;
493fef2d26SRiver Riddle         auto *originalOp = info->originalProducer.getOperation();
503fef2d26SRiver Riddle         auto *originalOpInLinalgOpsVector =
513fef2d26SRiver Riddle             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
526089d612SRahul Kayaith         *originalOpInLinalgOpsVector = info->fusedProducer;
533fef2d26SRiver Riddle         // Don't mark for erasure in the tensor case, let DCE handle this.
543fef2d26SRiver Riddle         changed = true;
553fef2d26SRiver Riddle       }
563fef2d26SRiver Riddle     }
573fef2d26SRiver Riddle   }
583fef2d26SRiver Riddle 
593fef2d26SRiver Riddle   return changed ? success() : failure();
603fef2d26SRiver Riddle }
613fef2d26SRiver Riddle 
623fef2d26SRiver Riddle namespace {
633fef2d26SRiver Riddle struct TestLinalgGreedyFusion
6458ceae95SRiver Riddle     : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> {
655e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion)
665e50dd04SRiver Riddle 
673fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
684c48f016SMatthias Springer     registry.insert<affine::AffineDialect, linalg::LinalgDialect,
694c48f016SMatthias Springer                     memref::MemRefDialect, scf::SCFDialect>();
703fef2d26SRiver Riddle   }
71b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-linalg-greedy-fusion"; }
72b5e22e6dSMehdi Amini   StringRef getDescription() const final {
73b5e22e6dSMehdi Amini     return "Test Linalg fusion by applying a greedy test transformation.";
74b5e22e6dSMehdi Amini   }
7541574554SRiver Riddle   void runOnOperation() override {
763fef2d26SRiver Riddle     MLIRContext *context = &getContext();
773fef2d26SRiver Riddle     RewritePatternSet patterns =
783fef2d26SRiver Riddle         linalg::getLinalgTilingCanonicalizationPatterns(context);
792de2dbefSMatthias Springer     patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
80d18ffd61SMatthias Springer     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
813fef2d26SRiver Riddle     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
8258ceae95SRiver Riddle     OpPassManager pm(func::FuncOp::getOperationName());
833fef2d26SRiver Riddle     pm.addPass(createLoopInvariantCodeMotionPass());
843fef2d26SRiver Riddle     pm.addPass(createCanonicalizerPass());
853fef2d26SRiver Riddle     pm.addPass(createCSEPass());
860f304ef0SRiver Riddle     do {
87*09dfc571SJacques Pienaar       (void)applyPatternsGreedily(getOperation(), frozenPatterns);
880f304ef0SRiver Riddle       if (failed(runPipeline(pm, getOperation())))
893fef2d26SRiver Riddle         this->signalPassFailure();
9041574554SRiver Riddle     } while (succeeded(fuseLinalgOpsGreedily(getOperation())));
913fef2d26SRiver Riddle   }
923fef2d26SRiver Riddle };
933fef2d26SRiver Riddle } // namespace
943fef2d26SRiver Riddle 
953fef2d26SRiver Riddle namespace mlir {
963fef2d26SRiver Riddle namespace test {
973fef2d26SRiver Riddle void registerTestLinalgGreedyFusion() {
98b5e22e6dSMehdi Amini   PassRegistration<TestLinalgGreedyFusion>();
993fef2d26SRiver Riddle }
1003fef2d26SRiver Riddle 
1013fef2d26SRiver Riddle } // namespace test
1023fef2d26SRiver Riddle } // namespace mlir
103