xref: /llvm-project/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp (revision b15845c0059b06f406e33f278127d7eb41ff5ab6)
1 //===- CallPromotionUtilsTest.cpp - CallPromotionUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
10 #include "llvm/Analysis/CtxProfAnalysis.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/LLVMContext.h"
15 #include "llvm/IR/MDBuilder.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/NoFolder.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/ProfileData/PGOCtxProfReader.h"
20 #include "llvm/ProfileData/PGOCtxProfWriter.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/Testing/Support/SupportHelpers.h"
24 #include "gtest/gtest.h"
25 
26 using namespace llvm;
27 
28 static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
29   SMDiagnostic Err;
30   std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
31   if (!Mod)
32     Err.print("UtilsTests", errs());
33   return Mod;
34 }
35 
36 // Returns a constant representing the vtable's address point specified by the
37 // offset.
38 static Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
39                                              uint32_t AddressPointOffset) {
40   Module &M = *VTable->getParent();
41   LLVMContext &Context = M.getContext();
42   assert(AddressPointOffset <
43              M.getDataLayout().getTypeAllocSize(VTable->getValueType()) &&
44          "Out-of-bound access");
45 
46   return ConstantExpr::getInBoundsGetElementPtr(
47       Type::getInt8Ty(Context), VTable,
48       llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset));
49 }
50 
51 TEST(CallPromotionUtilsTest, TryPromoteCall) {
52   LLVMContext C;
53   std::unique_ptr<Module> M = parseIR(C,
54                                       R"IR(
55 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
56 %class.Interface = type { i32 (...)** }
57 
58 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
59 
60 define void @f() {
61 entry:
62   %o = alloca %class.Impl
63   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
64   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
65   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
66   store i32 3, i32* %f
67   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
68   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
69   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
70   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
71   call void %fp(%class.Interface* nonnull %base.i)
72   ret void
73 }
74 
75 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
76 )IR");
77 
78   auto *GV = M->getNamedValue("f");
79   ASSERT_TRUE(GV);
80   auto *F = dyn_cast<Function>(GV);
81   ASSERT_TRUE(F);
82   Instruction *Inst = &F->front().front();
83   auto *AI = dyn_cast<AllocaInst>(Inst);
84   ASSERT_TRUE(AI);
85   Inst = &*++F->front().rbegin();
86   auto *CI = dyn_cast<CallInst>(Inst);
87   ASSERT_TRUE(CI);
88   ASSERT_FALSE(CI->getCalledFunction());
89   bool IsPromoted = tryPromoteCall(*CI);
90   EXPECT_TRUE(IsPromoted);
91   GV = M->getNamedValue("_ZN4Impl3RunEv");
92   ASSERT_TRUE(GV);
93   auto *F1 = dyn_cast<Function>(GV);
94   EXPECT_EQ(F1, CI->getCalledFunction());
95 }
96 
97 TEST(CallPromotionUtilsTest, TryPromoteCall_NoFPLoad) {
98   LLVMContext C;
99   std::unique_ptr<Module> M = parseIR(C,
100                                       R"IR(
101 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
102 %class.Interface = type { i32 (...)** }
103 
104 define void @f(void (%class.Interface*)* %fp, %class.Interface* nonnull %base.i) {
105 entry:
106   call void %fp(%class.Interface* nonnull %base.i)
107   ret void
108 }
109 )IR");
110 
111   auto *GV = M->getNamedValue("f");
112   ASSERT_TRUE(GV);
113   auto *F = dyn_cast<Function>(GV);
114   ASSERT_TRUE(F);
115   Instruction *Inst = &F->front().front();
116   auto *CI = dyn_cast<CallInst>(Inst);
117   ASSERT_TRUE(CI);
118   ASSERT_FALSE(CI->getCalledFunction());
119   bool IsPromoted = tryPromoteCall(*CI);
120   EXPECT_FALSE(IsPromoted);
121 }
122 
123 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTablePtrLoad) {
124   LLVMContext C;
125   std::unique_ptr<Module> M = parseIR(C,
126                                       R"IR(
127 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
128 %class.Interface = type { i32 (...)** }
129 
130 define void @f(void (%class.Interface*)** %vtable.i, %class.Interface* nonnull %base.i) {
131 entry:
132   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
133   call void %fp(%class.Interface* nonnull %base.i)
134   ret void
135 }
136 )IR");
137 
138   auto *GV = M->getNamedValue("f");
139   ASSERT_TRUE(GV);
140   auto *F = dyn_cast<Function>(GV);
141   ASSERT_TRUE(F);
142   Instruction *Inst = &*++F->front().rbegin();
143   auto *CI = dyn_cast<CallInst>(Inst);
144   ASSERT_TRUE(CI);
145   ASSERT_FALSE(CI->getCalledFunction());
146   bool IsPromoted = tryPromoteCall(*CI);
147   EXPECT_FALSE(IsPromoted);
148 }
149 
150 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTableInitFound) {
151   LLVMContext C;
152   std::unique_ptr<Module> M = parseIR(C,
153                                       R"IR(
154 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
155 %class.Interface = type { i32 (...)** }
156 
157 define void @f() {
158 entry:
159   %o = alloca %class.Impl
160   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
161   store i32 3, i32* %f
162   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
163   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
164   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
165   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
166   call void %fp(%class.Interface* nonnull %base.i)
167   ret void
168 }
169 
170 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
171 )IR");
172 
173   auto *GV = M->getNamedValue("f");
174   ASSERT_TRUE(GV);
175   auto *F = dyn_cast<Function>(GV);
176   ASSERT_TRUE(F);
177   Instruction *Inst = &*++F->front().rbegin();
178   auto *CI = dyn_cast<CallInst>(Inst);
179   ASSERT_TRUE(CI);
180   ASSERT_FALSE(CI->getCalledFunction());
181   bool IsPromoted = tryPromoteCall(*CI);
182   EXPECT_FALSE(IsPromoted);
183 }
184 
185 TEST(CallPromotionUtilsTest, TryPromoteCall_EmptyVTable) {
186   LLVMContext C;
187   std::unique_ptr<Module> M = parseIR(C,
188                                       R"IR(
189 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
190 %class.Interface = type { i32 (...)** }
191 
192 @_ZTV4Impl = external global { [3 x i8*] }
193 
194 define void @f() {
195 entry:
196   %o = alloca %class.Impl
197   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
198   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
199   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
200   store i32 3, i32* %f
201   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
202   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
203   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
204   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
205   call void %fp(%class.Interface* nonnull %base.i)
206   ret void
207 }
208 
209 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
210 )IR");
211 
212   auto *GV = M->getNamedValue("f");
213   ASSERT_TRUE(GV);
214   auto *F = dyn_cast<Function>(GV);
215   ASSERT_TRUE(F);
216   Instruction *Inst = &F->front().front();
217   auto *AI = dyn_cast<AllocaInst>(Inst);
218   ASSERT_TRUE(AI);
219   Inst = &*++F->front().rbegin();
220   auto *CI = dyn_cast<CallInst>(Inst);
221   ASSERT_TRUE(CI);
222   ASSERT_FALSE(CI->getCalledFunction());
223   bool IsPromoted = tryPromoteCall(*CI);
224   EXPECT_FALSE(IsPromoted);
225 }
226 
227 TEST(CallPromotionUtilsTest, TryPromoteCall_NullFP) {
228   LLVMContext C;
229   std::unique_ptr<Module> M = parseIR(C,
230                                       R"IR(
231 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
232 %class.Interface = type { i32 (...)** }
233 
234 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* null] }
235 
236 define void @f() {
237 entry:
238   %o = alloca %class.Impl
239   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
240   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
241   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
242   store i32 3, i32* %f
243   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
244   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
245   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
246   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
247   call void %fp(%class.Interface* nonnull %base.i)
248   ret void
249 }
250 
251 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
252 )IR");
253 
254   auto *GV = M->getNamedValue("f");
255   ASSERT_TRUE(GV);
256   auto *F = dyn_cast<Function>(GV);
257   ASSERT_TRUE(F);
258   Instruction *Inst = &F->front().front();
259   auto *AI = dyn_cast<AllocaInst>(Inst);
260   ASSERT_TRUE(AI);
261   Inst = &*++F->front().rbegin();
262   auto *CI = dyn_cast<CallInst>(Inst);
263   ASSERT_TRUE(CI);
264   ASSERT_FALSE(CI->getCalledFunction());
265   bool IsPromoted = tryPromoteCall(*CI);
266   EXPECT_FALSE(IsPromoted);
267 }
268 
269 // Based on clang/test/CodeGenCXX/member-function-pointer-calls.cpp
270 TEST(CallPromotionUtilsTest, TryPromoteCall_MemberFunctionCalls) {
271   LLVMContext C;
272   std::unique_ptr<Module> M = parseIR(C,
273                                       R"IR(
274 %struct.A = type { i32 (...)** }
275 
276 @_ZTV1A = linkonce_odr unnamed_addr constant { [4 x i8*] } { [4 x i8*] [i8* null, i8* null, i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf1Ev to i8*), i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf2Ev to i8*)] }, align 8
277 
278 define i32 @_Z2g1v() {
279 entry:
280   %a = alloca %struct.A, align 8
281   %0 = bitcast %struct.A* %a to i8*
282   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
283   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
284   %2 = bitcast %struct.A* %a to i8*
285   %3 = bitcast i8* %2 to i8**
286   %vtable.i = load i8*, i8** %3, align 8
287   %4 = bitcast i8* %vtable.i to i32 (%struct.A*)**
288   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %4, align 8
289   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
290   ret i32 %call.i
291 }
292 
293 define i32 @_Z2g2v() {
294 entry:
295   %a = alloca %struct.A, align 8
296   %0 = bitcast %struct.A* %a to i8*
297   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
298   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
299   %2 = bitcast %struct.A* %a to i8*
300   %3 = bitcast i8* %2 to i8**
301   %vtable.i = load i8*, i8** %3, align 8
302   %4 = getelementptr i8, i8* %vtable.i, i64 8
303   %5 = bitcast i8* %4 to i32 (%struct.A*)**
304   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %5, align 8
305   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
306   ret i32 %call.i
307 }
308 
309 declare i32 @_ZN1A3vf1Ev(%struct.A* %this)
310 declare i32 @_ZN1A3vf2Ev(%struct.A* %this)
311 )IR");
312 
313   auto *GV = M->getNamedValue("_Z2g1v");
314   ASSERT_TRUE(GV);
315   auto *F = dyn_cast<Function>(GV);
316   ASSERT_TRUE(F);
317   Instruction *Inst = &F->front().front();
318   auto *AI = dyn_cast<AllocaInst>(Inst);
319   ASSERT_TRUE(AI);
320   Inst = &*++F->front().rbegin();
321   auto *CI = dyn_cast<CallInst>(Inst);
322   ASSERT_TRUE(CI);
323   ASSERT_FALSE(CI->getCalledFunction());
324   bool IsPromoted1 = tryPromoteCall(*CI);
325   EXPECT_TRUE(IsPromoted1);
326   GV = M->getNamedValue("_ZN1A3vf1Ev");
327   ASSERT_TRUE(GV);
328   F = dyn_cast<Function>(GV);
329   EXPECT_EQ(F, CI->getCalledFunction());
330 
331   GV = M->getNamedValue("_Z2g2v");
332   ASSERT_TRUE(GV);
333   F = dyn_cast<Function>(GV);
334   ASSERT_TRUE(F);
335   Inst = &F->front().front();
336   AI = dyn_cast<AllocaInst>(Inst);
337   ASSERT_TRUE(AI);
338   Inst = &*++F->front().rbegin();
339   CI = dyn_cast<CallInst>(Inst);
340   ASSERT_TRUE(CI);
341   ASSERT_FALSE(CI->getCalledFunction());
342   bool IsPromoted2 = tryPromoteCall(*CI);
343   EXPECT_TRUE(IsPromoted2);
344   GV = M->getNamedValue("_ZN1A3vf2Ev");
345   ASSERT_TRUE(GV);
346   F = dyn_cast<Function>(GV);
347   EXPECT_EQ(F, CI->getCalledFunction());
348 }
349 
350 // Check that it isn't crashing due to missing promotion legality.
351 TEST(CallPromotionUtilsTest, TryPromoteCall_Legality) {
352   LLVMContext C;
353   std::unique_ptr<Module> M = parseIR(C,
354                                       R"IR(
355 %struct1 = type <{ i32, i64 }>
356 %struct2 = type <{ i32, i64 }>
357 
358 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
359 %class.Interface = type { i32 (...)** }
360 
361 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (%struct2 (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
362 
363 define %struct1 @f() {
364 entry:
365   %o = alloca %class.Impl
366   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
367   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
368   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
369   store i32 3, i32* %f
370   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
371   %c = bitcast %class.Interface* %base.i to %struct1 (%class.Interface*)***
372   %vtable.i = load %struct1 (%class.Interface*)**, %struct1 (%class.Interface*)*** %c
373   %fp = load %struct1 (%class.Interface*)*, %struct1 (%class.Interface*)** %vtable.i
374   %rv = call %struct1 %fp(%class.Interface* nonnull %base.i)
375   ret %struct1 %rv
376 }
377 
378 declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
379 )IR");
380 
381   auto *GV = M->getNamedValue("f");
382   ASSERT_TRUE(GV);
383   auto *F = dyn_cast<Function>(GV);
384   ASSERT_TRUE(F);
385   Instruction *Inst = &F->front().front();
386   auto *AI = dyn_cast<AllocaInst>(Inst);
387   ASSERT_TRUE(AI);
388   Inst = &*++F->front().rbegin();
389   auto *CI = dyn_cast<CallInst>(Inst);
390   ASSERT_TRUE(CI);
391   ASSERT_FALSE(CI->getCalledFunction());
392   bool IsPromoted = tryPromoteCall(*CI);
393   EXPECT_FALSE(IsPromoted);
394 }
395 
396 TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
397   LLVMContext C;
398   std::unique_ptr<Module> M = parseIR(C,
399                                       R"IR(
400 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
401 target triple = "x86_64-unknown-linux-gnu"
402 
403 @_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
404 @_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
405 @_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
406 
407 define i32 @testfunc(ptr %d) {
408 entry:
409   %vtable = load ptr, ptr %d, !prof !7
410   %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
411   %0 = load ptr, ptr %vfn
412   %call = tail call i32 %0(ptr %d), !prof !8
413   ret i32 %call
414 }
415 
416 define i32 @_ZN5Base15func1Ev(ptr %this) {
417 entry:
418   ret i32 2
419 }
420 
421 declare i32 @_ZN5Base25func2Ev(ptr)
422 declare i32 @_ZN5Base15func0Ev(ptr)
423 declare void @_ZN5Base35func3Ev(ptr)
424 
425 !0 = !{i64 16, !"_ZTS5Base1"}
426 !1 = !{i64 48, !"_ZTS5Base2"}
427 !2 = !{i64 16, !"_ZTS8Derived1"}
428 !3 = !{i64 64, !"_ZTS5Base1"}
429 !4 = !{i64 40, !"_ZTS5Base2"}
430 !5 = !{i64 16, !"_ZTS5Base3"}
431 !6 = !{i64 16, !"_ZTS8Derived2"}
432 !7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
433 !8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");
434 
435   Function *F = M->getFunction("testfunc");
436   CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
437   ASSERT_TRUE(CI && CI->isIndirectCall());
438 
439   // Create the constant and the branch weights
440   SmallVector<Constant *, 3> VTableAddressPoints;
441 
442   for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
443                                                  {"_ZTV8Derived1", 16},
444                                                  {"_ZTV8Derived2", 64}})
445     VTableAddressPoints.push_back(getVTableAddressPointOffset(
446         M->getGlobalVariable(VTableName), AddressPointOffset));
447 
448   MDBuilder MDB(C);
449   MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);
450 
451   size_t OrigEntryBBSize = F->front().size();
452 
453   LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());
454 
455   Function *Callee = M->getFunction("_ZN5Base15func1Ev");
456   // Tests that promoted direct call is returned.
457   CallBase &DirectCB = promoteCallWithVTableCmp(
458       *CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
459   EXPECT_EQ(DirectCB.getCalledOperand(), Callee);
460 
461   // Promotion inserts 3 icmp instructions and 2 or instructions, and removes
462   // 1 call instruction from the entry block.
463   EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
464 }
465 
466 TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) {
467   LLVMContext C;
468   std::unique_ptr<Module> M = parseIR(C,
469                                       R"IR(
470 define i32 @testfunc1(ptr %d) !guid !0 {
471   call void @llvm.instrprof.increment(ptr @testfunc1, i64 0, i32 1, i32 0)
472   call void @llvm.instrprof.callsite(ptr @testfunc1, i64 0, i32 1, i32 0, ptr %d)
473   %call = call i32 %d()
474   ret i32 %call
475 }
476 
477 define i32 @f1() !guid !1 {
478   call void @llvm.instrprof.increment(ptr @f1, i64 0, i32 1, i32 0)
479   ret i32 2
480 }
481 
482 define i32 @f2() !guid !2 {
483   call void @llvm.instrprof.increment(ptr @f2, i64 0, i32 1, i32 0)
484   call void @llvm.instrprof.callsite(ptr @f2, i64 0, i32 1, i32 0, ptr @f4)
485   %r = call i32 @f4()
486   ret i32 %r
487 }
488 
489 define i32 @testfunc2(ptr %p) !guid !4 {
490   call void @llvm.instrprof.increment(ptr @testfunc2, i64 0, i32 1, i32 0)
491   call void @llvm.instrprof.callsite(ptr @testfunc2, i64 0, i32 1, i32 0, ptr @testfunc1)
492   %r = call i32 @testfunc1(ptr %p)
493   ret i32 %r
494 }
495 
496 declare i32 @f3()
497 
498 define i32 @f4() !guid !3 {
499   ret i32 3
500 }
501 
502 !0 = !{i64 1000}
503 !1 = !{i64 1001}
504 !2 = !{i64 1002}
505 !3 = !{i64 1004}
506 !4 = !{i64 1005}
507 )IR");
508 
509   const char *Profile = R"json(
510     [
511     {
512       "Guid": 1000,
513       "Counters": [1],
514       "Callsites": [
515         [{ "Guid": 1001,
516             "Counters": [10]},
517           { "Guid": 1002,
518             "Counters": [11],
519             "Callsites": [[{"Guid": 1004, "Counters":[13]}]]
520           },
521           { "Guid": 1003,
522             "Counters": [12]
523           }]]
524     },
525     {
526       "Guid": 1005,
527       "Counters": [2],
528       "Callsites": [
529         [{ "Guid": 1000,
530             "Counters": [1],
531             "Callsites": [
532               [{ "Guid": 1001,
533                   "Counters": [101]},
534                 { "Guid": 1002,
535                   "Counters": [102],
536                   "Callsites": [[{"Guid": 1004, "Counters":[104]}]]
537                 },
538                 { "Guid": 1003,
539                   "Counters": [103]
540                 }]]}]]}]
541     )json";
542 
543   llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique=*/true);
544   {
545     std::error_code EC;
546     raw_fd_stream Out(ProfileFile.path(), EC);
547     ASSERT_FALSE(EC);
548     // "False" means no error.
549     ASSERT_FALSE(llvm::createCtxProfFromYAML(Profile, Out));
550   }
551 
552   ModuleAnalysisManager MAM;
553   MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); });
554   MAM.registerPass([&]() { return PassInstrumentationAnalysis(); });
555   auto &CtxProf = MAM.getResult<CtxProfAnalysis>(*M);
556   auto *Caller = M->getFunction("testfunc1");
557   ASSERT_NE(Caller, nullptr);
558   auto *Callee = M->getFunction("f2");
559   ASSERT_NE(Callee, nullptr);
560   auto *IndirectCS = [&]() -> CallBase * {
561     for (auto &BB : *Caller)
562       for (auto &I : BB)
563         if (auto *CB = dyn_cast<CallBase>(&I); CB && CB->isIndirectCall())
564           return CB;
565     return nullptr;
566   }();
567   ASSERT_NE(IndirectCS, nullptr);
568   promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf);
569 
570   std::string Str;
571   raw_string_ostream OS(Str);
572   CtxProfAnalysisPrinterPass Printer(OS);
573   Printer.run(*M, MAM);
574   const char *Expected = R"yaml(
575 - Guid:            1000
576   Counters:        [ 1, 11, 22 ]
577   Callsites:
578     - - Guid:            1001
579         Counters:        [ 10 ]
580       - Guid:            1003
581         Counters:        [ 12 ]
582     - - Guid:            1002
583         Counters:        [ 11 ]
584         Callsites:
585           - - Guid:            1004
586               Counters:        [ 13 ]
587 - Guid:            1005
588   Counters:        [ 2 ]
589   Callsites:
590     - - Guid:            1000
591         Counters:        [ 1, 102, 204 ]
592         Callsites:
593           - - Guid:            1001
594               Counters:        [ 101 ]
595             - Guid:            1003
596               Counters:        [ 103 ]
597           - - Guid:            1002
598               Counters:        [ 102 ]
599               Callsites:
600                 - - Guid:            1004
601                     Counters:        [ 104 ]
602 )yaml";
603   EXPECT_EQ(Expected, Str);
604 }
605