xref: /llvm-project/mlir/test/lib/Transforms/TestInlining.cpp (revision 69d757c0e8ffc5b49fda10df38e470a56d616ef4)
10ba00878SRiver Riddle //===- TestInlining.cpp - Pass to inline calls in the test dialect --------===//
20ba00878SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ba00878SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
80ba00878SRiver Riddle //
90ba00878SRiver Riddle // TODO(riverriddle) This pass is only necessary because the main inlining pass
100ba00878SRiver Riddle // has no abstracted away the call+callee relationship. When the inlining
110ba00878SRiver Riddle // interface has this support, this pass should be removed.
120ba00878SRiver Riddle //
130ba00878SRiver Riddle //===----------------------------------------------------------------------===//
140ba00878SRiver Riddle 
150ba00878SRiver Riddle #include "TestDialect.h"
16*69d757c0SRob Suderman #include "mlir/Dialect/StandardOps/IR/Ops.h"
170ba00878SRiver Riddle #include "mlir/IR/Function.h"
180ba00878SRiver Riddle #include "mlir/Pass/Pass.h"
190ba00878SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
200ba00878SRiver Riddle #include "mlir/Transforms/Passes.h"
210ba00878SRiver Riddle #include "llvm/ADT/StringSet.h"
220ba00878SRiver Riddle 
230ba00878SRiver Riddle using namespace mlir;
240ba00878SRiver Riddle 
250ba00878SRiver Riddle namespace {
260ba00878SRiver Riddle struct Inliner : public FunctionPass<Inliner> {
270ba00878SRiver Riddle   void runOnFunction() override {
280ba00878SRiver Riddle     auto function = getFunction();
290ba00878SRiver Riddle 
300ba00878SRiver Riddle     // Collect each of the direct function calls within the module.
310ba00878SRiver Riddle     SmallVector<CallIndirectOp, 16> callers;
320ba00878SRiver Riddle     function.walk([&](CallIndirectOp caller) { callers.push_back(caller); });
330ba00878SRiver Riddle 
340ba00878SRiver Riddle     // Build the inliner interface.
350ba00878SRiver Riddle     InlinerInterface interface(&getContext());
360ba00878SRiver Riddle 
370ba00878SRiver Riddle     // Try to inline each of the call operations.
380ba00878SRiver Riddle     for (auto caller : callers) {
390ba00878SRiver Riddle       auto callee = dyn_cast_or_null<FunctionalRegionOp>(
402bdf33ccSRiver Riddle           caller.getCallee().getDefiningOp());
410ba00878SRiver Riddle       if (!callee)
420ba00878SRiver Riddle         continue;
430ba00878SRiver Riddle 
440ba00878SRiver Riddle       // Inline the functional region operation, but only clone the internal
450ba00878SRiver Riddle       // region if there is more than one use.
460ba00878SRiver Riddle       if (failed(inlineRegion(
470ba00878SRiver Riddle               interface, &callee.body(), caller,
480ba00878SRiver Riddle               llvm::to_vector<8>(caller.getArgOperands()),
49fd01d862SRiver Riddle               SmallVector<Value, 8>(caller.getResults()), caller.getLoc(),
502bdf33ccSRiver Riddle               /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
510ba00878SRiver Riddle         continue;
520ba00878SRiver Riddle 
530ba00878SRiver Riddle       // If the inlining was successful then erase the call and callee if
540ba00878SRiver Riddle       // possible.
550ba00878SRiver Riddle       caller.erase();
560ba00878SRiver Riddle       if (callee.use_empty())
570ba00878SRiver Riddle         callee.erase();
580ba00878SRiver Riddle     }
590ba00878SRiver Riddle   }
600ba00878SRiver Riddle };
610ba00878SRiver Riddle } // end anonymous namespace
620ba00878SRiver Riddle 
63c6477050SMehdi Amini namespace mlir {
64c6477050SMehdi Amini void registerInliner() {
65c6477050SMehdi Amini   PassRegistration<Inliner>("test-inline", "Test inlining region calls");
66c6477050SMehdi Amini }
67c6477050SMehdi Amini } // namespace mlir
68