xref: /llvm-project/llvm/unittests/Analysis/PluginInlineAdvisorAnalysisTest.cpp (revision 65f7ebe72e4ca1b788e13dfbd2f71b5beeffba7d)
1 #include "llvm/Analysis/CallGraph.h"
2 #include "llvm/AsmParser/Parser.h"
3 #include "llvm/Config/config.h"
4 #include "llvm/Passes/PassBuilder.h"
5 #include "llvm/Passes/PassPlugin.h"
6 #include "llvm/Support/CommandLine.h"
7 #include "llvm/Support/raw_ostream.h"
8 #include "llvm/Testing/Support/Error.h"
9 #include "gtest/gtest.h"
10 
11 namespace llvm {
12 
13 namespace {
14 
15 void anchor() {}
16 
17 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") {
18   const auto &Argvs = testing::internal::GetArgvs();
19   const char *Argv0 =
20       Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest";
21   void *Ptr = (void *)(intptr_t)anchor;
22   std::string Path = sys::fs::getMainExecutable(Argv0, Ptr);
23   llvm::SmallString<256> Buf{sys::path::parent_path(Path)};
24   sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str());
25   return std::string(Buf.str());
26 }
27 
28 // Example of a custom InlineAdvisor that only inlines calls to functions called
29 // "foo".
30 class FooOnlyInlineAdvisor : public InlineAdvisor {
31 public:
32   FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM,
33                        InlineParams Params, InlineContext IC)
34       : InlineAdvisor(M, FAM, IC) {}
35 
36   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override {
37     if (CB.getCalledFunction()->getName() == "foo")
38       return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true);
39     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
40   }
41 };
42 
43 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM,
44                                      InlineParams Params, InlineContext IC) {
45   return new FooOnlyInlineAdvisor(M, FAM, Params, IC);
46 }
47 
48 struct CompilerInstance {
49   LLVMContext Ctx;
50   ModulePassManager MPM;
51   InlineParams IP;
52 
53   PassBuilder PB;
54   LoopAnalysisManager LAM;
55   FunctionAnalysisManager FAM;
56   CGSCCAnalysisManager CGAM;
57   ModuleAnalysisManager MAM;
58 
59   SMDiagnostic Error;
60 
61   // connect the plugin to our compiler instance
62   void setupPlugin() {
63     auto PluginPath = libPath();
64     ASSERT_NE("", PluginPath);
65     Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath);
66     ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath;
67     Plugin->registerPassBuilderCallbacks(PB);
68     ASSERT_THAT_ERROR(PB.parsePassPipeline(MPM, "dynamic-inline-advisor"),
69                       Succeeded());
70   }
71 
72   // connect the FooOnlyInlineAdvisor to our compiler instance
73   void setupFooOnly() {
74     MAM.registerPass(
75         [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); });
76   }
77 
78   CompilerInstance() {
79     IP = getInlineParams(3, 0);
80     PB.registerModuleAnalyses(MAM);
81     PB.registerCGSCCAnalyses(CGAM);
82     PB.registerFunctionAnalyses(FAM);
83     PB.registerLoopAnalyses(LAM);
84     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
85     MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default,
86                                   ThinOrFullLTOPhase::None));
87   }
88 
89   ~CompilerInstance() {
90     // Reset the static variable that tracks if the plugin has been registered.
91     // This is needed to allow the test to run multiple times.
92     PluginInlineAdvisorAnalysis::HasBeenRegistered = false;
93   }
94 
95   std::string output;
96   std::unique_ptr<Module> outputM;
97 
98   // run with the default inliner
99   auto run_default(StringRef IR) {
100     PluginInlineAdvisorAnalysis::HasBeenRegistered = false;
101     outputM = parseAssemblyString(IR, Error, Ctx);
102     MPM.run(*outputM, MAM);
103     ASSERT_TRUE(outputM);
104     output.clear();
105     raw_string_ostream o_stream{output};
106     outputM->print(o_stream, nullptr);
107     ASSERT_TRUE(true);
108   }
109 
110   // run with the dnamic inliner
111   auto run_dynamic(StringRef IR) {
112     // note typically the constructor for the DynamicInlineAdvisorAnalysis
113     // will automatically set this to true, we controll it here only to
114     // altenate between the default and dynamic inliner in our test
115     PluginInlineAdvisorAnalysis::HasBeenRegistered = true;
116     outputM = parseAssemblyString(IR, Error, Ctx);
117     MPM.run(*outputM, MAM);
118     ASSERT_TRUE(outputM);
119     output.clear();
120     raw_string_ostream o_stream{output};
121     outputM->print(o_stream, nullptr);
122     ASSERT_TRUE(true);
123   }
124 };
125 
126 StringRef TestIRS[] = {
127     // Simple 3 function inline case
128     R"(
129 define void @f1() {
130   call void @foo()
131   ret void
132 }
133 define void @foo() {
134   call void @f3()
135   ret void
136 }
137 define void @f3() {
138   ret void
139 }
140   )",
141     // Test that has 5 functions of which 2 are recursive
142     R"(
143 define void @f1() {
144   call void @foo()
145   ret void
146 }
147 define void @f2() {
148   call void @foo()
149   ret void
150 }
151 define void @foo() {
152   call void @f4()
153   call void @f5()
154   ret void
155 }
156 define void @f4() {
157   ret void
158 }
159 define void @f5() {
160   call void @foo()
161   ret void
162 }
163   )",
164     // test with 2 mutually recursive functions and 1 function with a loop
165     R"(
166 define void @f1() {
167   call void @f2()
168   ret void
169 }
170 define void @f2() {
171   call void @f3()
172   ret void
173 }
174 define void @f3() {
175   call void @f1()
176   ret void
177 }
178 define void @f4() {
179   br label %loop
180 loop:
181   call void @f5()
182   br label %loop
183 }
184 define void @f5() {
185   ret void
186 }
187   )",
188     // test that has a function that computes fibonacci in a loop, one in a
189     // recurisve manner, and one that calls both and compares them
190     R"(
191 define i32 @fib_loop(i32 %n){
192     %curr = alloca i32
193     %last = alloca i32
194     %i = alloca i32
195     store i32 1, i32* %curr
196     store i32 1, i32* %last
197     store i32 2, i32* %i
198     br label %loop_cond
199   loop_cond:
200     %i_val = load i32, i32* %i
201     %cmp = icmp slt i32 %i_val, %n
202     br i1 %cmp, label %loop_body, label %loop_end
203   loop_body:
204     %curr_val = load i32, i32* %curr
205     %last_val = load i32, i32* %last
206     %add = add i32 %curr_val, %last_val
207     store i32 %add, i32* %last
208     store i32 %curr_val, i32* %curr
209     %i_val2 = load i32, i32* %i
210     %add2 = add i32 %i_val2, 1
211     store i32 %add2, i32* %i
212     br label %loop_cond
213   loop_end:
214     %curr_val3 = load i32, i32* %curr
215     ret i32 %curr_val3
216 }
217 
218 define i32 @fib_rec(i32 %n){
219     %cmp = icmp eq i32 %n, 0
220     %cmp2 = icmp eq i32 %n, 1
221     %or = or i1 %cmp, %cmp2
222     br i1 %or, label %if_true, label %if_false
223   if_true:
224     ret i32 1
225   if_false:
226     %sub = sub i32 %n, 1
227     %call = call i32 @fib_rec(i32 %sub)
228     %sub2 = sub i32 %n, 2
229     %call2 = call i32 @fib_rec(i32 %sub2)
230     %add = add i32 %call, %call2
231     ret i32 %add
232 }
233 
234 define i32 @fib_check(){
235     %correct = alloca i32
236     %i = alloca i32
237     store i32 1, i32* %correct
238     store i32 0, i32* %i
239     br label %loop_cond
240   loop_cond:
241     %i_val = load i32, i32* %i
242     %cmp = icmp slt i32 %i_val, 10
243     br i1 %cmp, label %loop_body, label %loop_end
244   loop_body:
245     %i_val2 = load i32, i32* %i
246     %call = call i32 @fib_loop(i32 %i_val2)
247     %i_val3 = load i32, i32* %i
248     %call2 = call i32 @fib_rec(i32 %i_val3)
249     %cmp2 = icmp ne i32 %call, %call2
250     br i1 %cmp2, label %if_true, label %if_false
251   if_true:
252     store i32 0, i32* %correct
253     br label %if_end
254   if_false:
255     br label %if_end
256   if_end:
257     %i_val4 = load i32, i32* %i
258     %add = add i32 %i_val4, 1
259     store i32 %add, i32* %i
260     br label %loop_cond
261   loop_end:
262     %correct_val = load i32, i32* %correct
263     ret i32 %correct_val
264 }
265   )"};
266 
267 } // namespace
268 
269 // check that loading a plugin works
270 // the plugin being loaded acts identically to the default inliner
271 TEST(PluginInlineAdvisorTest, PluginLoad) {
272 #if !defined(LLVM_ENABLE_PLUGINS)
273   // Skip the test if plugins are disabled.
274   GTEST_SKIP();
275 #endif
276   CompilerInstance CI{};
277   CI.setupPlugin();
278 
279   for (StringRef IR : TestIRS) {
280     CI.run_default(IR);
281     std::string default_output = CI.output;
282     CI.run_dynamic(IR);
283     std::string dynamic_output = CI.output;
284     ASSERT_EQ(default_output, dynamic_output);
285   }
286 }
287 
288 // check that the behaviour of a custom inliner is correct
289 // the custom inliner inlines all functions that are not named "foo"
290 // this testdoes not require plugins to be enabled
291 TEST(PluginInlineAdvisorTest, CustomAdvisor) {
292   CompilerInstance CI{};
293   CI.setupFooOnly();
294 
295   for (StringRef IR : TestIRS) {
296     CI.run_dynamic(IR);
297     CallGraph CGraph = CallGraph(*CI.outputM);
298     for (auto &node : CGraph) {
299       for (auto &edge : *node.second) {
300         if (!edge.first)
301           continue;
302         ASSERT_NE(edge.second->getFunction()->getName(), "foo");
303       }
304     }
305   }
306 }
307 
308 } // namespace llvm
309