xref: /llvm-project/llvm/unittests/Analysis/PluginInlineAdvisorAnalysisTest.cpp (revision 07af0e2d3e8485ad6f84da1ad9851538b62d2942)
1 #include "llvm/Analysis/CallGraph.h"
2 #include "llvm/AsmParser/Parser.h"
3 #include "llvm/Config/config.h"
4 #include "llvm/Passes/PassBuilder.h"
5 #include "llvm/Passes/PassPlugin.h"
6 #include "llvm/Support/CommandLine.h"
7 #include "llvm/Support/raw_ostream.h"
8 #include "llvm/Testing/Support/Error.h"
9 #include "gtest/gtest.h"
10 
11 namespace llvm {
12 
13 void anchor() {}
14 
15 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") {
16   const auto &Argvs = testing::internal::GetArgvs();
17   const char *Argv0 =
18       Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest";
19   void *Ptr = (void *)(intptr_t)anchor;
20   std::string Path = sys::fs::getMainExecutable(Argv0, Ptr);
21   llvm::SmallString<256> Buf{sys::path::parent_path(Path)};
22   sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str());
23   return std::string(Buf.str());
24 }
25 
26 // Example of a custom InlineAdvisor that only inlines calls to functions called
27 // "foo".
28 class FooOnlyInlineAdvisor : public InlineAdvisor {
29 public:
30   FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM,
31                        InlineParams Params, InlineContext IC)
32       : InlineAdvisor(M, FAM, IC) {}
33 
34   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override {
35     if (CB.getCalledFunction()->getName() == "foo")
36       return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true);
37     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
38   }
39 };
40 
41 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM,
42                                      InlineParams Params, InlineContext IC) {
43   return new FooOnlyInlineAdvisor(M, FAM, Params, IC);
44 }
45 
46 struct CompilerInstance {
47   LLVMContext Ctx;
48   ModulePassManager MPM;
49   InlineParams IP;
50 
51   PassBuilder PB;
52   LoopAnalysisManager LAM;
53   FunctionAnalysisManager FAM;
54   CGSCCAnalysisManager CGAM;
55   ModuleAnalysisManager MAM;
56 
57   SMDiagnostic Error;
58 
59   // connect the plugin to our compiler instance
60   void setupPlugin() {
61     auto PluginPath = libPath();
62     ASSERT_NE("", PluginPath);
63     Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath);
64     ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath;
65     Plugin->registerPassBuilderCallbacks(PB);
66     ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"),
67                       Succeeded());
68   }
69 
70   // connect the FooOnlyInlineAdvisor to our compiler instance
71   void setupFooOnly() {
72     MAM.registerPass(
73         [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); });
74   }
75 
76   CompilerInstance() {
77     IP = getInlineParams(3, 0);
78     PB.registerModuleAnalyses(MAM);
79     PB.registerCGSCCAnalyses(CGAM);
80     PB.registerFunctionAnalyses(FAM);
81     PB.registerLoopAnalyses(LAM);
82     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
83     MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default,
84                                   ThinOrFullLTOPhase::None));
85   }
86 
87   std::string output;
88   std::unique_ptr<Module> outputM;
89 
90   // run with the default inliner
91   auto run_default(StringRef IR) {
92     PluginInlineAdvisorAnalysis::HasBeenRegistered = false;
93     outputM = parseAssemblyString(IR, Error, Ctx);
94     MPM.run(*outputM, MAM);
95     ASSERT_TRUE(outputM);
96     output.clear();
97     raw_string_ostream o_stream{output};
98     outputM->print(o_stream, nullptr);
99     ASSERT_TRUE(true);
100   }
101 
102   // run with the dnamic inliner
103   auto run_dynamic(StringRef IR) {
104     // note typically the constructor for the DynamicInlineAdvisorAnalysis
105     // will automatically set this to true, we controll it here only to
106     // altenate between the default and dynamic inliner in our test
107     PluginInlineAdvisorAnalysis::HasBeenRegistered = true;
108     outputM = parseAssemblyString(IR, Error, Ctx);
109     MPM.run(*outputM, MAM);
110     ASSERT_TRUE(outputM);
111     output.clear();
112     raw_string_ostream o_stream{output};
113     outputM->print(o_stream, nullptr);
114     ASSERT_TRUE(true);
115   }
116 };
117 
118 StringRef TestIRS[] = {
119     // Simple 3 function inline case
120     R"(
121 define void @f1() {
122   call void @foo()
123   ret void
124 }
125 define void @foo() {
126   call void @f3()
127   ret void
128 }
129 define void @f3() {
130   ret void
131 }
132   )",
133     // Test that has 5 functions of which 2 are recursive
134     R"(
135 define void @f1() {
136   call void @foo()
137   ret void
138 }
139 define void @f2() {
140   call void @foo()
141   ret void
142 }
143 define void @foo() {
144   call void @f4()
145   call void @f5()
146   ret void
147 }
148 define void @f4() {
149   ret void
150 }
151 define void @f5() {
152   call void @foo()
153   ret void
154 }
155   )",
156     // test with 2 mutually recursive functions and 1 function with a loop
157     R"(
158 define void @f1() {
159   call void @f2()
160   ret void
161 }
162 define void @f2() {
163   call void @f3()
164   ret void
165 }
166 define void @f3() {
167   call void @f1()
168   ret void
169 }
170 define void @f4() {
171   br label %loop
172 loop:
173   call void @f5()
174   br label %loop
175 }
176 define void @f5() {
177   ret void
178 }
179   )",
180     // test that has a function that computes fibonacci in a loop, one in a
181     // recurisve manner, and one that calls both and compares them
182     R"(
183 define i32 @fib_loop(i32 %n){
184     %curr = alloca i32
185     %last = alloca i32
186     %i = alloca i32
187     store i32 1, i32* %curr
188     store i32 1, i32* %last
189     store i32 2, i32* %i
190     br label %loop_cond
191   loop_cond:
192     %i_val = load i32, i32* %i
193     %cmp = icmp slt i32 %i_val, %n
194     br i1 %cmp, label %loop_body, label %loop_end
195   loop_body:
196     %curr_val = load i32, i32* %curr
197     %last_val = load i32, i32* %last
198     %add = add i32 %curr_val, %last_val
199     store i32 %add, i32* %last
200     store i32 %curr_val, i32* %curr
201     %i_val2 = load i32, i32* %i
202     %add2 = add i32 %i_val2, 1
203     store i32 %add2, i32* %i
204     br label %loop_cond
205   loop_end:
206     %curr_val3 = load i32, i32* %curr
207     ret i32 %curr_val3
208 }
209 
210 define i32 @fib_rec(i32 %n){
211     %cmp = icmp eq i32 %n, 0
212     %cmp2 = icmp eq i32 %n, 1
213     %or = or i1 %cmp, %cmp2
214     br i1 %or, label %if_true, label %if_false
215   if_true:
216     ret i32 1
217   if_false:
218     %sub = sub i32 %n, 1
219     %call = call i32 @fib_rec(i32 %sub)
220     %sub2 = sub i32 %n, 2
221     %call2 = call i32 @fib_rec(i32 %sub2)
222     %add = add i32 %call, %call2
223     ret i32 %add
224 }
225 
226 define i32 @fib_check(){
227     %correct = alloca i32
228     %i = alloca i32
229     store i32 1, i32* %correct
230     store i32 0, i32* %i
231     br label %loop_cond
232   loop_cond:
233     %i_val = load i32, i32* %i
234     %cmp = icmp slt i32 %i_val, 10
235     br i1 %cmp, label %loop_body, label %loop_end
236   loop_body:
237     %i_val2 = load i32, i32* %i
238     %call = call i32 @fib_loop(i32 %i_val2)
239     %i_val3 = load i32, i32* %i
240     %call2 = call i32 @fib_rec(i32 %i_val3)
241     %cmp2 = icmp ne i32 %call, %call2
242     br i1 %cmp2, label %if_true, label %if_false
243   if_true:
244     store i32 0, i32* %correct
245     br label %if_end
246   if_false:
247     br label %if_end
248   if_end:
249     %i_val4 = load i32, i32* %i
250     %add = add i32 %i_val4, 1
251     store i32 %add, i32* %i
252     br label %loop_cond
253   loop_end:
254     %correct_val = load i32, i32* %correct
255     ret i32 %correct_val
256 }
257   )"};
258 
259 // check that loading a plugin works
260 // the plugin being loaded acts identically to the default inliner
261 TEST(PluginInlineAdvisorTest, PluginLoad) {
262 #if !defined(LLVM_ENABLE_PLUGINS)
263   // Disable the test if plugins are disabled.
264   return;
265 #endif
266   CompilerInstance CI{};
267   CI.setupPlugin();
268 
269   for (StringRef IR : TestIRS) {
270     CI.run_default(IR);
271     std::string default_output = CI.output;
272     CI.run_dynamic(IR);
273     std::string dynamic_output = CI.output;
274     ASSERT_EQ(default_output, dynamic_output);
275   }
276 }
277 
278 // check that the behaviour of a custom inliner is correct
279 // the custom inliner inlines all functions that are not named "foo"
280 // this testdoes not require plugins to be enabled
281 TEST(PluginInlineAdvisorTest, CustomAdvisor) {
282   CompilerInstance CI{};
283   CI.setupFooOnly();
284 
285   for (StringRef IR : TestIRS) {
286     CI.run_dynamic(IR);
287     CallGraph CGraph = CallGraph(*CI.outputM);
288     for (auto &node : CGraph) {
289       for (auto &edge : *node.second) {
290         if (!edge.first)
291           continue;
292         ASSERT_NE(edge.second->getFunction()->getName(), "foo");
293       }
294     }
295   }
296 }
297 
298 } // namespace llvm
299