xref: /llvm-project/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp (revision d072ca1a496cc3f4ad0adf6f7d43f76406a704d6)
1a5ef51d7SRiver Riddle //===- InlinerExtension.cpp - Func Inliner Extension ----------------------===//
2a5ef51d7SRiver Riddle //
3a5ef51d7SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a5ef51d7SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a5ef51d7SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a5ef51d7SRiver Riddle //
7a5ef51d7SRiver Riddle //===----------------------------------------------------------------------===//
8a5ef51d7SRiver Riddle 
9a5ef51d7SRiver Riddle #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
10a5ef51d7SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11a5ef51d7SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
12a5ef51d7SRiver Riddle #include "mlir/IR/DialectInterface.h"
13a5ef51d7SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
14a5ef51d7SRiver Riddle 
15a5ef51d7SRiver Riddle using namespace mlir;
16a5ef51d7SRiver Riddle using namespace mlir::func;
17a5ef51d7SRiver Riddle 
18a5ef51d7SRiver Riddle //===----------------------------------------------------------------------===//
19a5ef51d7SRiver Riddle // FuncDialect Interfaces
20a5ef51d7SRiver Riddle //===----------------------------------------------------------------------===//
21a5ef51d7SRiver Riddle namespace {
22a5ef51d7SRiver Riddle /// This class defines the interface for handling inlining with func operations.
23a5ef51d7SRiver Riddle struct FuncInlinerInterface : public DialectInlinerInterface {
24a5ef51d7SRiver Riddle   using DialectInlinerInterface::DialectInlinerInterface;
25a5ef51d7SRiver Riddle 
26a5ef51d7SRiver Riddle   //===--------------------------------------------------------------------===//
27a5ef51d7SRiver Riddle   // Analysis Hooks
28a5ef51d7SRiver Riddle   //===--------------------------------------------------------------------===//
29a5ef51d7SRiver Riddle 
30*d072ca1aSOleksandr "Alex" Zinenko   /// Call operations can be inlined unless specified otherwise by attributes
31*d072ca1aSOleksandr "Alex" Zinenko   /// on either the call or the callbale.
32a5ef51d7SRiver Riddle   bool isLegalToInline(Operation *call, Operation *callable,
33a5ef51d7SRiver Riddle                        bool wouldBeCloned) const final {
34*d072ca1aSOleksandr "Alex" Zinenko     auto callOp = dyn_cast<func::CallOp>(call);
35*d072ca1aSOleksandr "Alex" Zinenko     auto funcOp = dyn_cast<func::FuncOp>(callable);
36*d072ca1aSOleksandr "Alex" Zinenko     return !(callOp && callOp.getNoInline()) &&
37*d072ca1aSOleksandr "Alex" Zinenko            !(funcOp && funcOp.getNoInline());
38a5ef51d7SRiver Riddle   }
39a5ef51d7SRiver Riddle 
40a5ef51d7SRiver Riddle   /// All operations can be inlined.
41a5ef51d7SRiver Riddle   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
42a5ef51d7SRiver Riddle     return true;
43a5ef51d7SRiver Riddle   }
44a5ef51d7SRiver Riddle 
45*d072ca1aSOleksandr "Alex" Zinenko   /// All function bodies can be inlined.
46a5ef51d7SRiver Riddle   bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
47a5ef51d7SRiver Riddle     return true;
48a5ef51d7SRiver Riddle   }
49a5ef51d7SRiver Riddle 
50a5ef51d7SRiver Riddle   //===--------------------------------------------------------------------===//
51a5ef51d7SRiver Riddle   // Transformation Hooks
52a5ef51d7SRiver Riddle   //===--------------------------------------------------------------------===//
53a5ef51d7SRiver Riddle 
54a5ef51d7SRiver Riddle   /// Handle the given inlined terminator by replacing it with a new operation
55a5ef51d7SRiver Riddle   /// as necessary.
56a5ef51d7SRiver Riddle   void handleTerminator(Operation *op, Block *newDest) const final {
57a5ef51d7SRiver Riddle     // Only return needs to be handled here.
58a5ef51d7SRiver Riddle     auto returnOp = dyn_cast<ReturnOp>(op);
59a5ef51d7SRiver Riddle     if (!returnOp)
60a5ef51d7SRiver Riddle       return;
61a5ef51d7SRiver Riddle 
62a5ef51d7SRiver Riddle     // Replace the return with a branch to the dest.
63a5ef51d7SRiver Riddle     OpBuilder builder(op);
64a5ef51d7SRiver Riddle     builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
65a5ef51d7SRiver Riddle     op->erase();
66a5ef51d7SRiver Riddle   }
67a5ef51d7SRiver Riddle 
68a5ef51d7SRiver Riddle   /// Handle the given inlined terminator by replacing it with a new operation
69a5ef51d7SRiver Riddle   /// as necessary.
7026a0b277SMehdi Amini   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
71a5ef51d7SRiver Riddle     // Only return needs to be handled here.
72a5ef51d7SRiver Riddle     auto returnOp = cast<ReturnOp>(op);
73a5ef51d7SRiver Riddle 
74a5ef51d7SRiver Riddle     // Replace the values directly with the return operands.
75a5ef51d7SRiver Riddle     assert(returnOp.getNumOperands() == valuesToRepl.size());
76a5ef51d7SRiver Riddle     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
77a5ef51d7SRiver Riddle       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
78a5ef51d7SRiver Riddle   }
79a5ef51d7SRiver Riddle };
80a5ef51d7SRiver Riddle } // namespace
81a5ef51d7SRiver Riddle 
82a5ef51d7SRiver Riddle //===----------------------------------------------------------------------===//
83a5ef51d7SRiver Riddle // Registration
84a5ef51d7SRiver Riddle //===----------------------------------------------------------------------===//
85a5ef51d7SRiver Riddle 
86a5ef51d7SRiver Riddle void mlir::func::registerInlinerExtension(DialectRegistry &registry) {
87a5ef51d7SRiver Riddle   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
88a5ef51d7SRiver Riddle     dialect->addInterfaces<FuncInlinerInterface>();
89a5ef51d7SRiver Riddle 
90a5ef51d7SRiver Riddle     // The inliner extension relies on the ControlFlow dialect.
91a5ef51d7SRiver Riddle     ctx->getOrLoadDialect<cf::ControlFlowDialect>();
92a5ef51d7SRiver Riddle   });
93a5ef51d7SRiver Riddle }
94