xref: /llvm-project/llvm/unittests/Analysis/PluginInlineAdvisorAnalysisTest.cpp (revision c84a99dfd391eb4d89aff8d6453016045098b444)
107af0e2dSibricchi #include "llvm/Analysis/CallGraph.h"
207af0e2dSibricchi #include "llvm/AsmParser/Parser.h"
307af0e2dSibricchi #include "llvm/Config/config.h"
474deadf1SNikita Popov #include "llvm/IR/Module.h"
507af0e2dSibricchi #include "llvm/Passes/PassBuilder.h"
607af0e2dSibricchi #include "llvm/Passes/PassPlugin.h"
707af0e2dSibricchi #include "llvm/Support/CommandLine.h"
807af0e2dSibricchi #include "llvm/Support/raw_ostream.h"
907af0e2dSibricchi #include "llvm/Testing/Support/Error.h"
1007af0e2dSibricchi #include "gtest/gtest.h"
1107af0e2dSibricchi 
1207af0e2dSibricchi namespace llvm {
1307af0e2dSibricchi 
1465f7ebe7Sibricchi namespace {
1565f7ebe7Sibricchi 
1607af0e2dSibricchi void anchor() {}
1707af0e2dSibricchi 
1807af0e2dSibricchi static std::string libPath(const std::string Name = "InlineAdvisorPlugin") {
1907af0e2dSibricchi   const auto &Argvs = testing::internal::GetArgvs();
2007af0e2dSibricchi   const char *Argv0 =
2107af0e2dSibricchi       Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest";
2207af0e2dSibricchi   void *Ptr = (void *)(intptr_t)anchor;
2307af0e2dSibricchi   std::string Path = sys::fs::getMainExecutable(Argv0, Ptr);
2407af0e2dSibricchi   llvm::SmallString<256> Buf{sys::path::parent_path(Path)};
2507af0e2dSibricchi   sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str());
2607af0e2dSibricchi   return std::string(Buf.str());
2707af0e2dSibricchi }
2807af0e2dSibricchi 
2907af0e2dSibricchi // Example of a custom InlineAdvisor that only inlines calls to functions called
3007af0e2dSibricchi // "foo".
3107af0e2dSibricchi class FooOnlyInlineAdvisor : public InlineAdvisor {
3207af0e2dSibricchi public:
3307af0e2dSibricchi   FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM,
3407af0e2dSibricchi                        InlineParams Params, InlineContext IC)
3507af0e2dSibricchi       : InlineAdvisor(M, FAM, IC) {}
3607af0e2dSibricchi 
3707af0e2dSibricchi   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override {
3807af0e2dSibricchi     if (CB.getCalledFunction()->getName() == "foo")
3907af0e2dSibricchi       return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true);
4007af0e2dSibricchi     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
4107af0e2dSibricchi   }
4207af0e2dSibricchi };
4307af0e2dSibricchi 
4407af0e2dSibricchi static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM,
4507af0e2dSibricchi                                      InlineParams Params, InlineContext IC) {
4607af0e2dSibricchi   return new FooOnlyInlineAdvisor(M, FAM, Params, IC);
4707af0e2dSibricchi }
4807af0e2dSibricchi 
4907af0e2dSibricchi struct CompilerInstance {
5007af0e2dSibricchi   LLVMContext Ctx;
5107af0e2dSibricchi   ModulePassManager MPM;
5207af0e2dSibricchi   InlineParams IP;
5307af0e2dSibricchi 
5407af0e2dSibricchi   PassBuilder PB;
5507af0e2dSibricchi   LoopAnalysisManager LAM;
5607af0e2dSibricchi   FunctionAnalysisManager FAM;
5707af0e2dSibricchi   CGSCCAnalysisManager CGAM;
5807af0e2dSibricchi   ModuleAnalysisManager MAM;
5907af0e2dSibricchi 
6007af0e2dSibricchi   SMDiagnostic Error;
6107af0e2dSibricchi 
6207af0e2dSibricchi   // connect the plugin to our compiler instance
6307af0e2dSibricchi   void setupPlugin() {
6407af0e2dSibricchi     auto PluginPath = libPath();
6507af0e2dSibricchi     ASSERT_NE("", PluginPath);
6607af0e2dSibricchi     Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath);
6707af0e2dSibricchi     ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath;
6807af0e2dSibricchi     Plugin->registerPassBuilderCallbacks(PB);
6907af0e2dSibricchi   }
7007af0e2dSibricchi 
7107af0e2dSibricchi   // connect the FooOnlyInlineAdvisor to our compiler instance
7207af0e2dSibricchi   void setupFooOnly() {
7307af0e2dSibricchi     MAM.registerPass(
7407af0e2dSibricchi         [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); });
7507af0e2dSibricchi   }
7607af0e2dSibricchi 
7707af0e2dSibricchi   CompilerInstance() {
7807af0e2dSibricchi     IP = getInlineParams(3, 0);
7907af0e2dSibricchi     PB.registerModuleAnalyses(MAM);
8007af0e2dSibricchi     PB.registerCGSCCAnalyses(CGAM);
8107af0e2dSibricchi     PB.registerFunctionAnalyses(FAM);
8207af0e2dSibricchi     PB.registerLoopAnalyses(LAM);
8307af0e2dSibricchi     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
8407af0e2dSibricchi     MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default,
8507af0e2dSibricchi                                   ThinOrFullLTOPhase::None));
8607af0e2dSibricchi   }
8707af0e2dSibricchi 
8807af0e2dSibricchi   std::string output;
8907af0e2dSibricchi   std::unique_ptr<Module> outputM;
9007af0e2dSibricchi 
91*ab4253f6SMichele Scandale   auto run(StringRef IR) {
9207af0e2dSibricchi     outputM = parseAssemblyString(IR, Error, Ctx);
9307af0e2dSibricchi     MPM.run(*outputM, MAM);
9407af0e2dSibricchi     ASSERT_TRUE(outputM);
9507af0e2dSibricchi     output.clear();
9607af0e2dSibricchi     raw_string_ostream o_stream{output};
9707af0e2dSibricchi     outputM->print(o_stream, nullptr);
9807af0e2dSibricchi     ASSERT_TRUE(true);
9907af0e2dSibricchi   }
10007af0e2dSibricchi };
10107af0e2dSibricchi 
10207af0e2dSibricchi StringRef TestIRS[] = {
10307af0e2dSibricchi     // Simple 3 function inline case
10407af0e2dSibricchi     R"(
10507af0e2dSibricchi define void @f1() {
10607af0e2dSibricchi   call void @foo()
10707af0e2dSibricchi   ret void
10807af0e2dSibricchi }
10907af0e2dSibricchi define void @foo() {
11007af0e2dSibricchi   call void @f3()
11107af0e2dSibricchi   ret void
11207af0e2dSibricchi }
11307af0e2dSibricchi define void @f3() {
11407af0e2dSibricchi   ret void
11507af0e2dSibricchi }
11607af0e2dSibricchi   )",
11707af0e2dSibricchi     // Test that has 5 functions of which 2 are recursive
11807af0e2dSibricchi     R"(
11907af0e2dSibricchi define void @f1() {
12007af0e2dSibricchi   call void @foo()
12107af0e2dSibricchi   ret void
12207af0e2dSibricchi }
12307af0e2dSibricchi define void @f2() {
12407af0e2dSibricchi   call void @foo()
12507af0e2dSibricchi   ret void
12607af0e2dSibricchi }
12707af0e2dSibricchi define void @foo() {
12807af0e2dSibricchi   call void @f4()
12907af0e2dSibricchi   call void @f5()
13007af0e2dSibricchi   ret void
13107af0e2dSibricchi }
13207af0e2dSibricchi define void @f4() {
13307af0e2dSibricchi   ret void
13407af0e2dSibricchi }
13507af0e2dSibricchi define void @f5() {
13607af0e2dSibricchi   call void @foo()
13707af0e2dSibricchi   ret void
13807af0e2dSibricchi }
13907af0e2dSibricchi   )",
14007af0e2dSibricchi     // test with 2 mutually recursive functions and 1 function with a loop
14107af0e2dSibricchi     R"(
14207af0e2dSibricchi define void @f1() {
14307af0e2dSibricchi   call void @f2()
14407af0e2dSibricchi   ret void
14507af0e2dSibricchi }
14607af0e2dSibricchi define void @f2() {
14707af0e2dSibricchi   call void @f3()
14807af0e2dSibricchi   ret void
14907af0e2dSibricchi }
15007af0e2dSibricchi define void @f3() {
15107af0e2dSibricchi   call void @f1()
15207af0e2dSibricchi   ret void
15307af0e2dSibricchi }
15407af0e2dSibricchi define void @f4() {
15507af0e2dSibricchi   br label %loop
15607af0e2dSibricchi loop:
15707af0e2dSibricchi   call void @f5()
15807af0e2dSibricchi   br label %loop
15907af0e2dSibricchi }
16007af0e2dSibricchi define void @f5() {
16107af0e2dSibricchi   ret void
16207af0e2dSibricchi }
16307af0e2dSibricchi   )",
16407af0e2dSibricchi     // test that has a function that computes fibonacci in a loop, one in a
16507af0e2dSibricchi     // recurisve manner, and one that calls both and compares them
16607af0e2dSibricchi     R"(
16707af0e2dSibricchi define i32 @fib_loop(i32 %n){
16807af0e2dSibricchi     %curr = alloca i32
16907af0e2dSibricchi     %last = alloca i32
17007af0e2dSibricchi     %i = alloca i32
17107af0e2dSibricchi     store i32 1, i32* %curr
17207af0e2dSibricchi     store i32 1, i32* %last
17307af0e2dSibricchi     store i32 2, i32* %i
17407af0e2dSibricchi     br label %loop_cond
17507af0e2dSibricchi   loop_cond:
17607af0e2dSibricchi     %i_val = load i32, i32* %i
17707af0e2dSibricchi     %cmp = icmp slt i32 %i_val, %n
17807af0e2dSibricchi     br i1 %cmp, label %loop_body, label %loop_end
17907af0e2dSibricchi   loop_body:
18007af0e2dSibricchi     %curr_val = load i32, i32* %curr
18107af0e2dSibricchi     %last_val = load i32, i32* %last
18207af0e2dSibricchi     %add = add i32 %curr_val, %last_val
18307af0e2dSibricchi     store i32 %add, i32* %last
18407af0e2dSibricchi     store i32 %curr_val, i32* %curr
18507af0e2dSibricchi     %i_val2 = load i32, i32* %i
18607af0e2dSibricchi     %add2 = add i32 %i_val2, 1
18707af0e2dSibricchi     store i32 %add2, i32* %i
18807af0e2dSibricchi     br label %loop_cond
18907af0e2dSibricchi   loop_end:
19007af0e2dSibricchi     %curr_val3 = load i32, i32* %curr
19107af0e2dSibricchi     ret i32 %curr_val3
19207af0e2dSibricchi }
19307af0e2dSibricchi 
19407af0e2dSibricchi define i32 @fib_rec(i32 %n){
19507af0e2dSibricchi     %cmp = icmp eq i32 %n, 0
19607af0e2dSibricchi     %cmp2 = icmp eq i32 %n, 1
19707af0e2dSibricchi     %or = or i1 %cmp, %cmp2
19807af0e2dSibricchi     br i1 %or, label %if_true, label %if_false
19907af0e2dSibricchi   if_true:
20007af0e2dSibricchi     ret i32 1
20107af0e2dSibricchi   if_false:
20207af0e2dSibricchi     %sub = sub i32 %n, 1
20307af0e2dSibricchi     %call = call i32 @fib_rec(i32 %sub)
20407af0e2dSibricchi     %sub2 = sub i32 %n, 2
20507af0e2dSibricchi     %call2 = call i32 @fib_rec(i32 %sub2)
20607af0e2dSibricchi     %add = add i32 %call, %call2
20707af0e2dSibricchi     ret i32 %add
20807af0e2dSibricchi }
20907af0e2dSibricchi 
21007af0e2dSibricchi define i32 @fib_check(){
21107af0e2dSibricchi     %correct = alloca i32
21207af0e2dSibricchi     %i = alloca i32
21307af0e2dSibricchi     store i32 1, i32* %correct
21407af0e2dSibricchi     store i32 0, i32* %i
21507af0e2dSibricchi     br label %loop_cond
21607af0e2dSibricchi   loop_cond:
21707af0e2dSibricchi     %i_val = load i32, i32* %i
21807af0e2dSibricchi     %cmp = icmp slt i32 %i_val, 10
21907af0e2dSibricchi     br i1 %cmp, label %loop_body, label %loop_end
22007af0e2dSibricchi   loop_body:
22107af0e2dSibricchi     %i_val2 = load i32, i32* %i
22207af0e2dSibricchi     %call = call i32 @fib_loop(i32 %i_val2)
22307af0e2dSibricchi     %i_val3 = load i32, i32* %i
22407af0e2dSibricchi     %call2 = call i32 @fib_rec(i32 %i_val3)
22507af0e2dSibricchi     %cmp2 = icmp ne i32 %call, %call2
22607af0e2dSibricchi     br i1 %cmp2, label %if_true, label %if_false
22707af0e2dSibricchi   if_true:
22807af0e2dSibricchi     store i32 0, i32* %correct
22907af0e2dSibricchi     br label %if_end
23007af0e2dSibricchi   if_false:
23107af0e2dSibricchi     br label %if_end
23207af0e2dSibricchi   if_end:
23307af0e2dSibricchi     %i_val4 = load i32, i32* %i
23407af0e2dSibricchi     %add = add i32 %i_val4, 1
23507af0e2dSibricchi     store i32 %add, i32* %i
23607af0e2dSibricchi     br label %loop_cond
23707af0e2dSibricchi   loop_end:
23807af0e2dSibricchi     %correct_val = load i32, i32* %correct
23907af0e2dSibricchi     ret i32 %correct_val
24007af0e2dSibricchi }
24107af0e2dSibricchi   )"};
24207af0e2dSibricchi 
24365f7ebe7Sibricchi } // namespace
24465f7ebe7Sibricchi 
24507af0e2dSibricchi // check that loading a plugin works
24607af0e2dSibricchi // the plugin being loaded acts identically to the default inliner
24707af0e2dSibricchi TEST(PluginInlineAdvisorTest, PluginLoad) {
24807af0e2dSibricchi #if !defined(LLVM_ENABLE_PLUGINS)
2497fc87159SPaul Robinson   // Skip the test if plugins are disabled.
2507fc87159SPaul Robinson   GTEST_SKIP();
25107af0e2dSibricchi #endif
252*ab4253f6SMichele Scandale   CompilerInstance DefaultCI{};
253*ab4253f6SMichele Scandale 
254*ab4253f6SMichele Scandale   CompilerInstance PluginCI{};
255*ab4253f6SMichele Scandale   PluginCI.setupPlugin();
25607af0e2dSibricchi 
25707af0e2dSibricchi   for (StringRef IR : TestIRS) {
258*ab4253f6SMichele Scandale     DefaultCI.run(IR);
259*ab4253f6SMichele Scandale     std::string default_output = DefaultCI.output;
260*ab4253f6SMichele Scandale     PluginCI.run(IR);
261*ab4253f6SMichele Scandale     std::string dynamic_output = PluginCI.output;
26207af0e2dSibricchi     ASSERT_EQ(default_output, dynamic_output);
26307af0e2dSibricchi   }
26407af0e2dSibricchi }
26507af0e2dSibricchi 
26607af0e2dSibricchi // check that the behaviour of a custom inliner is correct
26707af0e2dSibricchi // the custom inliner inlines all functions that are not named "foo"
26807af0e2dSibricchi // this testdoes not require plugins to be enabled
26907af0e2dSibricchi TEST(PluginInlineAdvisorTest, CustomAdvisor) {
27007af0e2dSibricchi   CompilerInstance CI{};
27107af0e2dSibricchi   CI.setupFooOnly();
27207af0e2dSibricchi 
27307af0e2dSibricchi   for (StringRef IR : TestIRS) {
274*ab4253f6SMichele Scandale     CI.run(IR);
27507af0e2dSibricchi     CallGraph CGraph = CallGraph(*CI.outputM);
27607af0e2dSibricchi     for (auto &node : CGraph) {
27707af0e2dSibricchi       for (auto &edge : *node.second) {
27807af0e2dSibricchi         if (!edge.first)
27907af0e2dSibricchi           continue;
28007af0e2dSibricchi         ASSERT_NE(edge.second->getFunction()->getName(), "foo");
28107af0e2dSibricchi       }
28207af0e2dSibricchi     }
28307af0e2dSibricchi   }
28407af0e2dSibricchi }
28507af0e2dSibricchi 
28607af0e2dSibricchi } // namespace llvm
287