107af0e2dSibricchi #include "llvm/Analysis/CallGraph.h" 207af0e2dSibricchi #include "llvm/AsmParser/Parser.h" 307af0e2dSibricchi #include "llvm/Config/config.h" 474deadf1SNikita Popov #include "llvm/IR/Module.h" 507af0e2dSibricchi #include "llvm/Passes/PassBuilder.h" 607af0e2dSibricchi #include "llvm/Passes/PassPlugin.h" 707af0e2dSibricchi #include "llvm/Support/CommandLine.h" 807af0e2dSibricchi #include "llvm/Support/raw_ostream.h" 907af0e2dSibricchi #include "llvm/Testing/Support/Error.h" 1007af0e2dSibricchi #include "gtest/gtest.h" 1107af0e2dSibricchi 1207af0e2dSibricchi namespace llvm { 1307af0e2dSibricchi 1465f7ebe7Sibricchi namespace { 1565f7ebe7Sibricchi 1607af0e2dSibricchi void anchor() {} 1707af0e2dSibricchi 1807af0e2dSibricchi static std::string libPath(const std::string Name = "InlineAdvisorPlugin") { 1907af0e2dSibricchi const auto &Argvs = testing::internal::GetArgvs(); 2007af0e2dSibricchi const char *Argv0 = 2107af0e2dSibricchi Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest"; 2207af0e2dSibricchi void *Ptr = (void *)(intptr_t)anchor; 2307af0e2dSibricchi std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 2407af0e2dSibricchi llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 2507af0e2dSibricchi sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 2607af0e2dSibricchi return std::string(Buf.str()); 2707af0e2dSibricchi } 2807af0e2dSibricchi 2907af0e2dSibricchi // Example of a custom InlineAdvisor that only inlines calls to functions called 3007af0e2dSibricchi // "foo". 3107af0e2dSibricchi class FooOnlyInlineAdvisor : public InlineAdvisor { 3207af0e2dSibricchi public: 3307af0e2dSibricchi FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM, 3407af0e2dSibricchi InlineParams Params, InlineContext IC) 3507af0e2dSibricchi : InlineAdvisor(M, FAM, IC) {} 3607af0e2dSibricchi 3707af0e2dSibricchi std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override { 3807af0e2dSibricchi if (CB.getCalledFunction()->getName() == "foo") 3907af0e2dSibricchi return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true); 4007af0e2dSibricchi return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 4107af0e2dSibricchi } 4207af0e2dSibricchi }; 4307af0e2dSibricchi 4407af0e2dSibricchi static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM, 4507af0e2dSibricchi InlineParams Params, InlineContext IC) { 4607af0e2dSibricchi return new FooOnlyInlineAdvisor(M, FAM, Params, IC); 4707af0e2dSibricchi } 4807af0e2dSibricchi 4907af0e2dSibricchi struct CompilerInstance { 5007af0e2dSibricchi LLVMContext Ctx; 5107af0e2dSibricchi ModulePassManager MPM; 5207af0e2dSibricchi InlineParams IP; 5307af0e2dSibricchi 5407af0e2dSibricchi PassBuilder PB; 5507af0e2dSibricchi LoopAnalysisManager LAM; 5607af0e2dSibricchi FunctionAnalysisManager FAM; 5707af0e2dSibricchi CGSCCAnalysisManager CGAM; 5807af0e2dSibricchi ModuleAnalysisManager MAM; 5907af0e2dSibricchi 6007af0e2dSibricchi SMDiagnostic Error; 6107af0e2dSibricchi 6207af0e2dSibricchi // connect the plugin to our compiler instance 6307af0e2dSibricchi void setupPlugin() { 6407af0e2dSibricchi auto PluginPath = libPath(); 6507af0e2dSibricchi ASSERT_NE("", PluginPath); 6607af0e2dSibricchi Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 6707af0e2dSibricchi ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 6807af0e2dSibricchi Plugin->registerPassBuilderCallbacks(PB); 6907af0e2dSibricchi } 7007af0e2dSibricchi 7107af0e2dSibricchi // connect the FooOnlyInlineAdvisor to our compiler instance 7207af0e2dSibricchi void setupFooOnly() { 7307af0e2dSibricchi MAM.registerPass( 7407af0e2dSibricchi [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); }); 7507af0e2dSibricchi } 7607af0e2dSibricchi 7707af0e2dSibricchi CompilerInstance() { 7807af0e2dSibricchi IP = getInlineParams(3, 0); 7907af0e2dSibricchi PB.registerModuleAnalyses(MAM); 8007af0e2dSibricchi PB.registerCGSCCAnalyses(CGAM); 8107af0e2dSibricchi PB.registerFunctionAnalyses(FAM); 8207af0e2dSibricchi PB.registerLoopAnalyses(LAM); 8307af0e2dSibricchi PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 8407af0e2dSibricchi MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 8507af0e2dSibricchi ThinOrFullLTOPhase::None)); 8607af0e2dSibricchi } 8707af0e2dSibricchi 8807af0e2dSibricchi std::string output; 8907af0e2dSibricchi std::unique_ptr<Module> outputM; 9007af0e2dSibricchi 91*ab4253f6SMichele Scandale auto run(StringRef IR) { 9207af0e2dSibricchi outputM = parseAssemblyString(IR, Error, Ctx); 9307af0e2dSibricchi MPM.run(*outputM, MAM); 9407af0e2dSibricchi ASSERT_TRUE(outputM); 9507af0e2dSibricchi output.clear(); 9607af0e2dSibricchi raw_string_ostream o_stream{output}; 9707af0e2dSibricchi outputM->print(o_stream, nullptr); 9807af0e2dSibricchi ASSERT_TRUE(true); 9907af0e2dSibricchi } 10007af0e2dSibricchi }; 10107af0e2dSibricchi 10207af0e2dSibricchi StringRef TestIRS[] = { 10307af0e2dSibricchi // Simple 3 function inline case 10407af0e2dSibricchi R"( 10507af0e2dSibricchi define void @f1() { 10607af0e2dSibricchi call void @foo() 10707af0e2dSibricchi ret void 10807af0e2dSibricchi } 10907af0e2dSibricchi define void @foo() { 11007af0e2dSibricchi call void @f3() 11107af0e2dSibricchi ret void 11207af0e2dSibricchi } 11307af0e2dSibricchi define void @f3() { 11407af0e2dSibricchi ret void 11507af0e2dSibricchi } 11607af0e2dSibricchi )", 11707af0e2dSibricchi // Test that has 5 functions of which 2 are recursive 11807af0e2dSibricchi R"( 11907af0e2dSibricchi define void @f1() { 12007af0e2dSibricchi call void @foo() 12107af0e2dSibricchi ret void 12207af0e2dSibricchi } 12307af0e2dSibricchi define void @f2() { 12407af0e2dSibricchi call void @foo() 12507af0e2dSibricchi ret void 12607af0e2dSibricchi } 12707af0e2dSibricchi define void @foo() { 12807af0e2dSibricchi call void @f4() 12907af0e2dSibricchi call void @f5() 13007af0e2dSibricchi ret void 13107af0e2dSibricchi } 13207af0e2dSibricchi define void @f4() { 13307af0e2dSibricchi ret void 13407af0e2dSibricchi } 13507af0e2dSibricchi define void @f5() { 13607af0e2dSibricchi call void @foo() 13707af0e2dSibricchi ret void 13807af0e2dSibricchi } 13907af0e2dSibricchi )", 14007af0e2dSibricchi // test with 2 mutually recursive functions and 1 function with a loop 14107af0e2dSibricchi R"( 14207af0e2dSibricchi define void @f1() { 14307af0e2dSibricchi call void @f2() 14407af0e2dSibricchi ret void 14507af0e2dSibricchi } 14607af0e2dSibricchi define void @f2() { 14707af0e2dSibricchi call void @f3() 14807af0e2dSibricchi ret void 14907af0e2dSibricchi } 15007af0e2dSibricchi define void @f3() { 15107af0e2dSibricchi call void @f1() 15207af0e2dSibricchi ret void 15307af0e2dSibricchi } 15407af0e2dSibricchi define void @f4() { 15507af0e2dSibricchi br label %loop 15607af0e2dSibricchi loop: 15707af0e2dSibricchi call void @f5() 15807af0e2dSibricchi br label %loop 15907af0e2dSibricchi } 16007af0e2dSibricchi define void @f5() { 16107af0e2dSibricchi ret void 16207af0e2dSibricchi } 16307af0e2dSibricchi )", 16407af0e2dSibricchi // test that has a function that computes fibonacci in a loop, one in a 16507af0e2dSibricchi // recurisve manner, and one that calls both and compares them 16607af0e2dSibricchi R"( 16707af0e2dSibricchi define i32 @fib_loop(i32 %n){ 16807af0e2dSibricchi %curr = alloca i32 16907af0e2dSibricchi %last = alloca i32 17007af0e2dSibricchi %i = alloca i32 17107af0e2dSibricchi store i32 1, i32* %curr 17207af0e2dSibricchi store i32 1, i32* %last 17307af0e2dSibricchi store i32 2, i32* %i 17407af0e2dSibricchi br label %loop_cond 17507af0e2dSibricchi loop_cond: 17607af0e2dSibricchi %i_val = load i32, i32* %i 17707af0e2dSibricchi %cmp = icmp slt i32 %i_val, %n 17807af0e2dSibricchi br i1 %cmp, label %loop_body, label %loop_end 17907af0e2dSibricchi loop_body: 18007af0e2dSibricchi %curr_val = load i32, i32* %curr 18107af0e2dSibricchi %last_val = load i32, i32* %last 18207af0e2dSibricchi %add = add i32 %curr_val, %last_val 18307af0e2dSibricchi store i32 %add, i32* %last 18407af0e2dSibricchi store i32 %curr_val, i32* %curr 18507af0e2dSibricchi %i_val2 = load i32, i32* %i 18607af0e2dSibricchi %add2 = add i32 %i_val2, 1 18707af0e2dSibricchi store i32 %add2, i32* %i 18807af0e2dSibricchi br label %loop_cond 18907af0e2dSibricchi loop_end: 19007af0e2dSibricchi %curr_val3 = load i32, i32* %curr 19107af0e2dSibricchi ret i32 %curr_val3 19207af0e2dSibricchi } 19307af0e2dSibricchi 19407af0e2dSibricchi define i32 @fib_rec(i32 %n){ 19507af0e2dSibricchi %cmp = icmp eq i32 %n, 0 19607af0e2dSibricchi %cmp2 = icmp eq i32 %n, 1 19707af0e2dSibricchi %or = or i1 %cmp, %cmp2 19807af0e2dSibricchi br i1 %or, label %if_true, label %if_false 19907af0e2dSibricchi if_true: 20007af0e2dSibricchi ret i32 1 20107af0e2dSibricchi if_false: 20207af0e2dSibricchi %sub = sub i32 %n, 1 20307af0e2dSibricchi %call = call i32 @fib_rec(i32 %sub) 20407af0e2dSibricchi %sub2 = sub i32 %n, 2 20507af0e2dSibricchi %call2 = call i32 @fib_rec(i32 %sub2) 20607af0e2dSibricchi %add = add i32 %call, %call2 20707af0e2dSibricchi ret i32 %add 20807af0e2dSibricchi } 20907af0e2dSibricchi 21007af0e2dSibricchi define i32 @fib_check(){ 21107af0e2dSibricchi %correct = alloca i32 21207af0e2dSibricchi %i = alloca i32 21307af0e2dSibricchi store i32 1, i32* %correct 21407af0e2dSibricchi store i32 0, i32* %i 21507af0e2dSibricchi br label %loop_cond 21607af0e2dSibricchi loop_cond: 21707af0e2dSibricchi %i_val = load i32, i32* %i 21807af0e2dSibricchi %cmp = icmp slt i32 %i_val, 10 21907af0e2dSibricchi br i1 %cmp, label %loop_body, label %loop_end 22007af0e2dSibricchi loop_body: 22107af0e2dSibricchi %i_val2 = load i32, i32* %i 22207af0e2dSibricchi %call = call i32 @fib_loop(i32 %i_val2) 22307af0e2dSibricchi %i_val3 = load i32, i32* %i 22407af0e2dSibricchi %call2 = call i32 @fib_rec(i32 %i_val3) 22507af0e2dSibricchi %cmp2 = icmp ne i32 %call, %call2 22607af0e2dSibricchi br i1 %cmp2, label %if_true, label %if_false 22707af0e2dSibricchi if_true: 22807af0e2dSibricchi store i32 0, i32* %correct 22907af0e2dSibricchi br label %if_end 23007af0e2dSibricchi if_false: 23107af0e2dSibricchi br label %if_end 23207af0e2dSibricchi if_end: 23307af0e2dSibricchi %i_val4 = load i32, i32* %i 23407af0e2dSibricchi %add = add i32 %i_val4, 1 23507af0e2dSibricchi store i32 %add, i32* %i 23607af0e2dSibricchi br label %loop_cond 23707af0e2dSibricchi loop_end: 23807af0e2dSibricchi %correct_val = load i32, i32* %correct 23907af0e2dSibricchi ret i32 %correct_val 24007af0e2dSibricchi } 24107af0e2dSibricchi )"}; 24207af0e2dSibricchi 24365f7ebe7Sibricchi } // namespace 24465f7ebe7Sibricchi 24507af0e2dSibricchi // check that loading a plugin works 24607af0e2dSibricchi // the plugin being loaded acts identically to the default inliner 24707af0e2dSibricchi TEST(PluginInlineAdvisorTest, PluginLoad) { 24807af0e2dSibricchi #if !defined(LLVM_ENABLE_PLUGINS) 2497fc87159SPaul Robinson // Skip the test if plugins are disabled. 2507fc87159SPaul Robinson GTEST_SKIP(); 25107af0e2dSibricchi #endif 252*ab4253f6SMichele Scandale CompilerInstance DefaultCI{}; 253*ab4253f6SMichele Scandale 254*ab4253f6SMichele Scandale CompilerInstance PluginCI{}; 255*ab4253f6SMichele Scandale PluginCI.setupPlugin(); 25607af0e2dSibricchi 25707af0e2dSibricchi for (StringRef IR : TestIRS) { 258*ab4253f6SMichele Scandale DefaultCI.run(IR); 259*ab4253f6SMichele Scandale std::string default_output = DefaultCI.output; 260*ab4253f6SMichele Scandale PluginCI.run(IR); 261*ab4253f6SMichele Scandale std::string dynamic_output = PluginCI.output; 26207af0e2dSibricchi ASSERT_EQ(default_output, dynamic_output); 26307af0e2dSibricchi } 26407af0e2dSibricchi } 26507af0e2dSibricchi 26607af0e2dSibricchi // check that the behaviour of a custom inliner is correct 26707af0e2dSibricchi // the custom inliner inlines all functions that are not named "foo" 26807af0e2dSibricchi // this testdoes not require plugins to be enabled 26907af0e2dSibricchi TEST(PluginInlineAdvisorTest, CustomAdvisor) { 27007af0e2dSibricchi CompilerInstance CI{}; 27107af0e2dSibricchi CI.setupFooOnly(); 27207af0e2dSibricchi 27307af0e2dSibricchi for (StringRef IR : TestIRS) { 274*ab4253f6SMichele Scandale CI.run(IR); 27507af0e2dSibricchi CallGraph CGraph = CallGraph(*CI.outputM); 27607af0e2dSibricchi for (auto &node : CGraph) { 27707af0e2dSibricchi for (auto &edge : *node.second) { 27807af0e2dSibricchi if (!edge.first) 27907af0e2dSibricchi continue; 28007af0e2dSibricchi ASSERT_NE(edge.second->getFunction()->getName(), "foo"); 28107af0e2dSibricchi } 28207af0e2dSibricchi } 28307af0e2dSibricchi } 28407af0e2dSibricchi } 28507af0e2dSibricchi 28607af0e2dSibricchi } // namespace llvm 287