xref: /llvm-project/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp (revision 41e06ae7ba911661962c9f190ae88d1f29c076da)
1 //===- CallPromotionUtilsTest.cpp - CallPromotionUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Instructions.h"
12 #include "llvm/IR/LLVMContext.h"
13 #include "llvm/IR/Module.h"
14 #include "llvm/Support/SourceMgr.h"
15 #include "gtest/gtest.h"
16 
17 using namespace llvm;
18 
19 static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
20   SMDiagnostic Err;
21   std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
22   if (!Mod)
23     Err.print("UtilsTests", errs());
24   return Mod;
25 }
26 
27 TEST(CallPromotionUtilsTest, TryPromoteCall) {
28   LLVMContext C;
29   std::unique_ptr<Module> M = parseIR(C,
30                                       R"IR(
31 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
32 %class.Interface = type { i32 (...)** }
33 
34 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
35 
36 define void @f() {
37 entry:
38   %o = alloca %class.Impl
39   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
40   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
41   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
42   store i32 3, i32* %f
43   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
44   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
45   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
46   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
47   call void %fp(%class.Interface* nonnull %base.i)
48   ret void
49 }
50 
51 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
52 )IR");
53 
54   auto *GV = M->getNamedValue("f");
55   ASSERT_TRUE(GV);
56   auto *F = dyn_cast<Function>(GV);
57   ASSERT_TRUE(F);
58   Instruction *Inst = &F->front().front();
59   auto *AI = dyn_cast<AllocaInst>(Inst);
60   ASSERT_TRUE(AI);
61   Inst = &*++F->front().rbegin();
62   auto *CI = dyn_cast<CallInst>(Inst);
63   ASSERT_TRUE(CI);
64   CallSite CS(CI);
65   ASSERT_FALSE(CS.getCalledFunction());
66   bool IsPromoted = tryPromoteCall(CS);
67   EXPECT_TRUE(IsPromoted);
68   GV = M->getNamedValue("_ZN4Impl3RunEv");
69   ASSERT_TRUE(GV);
70   auto *F1 = dyn_cast<Function>(GV);
71   EXPECT_EQ(F1, CS.getCalledFunction());
72 }
73 
74 TEST(CallPromotionUtilsTest, TryPromoteCall_NoFPLoad) {
75   LLVMContext C;
76   std::unique_ptr<Module> M = parseIR(C,
77                                       R"IR(
78 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
79 %class.Interface = type { i32 (...)** }
80 
81 define void @f(void (%class.Interface*)* %fp, %class.Interface* nonnull %base.i) {
82 entry:
83   call void %fp(%class.Interface* nonnull %base.i)
84   ret void
85 }
86 )IR");
87 
88   auto *GV = M->getNamedValue("f");
89   ASSERT_TRUE(GV);
90   auto *F = dyn_cast<Function>(GV);
91   ASSERT_TRUE(F);
92   Instruction *Inst = &F->front().front();
93   auto *CI = dyn_cast<CallInst>(Inst);
94   ASSERT_TRUE(CI);
95   CallSite CS(CI);
96   ASSERT_FALSE(CS.getCalledFunction());
97   bool IsPromoted = tryPromoteCall(CS);
98   EXPECT_FALSE(IsPromoted);
99 }
100 
101 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTablePtrLoad) {
102   LLVMContext C;
103   std::unique_ptr<Module> M = parseIR(C,
104                                       R"IR(
105 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
106 %class.Interface = type { i32 (...)** }
107 
108 define void @f(void (%class.Interface*)** %vtable.i, %class.Interface* nonnull %base.i) {
109 entry:
110   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
111   call void %fp(%class.Interface* nonnull %base.i)
112   ret void
113 }
114 )IR");
115 
116   auto *GV = M->getNamedValue("f");
117   ASSERT_TRUE(GV);
118   auto *F = dyn_cast<Function>(GV);
119   ASSERT_TRUE(F);
120   Instruction *Inst = &*++F->front().rbegin();
121   auto *CI = dyn_cast<CallInst>(Inst);
122   ASSERT_TRUE(CI);
123   CallSite CS(CI);
124   ASSERT_FALSE(CS.getCalledFunction());
125   bool IsPromoted = tryPromoteCall(CS);
126   EXPECT_FALSE(IsPromoted);
127 }
128 
129 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTableInitFound) {
130   LLVMContext C;
131   std::unique_ptr<Module> M = parseIR(C,
132                                       R"IR(
133 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
134 %class.Interface = type { i32 (...)** }
135 
136 define void @f() {
137 entry:
138   %o = alloca %class.Impl
139   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
140   store i32 3, i32* %f
141   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
142   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
143   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
144   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
145   call void %fp(%class.Interface* nonnull %base.i)
146   ret void
147 }
148 
149 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
150 )IR");
151 
152   auto *GV = M->getNamedValue("f");
153   ASSERT_TRUE(GV);
154   auto *F = dyn_cast<Function>(GV);
155   ASSERT_TRUE(F);
156   Instruction *Inst = &*++F->front().rbegin();
157   auto *CI = dyn_cast<CallInst>(Inst);
158   ASSERT_TRUE(CI);
159   CallSite CS(CI);
160   ASSERT_FALSE(CS.getCalledFunction());
161   bool IsPromoted = tryPromoteCall(CS);
162   EXPECT_FALSE(IsPromoted);
163 }
164 
165 TEST(CallPromotionUtilsTest, TryPromoteCall_EmptyVTable) {
166   LLVMContext C;
167   std::unique_ptr<Module> M = parseIR(C,
168                                       R"IR(
169 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
170 %class.Interface = type { i32 (...)** }
171 
172 @_ZTV4Impl = external global { [3 x i8*] }
173 
174 define void @f() {
175 entry:
176   %o = alloca %class.Impl
177   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
178   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
179   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
180   store i32 3, i32* %f
181   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
182   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
183   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
184   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
185   call void %fp(%class.Interface* nonnull %base.i)
186   ret void
187 }
188 
189 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
190 )IR");
191 
192   auto *GV = M->getNamedValue("f");
193   ASSERT_TRUE(GV);
194   auto *F = dyn_cast<Function>(GV);
195   ASSERT_TRUE(F);
196   Instruction *Inst = &F->front().front();
197   auto *AI = dyn_cast<AllocaInst>(Inst);
198   ASSERT_TRUE(AI);
199   Inst = &*++F->front().rbegin();
200   auto *CI = dyn_cast<CallInst>(Inst);
201   ASSERT_TRUE(CI);
202   CallSite CS(CI);
203   ASSERT_FALSE(CS.getCalledFunction());
204   bool IsPromoted = tryPromoteCall(CS);
205   EXPECT_FALSE(IsPromoted);
206 }
207 
208 TEST(CallPromotionUtilsTest, TryPromoteCall_NullFP) {
209   LLVMContext C;
210   std::unique_ptr<Module> M = parseIR(C,
211                                       R"IR(
212 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
213 %class.Interface = type { i32 (...)** }
214 
215 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* null] }
216 
217 define void @f() {
218 entry:
219   %o = alloca %class.Impl
220   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
221   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
222   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
223   store i32 3, i32* %f
224   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
225   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
226   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
227   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
228   call void %fp(%class.Interface* nonnull %base.i)
229   ret void
230 }
231 
232 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
233 )IR");
234 
235   auto *GV = M->getNamedValue("f");
236   ASSERT_TRUE(GV);
237   auto *F = dyn_cast<Function>(GV);
238   ASSERT_TRUE(F);
239   Instruction *Inst = &F->front().front();
240   auto *AI = dyn_cast<AllocaInst>(Inst);
241   ASSERT_TRUE(AI);
242   Inst = &*++F->front().rbegin();
243   auto *CI = dyn_cast<CallInst>(Inst);
244   ASSERT_TRUE(CI);
245   CallSite CS(CI);
246   ASSERT_FALSE(CS.getCalledFunction());
247   bool IsPromoted = tryPromoteCall(CS);
248   EXPECT_FALSE(IsPromoted);
249 }
250 
251 // Based on clang/test/CodeGenCXX/member-function-pointer-calls.cpp
252 TEST(CallPromotionUtilsTest, TryPromoteCall_MemberFunctionCalls) {
253   LLVMContext C;
254   std::unique_ptr<Module> M = parseIR(C,
255                                       R"IR(
256 %struct.A = type { i32 (...)** }
257 
258 @_ZTV1A = linkonce_odr unnamed_addr constant { [4 x i8*] } { [4 x i8*] [i8* null, i8* null, i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf1Ev to i8*), i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf2Ev to i8*)] }, align 8
259 
260 define i32 @_Z2g1v() {
261 entry:
262   %a = alloca %struct.A, align 8
263   %0 = bitcast %struct.A* %a to i8*
264   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
265   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
266   %2 = bitcast %struct.A* %a to i8*
267   %3 = bitcast i8* %2 to i8**
268   %vtable.i = load i8*, i8** %3, align 8
269   %4 = bitcast i8* %vtable.i to i32 (%struct.A*)**
270   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %4, align 8
271   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
272   ret i32 %call.i
273 }
274 
275 define i32 @_Z2g2v() {
276 entry:
277   %a = alloca %struct.A, align 8
278   %0 = bitcast %struct.A* %a to i8*
279   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
280   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
281   %2 = bitcast %struct.A* %a to i8*
282   %3 = bitcast i8* %2 to i8**
283   %vtable.i = load i8*, i8** %3, align 8
284   %4 = getelementptr i8, i8* %vtable.i, i64 8
285   %5 = bitcast i8* %4 to i32 (%struct.A*)**
286   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %5, align 8
287   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
288   ret i32 %call.i
289 }
290 
291 declare i32 @_ZN1A3vf1Ev(%struct.A* %this)
292 declare i32 @_ZN1A3vf2Ev(%struct.A* %this)
293 )IR");
294 
295   auto *GV = M->getNamedValue("_Z2g1v");
296   ASSERT_TRUE(GV);
297   auto *F = dyn_cast<Function>(GV);
298   ASSERT_TRUE(F);
299   Instruction *Inst = &F->front().front();
300   auto *AI = dyn_cast<AllocaInst>(Inst);
301   ASSERT_TRUE(AI);
302   Inst = &*++F->front().rbegin();
303   auto *CI = dyn_cast<CallInst>(Inst);
304   ASSERT_TRUE(CI);
305   CallSite CS1(CI);
306   ASSERT_FALSE(CS1.getCalledFunction());
307   bool IsPromoted1 = tryPromoteCall(CS1);
308   EXPECT_TRUE(IsPromoted1);
309   GV = M->getNamedValue("_ZN1A3vf1Ev");
310   ASSERT_TRUE(GV);
311   F = dyn_cast<Function>(GV);
312   EXPECT_EQ(F, CS1.getCalledFunction());
313 
314   GV = M->getNamedValue("_Z2g2v");
315   ASSERT_TRUE(GV);
316   F = dyn_cast<Function>(GV);
317   ASSERT_TRUE(F);
318   Inst = &F->front().front();
319   AI = dyn_cast<AllocaInst>(Inst);
320   ASSERT_TRUE(AI);
321   Inst = &*++F->front().rbegin();
322   CI = dyn_cast<CallInst>(Inst);
323   ASSERT_TRUE(CI);
324   CallSite CS2(CI);
325   ASSERT_FALSE(CS2.getCalledFunction());
326   bool IsPromoted2 = tryPromoteCall(CS2);
327   EXPECT_TRUE(IsPromoted2);
328   GV = M->getNamedValue("_ZN1A3vf2Ev");
329   ASSERT_TRUE(GV);
330   F = dyn_cast<Function>(GV);
331   EXPECT_EQ(F, CS2.getCalledFunction());
332 }
333 
334 // Check that it isn't crashing due to missing promotion legality.
335 TEST(CallPromotionUtilsTest, TryPromoteCall_Legality) {
336   LLVMContext C;
337   std::unique_ptr<Module> M = parseIR(C,
338                                       R"IR(
339 %struct1 = type <{ i32, i64 }>
340 %struct2 = type <{ i32, i64 }>
341 
342 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
343 %class.Interface = type { i32 (...)** }
344 
345 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (%struct2 (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
346 
347 define %struct1 @f() {
348 entry:
349   %o = alloca %class.Impl
350   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
351   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
352   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
353   store i32 3, i32* %f
354   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
355   %c = bitcast %class.Interface* %base.i to %struct1 (%class.Interface*)***
356   %vtable.i = load %struct1 (%class.Interface*)**, %struct1 (%class.Interface*)*** %c
357   %fp = load %struct1 (%class.Interface*)*, %struct1 (%class.Interface*)** %vtable.i
358   %rv = call %struct1 %fp(%class.Interface* nonnull %base.i)
359   ret %struct1 %rv
360 }
361 
362 declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
363 )IR");
364 
365   auto *GV = M->getNamedValue("f");
366   ASSERT_TRUE(GV);
367   auto *F = dyn_cast<Function>(GV);
368   ASSERT_TRUE(F);
369   Instruction *Inst = &F->front().front();
370   auto *AI = dyn_cast<AllocaInst>(Inst);
371   ASSERT_TRUE(AI);
372   Inst = &*++F->front().rbegin();
373   auto *CI = dyn_cast<CallInst>(Inst);
374   ASSERT_TRUE(CI);
375   CallSite CS(CI);
376   ASSERT_FALSE(CS.getCalledFunction());
377   bool IsPromoted = tryPromoteCall(CS);
378   EXPECT_FALSE(IsPromoted);
379 }
380