1 //===- InlinerExtension.cpp - Func Inliner Extension ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" 10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 11 #include "mlir/Dialect/Func/IR/FuncOps.h" 12 #include "mlir/IR/DialectInterface.h" 13 #include "mlir/Transforms/InliningUtils.h" 14 15 using namespace mlir; 16 using namespace mlir::func; 17 18 //===----------------------------------------------------------------------===// 19 // FuncDialect Interfaces 20 //===----------------------------------------------------------------------===// 21 namespace { 22 /// This class defines the interface for handling inlining with func operations. 23 struct FuncInlinerInterface : public DialectInlinerInterface { 24 using DialectInlinerInterface::DialectInlinerInterface; 25 26 //===--------------------------------------------------------------------===// 27 // Analysis Hooks 28 //===--------------------------------------------------------------------===// 29 30 /// Call operations can be inlined unless specified otherwise by attributes 31 /// on either the call or the callbale. 32 bool isLegalToInline(Operation *call, Operation *callable, 33 bool wouldBeCloned) const final { 34 auto callOp = dyn_cast<func::CallOp>(call); 35 auto funcOp = dyn_cast<func::FuncOp>(callable); 36 return !(callOp && callOp.getNoInline()) && 37 !(funcOp && funcOp.getNoInline()); 38 } 39 40 /// All operations can be inlined. 41 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 42 return true; 43 } 44 45 /// All function bodies can be inlined. 46 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { 47 return true; 48 } 49 50 //===--------------------------------------------------------------------===// 51 // Transformation Hooks 52 //===--------------------------------------------------------------------===// 53 54 /// Handle the given inlined terminator by replacing it with a new operation 55 /// as necessary. 56 void handleTerminator(Operation *op, Block *newDest) const final { 57 // Only return needs to be handled here. 58 auto returnOp = dyn_cast<ReturnOp>(op); 59 if (!returnOp) 60 return; 61 62 // Replace the return with a branch to the dest. 63 OpBuilder builder(op); 64 builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); 65 op->erase(); 66 } 67 68 /// Handle the given inlined terminator by replacing it with a new operation 69 /// as necessary. 70 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { 71 // Only return needs to be handled here. 72 auto returnOp = cast<ReturnOp>(op); 73 74 // Replace the values directly with the return operands. 75 assert(returnOp.getNumOperands() == valuesToRepl.size()); 76 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 77 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 78 } 79 }; 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // Registration 84 //===----------------------------------------------------------------------===// 85 86 void mlir::func::registerInlinerExtension(DialectRegistry ®istry) { 87 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { 88 dialect->addInterfaces<FuncInlinerInterface>(); 89 90 // The inliner extension relies on the ControlFlow dialect. 91 ctx->getOrLoadDialect<cf::ControlFlowDialect>(); 92 }); 93 } 94