1 #include "llvm/Analysis/CallGraph.h" 2 #include "llvm/AsmParser/Parser.h" 3 #include "llvm/Config/config.h" 4 #include "llvm/IR/Module.h" 5 #include "llvm/Passes/PassBuilder.h" 6 #include "llvm/Passes/PassPlugin.h" 7 #include "llvm/Support/CommandLine.h" 8 #include "llvm/Support/raw_ostream.h" 9 #include "llvm/Testing/Support/Error.h" 10 #include "gtest/gtest.h" 11 12 #include "llvm/Analysis/InlineOrder.h" 13 14 namespace llvm { 15 16 namespace { 17 18 void anchor() {} 19 20 std::string libPath(const std::string Name = "InlineOrderPlugin") { 21 const auto &Argvs = testing::internal::GetArgvs(); 22 const char *Argv0 = 23 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineOrderAnalysisTest"; 24 void *Ptr = (void *)(intptr_t)anchor; 25 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 26 llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 27 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 28 return std::string(Buf.str()); 29 } 30 31 struct CompilerInstance { 32 LLVMContext Ctx; 33 ModulePassManager MPM; 34 InlineParams IP; 35 36 PassBuilder PB; 37 LoopAnalysisManager LAM; 38 FunctionAnalysisManager FAM; 39 CGSCCAnalysisManager CGAM; 40 ModuleAnalysisManager MAM; 41 42 SMDiagnostic Error; 43 44 // Connect the plugin to our compiler instance. 45 void setupPlugin() { 46 auto PluginPath = libPath(); 47 ASSERT_NE("", PluginPath); 48 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 49 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 50 Plugin->registerPassBuilderCallbacks(PB); 51 } 52 53 CompilerInstance() { 54 IP = getInlineParams(3, 0); 55 PB.registerModuleAnalyses(MAM); 56 PB.registerCGSCCAnalyses(CGAM); 57 PB.registerFunctionAnalyses(FAM); 58 PB.registerLoopAnalyses(LAM); 59 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 60 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 61 ThinOrFullLTOPhase::None)); 62 } 63 64 std::string Output; 65 std::unique_ptr<Module> OutputM; 66 67 // Run with the dynamic inline order. 68 auto run(StringRef IR) { 69 OutputM = parseAssemblyString(IR, Error, Ctx); 70 MPM.run(*OutputM, MAM); 71 ASSERT_TRUE(OutputM); 72 Output.clear(); 73 raw_string_ostream OStream{Output}; 74 OutputM->print(OStream, nullptr); 75 ASSERT_TRUE(true); 76 } 77 }; 78 79 StringRef TestIRS[] = { 80 // Simple 3 function inline case. 81 R"( 82 define void @f1() { 83 call void @foo() 84 ret void 85 } 86 define void @foo() { 87 call void @f3() 88 ret void 89 } 90 define void @f3() { 91 ret void 92 } 93 )", 94 // Test that has 5 functions of which 2 are recursive. 95 R"( 96 define void @f1() { 97 call void @foo() 98 ret void 99 } 100 define void @f2() { 101 call void @foo() 102 ret void 103 } 104 define void @foo() { 105 call void @f4() 106 call void @f5() 107 ret void 108 } 109 define void @f4() { 110 ret void 111 } 112 define void @f5() { 113 call void @foo() 114 ret void 115 } 116 )", 117 // Test with 2 mutually recursive functions and 1 function with a loop. 118 R"( 119 define void @f1() { 120 call void @f2() 121 ret void 122 } 123 define void @f2() { 124 call void @foo() 125 ret void 126 } 127 define void @foo() { 128 call void @f1() 129 ret void 130 } 131 define void @f4() { 132 br label %loop 133 loop: 134 call void @f5() 135 br label %loop 136 } 137 define void @f5() { 138 ret void 139 } 140 )", 141 // Test that has a function that computes fibonacci in a loop, one in a 142 // recursive manner, and one that calls both and compares them. 143 R"( 144 define i32 @fib_loop(i32 %n){ 145 %curr = alloca i32 146 %last = alloca i32 147 %i = alloca i32 148 store i32 1, i32* %curr 149 store i32 1, i32* %last 150 store i32 2, i32* %i 151 br label %loop_cond 152 loop_cond: 153 %i_val = load i32, i32* %i 154 %cmp = icmp slt i32 %i_val, %n 155 br i1 %cmp, label %loop_body, label %loop_end 156 loop_body: 157 %curr_val = load i32, i32* %curr 158 %last_val = load i32, i32* %last 159 %add = add i32 %curr_val, %last_val 160 store i32 %add, i32* %last 161 store i32 %curr_val, i32* %curr 162 %i_val2 = load i32, i32* %i 163 %add2 = add i32 %i_val2, 1 164 store i32 %add2, i32* %i 165 br label %loop_cond 166 loop_end: 167 %curr_val3 = load i32, i32* %curr 168 ret i32 %curr_val3 169 } 170 171 define i32 @foo(i32 %n){ 172 %cmp = icmp eq i32 %n, 0 173 %cmp2 = icmp eq i32 %n, 1 174 %or = or i1 %cmp, %cmp2 175 br i1 %or, label %if_true, label %if_false 176 if_true: 177 ret i32 1 178 if_false: 179 %sub = sub i32 %n, 1 180 %call = call i32 @foo(i32 %sub) 181 %sub2 = sub i32 %n, 2 182 %call2 = call i32 @foo(i32 %sub2) 183 %add = add i32 %call, %call2 184 ret i32 %add 185 } 186 187 define i32 @fib_check(){ 188 %correct = alloca i32 189 %i = alloca i32 190 store i32 1, i32* %correct 191 store i32 0, i32* %i 192 br label %loop_cond 193 loop_cond: 194 %i_val = load i32, i32* %i 195 %cmp = icmp slt i32 %i_val, 10 196 br i1 %cmp, label %loop_body, label %loop_end 197 loop_body: 198 %i_val2 = load i32, i32* %i 199 %call = call i32 @fib_loop(i32 %i_val2) 200 %i_val3 = load i32, i32* %i 201 %call2 = call i32 @foo(i32 %i_val3) 202 %cmp2 = icmp ne i32 %call, %call2 203 br i1 %cmp2, label %if_true, label %if_false 204 if_true: 205 store i32 0, i32* %correct 206 br label %if_end 207 if_false: 208 br label %if_end 209 if_end: 210 %i_val4 = load i32, i32* %i 211 %add = add i32 %i_val4, 1 212 store i32 %add, i32* %i 213 br label %loop_cond 214 loop_end: 215 %correct_val = load i32, i32* %correct 216 ret i32 %correct_val 217 } 218 )"}; 219 220 } // namespace 221 222 // Check that the behaviour of a custom inline order is correct. 223 // The custom order drops any functions named "foo" so all tests 224 // should contain at least one function named foo. 225 TEST(PluginInlineOrderTest, NoInlineFoo) { 226 #if !defined(LLVM_ENABLE_PLUGINS) 227 // Skip the test if plugins are disabled. 228 GTEST_SKIP(); 229 #endif 230 CompilerInstance CI{}; 231 CI.setupPlugin(); 232 233 for (StringRef IR : TestIRS) { 234 bool FoundFoo = false; 235 CI.run(IR); 236 CallGraph CGraph = CallGraph(*CI.OutputM); 237 for (auto &Node : CGraph) { 238 for (auto &Edge : *Node.second) { 239 FoundFoo |= Edge.second->getFunction()->getName() == "foo"; 240 } 241 } 242 ASSERT_TRUE(FoundFoo); 243 } 244 } 245 246 } // namespace llvm 247