xref: /llvm-project/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp (revision 73c3b7337b0a3a8cb447f9801341d5648aebe9b2)
1 //===- CallPromotionUtilsTest.cpp - CallPromotionUtils unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
10 #include "llvm/Analysis/CtxProfAnalysis.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/IRBuilder.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/LLVMContext.h"
15 #include "llvm/IR/MDBuilder.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/NoFolder.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/ProfileData/PGOCtxProfReader.h"
20 #include "llvm/ProfileData/PGOCtxProfWriter.h"
21 #include "llvm/Support/JSON.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "llvm/Testing/Support/SupportHelpers.h"
25 #include "gtest/gtest.h"
26 
27 using namespace llvm;
28 
29 static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
30   SMDiagnostic Err;
31   std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
32   if (!Mod)
33     Err.print("UtilsTests", errs());
34   return Mod;
35 }
36 
37 // Returns a constant representing the vtable's address point specified by the
38 // offset.
39 static Constant *getVTableAddressPointOffset(GlobalVariable *VTable,
40                                              uint32_t AddressPointOffset) {
41   Module &M = *VTable->getParent();
42   LLVMContext &Context = M.getContext();
43   assert(AddressPointOffset <
44              M.getDataLayout().getTypeAllocSize(VTable->getValueType()) &&
45          "Out-of-bound access");
46 
47   return ConstantExpr::getInBoundsGetElementPtr(
48       Type::getInt8Ty(Context), VTable,
49       llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset));
50 }
51 
52 TEST(CallPromotionUtilsTest, TryPromoteCall) {
53   LLVMContext C;
54   std::unique_ptr<Module> M = parseIR(C,
55                                       R"IR(
56 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
57 %class.Interface = type { i32 (...)** }
58 
59 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
60 
61 define void @f() {
62 entry:
63   %o = alloca %class.Impl
64   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
65   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
66   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
67   store i32 3, i32* %f
68   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
69   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
70   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
71   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
72   call void %fp(%class.Interface* nonnull %base.i)
73   ret void
74 }
75 
76 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
77 )IR");
78 
79   auto *GV = M->getNamedValue("f");
80   ASSERT_TRUE(GV);
81   auto *F = dyn_cast<Function>(GV);
82   ASSERT_TRUE(F);
83   Instruction *Inst = &F->front().front();
84   auto *AI = dyn_cast<AllocaInst>(Inst);
85   ASSERT_TRUE(AI);
86   Inst = &*++F->front().rbegin();
87   auto *CI = dyn_cast<CallInst>(Inst);
88   ASSERT_TRUE(CI);
89   ASSERT_FALSE(CI->getCalledFunction());
90   bool IsPromoted = tryPromoteCall(*CI);
91   EXPECT_TRUE(IsPromoted);
92   GV = M->getNamedValue("_ZN4Impl3RunEv");
93   ASSERT_TRUE(GV);
94   auto *F1 = dyn_cast<Function>(GV);
95   EXPECT_EQ(F1, CI->getCalledFunction());
96 }
97 
98 TEST(CallPromotionUtilsTest, TryPromoteCall_NoFPLoad) {
99   LLVMContext C;
100   std::unique_ptr<Module> M = parseIR(C,
101                                       R"IR(
102 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
103 %class.Interface = type { i32 (...)** }
104 
105 define void @f(void (%class.Interface*)* %fp, %class.Interface* nonnull %base.i) {
106 entry:
107   call void %fp(%class.Interface* nonnull %base.i)
108   ret void
109 }
110 )IR");
111 
112   auto *GV = M->getNamedValue("f");
113   ASSERT_TRUE(GV);
114   auto *F = dyn_cast<Function>(GV);
115   ASSERT_TRUE(F);
116   Instruction *Inst = &F->front().front();
117   auto *CI = dyn_cast<CallInst>(Inst);
118   ASSERT_TRUE(CI);
119   ASSERT_FALSE(CI->getCalledFunction());
120   bool IsPromoted = tryPromoteCall(*CI);
121   EXPECT_FALSE(IsPromoted);
122 }
123 
124 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTablePtrLoad) {
125   LLVMContext C;
126   std::unique_ptr<Module> M = parseIR(C,
127                                       R"IR(
128 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
129 %class.Interface = type { i32 (...)** }
130 
131 define void @f(void (%class.Interface*)** %vtable.i, %class.Interface* nonnull %base.i) {
132 entry:
133   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
134   call void %fp(%class.Interface* nonnull %base.i)
135   ret void
136 }
137 )IR");
138 
139   auto *GV = M->getNamedValue("f");
140   ASSERT_TRUE(GV);
141   auto *F = dyn_cast<Function>(GV);
142   ASSERT_TRUE(F);
143   Instruction *Inst = &*++F->front().rbegin();
144   auto *CI = dyn_cast<CallInst>(Inst);
145   ASSERT_TRUE(CI);
146   ASSERT_FALSE(CI->getCalledFunction());
147   bool IsPromoted = tryPromoteCall(*CI);
148   EXPECT_FALSE(IsPromoted);
149 }
150 
151 TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTableInitFound) {
152   LLVMContext C;
153   std::unique_ptr<Module> M = parseIR(C,
154                                       R"IR(
155 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
156 %class.Interface = type { i32 (...)** }
157 
158 define void @f() {
159 entry:
160   %o = alloca %class.Impl
161   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
162   store i32 3, i32* %f
163   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
164   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
165   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
166   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
167   call void %fp(%class.Interface* nonnull %base.i)
168   ret void
169 }
170 
171 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
172 )IR");
173 
174   auto *GV = M->getNamedValue("f");
175   ASSERT_TRUE(GV);
176   auto *F = dyn_cast<Function>(GV);
177   ASSERT_TRUE(F);
178   Instruction *Inst = &*++F->front().rbegin();
179   auto *CI = dyn_cast<CallInst>(Inst);
180   ASSERT_TRUE(CI);
181   ASSERT_FALSE(CI->getCalledFunction());
182   bool IsPromoted = tryPromoteCall(*CI);
183   EXPECT_FALSE(IsPromoted);
184 }
185 
186 TEST(CallPromotionUtilsTest, TryPromoteCall_EmptyVTable) {
187   LLVMContext C;
188   std::unique_ptr<Module> M = parseIR(C,
189                                       R"IR(
190 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
191 %class.Interface = type { i32 (...)** }
192 
193 @_ZTV4Impl = external global { [3 x i8*] }
194 
195 define void @f() {
196 entry:
197   %o = alloca %class.Impl
198   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
199   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
200   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
201   store i32 3, i32* %f
202   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
203   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
204   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
205   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
206   call void %fp(%class.Interface* nonnull %base.i)
207   ret void
208 }
209 
210 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
211 )IR");
212 
213   auto *GV = M->getNamedValue("f");
214   ASSERT_TRUE(GV);
215   auto *F = dyn_cast<Function>(GV);
216   ASSERT_TRUE(F);
217   Instruction *Inst = &F->front().front();
218   auto *AI = dyn_cast<AllocaInst>(Inst);
219   ASSERT_TRUE(AI);
220   Inst = &*++F->front().rbegin();
221   auto *CI = dyn_cast<CallInst>(Inst);
222   ASSERT_TRUE(CI);
223   ASSERT_FALSE(CI->getCalledFunction());
224   bool IsPromoted = tryPromoteCall(*CI);
225   EXPECT_FALSE(IsPromoted);
226 }
227 
228 TEST(CallPromotionUtilsTest, TryPromoteCall_NullFP) {
229   LLVMContext C;
230   std::unique_ptr<Module> M = parseIR(C,
231                                       R"IR(
232 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
233 %class.Interface = type { i32 (...)** }
234 
235 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* null] }
236 
237 define void @f() {
238 entry:
239   %o = alloca %class.Impl
240   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
241   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
242   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
243   store i32 3, i32* %f
244   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
245   %c = bitcast %class.Interface* %base.i to void (%class.Interface*)***
246   %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c
247   %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i
248   call void %fp(%class.Interface* nonnull %base.i)
249   ret void
250 }
251 
252 declare void @_ZN4Impl3RunEv(%class.Impl* %this)
253 )IR");
254 
255   auto *GV = M->getNamedValue("f");
256   ASSERT_TRUE(GV);
257   auto *F = dyn_cast<Function>(GV);
258   ASSERT_TRUE(F);
259   Instruction *Inst = &F->front().front();
260   auto *AI = dyn_cast<AllocaInst>(Inst);
261   ASSERT_TRUE(AI);
262   Inst = &*++F->front().rbegin();
263   auto *CI = dyn_cast<CallInst>(Inst);
264   ASSERT_TRUE(CI);
265   ASSERT_FALSE(CI->getCalledFunction());
266   bool IsPromoted = tryPromoteCall(*CI);
267   EXPECT_FALSE(IsPromoted);
268 }
269 
270 // Based on clang/test/CodeGenCXX/member-function-pointer-calls.cpp
271 TEST(CallPromotionUtilsTest, TryPromoteCall_MemberFunctionCalls) {
272   LLVMContext C;
273   std::unique_ptr<Module> M = parseIR(C,
274                                       R"IR(
275 %struct.A = type { i32 (...)** }
276 
277 @_ZTV1A = linkonce_odr unnamed_addr constant { [4 x i8*] } { [4 x i8*] [i8* null, i8* null, i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf1Ev to i8*), i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf2Ev to i8*)] }, align 8
278 
279 define i32 @_Z2g1v() {
280 entry:
281   %a = alloca %struct.A, align 8
282   %0 = bitcast %struct.A* %a to i8*
283   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
284   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
285   %2 = bitcast %struct.A* %a to i8*
286   %3 = bitcast i8* %2 to i8**
287   %vtable.i = load i8*, i8** %3, align 8
288   %4 = bitcast i8* %vtable.i to i32 (%struct.A*)**
289   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %4, align 8
290   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
291   ret i32 %call.i
292 }
293 
294 define i32 @_Z2g2v() {
295 entry:
296   %a = alloca %struct.A, align 8
297   %0 = bitcast %struct.A* %a to i8*
298   %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0
299   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8
300   %2 = bitcast %struct.A* %a to i8*
301   %3 = bitcast i8* %2 to i8**
302   %vtable.i = load i8*, i8** %3, align 8
303   %4 = getelementptr i8, i8* %vtable.i, i64 8
304   %5 = bitcast i8* %4 to i32 (%struct.A*)**
305   %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %5, align 8
306   %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a)
307   ret i32 %call.i
308 }
309 
310 declare i32 @_ZN1A3vf1Ev(%struct.A* %this)
311 declare i32 @_ZN1A3vf2Ev(%struct.A* %this)
312 )IR");
313 
314   auto *GV = M->getNamedValue("_Z2g1v");
315   ASSERT_TRUE(GV);
316   auto *F = dyn_cast<Function>(GV);
317   ASSERT_TRUE(F);
318   Instruction *Inst = &F->front().front();
319   auto *AI = dyn_cast<AllocaInst>(Inst);
320   ASSERT_TRUE(AI);
321   Inst = &*++F->front().rbegin();
322   auto *CI = dyn_cast<CallInst>(Inst);
323   ASSERT_TRUE(CI);
324   ASSERT_FALSE(CI->getCalledFunction());
325   bool IsPromoted1 = tryPromoteCall(*CI);
326   EXPECT_TRUE(IsPromoted1);
327   GV = M->getNamedValue("_ZN1A3vf1Ev");
328   ASSERT_TRUE(GV);
329   F = dyn_cast<Function>(GV);
330   EXPECT_EQ(F, CI->getCalledFunction());
331 
332   GV = M->getNamedValue("_Z2g2v");
333   ASSERT_TRUE(GV);
334   F = dyn_cast<Function>(GV);
335   ASSERT_TRUE(F);
336   Inst = &F->front().front();
337   AI = dyn_cast<AllocaInst>(Inst);
338   ASSERT_TRUE(AI);
339   Inst = &*++F->front().rbegin();
340   CI = dyn_cast<CallInst>(Inst);
341   ASSERT_TRUE(CI);
342   ASSERT_FALSE(CI->getCalledFunction());
343   bool IsPromoted2 = tryPromoteCall(*CI);
344   EXPECT_TRUE(IsPromoted2);
345   GV = M->getNamedValue("_ZN1A3vf2Ev");
346   ASSERT_TRUE(GV);
347   F = dyn_cast<Function>(GV);
348   EXPECT_EQ(F, CI->getCalledFunction());
349 }
350 
351 // Check that it isn't crashing due to missing promotion legality.
352 TEST(CallPromotionUtilsTest, TryPromoteCall_Legality) {
353   LLVMContext C;
354   std::unique_ptr<Module> M = parseIR(C,
355                                       R"IR(
356 %struct1 = type <{ i32, i64 }>
357 %struct2 = type <{ i32, i64 }>
358 
359 %class.Impl = type <{ %class.Interface, i32, [4 x i8] }>
360 %class.Interface = type { i32 (...)** }
361 
362 @_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (%struct2 (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
363 
364 define %struct1 @f() {
365 entry:
366   %o = alloca %class.Impl
367   %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0
368   store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, i32 0, i64 2) to i32 (...)**), i32 (...)*** %base
369   %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1
370   store i32 3, i32* %f
371   %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0
372   %c = bitcast %class.Interface* %base.i to %struct1 (%class.Interface*)***
373   %vtable.i = load %struct1 (%class.Interface*)**, %struct1 (%class.Interface*)*** %c
374   %fp = load %struct1 (%class.Interface*)*, %struct1 (%class.Interface*)** %vtable.i
375   %rv = call %struct1 %fp(%class.Interface* nonnull %base.i)
376   ret %struct1 %rv
377 }
378 
379 declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
380 )IR");
381 
382   auto *GV = M->getNamedValue("f");
383   ASSERT_TRUE(GV);
384   auto *F = dyn_cast<Function>(GV);
385   ASSERT_TRUE(F);
386   Instruction *Inst = &F->front().front();
387   auto *AI = dyn_cast<AllocaInst>(Inst);
388   ASSERT_TRUE(AI);
389   Inst = &*++F->front().rbegin();
390   auto *CI = dyn_cast<CallInst>(Inst);
391   ASSERT_TRUE(CI);
392   ASSERT_FALSE(CI->getCalledFunction());
393   bool IsPromoted = tryPromoteCall(*CI);
394   EXPECT_FALSE(IsPromoted);
395 }
396 
397 TEST(CallPromotionUtilsTest, promoteCallWithVTableCmp) {
398   LLVMContext C;
399   std::unique_ptr<Module> M = parseIR(C,
400                                       R"IR(
401 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
402 target triple = "x86_64-unknown-linux-gnu"
403 
404 @_ZTV5Base1 = constant { [4 x ptr] } { [4 x ptr] [ptr null, ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !0
405 @_ZTV8Derived1 = constant { [4 x ptr], [3 x ptr] } { [4 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev], [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base25func2Ev] }, !type !0, !type !1, !type !2
406 @_ZTV8Derived2 = constant { [3 x ptr], [3 x ptr], [4 x ptr] } { [3 x ptr] [ptr null, ptr null, ptr @_ZN5Base35func3Ev], [3 x ptr] [ptr inttoptr (i64 -8 to ptr), ptr null, ptr @_ZN5Base25func2Ev], [4 x ptr] [ptr inttoptr (i64 -16 to ptr), ptr null, ptr @_ZN5Base15func0Ev, ptr @_ZN5Base15func1Ev] }, !type !3, !type !4, !type !5, !type !6
407 
408 define i32 @testfunc(ptr %d) {
409 entry:
410   %vtable = load ptr, ptr %d, !prof !7
411   %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
412   %0 = load ptr, ptr %vfn
413   %call = tail call i32 %0(ptr %d), !prof !8
414   ret i32 %call
415 }
416 
417 define i32 @_ZN5Base15func1Ev(ptr %this) {
418 entry:
419   ret i32 2
420 }
421 
422 declare i32 @_ZN5Base25func2Ev(ptr)
423 declare i32 @_ZN5Base15func0Ev(ptr)
424 declare void @_ZN5Base35func3Ev(ptr)
425 
426 !0 = !{i64 16, !"_ZTS5Base1"}
427 !1 = !{i64 48, !"_ZTS5Base2"}
428 !2 = !{i64 16, !"_ZTS8Derived1"}
429 !3 = !{i64 64, !"_ZTS5Base1"}
430 !4 = !{i64 40, !"_ZTS5Base2"}
431 !5 = !{i64 16, !"_ZTS5Base3"}
432 !6 = !{i64 16, !"_ZTS8Derived2"}
433 !7 = !{!"VP", i32 2, i64 1600, i64 -9064381665493407289, i64 800, i64 5035968517245772950, i64 500, i64 3215870116411581797, i64 300}
434 !8 = !{!"VP", i32 0, i64 1600, i64 6804820478065511155, i64 1600})IR");
435 
436   Function *F = M->getFunction("testfunc");
437   CallInst *CI = dyn_cast<CallInst>(&*std::next(F->front().rbegin()));
438   ASSERT_TRUE(CI && CI->isIndirectCall());
439 
440   // Create the constant and the branch weights
441   SmallVector<Constant *, 3> VTableAddressPoints;
442 
443   for (auto &[VTableName, AddressPointOffset] : {std::pair{"_ZTV5Base1", 16},
444                                                  {"_ZTV8Derived1", 16},
445                                                  {"_ZTV8Derived2", 64}})
446     VTableAddressPoints.push_back(getVTableAddressPointOffset(
447         M->getGlobalVariable(VTableName), AddressPointOffset));
448 
449   MDBuilder MDB(C);
450   MDNode *BranchWeights = MDB.createBranchWeights(1600, 0);
451 
452   size_t OrigEntryBBSize = F->front().size();
453 
454   LoadInst *VPtr = dyn_cast<LoadInst>(&*F->front().begin());
455 
456   Function *Callee = M->getFunction("_ZN5Base15func1Ev");
457   // Tests that promoted direct call is returned.
458   CallBase &DirectCB = promoteCallWithVTableCmp(
459       *CI, VPtr, Callee, VTableAddressPoints, BranchWeights);
460   EXPECT_EQ(DirectCB.getCalledOperand(), Callee);
461 
462   // Promotion inserts 3 icmp instructions and 2 or instructions, and removes
463   // 1 call instruction from the entry block.
464   EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4);
465 }
466 
467 TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) {
468   LLVMContext C;
469   std::unique_ptr<Module> M = parseIR(C,
470                                       R"IR(
471 define i32 @testfunc1(ptr %d) !guid !0 {
472   call void @llvm.instrprof.increment(ptr @testfunc1, i64 0, i32 1, i32 0)
473   call void @llvm.instrprof.callsite(ptr @testfunc1, i64 0, i32 1, i32 0, ptr %d)
474   %call = call i32 %d()
475   ret i32 %call
476 }
477 
478 define i32 @f1() !guid !1 {
479   call void @llvm.instrprof.increment(ptr @f1, i64 0, i32 1, i32 0)
480   ret i32 2
481 }
482 
483 define i32 @f2() !guid !2 {
484   call void @llvm.instrprof.increment(ptr @f2, i64 0, i32 1, i32 0)
485   call void @llvm.instrprof.callsite(ptr @f2, i64 0, i32 1, i32 0, ptr @f4)
486   %r = call i32 @f4()
487   ret i32 %r
488 }
489 
490 define i32 @testfunc2(ptr %p) !guid !4 {
491   call void @llvm.instrprof.increment(ptr @testfunc2, i64 0, i32 1, i32 0)
492   call void @llvm.instrprof.callsite(ptr @testfunc2, i64 0, i32 1, i32 0, ptr @testfunc1)
493   %r = call i32 @testfunc1(ptr %p)
494   ret i32 %r
495 }
496 
497 declare i32 @f3()
498 
499 define i32 @f4() !guid !3 {
500   ret i32 3
501 }
502 
503 !0 = !{i64 1000}
504 !1 = !{i64 1001}
505 !2 = !{i64 1002}
506 !3 = !{i64 1004}
507 !4 = !{i64 1005}
508 )IR");
509 
510   const char *Profile = R"json(
511     [
512     {
513       "Guid": 1000,
514       "Counters": [1],
515       "Callsites": [
516         [{ "Guid": 1001,
517             "Counters": [10]},
518           { "Guid": 1002,
519             "Counters": [11],
520             "Callsites": [[{"Guid": 1004, "Counters":[13]}]]
521           },
522           { "Guid": 1003,
523             "Counters": [12]
524           }]]
525     },
526     {
527       "Guid": 1005,
528       "Counters": [2],
529       "Callsites": [
530         [{ "Guid": 1000,
531             "Counters": [1],
532             "Callsites": [
533               [{ "Guid": 1001,
534                   "Counters": [101]},
535                 { "Guid": 1002,
536                   "Counters": [102],
537                   "Callsites": [[{"Guid": 1004, "Counters":[104]}]]
538                 },
539                 { "Guid": 1003,
540                   "Counters": [103]
541                 }]]}]]}]
542     )json";
543 
544   llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique=*/true);
545   {
546     std::error_code EC;
547     raw_fd_stream Out(ProfileFile.path(), EC);
548     ASSERT_FALSE(EC);
549     // "False" means no error.
550     ASSERT_FALSE(llvm::createCtxProfFromJSON(Profile, Out));
551   }
552 
553   ModuleAnalysisManager MAM;
554   MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); });
555   MAM.registerPass([&]() { return PassInstrumentationAnalysis(); });
556   auto &CtxProf = MAM.getResult<CtxProfAnalysis>(*M);
557   auto *Caller = M->getFunction("testfunc1");
558   ASSERT_NE(Caller, nullptr);
559   auto *Callee = M->getFunction("f2");
560   ASSERT_NE(Callee, nullptr);
561   auto *IndirectCS = [&]() -> CallBase * {
562     for (auto &BB : *Caller)
563       for (auto &I : BB)
564         if (auto *CB = dyn_cast<CallBase>(&I); CB && CB->isIndirectCall())
565           return CB;
566     return nullptr;
567   }();
568   ASSERT_NE(IndirectCS, nullptr);
569   promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf);
570 
571   std::string Str;
572   raw_string_ostream OS(Str);
573   CtxProfAnalysisPrinterPass Printer(
574       OS, CtxProfAnalysisPrinterPass::PrintMode::JSON);
575   Printer.run(*M, MAM);
576   const char *Expected = R"json(
577   [
578   {
579     "Guid": 1000,
580     "Counters": [1, 11, 22],
581     "Callsites": [
582       [{ "Guid": 1001,
583           "Counters": [10]},
584         { "Guid": 1003,
585           "Counters": [12]
586         }],
587         [{ "Guid": 1002,
588           "Counters": [11],
589           "Callsites": [
590           [{ "Guid": 1004,
591             "Counters": [13] }]]}]]
592   },
593   {
594     "Guid": 1005,
595     "Counters": [2],
596     "Callsites": [
597       [{ "Guid": 1000,
598          "Counters": [1, 102, 204],
599          "Callsites": [
600             [{ "Guid": 1001,
601                "Counters": [101]},
602              { "Guid": 1003,
603                "Counters": [103]}],
604             [{ "Guid": 1002,
605                "Counters": [102],
606                "Callsites": [
607             [{ "Guid": 1004,
608                "Counters": [104]}]]}]]}]]}
609 ])json";
610   auto ExpectedJSON = json::parse(Expected);
611   ASSERT_TRUE(!!ExpectedJSON);
612   auto ProducedJSON = json::parse(Str);
613   ASSERT_TRUE(!!ProducedJSON);
614   EXPECT_EQ(*ProducedJSON, *ExpectedJSON);
615 }
616