1 #include "llvm/Analysis/CallGraph.h" 2 #include "llvm/AsmParser/Parser.h" 3 #include "llvm/Config/config.h" 4 #include "llvm/IR/Module.h" 5 #include "llvm/Passes/PassBuilder.h" 6 #include "llvm/Passes/PassPlugin.h" 7 #include "llvm/Support/CommandLine.h" 8 #include "llvm/Support/raw_ostream.h" 9 #include "llvm/Testing/Support/Error.h" 10 #include "gtest/gtest.h" 11 12 namespace llvm { 13 14 namespace { 15 16 void anchor() {} 17 18 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") { 19 const auto &Argvs = testing::internal::GetArgvs(); 20 const char *Argv0 = 21 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest"; 22 void *Ptr = (void *)(intptr_t)anchor; 23 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 24 llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 25 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 26 return std::string(Buf.str()); 27 } 28 29 // Example of a custom InlineAdvisor that only inlines calls to functions called 30 // "foo". 31 class FooOnlyInlineAdvisor : public InlineAdvisor { 32 public: 33 FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM, 34 InlineParams Params, InlineContext IC) 35 : InlineAdvisor(M, FAM, IC) {} 36 37 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override { 38 if (CB.getCalledFunction()->getName() == "foo") 39 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true); 40 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 41 } 42 }; 43 44 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM, 45 InlineParams Params, InlineContext IC) { 46 return new FooOnlyInlineAdvisor(M, FAM, Params, IC); 47 } 48 49 struct CompilerInstance { 50 LLVMContext Ctx; 51 ModulePassManager MPM; 52 InlineParams IP; 53 54 PassBuilder PB; 55 LoopAnalysisManager LAM; 56 FunctionAnalysisManager FAM; 57 CGSCCAnalysisManager CGAM; 58 ModuleAnalysisManager MAM; 59 60 SMDiagnostic Error; 61 62 // connect the plugin to our compiler instance 63 void setupPlugin() { 64 auto PluginPath = libPath(); 65 ASSERT_NE("", PluginPath); 66 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 67 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 68 Plugin->registerPassBuilderCallbacks(PB); 69 ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"), 70 Succeeded()); 71 } 72 73 // connect the FooOnlyInlineAdvisor to our compiler instance 74 void setupFooOnly() { 75 MAM.registerPass( 76 [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); }); 77 } 78 79 CompilerInstance() { 80 IP = getInlineParams(3, 0); 81 PB.registerModuleAnalyses(MAM); 82 PB.registerCGSCCAnalyses(CGAM); 83 PB.registerFunctionAnalyses(FAM); 84 PB.registerLoopAnalyses(LAM); 85 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 86 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 87 ThinOrFullLTOPhase::None)); 88 } 89 90 ~CompilerInstance() { 91 // Reset the static variable that tracks if the plugin has been registered. 92 // This is needed to allow the test to run multiple times. 93 PluginInlineAdvisorAnalysis::HasBeenRegistered = false; 94 } 95 96 std::string output; 97 std::unique_ptr<Module> outputM; 98 99 // run with the default inliner 100 auto run_default(StringRef IR) { 101 PluginInlineAdvisorAnalysis::HasBeenRegistered = false; 102 outputM = parseAssemblyString(IR, Error, Ctx); 103 MPM.run(*outputM, MAM); 104 ASSERT_TRUE(outputM); 105 output.clear(); 106 raw_string_ostream o_stream{output}; 107 outputM->print(o_stream, nullptr); 108 ASSERT_TRUE(true); 109 } 110 111 // run with the dnamic inliner 112 auto run_dynamic(StringRef IR) { 113 // note typically the constructor for the DynamicInlineAdvisorAnalysis 114 // will automatically set this to true, we controll it here only to 115 // altenate between the default and dynamic inliner in our test 116 PluginInlineAdvisorAnalysis::HasBeenRegistered = true; 117 outputM = parseAssemblyString(IR, Error, Ctx); 118 MPM.run(*outputM, MAM); 119 ASSERT_TRUE(outputM); 120 output.clear(); 121 raw_string_ostream o_stream{output}; 122 outputM->print(o_stream, nullptr); 123 ASSERT_TRUE(true); 124 } 125 }; 126 127 StringRef TestIRS[] = { 128 // Simple 3 function inline case 129 R"( 130 define void @f1() { 131 call void @foo() 132 ret void 133 } 134 define void @foo() { 135 call void @f3() 136 ret void 137 } 138 define void @f3() { 139 ret void 140 } 141 )", 142 // Test that has 5 functions of which 2 are recursive 143 R"( 144 define void @f1() { 145 call void @foo() 146 ret void 147 } 148 define void @f2() { 149 call void @foo() 150 ret void 151 } 152 define void @foo() { 153 call void @f4() 154 call void @f5() 155 ret void 156 } 157 define void @f4() { 158 ret void 159 } 160 define void @f5() { 161 call void @foo() 162 ret void 163 } 164 )", 165 // test with 2 mutually recursive functions and 1 function with a loop 166 R"( 167 define void @f1() { 168 call void @f2() 169 ret void 170 } 171 define void @f2() { 172 call void @f3() 173 ret void 174 } 175 define void @f3() { 176 call void @f1() 177 ret void 178 } 179 define void @f4() { 180 br label %loop 181 loop: 182 call void @f5() 183 br label %loop 184 } 185 define void @f5() { 186 ret void 187 } 188 )", 189 // test that has a function that computes fibonacci in a loop, one in a 190 // recurisve manner, and one that calls both and compares them 191 R"( 192 define i32 @fib_loop(i32 %n){ 193 %curr = alloca i32 194 %last = alloca i32 195 %i = alloca i32 196 store i32 1, i32* %curr 197 store i32 1, i32* %last 198 store i32 2, i32* %i 199 br label %loop_cond 200 loop_cond: 201 %i_val = load i32, i32* %i 202 %cmp = icmp slt i32 %i_val, %n 203 br i1 %cmp, label %loop_body, label %loop_end 204 loop_body: 205 %curr_val = load i32, i32* %curr 206 %last_val = load i32, i32* %last 207 %add = add i32 %curr_val, %last_val 208 store i32 %add, i32* %last 209 store i32 %curr_val, i32* %curr 210 %i_val2 = load i32, i32* %i 211 %add2 = add i32 %i_val2, 1 212 store i32 %add2, i32* %i 213 br label %loop_cond 214 loop_end: 215 %curr_val3 = load i32, i32* %curr 216 ret i32 %curr_val3 217 } 218 219 define i32 @fib_rec(i32 %n){ 220 %cmp = icmp eq i32 %n, 0 221 %cmp2 = icmp eq i32 %n, 1 222 %or = or i1 %cmp, %cmp2 223 br i1 %or, label %if_true, label %if_false 224 if_true: 225 ret i32 1 226 if_false: 227 %sub = sub i32 %n, 1 228 %call = call i32 @fib_rec(i32 %sub) 229 %sub2 = sub i32 %n, 2 230 %call2 = call i32 @fib_rec(i32 %sub2) 231 %add = add i32 %call, %call2 232 ret i32 %add 233 } 234 235 define i32 @fib_check(){ 236 %correct = alloca i32 237 %i = alloca i32 238 store i32 1, i32* %correct 239 store i32 0, i32* %i 240 br label %loop_cond 241 loop_cond: 242 %i_val = load i32, i32* %i 243 %cmp = icmp slt i32 %i_val, 10 244 br i1 %cmp, label %loop_body, label %loop_end 245 loop_body: 246 %i_val2 = load i32, i32* %i 247 %call = call i32 @fib_loop(i32 %i_val2) 248 %i_val3 = load i32, i32* %i 249 %call2 = call i32 @fib_rec(i32 %i_val3) 250 %cmp2 = icmp ne i32 %call, %call2 251 br i1 %cmp2, label %if_true, label %if_false 252 if_true: 253 store i32 0, i32* %correct 254 br label %if_end 255 if_false: 256 br label %if_end 257 if_end: 258 %i_val4 = load i32, i32* %i 259 %add = add i32 %i_val4, 1 260 store i32 %add, i32* %i 261 br label %loop_cond 262 loop_end: 263 %correct_val = load i32, i32* %correct 264 ret i32 %correct_val 265 } 266 )"}; 267 268 } // namespace 269 270 // check that loading a plugin works 271 // the plugin being loaded acts identically to the default inliner 272 TEST(PluginInlineAdvisorTest, PluginLoad) { 273 #if !defined(LLVM_ENABLE_PLUGINS) 274 // Skip the test if plugins are disabled. 275 GTEST_SKIP(); 276 #endif 277 CompilerInstance CI{}; 278 CI.setupPlugin(); 279 280 for (StringRef IR : TestIRS) { 281 CI.run_default(IR); 282 std::string default_output = CI.output; 283 CI.run_dynamic(IR); 284 std::string dynamic_output = CI.output; 285 ASSERT_EQ(default_output, dynamic_output); 286 } 287 } 288 289 // check that the behaviour of a custom inliner is correct 290 // the custom inliner inlines all functions that are not named "foo" 291 // this testdoes not require plugins to be enabled 292 TEST(PluginInlineAdvisorTest, CustomAdvisor) { 293 CompilerInstance CI{}; 294 CI.setupFooOnly(); 295 296 for (StringRef IR : TestIRS) { 297 CI.run_dynamic(IR); 298 CallGraph CGraph = CallGraph(*CI.outputM); 299 for (auto &node : CGraph) { 300 for (auto &edge : *node.second) { 301 if (!edge.first) 302 continue; 303 ASSERT_NE(edge.second->getFunction()->getName(), "foo"); 304 } 305 } 306 } 307 } 308 309 } // namespace llvm 310