xref: /llvm-project/llvm/unittests/Analysis/PluginInlineAdvisorAnalysisTest.cpp (revision c84a99dfd391eb4d89aff8d6453016045098b444)
1 #include "llvm/Analysis/CallGraph.h"
2 #include "llvm/AsmParser/Parser.h"
3 #include "llvm/Config/config.h"
4 #include "llvm/IR/Module.h"
5 #include "llvm/Passes/PassBuilder.h"
6 #include "llvm/Passes/PassPlugin.h"
7 #include "llvm/Support/CommandLine.h"
8 #include "llvm/Support/raw_ostream.h"
9 #include "llvm/Testing/Support/Error.h"
10 #include "gtest/gtest.h"
11 
12 namespace llvm {
13 
14 namespace {
15 
16 void anchor() {}
17 
18 static std::string libPath(const std::string Name = "InlineAdvisorPlugin") {
19   const auto &Argvs = testing::internal::GetArgvs();
20   const char *Argv0 =
21       Argvs.size() > 0 ? Argvs[0].c_str() : "PluginInlineAdvisorAnalysisTest";
22   void *Ptr = (void *)(intptr_t)anchor;
23   std::string Path = sys::fs::getMainExecutable(Argv0, Ptr);
24   llvm::SmallString<256> Buf{sys::path::parent_path(Path)};
25   sys::path::append(Buf, (Name + LLVM_PLUGIN_EXT).c_str());
26   return std::string(Buf.str());
27 }
28 
29 // Example of a custom InlineAdvisor that only inlines calls to functions called
30 // "foo".
31 class FooOnlyInlineAdvisor : public InlineAdvisor {
32 public:
33   FooOnlyInlineAdvisor(Module &M, FunctionAnalysisManager &FAM,
34                        InlineParams Params, InlineContext IC)
35       : InlineAdvisor(M, FAM, IC) {}
36 
37   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override {
38     if (CB.getCalledFunction()->getName() == "foo")
39       return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), true);
40     return std::make_unique<InlineAdvice>(this, CB, getCallerORE(CB), false);
41   }
42 };
43 
44 static InlineAdvisor *fooOnlyFactory(Module &M, FunctionAnalysisManager &FAM,
45                                      InlineParams Params, InlineContext IC) {
46   return new FooOnlyInlineAdvisor(M, FAM, Params, IC);
47 }
48 
49 struct CompilerInstance {
50   LLVMContext Ctx;
51   ModulePassManager MPM;
52   InlineParams IP;
53 
54   PassBuilder PB;
55   LoopAnalysisManager LAM;
56   FunctionAnalysisManager FAM;
57   CGSCCAnalysisManager CGAM;
58   ModuleAnalysisManager MAM;
59 
60   SMDiagnostic Error;
61 
62   // connect the plugin to our compiler instance
63   void setupPlugin() {
64     auto PluginPath = libPath();
65     ASSERT_NE("", PluginPath);
66     Expected<PassPlugin> Plugin = PassPlugin::Load(PluginPath);
67     ASSERT_TRUE(!!Plugin) << "Plugin path: " << PluginPath;
68     Plugin->registerPassBuilderCallbacks(PB);
69   }
70 
71   // connect the FooOnlyInlineAdvisor to our compiler instance
72   void setupFooOnly() {
73     MAM.registerPass(
74         [&] { return PluginInlineAdvisorAnalysis(fooOnlyFactory); });
75   }
76 
77   CompilerInstance() {
78     IP = getInlineParams(3, 0);
79     PB.registerModuleAnalyses(MAM);
80     PB.registerCGSCCAnalyses(CGAM);
81     PB.registerFunctionAnalyses(FAM);
82     PB.registerLoopAnalyses(LAM);
83     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
84     MPM.addPass(ModuleInlinerPass(IP, InliningAdvisorMode::Default,
85                                   ThinOrFullLTOPhase::None));
86   }
87 
88   std::string output;
89   std::unique_ptr<Module> outputM;
90 
91   auto run(StringRef IR) {
92     outputM = parseAssemblyString(IR, Error, Ctx);
93     MPM.run(*outputM, MAM);
94     ASSERT_TRUE(outputM);
95     output.clear();
96     raw_string_ostream o_stream{output};
97     outputM->print(o_stream, nullptr);
98     ASSERT_TRUE(true);
99   }
100 };
101 
102 StringRef TestIRS[] = {
103     // Simple 3 function inline case
104     R"(
105 define void @f1() {
106   call void @foo()
107   ret void
108 }
109 define void @foo() {
110   call void @f3()
111   ret void
112 }
113 define void @f3() {
114   ret void
115 }
116   )",
117     // Test that has 5 functions of which 2 are recursive
118     R"(
119 define void @f1() {
120   call void @foo()
121   ret void
122 }
123 define void @f2() {
124   call void @foo()
125   ret void
126 }
127 define void @foo() {
128   call void @f4()
129   call void @f5()
130   ret void
131 }
132 define void @f4() {
133   ret void
134 }
135 define void @f5() {
136   call void @foo()
137   ret void
138 }
139   )",
140     // test with 2 mutually recursive functions and 1 function with a loop
141     R"(
142 define void @f1() {
143   call void @f2()
144   ret void
145 }
146 define void @f2() {
147   call void @f3()
148   ret void
149 }
150 define void @f3() {
151   call void @f1()
152   ret void
153 }
154 define void @f4() {
155   br label %loop
156 loop:
157   call void @f5()
158   br label %loop
159 }
160 define void @f5() {
161   ret void
162 }
163   )",
164     // test that has a function that computes fibonacci in a loop, one in a
165     // recurisve manner, and one that calls both and compares them
166     R"(
167 define i32 @fib_loop(i32 %n){
168     %curr = alloca i32
169     %last = alloca i32
170     %i = alloca i32
171     store i32 1, i32* %curr
172     store i32 1, i32* %last
173     store i32 2, i32* %i
174     br label %loop_cond
175   loop_cond:
176     %i_val = load i32, i32* %i
177     %cmp = icmp slt i32 %i_val, %n
178     br i1 %cmp, label %loop_body, label %loop_end
179   loop_body:
180     %curr_val = load i32, i32* %curr
181     %last_val = load i32, i32* %last
182     %add = add i32 %curr_val, %last_val
183     store i32 %add, i32* %last
184     store i32 %curr_val, i32* %curr
185     %i_val2 = load i32, i32* %i
186     %add2 = add i32 %i_val2, 1
187     store i32 %add2, i32* %i
188     br label %loop_cond
189   loop_end:
190     %curr_val3 = load i32, i32* %curr
191     ret i32 %curr_val3
192 }
193 
194 define i32 @fib_rec(i32 %n){
195     %cmp = icmp eq i32 %n, 0
196     %cmp2 = icmp eq i32 %n, 1
197     %or = or i1 %cmp, %cmp2
198     br i1 %or, label %if_true, label %if_false
199   if_true:
200     ret i32 1
201   if_false:
202     %sub = sub i32 %n, 1
203     %call = call i32 @fib_rec(i32 %sub)
204     %sub2 = sub i32 %n, 2
205     %call2 = call i32 @fib_rec(i32 %sub2)
206     %add = add i32 %call, %call2
207     ret i32 %add
208 }
209 
210 define i32 @fib_check(){
211     %correct = alloca i32
212     %i = alloca i32
213     store i32 1, i32* %correct
214     store i32 0, i32* %i
215     br label %loop_cond
216   loop_cond:
217     %i_val = load i32, i32* %i
218     %cmp = icmp slt i32 %i_val, 10
219     br i1 %cmp, label %loop_body, label %loop_end
220   loop_body:
221     %i_val2 = load i32, i32* %i
222     %call = call i32 @fib_loop(i32 %i_val2)
223     %i_val3 = load i32, i32* %i
224     %call2 = call i32 @fib_rec(i32 %i_val3)
225     %cmp2 = icmp ne i32 %call, %call2
226     br i1 %cmp2, label %if_true, label %if_false
227   if_true:
228     store i32 0, i32* %correct
229     br label %if_end
230   if_false:
231     br label %if_end
232   if_end:
233     %i_val4 = load i32, i32* %i
234     %add = add i32 %i_val4, 1
235     store i32 %add, i32* %i
236     br label %loop_cond
237   loop_end:
238     %correct_val = load i32, i32* %correct
239     ret i32 %correct_val
240 }
241   )"};
242 
243 } // namespace
244 
245 // check that loading a plugin works
246 // the plugin being loaded acts identically to the default inliner
247 TEST(PluginInlineAdvisorTest, PluginLoad) {
248 #if !defined(LLVM_ENABLE_PLUGINS)
249   // Skip the test if plugins are disabled.
250   GTEST_SKIP();
251 #endif
252   CompilerInstance DefaultCI{};
253 
254   CompilerInstance PluginCI{};
255   PluginCI.setupPlugin();
256 
257   for (StringRef IR : TestIRS) {
258     DefaultCI.run(IR);
259     std::string default_output = DefaultCI.output;
260     PluginCI.run(IR);
261     std::string dynamic_output = PluginCI.output;
262     ASSERT_EQ(default_output, dynamic_output);
263   }
264 }
265 
266 // check that the behaviour of a custom inliner is correct
267 // the custom inliner inlines all functions that are not named "foo"
268 // this testdoes not require plugins to be enabled
269 TEST(PluginInlineAdvisorTest, CustomAdvisor) {
270   CompilerInstance CI{};
271   CI.setupFooOnly();
272 
273   for (StringRef IR : TestIRS) {
274     CI.run(IR);
275     CallGraph CGraph = CallGraph(*CI.outputM);
276     for (auto &node : CGraph) {
277       for (auto &edge : *node.second) {
278         if (!edge.first)
279           continue;
280         ASSERT_NE(edge.second->getFunction()->getName(), "foo");
281       }
282     }
283   }
284 }
285 
286 } // namespace llvm
287