1 #include "llvm/Analysis/CallGraph.h" 2 #include "llvm/AsmParser/Parser.h" 3 #include "llvm/Config/config.h" 4 #include "llvm/IR/Module.h" 5 #include "llvm/Passes/PassBuilder.h" 6 #include "llvm/Passes/PassPlugin.h" 7 #include "llvm/Support/CommandLine.h" 8 #include "llvm/Support/raw_ostream.h" 9 #include "llvm/Testing/Support/Error.h" 10 #include "gtest/gtest.h" 11 12 namespace llvm { 13 14 namespace { 15 16 void anchor() {} 17 18 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") { 19 const auto &Argvs = testing::internal::GetArgvs(); 20 const char *Argv0 = 21 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest"; 22 void *Ptr = (void *)(intptr_t)anchor; 23 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 24 llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 25 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 26 return std::string(Buf.str()); 27 } 28 29 // Example of a custom InlineAdvisor that only inlines calls to functions called 30 // "foo". 31 class FooOnlyInlineAdvisor : public InlineAdvisor { 32 public: 33 FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM, 34 InlineParams Params, InlineContext IC) 35 : InlineAdvisor(M, FAM, IC) {} 36 37 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override { 38 if (CB.getCalledFunction()->getName() == "foo") 39 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true); 40 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 41 } 42 }; 43 44 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM, 45 InlineParams Params, InlineContext IC) { 46 return new FooOnlyInlineAdvisor(M, FAM, Params, IC); 47 } 48 49 struct CompilerInstance { 50 LLVMContext Ctx; 51 ModulePassManager MPM; 52 InlineParams IP; 53 54 PassBuilder PB; 55 LoopAnalysisManager LAM; 56 FunctionAnalysisManager FAM; 57 CGSCCAnalysisManager CGAM; 58 ModuleAnalysisManager MAM; 59 60 SMDiagnostic Error; 61 62 // connect the plugin to our compiler instance 63 void setupPlugin() { 64 auto PluginPath = libPath(); 65 ASSERT_NE("", PluginPath); 66 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 67 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 68 Plugin->registerPassBuilderCallbacks(PB); 69 ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"), 70 Succeeded()); 71 } 72 73 // connect the FooOnlyInlineAdvisor to our compiler instance 74 void setupFooOnly() { 75 MAM.registerPass( 76 [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); }); 77 } 78 79 CompilerInstance() { 80 IP = getInlineParams(3, 0); 81 PB.registerModuleAnalyses(MAM); 82 PB.registerCGSCCAnalyses(CGAM); 83 PB.registerFunctionAnalyses(FAM); 84 PB.registerLoopAnalyses(LAM); 85 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 86 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 87 ThinOrFullLTOPhase::None)); 88 } 89 90 std::string output; 91 std::unique_ptr<Module> outputM; 92 93 auto run(StringRef IR) { 94 outputM = parseAssemblyString(IR, Error, Ctx); 95 MPM.run(*outputM, MAM); 96 ASSERT_TRUE(outputM); 97 output.clear(); 98 raw_string_ostream o_stream{output}; 99 outputM->print(o_stream, nullptr); 100 ASSERT_TRUE(true); 101 } 102 }; 103 104 StringRef TestIRS[] = { 105 // Simple 3 function inline case 106 R"( 107 define void @f1() { 108 call void @foo() 109 ret void 110 } 111 define void @foo() { 112 call void @f3() 113 ret void 114 } 115 define void @f3() { 116 ret void 117 } 118 )", 119 // Test that has 5 functions of which 2 are recursive 120 R"( 121 define void @f1() { 122 call void @foo() 123 ret void 124 } 125 define void @f2() { 126 call void @foo() 127 ret void 128 } 129 define void @foo() { 130 call void @f4() 131 call void @f5() 132 ret void 133 } 134 define void @f4() { 135 ret void 136 } 137 define void @f5() { 138 call void @foo() 139 ret void 140 } 141 )", 142 // test with 2 mutually recursive functions and 1 function with a loop 143 R"( 144 define void @f1() { 145 call void @f2() 146 ret void 147 } 148 define void @f2() { 149 call void @f3() 150 ret void 151 } 152 define void @f3() { 153 call void @f1() 154 ret void 155 } 156 define void @f4() { 157 br label %loop 158 loop: 159 call void @f5() 160 br label %loop 161 } 162 define void @f5() { 163 ret void 164 } 165 )", 166 // test that has a function that computes fibonacci in a loop, one in a 167 // recurisve manner, and one that calls both and compares them 168 R"( 169 define i32 @fib_loop(i32 %n){ 170 %curr = alloca i32 171 %last = alloca i32 172 %i = alloca i32 173 store i32 1, i32* %curr 174 store i32 1, i32* %last 175 store i32 2, i32* %i 176 br label %loop_cond 177 loop_cond: 178 %i_val = load i32, i32* %i 179 %cmp = icmp slt i32 %i_val, %n 180 br i1 %cmp, label %loop_body, label %loop_end 181 loop_body: 182 %curr_val = load i32, i32* %curr 183 %last_val = load i32, i32* %last 184 %add = add i32 %curr_val, %last_val 185 store i32 %add, i32* %last 186 store i32 %curr_val, i32* %curr 187 %i_val2 = load i32, i32* %i 188 %add2 = add i32 %i_val2, 1 189 store i32 %add2, i32* %i 190 br label %loop_cond 191 loop_end: 192 %curr_val3 = load i32, i32* %curr 193 ret i32 %curr_val3 194 } 195 196 define i32 @fib_rec(i32 %n){ 197 %cmp = icmp eq i32 %n, 0 198 %cmp2 = icmp eq i32 %n, 1 199 %or = or i1 %cmp, %cmp2 200 br i1 %or, label %if_true, label %if_false 201 if_true: 202 ret i32 1 203 if_false: 204 %sub = sub i32 %n, 1 205 %call = call i32 @fib_rec(i32 %sub) 206 %sub2 = sub i32 %n, 2 207 %call2 = call i32 @fib_rec(i32 %sub2) 208 %add = add i32 %call, %call2 209 ret i32 %add 210 } 211 212 define i32 @fib_check(){ 213 %correct = alloca i32 214 %i = alloca i32 215 store i32 1, i32* %correct 216 store i32 0, i32* %i 217 br label %loop_cond 218 loop_cond: 219 %i_val = load i32, i32* %i 220 %cmp = icmp slt i32 %i_val, 10 221 br i1 %cmp, label %loop_body, label %loop_end 222 loop_body: 223 %i_val2 = load i32, i32* %i 224 %call = call i32 @fib_loop(i32 %i_val2) 225 %i_val3 = load i32, i32* %i 226 %call2 = call i32 @fib_rec(i32 %i_val3) 227 %cmp2 = icmp ne i32 %call, %call2 228 br i1 %cmp2, label %if_true, label %if_false 229 if_true: 230 store i32 0, i32* %correct 231 br label %if_end 232 if_false: 233 br label %if_end 234 if_end: 235 %i_val4 = load i32, i32* %i 236 %add = add i32 %i_val4, 1 237 store i32 %add, i32* %i 238 br label %loop_cond 239 loop_end: 240 %correct_val = load i32, i32* %correct 241 ret i32 %correct_val 242 } 243 )"}; 244 245 } // namespace 246 247 // check that loading a plugin works 248 // the plugin being loaded acts identically to the default inliner 249 TEST(PluginInlineAdvisorTest, PluginLoad) { 250 #if !defined(LLVM_ENABLE_PLUGINS) 251 // Skip the test if plugins are disabled. 252 GTEST_SKIP(); 253 #endif 254 CompilerInstance DefaultCI{}; 255 256 CompilerInstance PluginCI{}; 257 PluginCI.setupPlugin(); 258 259 for (StringRef IR : TestIRS) { 260 DefaultCI.run(IR); 261 std::string default_output = DefaultCI.output; 262 PluginCI.run(IR); 263 std::string dynamic_output = PluginCI.output; 264 ASSERT_EQ(default_output, dynamic_output); 265 } 266 } 267 268 // check that the behaviour of a custom inliner is correct 269 // the custom inliner inlines all functions that are not named "foo" 270 // this testdoes not require plugins to be enabled 271 TEST(PluginInlineAdvisorTest, CustomAdvisor) { 272 CompilerInstance CI{}; 273 CI.setupFooOnly(); 274 275 for (StringRef IR : TestIRS) { 276 CI.run(IR); 277 CallGraph CGraph = CallGraph(*CI.outputM); 278 for (auto &node : CGraph) { 279 for (auto &edge : *node.second) { 280 if (!edge.first) 281 continue; 282 ASSERT_NE(edge.second->getFunction()->getName(), "foo"); 283 } 284 } 285 } 286 } 287 288 } // namespace llvm 289