1 #include "llvm/Analysis/CallGraph.h" 2 #include "llvm/AsmParser/Parser.h" 3 #include "llvm/Config/config.h" 4 #include "llvm/Passes/PassBuilder.h" 5 #include "llvm/Passes/PassPlugin.h" 6 #include "llvm/Support/CommandLine.h" 7 #include "llvm/Support/raw_ostream.h" 8 #include "llvm/Testing/Support/Error.h" 9 #include "gtest/gtest.h" 10 11 #include "llvm/Analysis/InlineOrder.h" 12 13 namespace llvm { 14 15 namespace { 16 17 void anchor() {} 18 19 std::string libPath(const std::string Name = "InlineOrderPlugin") { 20 const auto &Argvs = testing::internal::GetArgvs(); 21 const char *Argv0 = 22 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineOrderAnalysisTest"; 23 void *Ptr = (void *)(intptr_t)anchor; 24 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 25 llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 26 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 27 return std::string(Buf.str()); 28 } 29 30 struct CompilerInstance { 31 LLVMContext Ctx; 32 ModulePassManager MPM; 33 InlineParams IP; 34 35 PassBuilder PB; 36 LoopAnalysisManager LAM; 37 FunctionAnalysisManager FAM; 38 CGSCCAnalysisManager CGAM; 39 ModuleAnalysisManager MAM; 40 41 SMDiagnostic Error; 42 43 // Connect the plugin to our compiler instance. 44 void setupPlugin() { 45 auto PluginPath = libPath(); 46 ASSERT_NE("", PluginPath); 47 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 48 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 49 Plugin->registerPassBuilderCallbacks(PB); 50 } 51 52 CompilerInstance() { 53 IP = getInlineParams(3, 0); 54 PB.registerModuleAnalyses(MAM); 55 PB.registerCGSCCAnalyses(CGAM); 56 PB.registerFunctionAnalyses(FAM); 57 PB.registerLoopAnalyses(LAM); 58 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 59 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 60 ThinOrFullLTOPhase::None)); 61 } 62 63 ~CompilerInstance() { 64 // Reset the static variable that tracks if the plugin has been registered. 65 // This is needed to allow the test to run multiple times. 66 PluginInlineOrderAnalysis::unregister(); 67 } 68 69 std::string Output; 70 std::unique_ptr<Module> OutputM; 71 72 // Run with the dynamic inline order. 73 auto run(StringRef IR) { 74 OutputM = parseAssemblyString(IR, Error, Ctx); 75 MPM.run(*OutputM, MAM); 76 ASSERT_TRUE(OutputM); 77 Output.clear(); 78 raw_string_ostream OStream{Output}; 79 OutputM->print(OStream, nullptr); 80 ASSERT_TRUE(true); 81 } 82 }; 83 84 StringRef TestIRS[] = { 85 // Simple 3 function inline case. 86 R"( 87 define void @f1() { 88 call void @foo() 89 ret void 90 } 91 define void @foo() { 92 call void @f3() 93 ret void 94 } 95 define void @f3() { 96 ret void 97 } 98 )", 99 // Test that has 5 functions of which 2 are recursive. 100 R"( 101 define void @f1() { 102 call void @foo() 103 ret void 104 } 105 define void @f2() { 106 call void @foo() 107 ret void 108 } 109 define void @foo() { 110 call void @f4() 111 call void @f5() 112 ret void 113 } 114 define void @f4() { 115 ret void 116 } 117 define void @f5() { 118 call void @foo() 119 ret void 120 } 121 )", 122 // Test with 2 mutually recursive functions and 1 function with a loop. 123 R"( 124 define void @f1() { 125 call void @f2() 126 ret void 127 } 128 define void @f2() { 129 call void @foo() 130 ret void 131 } 132 define void @foo() { 133 call void @f1() 134 ret void 135 } 136 define void @f4() { 137 br label %loop 138 loop: 139 call void @f5() 140 br label %loop 141 } 142 define void @f5() { 143 ret void 144 } 145 )", 146 // Test that has a function that computes fibonacci in a loop, one in a 147 // recursive manner, and one that calls both and compares them. 148 R"( 149 define i32 @fib_loop(i32 %n){ 150 %curr = alloca i32 151 %last = alloca i32 152 %i = alloca i32 153 store i32 1, i32* %curr 154 store i32 1, i32* %last 155 store i32 2, i32* %i 156 br label %loop_cond 157 loop_cond: 158 %i_val = load i32, i32* %i 159 %cmp = icmp slt i32 %i_val, %n 160 br i1 %cmp, label %loop_body, label %loop_end 161 loop_body: 162 %curr_val = load i32, i32* %curr 163 %last_val = load i32, i32* %last 164 %add = add i32 %curr_val, %last_val 165 store i32 %add, i32* %last 166 store i32 %curr_val, i32* %curr 167 %i_val2 = load i32, i32* %i 168 %add2 = add i32 %i_val2, 1 169 store i32 %add2, i32* %i 170 br label %loop_cond 171 loop_end: 172 %curr_val3 = load i32, i32* %curr 173 ret i32 %curr_val3 174 } 175 176 define i32 @foo(i32 %n){ 177 %cmp = icmp eq i32 %n, 0 178 %cmp2 = icmp eq i32 %n, 1 179 %or = or i1 %cmp, %cmp2 180 br i1 %or, label %if_true, label %if_false 181 if_true: 182 ret i32 1 183 if_false: 184 %sub = sub i32 %n, 1 185 %call = call i32 @foo(i32 %sub) 186 %sub2 = sub i32 %n, 2 187 %call2 = call i32 @foo(i32 %sub2) 188 %add = add i32 %call, %call2 189 ret i32 %add 190 } 191 192 define i32 @fib_check(){ 193 %correct = alloca i32 194 %i = alloca i32 195 store i32 1, i32* %correct 196 store i32 0, i32* %i 197 br label %loop_cond 198 loop_cond: 199 %i_val = load i32, i32* %i 200 %cmp = icmp slt i32 %i_val, 10 201 br i1 %cmp, label %loop_body, label %loop_end 202 loop_body: 203 %i_val2 = load i32, i32* %i 204 %call = call i32 @fib_loop(i32 %i_val2) 205 %i_val3 = load i32, i32* %i 206 %call2 = call i32 @foo(i32 %i_val3) 207 %cmp2 = icmp ne i32 %call, %call2 208 br i1 %cmp2, label %if_true, label %if_false 209 if_true: 210 store i32 0, i32* %correct 211 br label %if_end 212 if_false: 213 br label %if_end 214 if_end: 215 %i_val4 = load i32, i32* %i 216 %add = add i32 %i_val4, 1 217 store i32 %add, i32* %i 218 br label %loop_cond 219 loop_end: 220 %correct_val = load i32, i32* %correct 221 ret i32 %correct_val 222 } 223 )"}; 224 225 } // namespace 226 227 // Check that the behaviour of a custom inline order is correct. 228 // The custom order drops any functions named "foo" so all tests 229 // should contain at least one function named foo. 230 TEST(PluginInlineOrderTest, NoInlineFoo) { 231 #if !defined(LLVM_ENABLE_PLUGINS) 232 // Skip the test if plugins are disabled. 233 GTEST_SKIP(); 234 #endif 235 CompilerInstance CI{}; 236 CI.setupPlugin(); 237 238 for (StringRef IR : TestIRS) { 239 bool FoundFoo = false; 240 CI.run(IR); 241 CallGraph CGraph = CallGraph(*CI.OutputM); 242 for (auto &Node : CGraph) { 243 for (auto &Edge : *Node.second) { 244 FoundFoo |= Edge.second->getFunction()->getName() == "foo"; 245 } 246 } 247 ASSERT_TRUE(FoundFoo); 248 } 249 } 250 251 } // namespace llvm 252