xref: /llvm-project/mlir/test/lib/Transforms/TestInlining.cpp (revision fd01d8626cdcce9f34caab060f8d3fd35f6661cc)
10ba00878SRiver Riddle //===- TestInlining.cpp - Pass to inline calls in the test dialect --------===//
20ba00878SRiver Riddle //
356222a06SMehdi Amini // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ba00878SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
80ba00878SRiver Riddle //
90ba00878SRiver Riddle // TODO(riverriddle) This pass is only necessary because the main inlining pass
100ba00878SRiver Riddle // has no abstracted away the call+callee relationship. When the inlining
110ba00878SRiver Riddle // interface has this support, this pass should be removed.
120ba00878SRiver Riddle //
130ba00878SRiver Riddle //===----------------------------------------------------------------------===//
140ba00878SRiver Riddle 
150ba00878SRiver Riddle #include "TestDialect.h"
160ba00878SRiver Riddle #include "mlir/Dialect/StandardOps/Ops.h"
170ba00878SRiver Riddle #include "mlir/IR/Function.h"
180ba00878SRiver Riddle #include "mlir/Pass/Pass.h"
190ba00878SRiver Riddle #include "mlir/Transforms/InliningUtils.h"
200ba00878SRiver Riddle #include "mlir/Transforms/Passes.h"
210ba00878SRiver Riddle #include "llvm/ADT/StringSet.h"
220ba00878SRiver Riddle 
230ba00878SRiver Riddle using namespace mlir;
240ba00878SRiver Riddle 
250ba00878SRiver Riddle namespace {
260ba00878SRiver Riddle struct Inliner : public FunctionPass<Inliner> {
270ba00878SRiver Riddle   void runOnFunction() override {
280ba00878SRiver Riddle     auto function = getFunction();
290ba00878SRiver Riddle 
300ba00878SRiver Riddle     // Collect each of the direct function calls within the module.
310ba00878SRiver Riddle     SmallVector<CallIndirectOp, 16> callers;
320ba00878SRiver Riddle     function.walk([&](CallIndirectOp caller) { callers.push_back(caller); });
330ba00878SRiver Riddle 
340ba00878SRiver Riddle     // Build the inliner interface.
350ba00878SRiver Riddle     InlinerInterface interface(&getContext());
360ba00878SRiver Riddle 
370ba00878SRiver Riddle     // Try to inline each of the call operations.
380ba00878SRiver Riddle     for (auto caller : callers) {
390ba00878SRiver Riddle       auto callee = dyn_cast_or_null<FunctionalRegionOp>(
400ba00878SRiver Riddle           caller.getCallee()->getDefiningOp());
410ba00878SRiver Riddle       if (!callee)
420ba00878SRiver Riddle         continue;
430ba00878SRiver Riddle 
440ba00878SRiver Riddle       // Inline the functional region operation, but only clone the internal
450ba00878SRiver Riddle       // region if there is more than one use.
460ba00878SRiver Riddle       if (failed(inlineRegion(
470ba00878SRiver Riddle               interface, &callee.body(), caller,
480ba00878SRiver Riddle               llvm::to_vector<8>(caller.getArgOperands()),
49*fd01d862SRiver Riddle               SmallVector<Value, 8>(caller.getResults()), caller.getLoc(),
500ba00878SRiver Riddle               /*shouldCloneInlinedRegion=*/!callee.getResult()->hasOneUse())))
510ba00878SRiver Riddle         continue;
520ba00878SRiver Riddle 
530ba00878SRiver Riddle       // If the inlining was successful then erase the call and callee if
540ba00878SRiver Riddle       // possible.
550ba00878SRiver Riddle       caller.erase();
560ba00878SRiver Riddle       if (callee.use_empty())
570ba00878SRiver Riddle         callee.erase();
580ba00878SRiver Riddle     }
590ba00878SRiver Riddle   }
600ba00878SRiver Riddle };
610ba00878SRiver Riddle } // end anonymous namespace
620ba00878SRiver Riddle 
630ba00878SRiver Riddle static PassRegistration<Inliner> pass("test-inline",
640ba00878SRiver Riddle                                       "Test inlining region calls");
65