xref: /llvm-project/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp (revision d072ca1a496cc3f4ad0adf6f7d43f76406a704d6)
1 //===- InlinerExtension.cpp - Func Inliner Extension ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/DialectInterface.h"
13 #include "mlir/Transforms/InliningUtils.h"
14 
15 using namespace mlir;
16 using namespace mlir::func;
17 
18 //===----------------------------------------------------------------------===//
19 // FuncDialect Interfaces
20 //===----------------------------------------------------------------------===//
21 namespace {
22 /// This class defines the interface for handling inlining with func operations.
23 struct FuncInlinerInterface : public DialectInlinerInterface {
24   using DialectInlinerInterface::DialectInlinerInterface;
25 
26   //===--------------------------------------------------------------------===//
27   // Analysis Hooks
28   //===--------------------------------------------------------------------===//
29 
30   /// Call operations can be inlined unless specified otherwise by attributes
31   /// on either the call or the callbale.
32   bool isLegalToInline(Operation *call, Operation *callable,
33                        bool wouldBeCloned) const final {
34     auto callOp = dyn_cast<func::CallOp>(call);
35     auto funcOp = dyn_cast<func::FuncOp>(callable);
36     return !(callOp && callOp.getNoInline()) &&
37            !(funcOp && funcOp.getNoInline());
38   }
39 
40   /// All operations can be inlined.
41   bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
42     return true;
43   }
44 
45   /// All function bodies can be inlined.
46   bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
47     return true;
48   }
49 
50   //===--------------------------------------------------------------------===//
51   // Transformation Hooks
52   //===--------------------------------------------------------------------===//
53 
54   /// Handle the given inlined terminator by replacing it with a new operation
55   /// as necessary.
56   void handleTerminator(Operation *op, Block *newDest) const final {
57     // Only return needs to be handled here.
58     auto returnOp = dyn_cast<ReturnOp>(op);
59     if (!returnOp)
60       return;
61 
62     // Replace the return with a branch to the dest.
63     OpBuilder builder(op);
64     builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
65     op->erase();
66   }
67 
68   /// Handle the given inlined terminator by replacing it with a new operation
69   /// as necessary.
70   void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
71     // Only return needs to be handled here.
72     auto returnOp = cast<ReturnOp>(op);
73 
74     // Replace the values directly with the return operands.
75     assert(returnOp.getNumOperands() == valuesToRepl.size());
76     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
77       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
78   }
79 };
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // Registration
84 //===----------------------------------------------------------------------===//
85 
86 void mlir::func::registerInlinerExtension(DialectRegistry &registry) {
87   registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
88     dialect->addInterfaces<FuncInlinerInterface>();
89 
90     // The inliner extension relies on the ControlFlow dialect.
91     ctx->getOrLoadDialect<cf::ControlFlowDialect>();
92   });
93 }
94