1 #include "llvm/Analysis/CallGraph.h" 2 #include "llvm/AsmParser/Parser.h" 3 #include "llvm/Config/config.h" 4 #include "llvm/IR/Module.h" 5 #include "llvm/Passes/PassBuilder.h" 6 #include "llvm/Passes/PassPlugin.h" 7 #include "llvm/Support/CommandLine.h" 8 #include "llvm/Support/raw_ostream.h" 9 #include "llvm/Testing/Support/Error.h" 10 #include "gtest/gtest.h" 11 12 namespace llvm { 13 14 namespace { 15 16 void anchor() {} 17 18 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") { 19 const auto &Argvs = testing::internal::GetArgvs(); 20 const char *Argv0 = 21 Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest"; 22 void *Ptr = (void *)(intptr_t)anchor; 23 std::string Path = sys::fs::getMainExecutable(Argv0, Ptr); 24 llvm::SmallString<256> Buf{sys::path::parent_path(Path)}; 25 sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str()); 26 return std::string(Buf.str()); 27 } 28 29 // Example of a custom InlineAdvisor that only inlines calls to functions called 30 // "foo". 31 class FooOnlyInlineAdvisor : public InlineAdvisor { 32 public: 33 FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM, 34 InlineParams Params, InlineContext IC) 35 : InlineAdvisor(M, FAM, IC) {} 36 37 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override { 38 if (CB.getCalledFunction()->getName() == "foo") 39 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true); 40 return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false); 41 } 42 }; 43 44 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM, 45 InlineParams Params, InlineContext IC) { 46 return new FooOnlyInlineAdvisor(M, FAM, Params, IC); 47 } 48 49 struct CompilerInstance { 50 LLVMContext Ctx; 51 ModulePassManager MPM; 52 InlineParams IP; 53 54 PassBuilder PB; 55 LoopAnalysisManager LAM; 56 FunctionAnalysisManager FAM; 57 CGSCCAnalysisManager CGAM; 58 ModuleAnalysisManager MAM; 59 60 SMDiagnostic Error; 61 62 // connect the plugin to our compiler instance 63 void setupPlugin() { 64 auto PluginPath = libPath(); 65 ASSERT_NE("", PluginPath); 66 Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath); 67 ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath; 68 Plugin->registerPassBuilderCallbacks(PB); 69 } 70 71 // connect the FooOnlyInlineAdvisor to our compiler instance 72 void setupFooOnly() { 73 MAM.registerPass( 74 [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); }); 75 } 76 77 CompilerInstance() { 78 IP = getInlineParams(3, 0); 79 PB.registerModuleAnalyses(MAM); 80 PB.registerCGSCCAnalyses(CGAM); 81 PB.registerFunctionAnalyses(FAM); 82 PB.registerLoopAnalyses(LAM); 83 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 84 MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default, 85 ThinOrFullLTOPhase::None)); 86 } 87 88 std::string output; 89 std::unique_ptr<Module> outputM; 90 91 auto run(StringRef IR) { 92 outputM = parseAssemblyString(IR, Error, Ctx); 93 MPM.run(*outputM, MAM); 94 ASSERT_TRUE(outputM); 95 output.clear(); 96 raw_string_ostream o_stream{output}; 97 outputM->print(o_stream, nullptr); 98 ASSERT_TRUE(true); 99 } 100 }; 101 102 StringRef TestIRS[] = { 103 // Simple 3 function inline case 104 R"( 105 define void @f1() { 106 call void @foo() 107 ret void 108 } 109 define void @foo() { 110 call void @f3() 111 ret void 112 } 113 define void @f3() { 114 ret void 115 } 116 )", 117 // Test that has 5 functions of which 2 are recursive 118 R"( 119 define void @f1() { 120 call void @foo() 121 ret void 122 } 123 define void @f2() { 124 call void @foo() 125 ret void 126 } 127 define void @foo() { 128 call void @f4() 129 call void @f5() 130 ret void 131 } 132 define void @f4() { 133 ret void 134 } 135 define void @f5() { 136 call void @foo() 137 ret void 138 } 139 )", 140 // test with 2 mutually recursive functions and 1 function with a loop 141 R"( 142 define void @f1() { 143 call void @f2() 144 ret void 145 } 146 define void @f2() { 147 call void @f3() 148 ret void 149 } 150 define void @f3() { 151 call void @f1() 152 ret void 153 } 154 define void @f4() { 155 br label %loop 156 loop: 157 call void @f5() 158 br label %loop 159 } 160 define void @f5() { 161 ret void 162 } 163 )", 164 // test that has a function that computes fibonacci in a loop, one in a 165 // recurisve manner, and one that calls both and compares them 166 R"( 167 define i32 @fib_loop(i32 %n){ 168 %curr = alloca i32 169 %last = alloca i32 170 %i = alloca i32 171 store i32 1, i32* %curr 172 store i32 1, i32* %last 173 store i32 2, i32* %i 174 br label %loop_cond 175 loop_cond: 176 %i_val = load i32, i32* %i 177 %cmp = icmp slt i32 %i_val, %n 178 br i1 %cmp, label %loop_body, label %loop_end 179 loop_body: 180 %curr_val = load i32, i32* %curr 181 %last_val = load i32, i32* %last 182 %add = add i32 %curr_val, %last_val 183 store i32 %add, i32* %last 184 store i32 %curr_val, i32* %curr 185 %i_val2 = load i32, i32* %i 186 %add2 = add i32 %i_val2, 1 187 store i32 %add2, i32* %i 188 br label %loop_cond 189 loop_end: 190 %curr_val3 = load i32, i32* %curr 191 ret i32 %curr_val3 192 } 193 194 define i32 @fib_rec(i32 %n){ 195 %cmp = icmp eq i32 %n, 0 196 %cmp2 = icmp eq i32 %n, 1 197 %or = or i1 %cmp, %cmp2 198 br i1 %or, label %if_true, label %if_false 199 if_true: 200 ret i32 1 201 if_false: 202 %sub = sub i32 %n, 1 203 %call = call i32 @fib_rec(i32 %sub) 204 %sub2 = sub i32 %n, 2 205 %call2 = call i32 @fib_rec(i32 %sub2) 206 %add = add i32 %call, %call2 207 ret i32 %add 208 } 209 210 define i32 @fib_check(){ 211 %correct = alloca i32 212 %i = alloca i32 213 store i32 1, i32* %correct 214 store i32 0, i32* %i 215 br label %loop_cond 216 loop_cond: 217 %i_val = load i32, i32* %i 218 %cmp = icmp slt i32 %i_val, 10 219 br i1 %cmp, label %loop_body, label %loop_end 220 loop_body: 221 %i_val2 = load i32, i32* %i 222 %call = call i32 @fib_loop(i32 %i_val2) 223 %i_val3 = load i32, i32* %i 224 %call2 = call i32 @fib_rec(i32 %i_val3) 225 %cmp2 = icmp ne i32 %call, %call2 226 br i1 %cmp2, label %if_true, label %if_false 227 if_true: 228 store i32 0, i32* %correct 229 br label %if_end 230 if_false: 231 br label %if_end 232 if_end: 233 %i_val4 = load i32, i32* %i 234 %add = add i32 %i_val4, 1 235 store i32 %add, i32* %i 236 br label %loop_cond 237 loop_end: 238 %correct_val = load i32, i32* %correct 239 ret i32 %correct_val 240 } 241 )"}; 242 243 } // namespace 244 245 // check that loading a plugin works 246 // the plugin being loaded acts identically to the default inliner 247 TEST(PluginInlineAdvisorTest, PluginLoad) { 248 #if !defined(LLVM_ENABLE_PLUGINS) 249 // Skip the test if plugins are disabled. 250 GTEST_SKIP(); 251 #endif 252 CompilerInstance DefaultCI{}; 253 254 CompilerInstance PluginCI{}; 255 PluginCI.setupPlugin(); 256 257 for (StringRef IR : TestIRS) { 258 DefaultCI.run(IR); 259 std::string default_output = DefaultCI.output; 260 PluginCI.run(IR); 261 std::string dynamic_output = PluginCI.output; 262 ASSERT_EQ(default_output, dynamic_output); 263 } 264 } 265 266 // check that the behaviour of a custom inliner is correct 267 // the custom inliner inlines all functions that are not named "foo" 268 // this testdoes not require plugins to be enabled 269 TEST(PluginInlineAdvisorTest, CustomAdvisor) { 270 CompilerInstance CI{}; 271 CI.setupFooOnly(); 272 273 for (StringRef IR : TestIRS) { 274 CI.run(IR); 275 CallGraph CGraph = CallGraph(*CI.outputM); 276 for (auto &node : CGraph) { 277 for (auto &edge : *node.second) { 278 if (!edge.first) 279 continue; 280 ASSERT_NE(edge.second->getFunction()->getName(), "foo"); 281 } 282 } 283 } 284 } 285 286 } // namespace llvm 287