xref: /llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp (revision 9df71d7673b5c98e1032d01be83724a45b42fafc)
1 //===- CallPromotionUtils.cpp - Utilities for call promotion ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities useful for promoting indirect call sites to
10 // direct call sites.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Analysis/Loads.h"
17 #include "llvm/Analysis/TypeMetadataUtils.h"
18 #include "llvm/IR/AttributeMask.h"
19 #include "llvm/IR/Constant.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "call-promotion-utils"
27 
28 /// Fix-up phi nodes in an invoke instruction's normal destination.
29 ///
30 /// After versioning an invoke instruction, values coming from the original
31 /// block will now be coming from the "merge" block. For example, in the code
32 /// below:
33 ///
34 ///   then_bb:
35 ///     %t0 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
36 ///
37 ///   else_bb:
38 ///     %t1 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
39 ///
40 ///   merge_bb:
41 ///     %t2 = phi i32 [ %t0, %then_bb ], [ %t1, %else_bb ]
42 ///     br %normal_dst
43 ///
44 ///   normal_dst:
45 ///     %t3 = phi i32 [ %x, %orig_bb ], ...
46 ///
47 /// "orig_bb" is no longer a predecessor of "normal_dst", so the phi nodes in
48 /// "normal_dst" must be fixed to refer to "merge_bb":
49 ///
50 ///    normal_dst:
51 ///      %t3 = phi i32 [ %x, %merge_bb ], ...
52 ///
53 static void fixupPHINodeForNormalDest(InvokeInst *Invoke, BasicBlock *OrigBlock,
54                                       BasicBlock *MergeBlock) {
55   for (PHINode &Phi : Invoke->getNormalDest()->phis()) {
56     int Idx = Phi.getBasicBlockIndex(OrigBlock);
57     if (Idx == -1)
58       continue;
59     Phi.setIncomingBlock(Idx, MergeBlock);
60   }
61 }
62 
63 /// Fix-up phi nodes in an invoke instruction's unwind destination.
64 ///
65 /// After versioning an invoke instruction, values coming from the original
66 /// block will now be coming from either the "then" block or the "else" block.
67 /// For example, in the code below:
68 ///
69 ///   then_bb:
70 ///     %t0 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
71 ///
72 ///   else_bb:
73 ///     %t1 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
74 ///
75 ///   unwind_dst:
76 ///     %t3 = phi i32 [ %x, %orig_bb ], ...
77 ///
78 /// "orig_bb" is no longer a predecessor of "unwind_dst", so the phi nodes in
79 /// "unwind_dst" must be fixed to refer to "then_bb" and "else_bb":
80 ///
81 ///   unwind_dst:
82 ///     %t3 = phi i32 [ %x, %then_bb ], [ %x, %else_bb ], ...
83 ///
84 static void fixupPHINodeForUnwindDest(InvokeInst *Invoke, BasicBlock *OrigBlock,
85                                       BasicBlock *ThenBlock,
86                                       BasicBlock *ElseBlock) {
87   for (PHINode &Phi : Invoke->getUnwindDest()->phis()) {
88     int Idx = Phi.getBasicBlockIndex(OrigBlock);
89     if (Idx == -1)
90       continue;
91     auto *V = Phi.getIncomingValue(Idx);
92     Phi.setIncomingBlock(Idx, ThenBlock);
93     Phi.addIncoming(V, ElseBlock);
94   }
95 }
96 
97 /// Create a phi node for the returned value of a call or invoke instruction.
98 ///
99 /// After versioning a call or invoke instruction that returns a value, we have
100 /// to merge the value of the original and new instructions. We do this by
101 /// creating a phi node and replacing uses of the original instruction with this
102 /// phi node.
103 ///
104 /// For example, if \p OrigInst is defined in "else_bb" and \p NewInst is
105 /// defined in "then_bb", we create the following phi node:
106 ///
107 ///   ; Uses of the original instruction are replaced by uses of the phi node.
108 ///   %t0 = phi i32 [ %orig_inst, %else_bb ], [ %new_inst, %then_bb ],
109 ///
110 static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst,
111                              BasicBlock *MergeBlock, IRBuilder<> &Builder) {
112 
113   if (OrigInst->getType()->isVoidTy() || OrigInst->use_empty())
114     return;
115 
116   Builder.SetInsertPoint(MergeBlock, MergeBlock->begin());
117   PHINode *Phi = Builder.CreatePHI(OrigInst->getType(), 0);
118   SmallVector<User *, 16> UsersToUpdate(OrigInst->users());
119   for (User *U : UsersToUpdate)
120     U->replaceUsesOfWith(OrigInst, Phi);
121   Phi->addIncoming(OrigInst, OrigInst->getParent());
122   Phi->addIncoming(NewInst, NewInst->getParent());
123 }
124 
125 /// Cast a call or invoke instruction to the given type.
126 ///
127 /// When promoting a call site, the return type of the call site might not match
128 /// that of the callee. If this is the case, we have to cast the returned value
129 /// to the correct type. The location of the cast depends on if we have a call
130 /// or invoke instruction.
131 ///
132 /// For example, if the call instruction below requires a bitcast after
133 /// promotion:
134 ///
135 ///   orig_bb:
136 ///     %t0 = call i32 @func()
137 ///     ...
138 ///
139 /// The bitcast is placed after the call instruction:
140 ///
141 ///   orig_bb:
142 ///     ; Uses of the original return value are replaced by uses of the bitcast.
143 ///     %t0 = call i32 @func()
144 ///     %t1 = bitcast i32 %t0 to ...
145 ///     ...
146 ///
147 /// A similar transformation is performed for invoke instructions. However,
148 /// since invokes are terminating, a new block is created for the bitcast. For
149 /// example, if the invoke instruction below requires a bitcast after promotion:
150 ///
151 ///   orig_bb:
152 ///     %t0 = invoke i32 @func() to label %normal_dst unwind label %unwind_dst
153 ///
154 /// The edge between the original block and the invoke's normal destination is
155 /// split, and the bitcast is placed there:
156 ///
157 ///   orig_bb:
158 ///     %t0 = invoke i32 @func() to label %split_bb unwind label %unwind_dst
159 ///
160 ///   split_bb:
161 ///     ; Uses of the original return value are replaced by uses of the bitcast.
162 ///     %t1 = bitcast i32 %t0 to ...
163 ///     br label %normal_dst
164 ///
165 static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
166 
167   // Save the users of the calling instruction. These uses will be changed to
168   // use the bitcast after we create it.
169   SmallVector<User *, 16> UsersToUpdate(CB.users());
170 
171   // Determine an appropriate location to create the bitcast for the return
172   // value. The location depends on if we have a call or invoke instruction.
173   BasicBlock::iterator InsertBefore;
174   if (auto *Invoke = dyn_cast<InvokeInst>(&CB))
175     InsertBefore =
176         SplitEdge(Invoke->getParent(), Invoke->getNormalDest())->begin();
177   else
178     InsertBefore = std::next(CB.getIterator());
179 
180   // Bitcast the return value to the correct type.
181   auto *Cast = CastInst::CreateBitOrPointerCast(&CB, RetTy, "", InsertBefore);
182   if (RetBitCast)
183     *RetBitCast = Cast;
184 
185   // Replace all the original uses of the calling instruction with the bitcast.
186   for (User *U : UsersToUpdate)
187     U->replaceUsesOfWith(&CB, Cast);
188 }
189 
190 /// Predicate and clone the given call site.
191 ///
192 /// This function creates an if-then-else structure at the location of the call
193 /// site. The "if" condition is specified by `Cond`.
194 /// The original call site is moved into the "else" block, and a clone of the
195 /// call site is placed in the "then" block. The cloned instruction is returned.
196 ///
197 /// For example, the call instruction below:
198 ///
199 ///   orig_bb:
200 ///     %t0 = call i32 %ptr()
201 ///     ...
202 ///
203 /// Is replace by the following:
204 ///
205 ///   orig_bb:
206 ///     %cond = Cond
207 ///     br i1 %cond, %then_bb, %else_bb
208 ///
209 ///   then_bb:
210 ///     ; The clone of the original call instruction is placed in the "then"
211 ///     ; block. It is not yet promoted.
212 ///     %t1 = call i32 %ptr()
213 ///     br merge_bb
214 ///
215 ///   else_bb:
216 ///     ; The original call instruction is moved to the "else" block.
217 ///     %t0 = call i32 %ptr()
218 ///     br merge_bb
219 ///
220 ///   merge_bb:
221 ///     ; Uses of the original call instruction are replaced by uses of the phi
222 ///     ; node.
223 ///     %t2 = phi i32 [ %t0, %else_bb ], [ %t1, %then_bb ]
224 ///     ...
225 ///
226 /// A similar transformation is performed for invoke instructions. However,
227 /// since invokes are terminating, more work is required. For example, the
228 /// invoke instruction below:
229 ///
230 ///   orig_bb:
231 ///     %t0 = invoke %ptr() to label %normal_dst unwind label %unwind_dst
232 ///
233 /// Is replace by the following:
234 ///
235 ///   orig_bb:
236 ///     %cond = Cond
237 ///     br i1 %cond, %then_bb, %else_bb
238 ///
239 ///   then_bb:
240 ///     ; The clone of the original invoke instruction is placed in the "then"
241 ///     ; block, and its normal destination is set to the "merge" block. It is
242 ///     ; not yet promoted.
243 ///     %t1 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
244 ///
245 ///   else_bb:
246 ///     ; The original invoke instruction is moved into the "else" block, and
247 ///     ; its normal destination is set to the "merge" block.
248 ///     %t0 = invoke i32 %ptr() to label %merge_bb unwind label %unwind_dst
249 ///
250 ///   merge_bb:
251 ///     ; Uses of the original invoke instruction are replaced by uses of the
252 ///     ; phi node, and the merge block branches to the normal destination.
253 ///     %t2 = phi i32 [ %t0, %else_bb ], [ %t1, %then_bb ]
254 ///     br %normal_dst
255 ///
256 /// An indirect musttail call is processed slightly differently in that:
257 /// 1. No merge block needed for the orginal and the cloned callsite, since
258 ///    either one ends the flow. No phi node is needed either.
259 /// 2. The return statement following the original call site is duplicated too
260 ///    and placed immediately after the cloned call site per the IR convention.
261 ///
262 /// For example, the musttail call instruction below:
263 ///
264 ///   orig_bb:
265 ///     %t0 = musttail call i32 %ptr()
266 ///     ...
267 ///
268 /// Is replaced by the following:
269 ///
270 ///   cond_bb:
271 ///     %cond = Cond
272 ///     br i1 %cond, %then_bb, %orig_bb
273 ///
274 ///   then_bb:
275 ///     ; The clone of the original call instruction is placed in the "then"
276 ///     ; block. It is not yet promoted.
277 ///     %t1 = musttail call i32 %ptr()
278 ///     ret %t1
279 ///
280 ///   orig_bb:
281 ///     ; The original call instruction stays in its original block.
282 ///     %t0 = musttail call i32 %ptr()
283 ///     ret %t0
284 static CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond,
285                                          MDNode *BranchWeights) {
286 
287   IRBuilder<> Builder(&CB);
288   CallBase *OrigInst = &CB;
289   BasicBlock *OrigBlock = OrigInst->getParent();
290 
291   if (OrigInst->isMustTailCall()) {
292     // Create an if-then structure. The original instruction stays in its block,
293     // and a clone of the original instruction is placed in the "then" block.
294     Instruction *ThenTerm =
295         SplitBlockAndInsertIfThen(Cond, &CB, false, BranchWeights);
296     BasicBlock *ThenBlock = ThenTerm->getParent();
297     ThenBlock->setName("if.true.direct_targ");
298     CallBase *NewInst = cast<CallBase>(OrigInst->clone());
299     NewInst->insertBefore(ThenTerm);
300 
301     // Place a clone of the optional bitcast after the new call site.
302     Value *NewRetVal = NewInst;
303     auto Next = OrigInst->getNextNode();
304     if (auto *BitCast = dyn_cast_or_null<BitCastInst>(Next)) {
305       assert(BitCast->getOperand(0) == OrigInst &&
306              "bitcast following musttail call must use the call");
307       auto NewBitCast = BitCast->clone();
308       NewBitCast->replaceUsesOfWith(OrigInst, NewInst);
309       NewBitCast->insertBefore(ThenTerm);
310       NewRetVal = NewBitCast;
311       Next = BitCast->getNextNode();
312     }
313 
314     // Place a clone of the return instruction after the new call site.
315     ReturnInst *Ret = dyn_cast_or_null<ReturnInst>(Next);
316     assert(Ret && "musttail call must precede a ret with an optional bitcast");
317     auto NewRet = Ret->clone();
318     if (Ret->getReturnValue())
319       NewRet->replaceUsesOfWith(Ret->getReturnValue(), NewRetVal);
320     NewRet->insertBefore(ThenTerm);
321 
322     // A return instructions is terminating, so we don't need the terminator
323     // instruction just created.
324     ThenTerm->eraseFromParent();
325 
326     return *NewInst;
327   }
328 
329   // Create an if-then-else structure. The original instruction is moved into
330   // the "else" block, and a clone of the original instruction is placed in the
331   // "then" block.
332   Instruction *ThenTerm = nullptr;
333   Instruction *ElseTerm = nullptr;
334   SplitBlockAndInsertIfThenElse(Cond, &CB, &ThenTerm, &ElseTerm, BranchWeights);
335   BasicBlock *ThenBlock = ThenTerm->getParent();
336   BasicBlock *ElseBlock = ElseTerm->getParent();
337   BasicBlock *MergeBlock = OrigInst->getParent();
338 
339   ThenBlock->setName("if.true.direct_targ");
340   ElseBlock->setName("if.false.orig_indirect");
341   MergeBlock->setName("if.end.icp");
342 
343   CallBase *NewInst = cast<CallBase>(OrigInst->clone());
344   OrigInst->moveBefore(ElseTerm);
345   NewInst->insertBefore(ThenTerm);
346 
347   // If the original call site is an invoke instruction, we have extra work to
348   // do since invoke instructions are terminating. We have to fix-up phi nodes
349   // in the invoke's normal and unwind destinations.
350   if (auto *OrigInvoke = dyn_cast<InvokeInst>(OrigInst)) {
351     auto *NewInvoke = cast<InvokeInst>(NewInst);
352 
353     // Invoke instructions are terminating, so we don't need the terminator
354     // instructions that were just created.
355     ThenTerm->eraseFromParent();
356     ElseTerm->eraseFromParent();
357 
358     // Branch from the "merge" block to the original normal destination.
359     Builder.SetInsertPoint(MergeBlock);
360     Builder.CreateBr(OrigInvoke->getNormalDest());
361 
362     // Fix-up phi nodes in the original invoke's normal and unwind destinations.
363     fixupPHINodeForNormalDest(OrigInvoke, OrigBlock, MergeBlock);
364     fixupPHINodeForUnwindDest(OrigInvoke, MergeBlock, ThenBlock, ElseBlock);
365 
366     // Now set the normal destinations of the invoke instructions to be the
367     // "merge" block.
368     OrigInvoke->setNormalDest(MergeBlock);
369     NewInvoke->setNormalDest(MergeBlock);
370   }
371 
372   // Create a phi node for the returned value of the call site.
373   createRetPHINode(OrigInst, NewInst, MergeBlock, Builder);
374 
375   return *NewInst;
376 }
377 
378 // Predicate and clone the given call site using condition `CB.callee ==
379 // Callee`. See the comment `versionCallSiteWithCond` for the transformation.
380 CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee,
381                                 MDNode *BranchWeights) {
382 
383   IRBuilder<> Builder(&CB);
384 
385   // Create the compare. The called value and callee must have the same type to
386   // be compared.
387   if (CB.getCalledOperand()->getType() != Callee->getType())
388     Callee = Builder.CreateBitCast(Callee, CB.getCalledOperand()->getType());
389   auto *Cond = Builder.CreateICmpEQ(CB.getCalledOperand(), Callee);
390 
391   return versionCallSiteWithCond(CB, Cond, BranchWeights);
392 }
393 
394 bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee,
395                             const char **FailureReason) {
396   assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
397 
398   auto &DL = Callee->getDataLayout();
399 
400   // Check the return type. The callee's return value type must be bitcast
401   // compatible with the call site's type.
402   Type *CallRetTy = CB.getType();
403   Type *FuncRetTy = Callee->getReturnType();
404   if (CallRetTy != FuncRetTy)
405     if (!CastInst::isBitOrNoopPointerCastable(FuncRetTy, CallRetTy, DL)) {
406       if (FailureReason)
407         *FailureReason = "Return type mismatch";
408       return false;
409     }
410 
411   // The number of formal arguments of the callee.
412   unsigned NumParams = Callee->getFunctionType()->getNumParams();
413 
414   // The number of actual arguments in the call.
415   unsigned NumArgs = CB.arg_size();
416 
417   // Check the number of arguments. The callee and call site must agree on the
418   // number of arguments.
419   if (NumArgs != NumParams && !Callee->isVarArg()) {
420     if (FailureReason)
421       *FailureReason = "The number of arguments mismatch";
422     return false;
423   }
424 
425   // Check the argument types. The callee's formal argument types must be
426   // bitcast compatible with the corresponding actual argument types of the call
427   // site.
428   unsigned I = 0;
429   for (; I < NumParams; ++I) {
430     // Make sure that the callee and call agree on byval/inalloca. The types do
431     // not have to match.
432     if (Callee->hasParamAttribute(I, Attribute::ByVal) !=
433         CB.getAttributes().hasParamAttr(I, Attribute::ByVal)) {
434       if (FailureReason)
435         *FailureReason = "byval mismatch";
436       return false;
437     }
438     if (Callee->hasParamAttribute(I, Attribute::InAlloca) !=
439         CB.getAttributes().hasParamAttr(I, Attribute::InAlloca)) {
440       if (FailureReason)
441         *FailureReason = "inalloca mismatch";
442       return false;
443     }
444 
445     Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I);
446     Type *ActualTy = CB.getArgOperand(I)->getType();
447     if (FormalTy == ActualTy)
448       continue;
449     if (!CastInst::isBitOrNoopPointerCastable(ActualTy, FormalTy, DL)) {
450       if (FailureReason)
451         *FailureReason = "Argument type mismatch";
452       return false;
453     }
454 
455     // MustTail call needs stricter type match. See
456     // Verifier::verifyMustTailCall().
457     if (CB.isMustTailCall()) {
458       PointerType *PF = dyn_cast<PointerType>(FormalTy);
459       PointerType *PA = dyn_cast<PointerType>(ActualTy);
460       if (!PF || !PA || PF->getAddressSpace() != PA->getAddressSpace()) {
461         if (FailureReason)
462           *FailureReason = "Musttail call Argument type mismatch";
463         return false;
464       }
465     }
466   }
467   for (; I < NumArgs; I++) {
468     // Vararg functions can have more arguments than parameters.
469     assert(Callee->isVarArg());
470     if (CB.paramHasAttr(I, Attribute::StructRet)) {
471       if (FailureReason)
472         *FailureReason = "SRet arg to vararg function";
473       return false;
474     }
475   }
476 
477   return true;
478 }
479 
480 CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
481                             CastInst **RetBitCast) {
482   assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
483 
484   // Set the called function of the call site to be the given callee (but don't
485   // change the type).
486   CB.setCalledOperand(Callee);
487 
488   // Since the call site will no longer be direct, we must clear metadata that
489   // is only appropriate for indirect calls. This includes !prof and !callees
490   // metadata.
491   CB.setMetadata(LLVMContext::MD_prof, nullptr);
492   CB.setMetadata(LLVMContext::MD_callees, nullptr);
493 
494   // If the function type of the call site matches that of the callee, no
495   // additional work is required.
496   if (CB.getFunctionType() == Callee->getFunctionType())
497     return CB;
498 
499   // Save the return types of the call site and callee.
500   Type *CallSiteRetTy = CB.getType();
501   Type *CalleeRetTy = Callee->getReturnType();
502 
503   // Change the function type of the call site the match that of the callee.
504   CB.mutateFunctionType(Callee->getFunctionType());
505 
506   // Inspect the arguments of the call site. If an argument's type doesn't
507   // match the corresponding formal argument's type in the callee, bitcast it
508   // to the correct type.
509   auto CalleeType = Callee->getFunctionType();
510   auto CalleeParamNum = CalleeType->getNumParams();
511 
512   LLVMContext &Ctx = Callee->getContext();
513   const AttributeList &CallerPAL = CB.getAttributes();
514   // The new list of argument attributes.
515   SmallVector<AttributeSet, 4> NewArgAttrs;
516   bool AttributeChanged = false;
517 
518   for (unsigned ArgNo = 0; ArgNo < CalleeParamNum; ++ArgNo) {
519     auto *Arg = CB.getArgOperand(ArgNo);
520     Type *FormalTy = CalleeType->getParamType(ArgNo);
521     Type *ActualTy = Arg->getType();
522     if (FormalTy != ActualTy) {
523       auto *Cast =
524           CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator());
525       CB.setArgOperand(ArgNo, Cast);
526 
527       // Remove any incompatible attributes for the argument.
528       AttrBuilder ArgAttrs(Ctx, CallerPAL.getParamAttrs(ArgNo));
529       ArgAttrs.remove(AttributeFuncs::typeIncompatible(FormalTy));
530 
531       // We may have a different byval/inalloca type.
532       if (ArgAttrs.getByValType())
533         ArgAttrs.addByValAttr(Callee->getParamByValType(ArgNo));
534       if (ArgAttrs.getInAllocaType())
535         ArgAttrs.addInAllocaAttr(Callee->getParamInAllocaType(ArgNo));
536 
537       NewArgAttrs.push_back(AttributeSet::get(Ctx, ArgAttrs));
538       AttributeChanged = true;
539     } else
540       NewArgAttrs.push_back(CallerPAL.getParamAttrs(ArgNo));
541   }
542 
543   // If the return type of the call site doesn't match that of the callee, cast
544   // the returned value to the appropriate type.
545   // Remove any incompatible return value attribute.
546   AttrBuilder RAttrs(Ctx, CallerPAL.getRetAttrs());
547   if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) {
548     createRetBitCast(CB, CallSiteRetTy, RetBitCast);
549     RAttrs.remove(AttributeFuncs::typeIncompatible(CalleeRetTy));
550     AttributeChanged = true;
551   }
552 
553   // Set the new callsite attribute.
554   if (AttributeChanged)
555     CB.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttrs(),
556                                         AttributeSet::get(Ctx, RAttrs),
557                                         NewArgAttrs));
558 
559   return CB;
560 }
561 
562 CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
563                                           MDNode *BranchWeights) {
564 
565   // Version the indirect call site. If the called value is equal to the given
566   // callee, 'NewInst' will be executed, otherwise the original call site will
567   // be executed.
568   CallBase &NewInst = versionCallSite(CB, Callee, BranchWeights);
569 
570   // Promote 'NewInst' so that it directly calls the desired function.
571   return promoteCall(NewInst, Callee);
572 }
573 
574 CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
575                                          Function *Callee,
576                                          ArrayRef<Constant *> AddressPoints,
577                                          MDNode *BranchWeights) {
578   assert(!AddressPoints.empty() && "Caller should guarantee");
579   IRBuilder<> Builder(&CB);
580   SmallVector<Value *, 2> ICmps;
581   for (auto &AddressPoint : AddressPoints)
582     ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint));
583 
584   // TODO: Perform tree height reduction if the number of ICmps is high.
585   Value *Cond = Builder.CreateOr(ICmps);
586 
587   // Version the indirect call site. If Cond is true, 'NewInst' will be
588   // executed, otherwise the original call site will be executed.
589   CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights);
590 
591   // Promote 'NewInst' so that it directly calls the desired function.
592   return promoteCall(NewInst, Callee);
593 }
594 
595 bool llvm::tryPromoteCall(CallBase &CB) {
596   assert(!CB.getCalledFunction());
597   Module *M = CB.getCaller()->getParent();
598   const DataLayout &DL = M->getDataLayout();
599   Value *Callee = CB.getCalledOperand();
600 
601   LoadInst *VTableEntryLoad = dyn_cast<LoadInst>(Callee);
602   if (!VTableEntryLoad)
603     return false; // Not a vtable entry load.
604   Value *VTableEntryPtr = VTableEntryLoad->getPointerOperand();
605   APInt VTableOffset(DL.getTypeSizeInBits(VTableEntryPtr->getType()), 0);
606   Value *VTableBasePtr = VTableEntryPtr->stripAndAccumulateConstantOffsets(
607       DL, VTableOffset, /* AllowNonInbounds */ true);
608   LoadInst *VTablePtrLoad = dyn_cast<LoadInst>(VTableBasePtr);
609   if (!VTablePtrLoad)
610     return false; // Not a vtable load.
611   Value *Object = VTablePtrLoad->getPointerOperand();
612   APInt ObjectOffset(DL.getTypeSizeInBits(Object->getType()), 0);
613   Value *ObjectBase = Object->stripAndAccumulateConstantOffsets(
614       DL, ObjectOffset, /* AllowNonInbounds */ true);
615   if (!(isa<AllocaInst>(ObjectBase) && ObjectOffset == 0))
616     // Not an Alloca or the offset isn't zero.
617     return false;
618 
619   // Look for the vtable pointer store into the object by the ctor.
620   BasicBlock::iterator BBI(VTablePtrLoad);
621   Value *VTablePtr = FindAvailableLoadedValue(
622       VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr);
623   if (!VTablePtr)
624     return false; // No vtable found.
625   APInt VTableOffsetGVBase(DL.getTypeSizeInBits(VTablePtr->getType()), 0);
626   Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets(
627       DL, VTableOffsetGVBase, /* AllowNonInbounds */ true);
628   GlobalVariable *GV = dyn_cast<GlobalVariable>(VTableGVBase);
629   if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer()))
630     // Not in the form of a global constant variable with an initializer.
631     return false;
632 
633   APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset;
634   if (!(VTableGVOffset.getActiveBits() <= 64))
635     return false; // Out of range.
636 
637   Function *DirectCallee = nullptr;
638   std::tie(DirectCallee, std::ignore) =
639       getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M);
640   if (!DirectCallee)
641     return false; // No function pointer found.
642 
643   if (!isLegalToPromote(CB, DirectCallee))
644     return false;
645 
646   // Success.
647   promoteCall(CB, DirectCallee);
648   return true;
649 }
650 
651 #undef DEBUG_TYPE
652