1 #include "llvm/Analysis/CallGraph.h" 2 #include "llvm/AsmParser/Parser.h" 3 #include "llvm/Config/config.h" 4 #include "llvm/Passes/PassBuilder.h" 5 #include "llvm/Passes/PassPlugin.h" 6 #include "llvm/Support/CommandLine.h" 7 #include "llvm/Support/raw_ostream.h" 8 #include "llvm/Testing/Support/Error.h" 9 #include "gtest/gtest.h" 10 11 namespace llvm { 12 13 namespace { 14 15 void anchor() {} 16 17 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") { 18 const auto &Argvs = testing::internal::GetArgvs(); 19 const char *Argv0 = 20 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest"; 21 void *Ptr = (void *)(intptr_t)anchor; 22 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 23 llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 24 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 25 return std::string(Buf.str()); 26 } 27 28 // Example of a custom InlineAdvisor that only inlines calls to functions called 29 // "foo". 30 class FooOnlyInlineAdvisor : public InlineAdvisor { 31 public: 32 FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM, 33 InlineParams Params, InlineContext IC) 34 : InlineAdvisor(M, FAM, IC) {} 35 36 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override { 37 if (CB.getCalledFunction()->getName() == "foo") 38 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true); 39 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 40 } 41 }; 42 43 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM, 44 InlineParams Params, InlineContext IC) { 45 return new FooOnlyInlineAdvisor(M, FAM, Params, IC); 46 } 47 48 struct CompilerInstance { 49 LLVMContext Ctx; 50 ModulePassManager MPM; 51 InlineParams IP; 52 53 PassBuilder PB; 54 LoopAnalysisManager LAM; 55 FunctionAnalysisManager FAM; 56 CGSCCAnalysisManager CGAM; 57 ModuleAnalysisManager MAM; 58 59 SMDiagnostic Error; 60 61 // connect the plugin to our compiler instance 62 void setupPlugin() { 63 auto PluginPath = libPath(); 64 ASSERT_NE("", PluginPath); 65 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 66 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 67 Plugin->registerPassBuilderCallbacks(PB); 68 ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"), 69 Succeeded()); 70 } 71 72 // connect the FooOnlyInlineAdvisor to our compiler instance 73 void setupFooOnly() { 74 MAM.registerPass( 75 [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); }); 76 } 77 78 CompilerInstance() { 79 IP = getInlineParams(3, 0); 80 PB.registerModuleAnalyses(MAM); 81 PB.registerCGSCCAnalyses(CGAM); 82 PB.registerFunctionAnalyses(FAM); 83 PB.registerLoopAnalyses(LAM); 84 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 85 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 86 ThinOrFullLTOPhase::None)); 87 } 88 89 ~CompilerInstance() { 90 // Reset the static variable that tracks if the plugin has been registered. 91 // This is needed to allow the test to run multiple times. 92 PluginInlineAdvisorAnalysis::HasBeenRegistered = false; 93 } 94 95 std::string output; 96 std::unique_ptr<Module> outputM; 97 98 // run with the default inliner 99 auto run_default(StringRef IR) { 100 PluginInlineAdvisorAnalysis::HasBeenRegistered = false; 101 outputM = parseAssemblyString(IR, Error, Ctx); 102 MPM.run(*outputM, MAM); 103 ASSERT_TRUE(outputM); 104 output.clear(); 105 raw_string_ostream o_stream{output}; 106 outputM->print(o_stream, nullptr); 107 ASSERT_TRUE(true); 108 } 109 110 // run with the dnamic inliner 111 auto run_dynamic(StringRef IR) { 112 // note typically the constructor for the DynamicInlineAdvisorAnalysis 113 // will automatically set this to true, we controll it here only to 114 // altenate between the default and dynamic inliner in our test 115 PluginInlineAdvisorAnalysis::HasBeenRegistered = true; 116 outputM = parseAssemblyString(IR, Error, Ctx); 117 MPM.run(*outputM, MAM); 118 ASSERT_TRUE(outputM); 119 output.clear(); 120 raw_string_ostream o_stream{output}; 121 outputM->print(o_stream, nullptr); 122 ASSERT_TRUE(true); 123 } 124 }; 125 126 StringRef TestIRS[] = { 127 // Simple 3 function inline case 128 R"( 129 define void @f1() { 130 call void @foo() 131 ret void 132 } 133 define void @foo() { 134 call void @f3() 135 ret void 136 } 137 define void @f3() { 138 ret void 139 } 140 )", 141 // Test that has 5 functions of which 2 are recursive 142 R"( 143 define void @f1() { 144 call void @foo() 145 ret void 146 } 147 define void @f2() { 148 call void @foo() 149 ret void 150 } 151 define void @foo() { 152 call void @f4() 153 call void @f5() 154 ret void 155 } 156 define void @f4() { 157 ret void 158 } 159 define void @f5() { 160 call void @foo() 161 ret void 162 } 163 )", 164 // test with 2 mutually recursive functions and 1 function with a loop 165 R"( 166 define void @f1() { 167 call void @f2() 168 ret void 169 } 170 define void @f2() { 171 call void @f3() 172 ret void 173 } 174 define void @f3() { 175 call void @f1() 176 ret void 177 } 178 define void @f4() { 179 br label %loop 180 loop: 181 call void @f5() 182 br label %loop 183 } 184 define void @f5() { 185 ret void 186 } 187 )", 188 // test that has a function that computes fibonacci in a loop, one in a 189 // recurisve manner, and one that calls both and compares them 190 R"( 191 define i32 @fib_loop(i32 %n){ 192 %curr = alloca i32 193 %last = alloca i32 194 %i = alloca i32 195 store i32 1, i32* %curr 196 store i32 1, i32* %last 197 store i32 2, i32* %i 198 br label %loop_cond 199 loop_cond: 200 %i_val = load i32, i32* %i 201 %cmp = icmp slt i32 %i_val, %n 202 br i1 %cmp, label %loop_body, label %loop_end 203 loop_body: 204 %curr_val = load i32, i32* %curr 205 %last_val = load i32, i32* %last 206 %add = add i32 %curr_val, %last_val 207 store i32 %add, i32* %last 208 store i32 %curr_val, i32* %curr 209 %i_val2 = load i32, i32* %i 210 %add2 = add i32 %i_val2, 1 211 store i32 %add2, i32* %i 212 br label %loop_cond 213 loop_end: 214 %curr_val3 = load i32, i32* %curr 215 ret i32 %curr_val3 216 } 217 218 define i32 @fib_rec(i32 %n){ 219 %cmp = icmp eq i32 %n, 0 220 %cmp2 = icmp eq i32 %n, 1 221 %or = or i1 %cmp, %cmp2 222 br i1 %or, label %if_true, label %if_false 223 if_true: 224 ret i32 1 225 if_false: 226 %sub = sub i32 %n, 1 227 %call = call i32 @fib_rec(i32 %sub) 228 %sub2 = sub i32 %n, 2 229 %call2 = call i32 @fib_rec(i32 %sub2) 230 %add = add i32 %call, %call2 231 ret i32 %add 232 } 233 234 define i32 @fib_check(){ 235 %correct = alloca i32 236 %i = alloca i32 237 store i32 1, i32* %correct 238 store i32 0, i32* %i 239 br label %loop_cond 240 loop_cond: 241 %i_val = load i32, i32* %i 242 %cmp = icmp slt i32 %i_val, 10 243 br i1 %cmp, label %loop_body, label %loop_end 244 loop_body: 245 %i_val2 = load i32, i32* %i 246 %call = call i32 @fib_loop(i32 %i_val2) 247 %i_val3 = load i32, i32* %i 248 %call2 = call i32 @fib_rec(i32 %i_val3) 249 %cmp2 = icmp ne i32 %call, %call2 250 br i1 %cmp2, label %if_true, label %if_false 251 if_true: 252 store i32 0, i32* %correct 253 br label %if_end 254 if_false: 255 br label %if_end 256 if_end: 257 %i_val4 = load i32, i32* %i 258 %add = add i32 %i_val4, 1 259 store i32 %add, i32* %i 260 br label %loop_cond 261 loop_end: 262 %correct_val = load i32, i32* %correct 263 ret i32 %correct_val 264 } 265 )"}; 266 267 } // namespace 268 269 // check that loading a plugin works 270 // the plugin being loaded acts identically to the default inliner 271 TEST(PluginInlineAdvisorTest, PluginLoad) { 272 #if !defined(LLVM_ENABLE_PLUGINS) 273 // Skip the test if plugins are disabled. 274 GTEST_SKIP(); 275 #endif 276 CompilerInstance CI{}; 277 CI.setupPlugin(); 278 279 for (StringRef IR : TestIRS) { 280 CI.run_default(IR); 281 std::string default_output = CI.output; 282 CI.run_dynamic(IR); 283 std::string dynamic_output = CI.output; 284 ASSERT_EQ(default_output, dynamic_output); 285 } 286 } 287 288 // check that the behaviour of a custom inliner is correct 289 // the custom inliner inlines all functions that are not named "foo" 290 // this testdoes not require plugins to be enabled 291 TEST(PluginInlineAdvisorTest, CustomAdvisor) { 292 CompilerInstance CI{}; 293 CI.setupFooOnly(); 294 295 for (StringRef IR : TestIRS) { 296 CI.run_dynamic(IR); 297 CallGraph CGraph = CallGraph(*CI.outputM); 298 for (auto &node : CGraph) { 299 for (auto &edge : *node.second) { 300 if (!edge.first) 301 continue; 302 ASSERT_NE(edge.second->getFunction()->getName(), "foo"); 303 } 304 } 305 } 306 } 307 308 } // namespace llvm 309