xref: /llvm-project/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp (revision 5d3f296733b66281a53dd451a983e69ae0bb482f)
1 //===- CallPromotionUtilsTest.cpp - CallPromotionUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/IRBuilder.h"
12 #include "llvm/IR/Instructions.h"
13 #include "llvm/IR/LLVMContext.h"
14 #include "llvm/IR/MDBuilder.h"
15 #include "llvm/IR/Module.h"
16 #include "llvm/IR/NoFolder.h"
17 #include "llvm/Support/SourceMgr.h"
18 #include "gtest/gtest.h"
19 
20 using namespace llvm;
21 
22 static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
23   SMDiagnostic Err;
24   std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
25   if (!Mod)
26     Err.print("UtilsTests", errs());
27   return Mod;
28 }
29 
30 // Returns a constant representing the vtable's address point specified by the
31 // offset.
32 static Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
33                                              uint32_t AddressPointOffset) {
34   Module &M = *VTable->getParent();
35   LLVMContext &Context = M.getContext();
36   assert(AddressPointOffset <
37              M.getDataLayout().getTypeAllocSize(VTable->getValueType()) &&
38          "Out-of-bound access");
39 
40   return ConstantExpr::getInBoundsGetElementPtr(
41       Type::getInt8Ty(Context), VTable,
42       llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset));
43 }
44 
45 TEST(CallPromotionUtilsTest, TryPromoteCall) {
46   LLVMContext C;
47   std::unique_ptr<Module> M = parseIR(C,
48                                       R"IR(
49 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
50 %class.Interface = type { i32 (...)** }
51 
52 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
53 
54 define void @f() {
55 entry:
56   %o = alloca %class.Impl
57   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
58   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
59   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
60   store i32 3, i32* %f
61   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
62   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
63   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
64   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
65   call void %fp(%class.Interface* nonnull %base.i)
66   ret void
67 }
68 
69 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
70 )IR");
71 
72   auto *GV = M->getNamedValue("f");
73   ASSERT_TRUE(GV);
74   auto *F = dyn_cast<Function>(GV);
75   ASSERT_TRUE(F);
76   Instruction *Inst = &F->front().front();
77   auto *AI = dyn_cast<AllocaInst>(Inst);
78   ASSERT_TRUE(AI);
79   Inst = &*++F->front().rbegin();
80   auto *CI = dyn_cast<CallInst>(Inst);
81   ASSERT_TRUE(CI);
82   ASSERT_FALSE(CI->getCalledFunction());
83   bool IsPromoted = tryPromoteCall(*CI);
84   EXPECT_TRUE(IsPromoted);
85   GV = M->getNamedValue("_ZN4Impl3RunEv");
86   ASSERT_TRUE(GV);
87   auto *F1 = dyn_cast<Function>(GV);
88   EXPECT_EQ(F1, CI->getCalledFunction());
89 }
90 
91 TEST(CallPromotionUtilsTest, TryPromoteCall_NoFPLoad) {
92   LLVMContext C;
93   std::unique_ptr<Module> M = parseIR(C,
94                                       R"IR(
95 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
96 %class.Interface = type { i32 (...)** }
97 
98 define void @f(void (%class.Interface*)* %fp, %class.Interface* nonnull %base.i) {
99 entry:
100   call void %fp(%class.Interface* nonnull %base.i)
101   ret void
102 }
103 )IR");
104 
105   auto *GV = M->getNamedValue("f");
106   ASSERT_TRUE(GV);
107   auto *F = dyn_cast<Function>(GV);
108   ASSERT_TRUE(F);
109   Instruction *Inst = &F->front().front();
110   auto *CI = dyn_cast<CallInst>(Inst);
111   ASSERT_TRUE(CI);
112   ASSERT_FALSE(CI->getCalledFunction());
113   bool IsPromoted = tryPromoteCall(*CI);
114   EXPECT_FALSE(IsPromoted);
115 }
116 
117 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTablePtrLoad) {
118   LLVMContext C;
119   std::unique_ptr<Module> M = parseIR(C,
120                                       R"IR(
121 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
122 %class.Interface = type { i32 (...)** }
123 
124 define void @f(void (%class.Interface*)** %vtable.i, %class.Interface* nonnull %base.i) {
125 entry:
126   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
127   call void %fp(%class.Interface* nonnull %base.i)
128   ret void
129 }
130 )IR");
131 
132   auto *GV = M->getNamedValue("f");
133   ASSERT_TRUE(GV);
134   auto *F = dyn_cast<Function>(GV);
135   ASSERT_TRUE(F);
136   Instruction *Inst = &*++F->front().rbegin();
137   auto *CI = dyn_cast<CallInst>(Inst);
138   ASSERT_TRUE(CI);
139   ASSERT_FALSE(CI->getCalledFunction());
140   bool IsPromoted = tryPromoteCall(*CI);
141   EXPECT_FALSE(IsPromoted);
142 }
143 
144 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTableInitFound) {
145   LLVMContext C;
146   std::unique_ptr<Module> M = parseIR(C,
147                                       R"IR(
148 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
149 %class.Interface = type { i32 (...)** }
150 
151 define void @f() {
152 entry:
153   %o = alloca %class.Impl
154   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
155   store i32 3, i32* %f
156   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
157   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
158   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
159   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
160   call void %fp(%class.Interface* nonnull %base.i)
161   ret void
162 }
163 
164 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
165 )IR");
166 
167   auto *GV = M->getNamedValue("f");
168   ASSERT_TRUE(GV);
169   auto *F = dyn_cast<Function>(GV);
170   ASSERT_TRUE(F);
171   Instruction *Inst = &*++F->front().rbegin();
172   auto *CI = dyn_cast<CallInst>(Inst);
173   ASSERT_TRUE(CI);
174   ASSERT_FALSE(CI->getCalledFunction());
175   bool IsPromoted = tryPromoteCall(*CI);
176   EXPECT_FALSE(IsPromoted);
177 }
178 
179 TEST(CallPromotionUtilsTest, TryPromoteCall_EmptyVTable) {
180   LLVMContext C;
181   std::unique_ptr<Module> M = parseIR(C,
182                                       R"IR(
183 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
184 %class.Interface = type { i32 (...)** }
185 
186 @_ZTV4Impl = external global { [3 x i8*] }
187 
188 define void @f() {
189 entry:
190   %o = alloca %class.Impl
191   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
192   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
193   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
194   store i32 3, i32* %f
195   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
196   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
197   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
198   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
199   call void %fp(%class.Interface* nonnull %base.i)
200   ret void
201 }
202 
203 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
204 )IR");
205 
206   auto *GV = M->getNamedValue("f");
207   ASSERT_TRUE(GV);
208   auto *F = dyn_cast<Function>(GV);
209   ASSERT_TRUE(F);
210   Instruction *Inst = &F->front().front();
211   auto *AI = dyn_cast<AllocaInst>(Inst);
212   ASSERT_TRUE(AI);
213   Inst = &*++F->front().rbegin();
214   auto *CI = dyn_cast<CallInst>(Inst);
215   ASSERT_TRUE(CI);
216   ASSERT_FALSE(CI->getCalledFunction());
217   bool IsPromoted = tryPromoteCall(*CI);
218   EXPECT_FALSE(IsPromoted);
219 }
220 
221 TEST(CallPromotionUtilsTest, TryPromoteCall_NullFP) {
222   LLVMContext C;
223   std::unique_ptr<Module> M = parseIR(C,
224                                       R"IR(
225 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
226 %class.Interface = type { i32 (...)** }
227 
228 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* null] }
229 
230 define void @f() {
231 entry:
232   %o = alloca %class.Impl
233   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
234   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
235   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
236   store i32 3, i32* %f
237   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
238   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
239   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
240   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
241   call void %fp(%class.Interface* nonnull %base.i)
242   ret void
243 }
244 
245 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
246 )IR");
247 
248   auto *GV = M->getNamedValue("f");
249   ASSERT_TRUE(GV);
250   auto *F = dyn_cast<Function>(GV);
251   ASSERT_TRUE(F);
252   Instruction *Inst = &F->front().front();
253   auto *AI = dyn_cast<AllocaInst>(Inst);
254   ASSERT_TRUE(AI);
255   Inst = &*++F->front().rbegin();
256   auto *CI = dyn_cast<CallInst>(Inst);
257   ASSERT_TRUE(CI);
258   ASSERT_FALSE(CI->getCalledFunction());
259   bool IsPromoted = tryPromoteCall(*CI);
260   EXPECT_FALSE(IsPromoted);
261 }
262 
263 // Based on clang/test/CodeGenCXX/member-function-pointer-calls.cpp
264 TEST(CallPromotionUtilsTest, TryPromoteCall_MemberFunctionCalls) {
265   LLVMContext C;
266   std::unique_ptr<Module> M = parseIR(C,
267                                       R"IR(
268 %struct.A = type { i32 (...)** }
269 
270 @_ZTV1A = linkonce_odr unnamed_addr constant { [4 x i8*] } { [4 x i8*] [i8* null, i8* null, i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf1Ev to i8*), i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf2Ev to i8*)] }, align 8
271 
272 define i32 @_Z2g1v() {
273 entry:
274   %a = alloca %struct.A, align 8
275   %0 = bitcast %struct.A* %a to i8*
276   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
277   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
278   %2 = bitcast %struct.A* %a to i8*
279   %3 = bitcast i8* %2 to i8**
280   %vtable.i = load i8*, i8** %3, align 8
281   %4 = bitcast i8* %vtable.i to i32 (%struct.A*)**
282   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %4, align 8
283   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
284   ret i32 %call.i
285 }
286 
287 define i32 @_Z2g2v() {
288 entry:
289   %a = alloca %struct.A, align 8
290   %0 = bitcast %struct.A* %a to i8*
291   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
292   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
293   %2 = bitcast %struct.A* %a to i8*
294   %3 = bitcast i8* %2 to i8**
295   %vtable.i = load i8*, i8** %3, align 8
296   %4 = getelementptr i8, i8* %vtable.i, i64 8
297   %5 = bitcast i8* %4 to i32 (%struct.A*)**
298   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %5, align 8
299   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
300   ret i32 %call.i
301 }
302 
303 declare i32 @_ZN1A3vf1Ev(%struct.A* %this)
304 declare i32 @_ZN1A3vf2Ev(%struct.A* %this)
305 )IR");
306 
307   auto *GV = M->getNamedValue("_Z2g1v");
308   ASSERT_TRUE(GV);
309   auto *F = dyn_cast<Function>(GV);
310   ASSERT_TRUE(F);
311   Instruction *Inst = &F->front().front();
312   auto *AI = dyn_cast<AllocaInst>(Inst);
313   ASSERT_TRUE(AI);
314   Inst = &*++F->front().rbegin();
315   auto *CI = dyn_cast<CallInst>(Inst);
316   ASSERT_TRUE(CI);
317   ASSERT_FALSE(CI->getCalledFunction());
318   bool IsPromoted1 = tryPromoteCall(*CI);
319   EXPECT_TRUE(IsPromoted1);
320   GV = M->getNamedValue("_ZN1A3vf1Ev");
321   ASSERT_TRUE(GV);
322   F = dyn_cast<Function>(GV);
323   EXPECT_EQ(F, CI->getCalledFunction());
324 
325   GV = M->getNamedValue("_Z2g2v");
326   ASSERT_TRUE(GV);
327   F = dyn_cast<Function>(GV);
328   ASSERT_TRUE(F);
329   Inst = &F->front().front();
330   AI = dyn_cast<AllocaInst>(Inst);
331   ASSERT_TRUE(AI);
332   Inst = &*++F->front().rbegin();
333   CI = dyn_cast<CallInst>(Inst);
334   ASSERT_TRUE(CI);
335   ASSERT_FALSE(CI->getCalledFunction());
336   bool IsPromoted2 = tryPromoteCall(*CI);
337   EXPECT_TRUE(IsPromoted2);
338   GV = M->getNamedValue("_ZN1A3vf2Ev");
339   ASSERT_TRUE(GV);
340   F = dyn_cast<Function>(GV);
341   EXPECT_EQ(F, CI->getCalledFunction());
342 }
343 
344 // Check that it isn't crashing due to missing promotion legality.
345 TEST(CallPromotionUtilsTest, TryPromoteCall_Legality) {
346   LLVMContext C;
347   std::unique_ptr<Module> M = parseIR(C,
348                                       R"IR(
349 %struct1 = type <{ i32, i64 }>
350 %struct2 = type <{ i32, i64 }>
351 
352 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
353 %class.Interface = type { i32 (...)** }
354 
355 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (%struct2 (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
356 
357 define %struct1 @f() {
358 entry:
359   %o = alloca %class.Impl
360   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
361   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
362   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
363   store i32 3, i32* %f
364   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
365   %c = bitcast %class.Interface* %base.i to %struct1 (%class.Interface*)***
366   %vtable.i = load %struct1 (%class.Interface*)**, %struct1 (%class.Interface*)*** %c
367   %fp = load %struct1 (%class.Interface*)*, %struct1 (%class.Interface*)** %vtable.i
368   %rv = call %struct1 %fp(%class.Interface* nonnull %base.i)
369   ret %struct1 %rv
370 }
371 
372 declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
373 )IR");
374 
375   auto *GV = M->getNamedValue("f");
376   ASSERT_TRUE(GV);
377   auto *F = dyn_cast<Function>(GV);
378   ASSERT_TRUE(F);
379   Instruction *Inst = &F->front().front();
380   auto *AI = dyn_cast<AllocaInst>(Inst);
381   ASSERT_TRUE(AI);
382   Inst = &*++F->front().rbegin();
383   auto *CI = dyn_cast<CallInst>(Inst);
384   ASSERT_TRUE(CI);
385   ASSERT_FALSE(CI->getCalledFunction());
386   bool IsPromoted = tryPromoteCall(*CI);
387   EXPECT_FALSE(IsPromoted);
388 }
389 
390 TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
391   LLVMContext C;
392   std::unique_ptr<Module> M = parseIR(C,
393                                       R"IR(
394 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
395 target triple = "x86_64-unknown-linux-gnu"
396 
397 @_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
398 @_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
399 @_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
400 
401 define i32 @testfunc(ptr %d) {
402 entry:
403   %vtable = load ptr, ptr %d, !prof !7
404   %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
405   %0 = load ptr, ptr %vfn
406   %call = tail call i32 %0(ptr %d), !prof !8
407   ret i32 %call
408 }
409 
410 define i32 @_ZN5Base15func1Ev(ptr %this) {
411 entry:
412   ret i32 2
413 }
414 
415 declare i32 @_ZN5Base25func2Ev(ptr)
416 declare i32 @_ZN5Base15func0Ev(ptr)
417 declare void @_ZN5Base35func3Ev(ptr)
418 
419 !0 = !{i64 16, !"_ZTS5Base1"}
420 !1 = !{i64 48, !"_ZTS5Base2"}
421 !2 = !{i64 16, !"_ZTS8Derived1"}
422 !3 = !{i64 64, !"_ZTS5Base1"}
423 !4 = !{i64 40, !"_ZTS5Base2"}
424 !5 = !{i64 16, !"_ZTS5Base3"}
425 !6 = !{i64 16, !"_ZTS8Derived2"}
426 !7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
427 !8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");
428 
429   Function *F = M->getFunction("testfunc");
430   CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
431   ASSERT_TRUE(CI && CI->isIndirectCall());
432 
433   // Create the constant and the branch weights
434   SmallVector<Constant *, 3> VTableAddressPoints;
435 
436   for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
437                                                  {"_ZTV8Derived1", 16},
438                                                  {"_ZTV8Derived2", 64}})
439     VTableAddressPoints.push_back(getVTableAddressPointOffset(
440         M->getGlobalVariable(VTableName), AddressPointOffset));
441 
442   MDBuilder MDB(C);
443   MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);
444 
445   size_t OrigEntryBBSize = F->front().size();
446 
447   LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());
448 
449   Function *Callee = M->getFunction("_ZN5Base15func1Ev");
450   // Tests that promoted direct call is returned.
451   CallBase &DirectCB = promoteCallWithVTableCmp(
452       *CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
453   EXPECT_EQ(DirectCB.getCalledOperand(), Callee);
454 
455   // Promotion inserts 3 icmp instructions and 2 or instructions, and removes
456   // 1 call instruction from the entry block.
457   EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
458 }
459