xref: /llvm-project/llvm/unittests/Analysis/PluginInlineAdvisorAnalysisTest.cpp (revision 74deadf19650f6f3b6392ba09caa20dd38ae41e0)
1 #include "llvm/Analysis/CallGraph.h"
2 #include "llvm/AsmParser/Parser.h"
3 #include "llvm/Config/config.h"
4 #include "llvm/IR/Module.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/Passes/PassPlugin.h"
7 #include "llvm/Support/CommandLine.h"
8 #include "llvm/Support/raw_ostream.h"
9 #include "llvm/Testing/Support/Error.h"
10 #include "gtest/gtest.h"
11 
12 namespace llvm {
13 
14 namespace {
15 
16 void anchor() {}
17 
18 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") {
19   const auto &Argvs = testing::internal::GetArgvs();
20   const char *Argv0 =
21       Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest";
22   void *Ptr = (void *)(intptr_t)anchor;
23   std::string Path = sys::fs::getMainExecutable(Argv0, Ptr);
24   llvm::SmallString<256> Buf{sys::path::parent_path(Path)};
25   sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str());
26   return std::string(Buf.str());
27 }
28 
29 // Example of a custom InlineAdvisor that only inlines calls to functions called
30 // "foo".
31 class FooOnlyInlineAdvisor : public InlineAdvisor {
32 public:
33   FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM,
34                        InlineParams Params, InlineContext IC)
35       : InlineAdvisor(M, FAM, IC) {}
36 
37   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override {
38     if (CB.getCalledFunction()->getName() == "foo")
39       return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true);
40     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
41   }
42 };
43 
44 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM,
45                                      InlineParams Params, InlineContext IC) {
46   return new FooOnlyInlineAdvisor(M, FAM, Params, IC);
47 }
48 
49 struct CompilerInstance {
50   LLVMContext Ctx;
51   ModulePassManager MPM;
52   InlineParams IP;
53 
54   PassBuilder PB;
55   LoopAnalysisManager LAM;
56   FunctionAnalysisManager FAM;
57   CGSCCAnalysisManager CGAM;
58   ModuleAnalysisManager MAM;
59 
60   SMDiagnostic Error;
61 
62   // connect the plugin to our compiler instance
63   void setupPlugin() {
64     auto PluginPath = libPath();
65     ASSERT_NE("", PluginPath);
66     Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath);
67     ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath;
68     Plugin->registerPassBuilderCallbacks(PB);
69     ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"),
70                       Succeeded());
71   }
72 
73   // connect the FooOnlyInlineAdvisor to our compiler instance
74   void setupFooOnly() {
75     MAM.registerPass(
76         [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); });
77   }
78 
79   CompilerInstance() {
80     IP = getInlineParams(3, 0);
81     PB.registerModuleAnalyses(MAM);
82     PB.registerCGSCCAnalyses(CGAM);
83     PB.registerFunctionAnalyses(FAM);
84     PB.registerLoopAnalyses(LAM);
85     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
86     MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default,
87                                   ThinOrFullLTOPhase::None));
88   }
89 
90   ~CompilerInstance() {
91     // Reset the static variable that tracks if the plugin has been registered.
92     // This is needed to allow the test to run multiple times.
93     PluginInlineAdvisorAnalysis::HasBeenRegistered = false;
94   }
95 
96   std::string output;
97   std::unique_ptr<Module> outputM;
98 
99   // run with the default inliner
100   auto run_default(StringRef IR) {
101     PluginInlineAdvisorAnalysis::HasBeenRegistered = false;
102     outputM = parseAssemblyString(IR, Error, Ctx);
103     MPM.run(*outputM, MAM);
104     ASSERT_TRUE(outputM);
105     output.clear();
106     raw_string_ostream o_stream{output};
107     outputM->print(o_stream, nullptr);
108     ASSERT_TRUE(true);
109   }
110 
111   // run with the dnamic inliner
112   auto run_dynamic(StringRef IR) {
113     // note typically the constructor for the DynamicInlineAdvisorAnalysis
114     // will automatically set this to true, we controll it here only to
115     // altenate between the default and dynamic inliner in our test
116     PluginInlineAdvisorAnalysis::HasBeenRegistered = true;
117     outputM = parseAssemblyString(IR, Error, Ctx);
118     MPM.run(*outputM, MAM);
119     ASSERT_TRUE(outputM);
120     output.clear();
121     raw_string_ostream o_stream{output};
122     outputM->print(o_stream, nullptr);
123     ASSERT_TRUE(true);
124   }
125 };
126 
127 StringRef TestIRS[] = {
128     // Simple 3 function inline case
129     R"(
130 define void @f1() {
131   call void @foo()
132   ret void
133 }
134 define void @foo() {
135   call void @f3()
136   ret void
137 }
138 define void @f3() {
139   ret void
140 }
141   )",
142     // Test that has 5 functions of which 2 are recursive
143     R"(
144 define void @f1() {
145   call void @foo()
146   ret void
147 }
148 define void @f2() {
149   call void @foo()
150   ret void
151 }
152 define void @foo() {
153   call void @f4()
154   call void @f5()
155   ret void
156 }
157 define void @f4() {
158   ret void
159 }
160 define void @f5() {
161   call void @foo()
162   ret void
163 }
164   )",
165     // test with 2 mutually recursive functions and 1 function with a loop
166     R"(
167 define void @f1() {
168   call void @f2()
169   ret void
170 }
171 define void @f2() {
172   call void @f3()
173   ret void
174 }
175 define void @f3() {
176   call void @f1()
177   ret void
178 }
179 define void @f4() {
180   br label %loop
181 loop:
182   call void @f5()
183   br label %loop
184 }
185 define void @f5() {
186   ret void
187 }
188   )",
189     // test that has a function that computes fibonacci in a loop, one in a
190     // recurisve manner, and one that calls both and compares them
191     R"(
192 define i32 @fib_loop(i32 %n){
193     %curr = alloca i32
194     %last = alloca i32
195     %i = alloca i32
196     store i32 1, i32* %curr
197     store i32 1, i32* %last
198     store i32 2, i32* %i
199     br label %loop_cond
200   loop_cond:
201     %i_val = load i32, i32* %i
202     %cmp = icmp slt i32 %i_val, %n
203     br i1 %cmp, label %loop_body, label %loop_end
204   loop_body:
205     %curr_val = load i32, i32* %curr
206     %last_val = load i32, i32* %last
207     %add = add i32 %curr_val, %last_val
208     store i32 %add, i32* %last
209     store i32 %curr_val, i32* %curr
210     %i_val2 = load i32, i32* %i
211     %add2 = add i32 %i_val2, 1
212     store i32 %add2, i32* %i
213     br label %loop_cond
214   loop_end:
215     %curr_val3 = load i32, i32* %curr
216     ret i32 %curr_val3
217 }
218 
219 define i32 @fib_rec(i32 %n){
220     %cmp = icmp eq i32 %n, 0
221     %cmp2 = icmp eq i32 %n, 1
222     %or = or i1 %cmp, %cmp2
223     br i1 %or, label %if_true, label %if_false
224   if_true:
225     ret i32 1
226   if_false:
227     %sub = sub i32 %n, 1
228     %call = call i32 @fib_rec(i32 %sub)
229     %sub2 = sub i32 %n, 2
230     %call2 = call i32 @fib_rec(i32 %sub2)
231     %add = add i32 %call, %call2
232     ret i32 %add
233 }
234 
235 define i32 @fib_check(){
236     %correct = alloca i32
237     %i = alloca i32
238     store i32 1, i32* %correct
239     store i32 0, i32* %i
240     br label %loop_cond
241   loop_cond:
242     %i_val = load i32, i32* %i
243     %cmp = icmp slt i32 %i_val, 10
244     br i1 %cmp, label %loop_body, label %loop_end
245   loop_body:
246     %i_val2 = load i32, i32* %i
247     %call = call i32 @fib_loop(i32 %i_val2)
248     %i_val3 = load i32, i32* %i
249     %call2 = call i32 @fib_rec(i32 %i_val3)
250     %cmp2 = icmp ne i32 %call, %call2
251     br i1 %cmp2, label %if_true, label %if_false
252   if_true:
253     store i32 0, i32* %correct
254     br label %if_end
255   if_false:
256     br label %if_end
257   if_end:
258     %i_val4 = load i32, i32* %i
259     %add = add i32 %i_val4, 1
260     store i32 %add, i32* %i
261     br label %loop_cond
262   loop_end:
263     %correct_val = load i32, i32* %correct
264     ret i32 %correct_val
265 }
266   )"};
267 
268 } // namespace
269 
270 // check that loading a plugin works
271 // the plugin being loaded acts identically to the default inliner
272 TEST(PluginInlineAdvisorTest, PluginLoad) {
273 #if !defined(LLVM_ENABLE_PLUGINS)
274   // Skip the test if plugins are disabled.
275   GTEST_SKIP();
276 #endif
277   CompilerInstance CI{};
278   CI.setupPlugin();
279 
280   for (StringRef IR : TestIRS) {
281     CI.run_default(IR);
282     std::string default_output = CI.output;
283     CI.run_dynamic(IR);
284     std::string dynamic_output = CI.output;
285     ASSERT_EQ(default_output, dynamic_output);
286   }
287 }
288 
289 // check that the behaviour of a custom inliner is correct
290 // the custom inliner inlines all functions that are not named "foo"
291 // this testdoes not require plugins to be enabled
292 TEST(PluginInlineAdvisorTest, CustomAdvisor) {
293   CompilerInstance CI{};
294   CI.setupFooOnly();
295 
296   for (StringRef IR : TestIRS) {
297     CI.run_dynamic(IR);
298     CallGraph CGraph = CallGraph(*CI.outputM);
299     for (auto &node : CGraph) {
300       for (auto &edge : *node.second) {
301         if (!edge.first)
302           continue;
303         ASSERT_NE(edge.second->getFunction()->getName(), "foo");
304       }
305     }
306   }
307 }
308 
309 } // namespace llvm
310