1 #include "llvm/Analysis/CallGraph.h" 2 #include "llvm/AsmParser/Parser.h" 3 #include "llvm/Config/config.h" 4 #include "llvm/Passes/PassBuilder.h" 5 #include "llvm/Passes/PassPlugin.h" 6 #include "llvm/Support/CommandLine.h" 7 #include "llvm/Support/raw_ostream.h" 8 #include "llvm/Testing/Support/Error.h" 9 #include "gtest/gtest.h" 10 11 namespace llvm { 12 13 void anchor() {} 14 15 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") { 16 const auto &Argvs = testing::internal::GetArgvs(); 17 const char *Argv0 = 18 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest"; 19 void *Ptr = (void *)(intptr_t)anchor; 20 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 21 llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 22 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 23 return std::string(Buf.str()); 24 } 25 26 // Example of a custom InlineAdvisor that only inlines calls to functions called 27 // "foo". 28 class FooOnlyInlineAdvisor : public InlineAdvisor { 29 public: 30 FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM, 31 InlineParams Params, InlineContext IC) 32 : InlineAdvisor(M, FAM, IC) {} 33 34 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override { 35 if (CB.getCalledFunction()->getName() == "foo") 36 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true); 37 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 38 } 39 }; 40 41 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM, 42 InlineParams Params, InlineContext IC) { 43 return new FooOnlyInlineAdvisor(M, FAM, Params, IC); 44 } 45 46 struct CompilerInstance { 47 LLVMContext Ctx; 48 ModulePassManager MPM; 49 InlineParams IP; 50 51 PassBuilder PB; 52 LoopAnalysisManager LAM; 53 FunctionAnalysisManager FAM; 54 CGSCCAnalysisManager CGAM; 55 ModuleAnalysisManager MAM; 56 57 SMDiagnostic Error; 58 59 // connect the plugin to our compiler instance 60 void setupPlugin() { 61 auto PluginPath = libPath(); 62 ASSERT_NE("", PluginPath); 63 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 64 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 65 Plugin->registerPassBuilderCallbacks(PB); 66 ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"), 67 Succeeded()); 68 } 69 70 // connect the FooOnlyInlineAdvisor to our compiler instance 71 void setupFooOnly() { 72 MAM.registerPass( 73 [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); }); 74 } 75 76 CompilerInstance() { 77 IP = getInlineParams(3, 0); 78 PB.registerModuleAnalyses(MAM); 79 PB.registerCGSCCAnalyses(CGAM); 80 PB.registerFunctionAnalyses(FAM); 81 PB.registerLoopAnalyses(LAM); 82 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 83 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 84 ThinOrFullLTOPhase::None)); 85 } 86 87 std::string output; 88 std::unique_ptr<Module> outputM; 89 90 // run with the default inliner 91 auto run_default(StringRef IR) { 92 PluginInlineAdvisorAnalysis::HasBeenRegistered = false; 93 outputM = parseAssemblyString(IR, Error, Ctx); 94 MPM.run(*outputM, MAM); 95 ASSERT_TRUE(outputM); 96 output.clear(); 97 raw_string_ostream o_stream{output}; 98 outputM->print(o_stream, nullptr); 99 ASSERT_TRUE(true); 100 } 101 102 // run with the dnamic inliner 103 auto run_dynamic(StringRef IR) { 104 // note typically the constructor for the DynamicInlineAdvisorAnalysis 105 // will automatically set this to true, we controll it here only to 106 // altenate between the default and dynamic inliner in our test 107 PluginInlineAdvisorAnalysis::HasBeenRegistered = true; 108 outputM = parseAssemblyString(IR, Error, Ctx); 109 MPM.run(*outputM, MAM); 110 ASSERT_TRUE(outputM); 111 output.clear(); 112 raw_string_ostream o_stream{output}; 113 outputM->print(o_stream, nullptr); 114 ASSERT_TRUE(true); 115 } 116 }; 117 118 StringRef TestIRS[] = { 119 // Simple 3 function inline case 120 R"( 121 define void @f1() { 122 call void @foo() 123 ret void 124 } 125 define void @foo() { 126 call void @f3() 127 ret void 128 } 129 define void @f3() { 130 ret void 131 } 132 )", 133 // Test that has 5 functions of which 2 are recursive 134 R"( 135 define void @f1() { 136 call void @foo() 137 ret void 138 } 139 define void @f2() { 140 call void @foo() 141 ret void 142 } 143 define void @foo() { 144 call void @f4() 145 call void @f5() 146 ret void 147 } 148 define void @f4() { 149 ret void 150 } 151 define void @f5() { 152 call void @foo() 153 ret void 154 } 155 )", 156 // test with 2 mutually recursive functions and 1 function with a loop 157 R"( 158 define void @f1() { 159 call void @f2() 160 ret void 161 } 162 define void @f2() { 163 call void @f3() 164 ret void 165 } 166 define void @f3() { 167 call void @f1() 168 ret void 169 } 170 define void @f4() { 171 br label %loop 172 loop: 173 call void @f5() 174 br label %loop 175 } 176 define void @f5() { 177 ret void 178 } 179 )", 180 // test that has a function that computes fibonacci in a loop, one in a 181 // recurisve manner, and one that calls both and compares them 182 R"( 183 define i32 @fib_loop(i32 %n){ 184 %curr = alloca i32 185 %last = alloca i32 186 %i = alloca i32 187 store i32 1, i32* %curr 188 store i32 1, i32* %last 189 store i32 2, i32* %i 190 br label %loop_cond 191 loop_cond: 192 %i_val = load i32, i32* %i 193 %cmp = icmp slt i32 %i_val, %n 194 br i1 %cmp, label %loop_body, label %loop_end 195 loop_body: 196 %curr_val = load i32, i32* %curr 197 %last_val = load i32, i32* %last 198 %add = add i32 %curr_val, %last_val 199 store i32 %add, i32* %last 200 store i32 %curr_val, i32* %curr 201 %i_val2 = load i32, i32* %i 202 %add2 = add i32 %i_val2, 1 203 store i32 %add2, i32* %i 204 br label %loop_cond 205 loop_end: 206 %curr_val3 = load i32, i32* %curr 207 ret i32 %curr_val3 208 } 209 210 define i32 @fib_rec(i32 %n){ 211 %cmp = icmp eq i32 %n, 0 212 %cmp2 = icmp eq i32 %n, 1 213 %or = or i1 %cmp, %cmp2 214 br i1 %or, label %if_true, label %if_false 215 if_true: 216 ret i32 1 217 if_false: 218 %sub = sub i32 %n, 1 219 %call = call i32 @fib_rec(i32 %sub) 220 %sub2 = sub i32 %n, 2 221 %call2 = call i32 @fib_rec(i32 %sub2) 222 %add = add i32 %call, %call2 223 ret i32 %add 224 } 225 226 define i32 @fib_check(){ 227 %correct = alloca i32 228 %i = alloca i32 229 store i32 1, i32* %correct 230 store i32 0, i32* %i 231 br label %loop_cond 232 loop_cond: 233 %i_val = load i32, i32* %i 234 %cmp = icmp slt i32 %i_val, 10 235 br i1 %cmp, label %loop_body, label %loop_end 236 loop_body: 237 %i_val2 = load i32, i32* %i 238 %call = call i32 @fib_loop(i32 %i_val2) 239 %i_val3 = load i32, i32* %i 240 %call2 = call i32 @fib_rec(i32 %i_val3) 241 %cmp2 = icmp ne i32 %call, %call2 242 br i1 %cmp2, label %if_true, label %if_false 243 if_true: 244 store i32 0, i32* %correct 245 br label %if_end 246 if_false: 247 br label %if_end 248 if_end: 249 %i_val4 = load i32, i32* %i 250 %add = add i32 %i_val4, 1 251 store i32 %add, i32* %i 252 br label %loop_cond 253 loop_end: 254 %correct_val = load i32, i32* %correct 255 ret i32 %correct_val 256 } 257 )"}; 258 259 // check that loading a plugin works 260 // the plugin being loaded acts identically to the default inliner 261 TEST(PluginInlineAdvisorTest, PluginLoad) { 262 #if !defined(LLVM_ENABLE_PLUGINS) 263 // Disable the test if plugins are disabled. 264 return; 265 #endif 266 CompilerInstance CI{}; 267 CI.setupPlugin(); 268 269 for (StringRef IR : TestIRS) { 270 CI.run_default(IR); 271 std::string default_output = CI.output; 272 CI.run_dynamic(IR); 273 std::string dynamic_output = CI.output; 274 ASSERT_EQ(default_output, dynamic_output); 275 } 276 } 277 278 // check that the behaviour of a custom inliner is correct 279 // the custom inliner inlines all functions that are not named "foo" 280 // this testdoes not require plugins to be enabled 281 TEST(PluginInlineAdvisorTest, CustomAdvisor) { 282 CompilerInstance CI{}; 283 CI.setupFooOnly(); 284 285 for (StringRef IR : TestIRS) { 286 CI.run_dynamic(IR); 287 CallGraph CGraph = CallGraph(*CI.outputM); 288 for (auto &node : CGraph) { 289 for (auto &edge : *node.second) { 290 if (!edge.first) 291 continue; 292 ASSERT_NE(edge.second->getFunction()->getName(), "foo"); 293 } 294 } 295 } 296 } 297 298 } // namespace llvm 299