xref: /llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp (revision e1aa1e43decf9275175845bea970ef6d7c2b1af6)
1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements whole program optimization of virtual calls in cases
10 // where we know (via !type metadata) that the list of callees is fixed. This
11 // includes the following:
12 // - Single implementation devirtualization: if a virtual call has a single
13 //   possible callee, replace all calls with a direct call to that callee.
14 // - Virtual constant propagation: if the virtual function's return type is an
15 //   integer <=64 bits and all possible callees are readnone, for each class and
16 //   each list of constant arguments: evaluate the function, store the return
17 //   value alongside the virtual table, and rewrite each virtual call as a load
18 //   from the virtual table.
19 // - Uniform return value optimization: if the conditions for virtual constant
20 //   propagation hold and each function returns the same constant value, replace
21 //   each virtual call with that constant.
22 // - Unique return value optimization for i1 return values: if the conditions
23 //   for virtual constant propagation hold and a single vtable's function
24 //   returns 0, or a single vtable's function returns 1, replace each virtual
25 //   call with a comparison of the vptr against that vtable's address.
26 //
27 // This pass is intended to be used during the regular and thin LTO pipelines:
28 //
29 // During regular LTO, the pass determines the best optimization for each
30 // virtual call and applies the resolutions directly to virtual calls that are
31 // eligible for virtual call optimization (i.e. calls that use either of the
32 // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics).
33 //
34 // During hybrid Regular/ThinLTO, the pass operates in two phases:
35 // - Export phase: this is run during the thin link over a single merged module
36 //   that contains all vtables with !type metadata that participate in the link.
37 //   The pass computes a resolution for each virtual call and stores it in the
38 //   type identifier summary.
39 // - Import phase: this is run during the thin backends over the individual
40 //   modules. The pass applies the resolutions previously computed during the
41 //   import phase to each eligible virtual call.
42 //
43 // During ThinLTO, the pass operates in two phases:
44 // - Export phase: this is run during the thin link over the index which
45 //   contains a summary of all vtables with !type metadata that participate in
46 //   the link. It computes a resolution for each virtual call and stores it in
47 //   the type identifier summary. Only single implementation devirtualization
48 //   is supported.
49 // - Import phase: (same as with hybrid case above).
50 //
51 //===----------------------------------------------------------------------===//
52 
53 #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
54 #include "llvm/ADT/ArrayRef.h"
55 #include "llvm/ADT/DenseMap.h"
56 #include "llvm/ADT/DenseMapInfo.h"
57 #include "llvm/ADT/DenseSet.h"
58 #include "llvm/ADT/MapVector.h"
59 #include "llvm/ADT/SmallVector.h"
60 #include "llvm/ADT/Statistic.h"
61 #include "llvm/Analysis/AssumptionCache.h"
62 #include "llvm/Analysis/BasicAliasAnalysis.h"
63 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
64 #include "llvm/Analysis/TypeMetadataUtils.h"
65 #include "llvm/Bitcode/BitcodeReader.h"
66 #include "llvm/Bitcode/BitcodeWriter.h"
67 #include "llvm/IR/Constants.h"
68 #include "llvm/IR/DataLayout.h"
69 #include "llvm/IR/DebugLoc.h"
70 #include "llvm/IR/DerivedTypes.h"
71 #include "llvm/IR/Dominators.h"
72 #include "llvm/IR/Function.h"
73 #include "llvm/IR/GlobalAlias.h"
74 #include "llvm/IR/GlobalVariable.h"
75 #include "llvm/IR/IRBuilder.h"
76 #include "llvm/IR/InstrTypes.h"
77 #include "llvm/IR/Instruction.h"
78 #include "llvm/IR/Instructions.h"
79 #include "llvm/IR/Intrinsics.h"
80 #include "llvm/IR/LLVMContext.h"
81 #include "llvm/IR/MDBuilder.h"
82 #include "llvm/IR/Metadata.h"
83 #include "llvm/IR/Module.h"
84 #include "llvm/IR/ModuleSummaryIndexYAML.h"
85 #include "llvm/Support/Casting.h"
86 #include "llvm/Support/CommandLine.h"
87 #include "llvm/Support/Errc.h"
88 #include "llvm/Support/Error.h"
89 #include "llvm/Support/FileSystem.h"
90 #include "llvm/Support/GlobPattern.h"
91 #include "llvm/TargetParser/Triple.h"
92 #include "llvm/Transforms/IPO.h"
93 #include "llvm/Transforms/IPO/FunctionAttrs.h"
94 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
95 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
96 #include "llvm/Transforms/Utils/Evaluator.h"
97 #include <algorithm>
98 #include <cstddef>
99 #include <map>
100 #include <set>
101 #include <string>
102 
103 using namespace llvm;
104 using namespace wholeprogramdevirt;
105 
106 #define DEBUG_TYPE "wholeprogramdevirt"
107 
108 STATISTIC(NumDevirtTargets, "Number of whole program devirtualization targets");
109 STATISTIC(NumSingleImpl, "Number of single implementation devirtualizations");
110 STATISTIC(NumBranchFunnel, "Number of branch funnels");
111 STATISTIC(NumUniformRetVal, "Number of uniform return value optimizations");
112 STATISTIC(NumUniqueRetVal, "Number of unique return value optimizations");
113 STATISTIC(NumVirtConstProp1Bit,
114           "Number of 1 bit virtual constant propagations");
115 STATISTIC(NumVirtConstProp, "Number of virtual constant propagations");
116 
117 static cl::opt<PassSummaryAction> ClSummaryAction(
118     "wholeprogramdevirt-summary-action",
119     cl::desc("What to do with the summary when running this pass"),
120     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
121                clEnumValN(PassSummaryAction::Import, "import",
122                           "Import typeid resolutions from summary and globals"),
123                clEnumValN(PassSummaryAction::Export, "export",
124                           "Export typeid resolutions to summary and globals")),
125     cl::Hidden);
126 
127 static cl::opt<std::string> ClReadSummary(
128     "wholeprogramdevirt-read-summary",
129     cl::desc(
130         "Read summary from given bitcode or YAML file before running pass"),
131     cl::Hidden);
132 
133 static cl::opt<std::string> ClWriteSummary(
134     "wholeprogramdevirt-write-summary",
135     cl::desc("Write summary to given bitcode or YAML file after running pass. "
136              "Output file format is deduced from extension: *.bc means writing "
137              "bitcode, otherwise YAML"),
138     cl::Hidden);
139 
140 static cl::opt<unsigned>
141     ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden,
142                 cl::init(10),
143                 cl::desc("Maximum number of call targets per "
144                          "call site to enable branch funnels"));
145 
146 static cl::opt<bool>
147     PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden,
148                        cl::desc("Print index-based devirtualization messages"));
149 
150 /// Provide a way to force enable whole program visibility in tests.
151 /// This is needed to support legacy tests that don't contain
152 /// !vcall_visibility metadata (the mere presense of type tests
153 /// previously implied hidden visibility).
154 static cl::opt<bool>
155     WholeProgramVisibility("whole-program-visibility", cl::Hidden,
156                            cl::desc("Enable whole program visibility"));
157 
158 /// Provide a way to force disable whole program for debugging or workarounds,
159 /// when enabled via the linker.
160 static cl::opt<bool> DisableWholeProgramVisibility(
161     "disable-whole-program-visibility", cl::Hidden,
162     cl::desc("Disable whole program visibility (overrides enabling options)"));
163 
164 /// Provide way to prevent certain function from being devirtualized
165 static cl::list<std::string>
166     SkipFunctionNames("wholeprogramdevirt-skip",
167                       cl::desc("Prevent function(s) from being devirtualized"),
168                       cl::Hidden, cl::CommaSeparated);
169 
170 /// With Clang, a pure virtual class's deleting destructor is emitted as a
171 /// `llvm.trap` intrinsic followed by an unreachable IR instruction. In the
172 /// context of whole program devirtualization, the deleting destructor of a pure
173 /// virtual class won't be invoked by the source code so safe to skip as a
174 /// devirtualize target.
175 ///
176 /// However, not all unreachable functions are safe to skip. In some cases, the
177 /// program intends to run such functions and terminate, for instance, a unit
178 /// test may run a death test. A non-test program might (or allowed to) invoke
179 /// such functions to report failures (whether/when it's a good practice or not
180 /// is a different topic).
181 ///
182 /// This option is enabled to keep an unreachable function as a possible
183 /// devirtualize target to conservatively keep the program behavior.
184 ///
185 /// TODO: Make a pure virtual class's deleting destructor precisely identifiable
186 /// in Clang's codegen for more devirtualization in LLVM.
187 static cl::opt<bool> WholeProgramDevirtKeepUnreachableFunction(
188     "wholeprogramdevirt-keep-unreachable-function",
189     cl::desc("Regard unreachable functions as possible devirtualize targets."),
190     cl::Hidden, cl::init(true));
191 
192 /// If explicitly specified, the devirt module pass will stop transformation
193 /// once the total number of devirtualizations reach the cutoff value. Setting
194 /// this option to 0 explicitly will do 0 devirtualization.
195 static cl::opt<unsigned> WholeProgramDevirtCutoff(
196     "wholeprogramdevirt-cutoff",
197     cl::desc("Max number of devirtualizations for devirt module pass"),
198     cl::init(0));
199 
200 /// Mechanism to add runtime checking of devirtualization decisions, optionally
201 /// trapping or falling back to indirect call on any that are not correct.
202 /// Trapping mode is useful for debugging undefined behavior leading to failures
203 /// with WPD. Fallback mode is useful for ensuring safety when whole program
204 /// visibility may be compromised.
205 enum WPDCheckMode { None, Trap, Fallback };
206 static cl::opt<WPDCheckMode> DevirtCheckMode(
207     "wholeprogramdevirt-check", cl::Hidden,
208     cl::desc("Type of checking for incorrect devirtualizations"),
209     cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"),
210                clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"),
211                clEnumValN(WPDCheckMode::Fallback, "fallback",
212                           "Fallback to indirect when incorrect")));
213 
214 namespace {
215 struct PatternList {
216   std::vector<GlobPattern> Patterns;
217   template <class T> void init(const T &StringList) {
218     for (const auto &S : StringList)
219       if (Expected<GlobPattern> Pat = GlobPattern::create(S))
220         Patterns.push_back(std::move(*Pat));
221   }
222   bool match(StringRef S) {
223     for (const GlobPattern &P : Patterns)
224       if (P.match(S))
225         return true;
226     return false;
227   }
228 };
229 } // namespace
230 
231 // Find the minimum offset that we may store a value of size Size bits at. If
232 // IsAfter is set, look for an offset before the object, otherwise look for an
233 // offset after the object.
234 uint64_t
235 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
236                                      bool IsAfter, uint64_t Size) {
237   // Find a minimum offset taking into account only vtable sizes.
238   uint64_t MinByte = 0;
239   for (const VirtualCallTarget &Target : Targets) {
240     if (IsAfter)
241       MinByte = std::max(MinByte, Target.minAfterBytes());
242     else
243       MinByte = std::max(MinByte, Target.minBeforeBytes());
244   }
245 
246   // Build a vector of arrays of bytes covering, for each target, a slice of the
247   // used region (see AccumBitVector::BytesUsed in
248   // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
249   // this aligns the used regions to start at MinByte.
250   //
251   // In this example, A, B and C are vtables, # is a byte already allocated for
252   // a virtual function pointer, AAAA... (etc.) are the used regions for the
253   // vtables and Offset(X) is the value computed for the Offset variable below
254   // for X.
255   //
256   //                    Offset(A)
257   //                    |       |
258   //                            |MinByte
259   // A: ################AAAAAAAA|AAAAAAAA
260   // B: ########BBBBBBBBBBBBBBBB|BBBB
261   // C: ########################|CCCCCCCCCCCCCCCC
262   //            |   Offset(B)   |
263   //
264   // This code produces the slices of A, B and C that appear after the divider
265   // at MinByte.
266   std::vector<ArrayRef<uint8_t>> Used;
267   for (const VirtualCallTarget &Target : Targets) {
268     ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
269                                        : Target.TM->Bits->Before.BytesUsed;
270     uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
271                               : MinByte - Target.minBeforeBytes();
272 
273     // Disregard used regions that are smaller than Offset. These are
274     // effectively all-free regions that do not need to be checked.
275     if (VTUsed.size() > Offset)
276       Used.push_back(VTUsed.slice(Offset));
277   }
278 
279   if (Size == 1) {
280     // Find a free bit in each member of Used.
281     for (unsigned I = 0;; ++I) {
282       uint8_t BitsUsed = 0;
283       for (auto &&B : Used)
284         if (I < B.size())
285           BitsUsed |= B[I];
286       if (BitsUsed != 0xff)
287         return (MinByte + I) * 8 + llvm::countr_zero(uint8_t(~BitsUsed));
288     }
289   } else {
290     // Find a free (Size/8) byte region in each member of Used.
291     // FIXME: see if alignment helps.
292     for (unsigned I = 0;; ++I) {
293       for (auto &&B : Used) {
294         unsigned Byte = 0;
295         while ((I + Byte) < B.size() && Byte < (Size / 8)) {
296           if (B[I + Byte])
297             goto NextI;
298           ++Byte;
299         }
300       }
301       return (MinByte + I) * 8;
302     NextI:;
303     }
304   }
305 }
306 
307 void wholeprogramdevirt::setBeforeReturnValues(
308     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
309     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
310   if (BitWidth == 1)
311     OffsetByte = -(AllocBefore / 8 + 1);
312   else
313     OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
314   OffsetBit = AllocBefore % 8;
315 
316   for (VirtualCallTarget &Target : Targets) {
317     if (BitWidth == 1)
318       Target.setBeforeBit(AllocBefore);
319     else
320       Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
321   }
322 }
323 
324 void wholeprogramdevirt::setAfterReturnValues(
325     MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
326     unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
327   if (BitWidth == 1)
328     OffsetByte = AllocAfter / 8;
329   else
330     OffsetByte = (AllocAfter + 7) / 8;
331   OffsetBit = AllocAfter % 8;
332 
333   for (VirtualCallTarget &Target : Targets) {
334     if (BitWidth == 1)
335       Target.setAfterBit(AllocAfter);
336     else
337       Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
338   }
339 }
340 
341 VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM)
342     : Fn(Fn), TM(TM),
343       IsBigEndian(Fn->getDataLayout().isBigEndian()),
344       WasDevirt(false) {}
345 
346 namespace {
347 
348 // Tracks the number of devirted calls in the IR transformation.
349 static unsigned NumDevirtCalls = 0;
350 
351 // A slot in a set of virtual tables. The TypeID identifies the set of virtual
352 // tables, and the ByteOffset is the offset in bytes from the address point to
353 // the virtual function pointer.
354 struct VTableSlot {
355   Metadata *TypeID;
356   uint64_t ByteOffset;
357 };
358 
359 } // end anonymous namespace
360 
361 namespace llvm {
362 
363 template <> struct DenseMapInfo<VTableSlot> {
364   static VTableSlot getEmptyKey() {
365     return {DenseMapInfo<Metadata *>::getEmptyKey(),
366             DenseMapInfo<uint64_t>::getEmptyKey()};
367   }
368   static VTableSlot getTombstoneKey() {
369     return {DenseMapInfo<Metadata *>::getTombstoneKey(),
370             DenseMapInfo<uint64_t>::getTombstoneKey()};
371   }
372   static unsigned getHashValue(const VTableSlot &I) {
373     return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
374            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
375   }
376   static bool isEqual(const VTableSlot &LHS,
377                       const VTableSlot &RHS) {
378     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
379   }
380 };
381 
382 template <> struct DenseMapInfo<VTableSlotSummary> {
383   static VTableSlotSummary getEmptyKey() {
384     return {DenseMapInfo<StringRef>::getEmptyKey(),
385             DenseMapInfo<uint64_t>::getEmptyKey()};
386   }
387   static VTableSlotSummary getTombstoneKey() {
388     return {DenseMapInfo<StringRef>::getTombstoneKey(),
389             DenseMapInfo<uint64_t>::getTombstoneKey()};
390   }
391   static unsigned getHashValue(const VTableSlotSummary &I) {
392     return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^
393            DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
394   }
395   static bool isEqual(const VTableSlotSummary &LHS,
396                       const VTableSlotSummary &RHS) {
397     return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
398   }
399 };
400 
401 } // end namespace llvm
402 
403 // Returns true if the function must be unreachable based on ValueInfo.
404 //
405 // In particular, identifies a function as unreachable in the following
406 // conditions
407 //   1) All summaries are live.
408 //   2) All function summaries indicate it's unreachable
409 //   3) There is no non-function with the same GUID (which is rare)
410 static bool mustBeUnreachableFunction(ValueInfo TheFnVI) {
411   if (WholeProgramDevirtKeepUnreachableFunction)
412     return false;
413 
414   if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) {
415     // Returns false if ValueInfo is absent, or the summary list is empty
416     // (e.g., function declarations).
417     return false;
418   }
419 
420   for (const auto &Summary : TheFnVI.getSummaryList()) {
421     // Conservatively returns false if any non-live functions are seen.
422     // In general either all summaries should be live or all should be dead.
423     if (!Summary->isLive())
424       return false;
425     if (auto *FS = dyn_cast<FunctionSummary>(Summary->getBaseObject())) {
426       if (!FS->fflags().MustBeUnreachable)
427         return false;
428     }
429     // Be conservative if a non-function has the same GUID (which is rare).
430     else
431       return false;
432   }
433   // All function summaries are live and all of them agree that the function is
434   // unreachble.
435   return true;
436 }
437 
438 namespace {
439 // A virtual call site. VTable is the loaded virtual table pointer, and CS is
440 // the indirect virtual call.
441 struct VirtualCallSite {
442   Value *VTable = nullptr;
443   CallBase &CB;
444 
445   // If non-null, this field points to the associated unsafe use count stored in
446   // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
447   // of that field for details.
448   unsigned *NumUnsafeUses = nullptr;
449 
450   void
451   emitRemark(const StringRef OptName, const StringRef TargetName,
452              function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
453     Function *F = CB.getCaller();
454     DebugLoc DLoc = CB.getDebugLoc();
455     BasicBlock *Block = CB.getParent();
456 
457     using namespace ore;
458     OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
459                       << NV("Optimization", OptName)
460                       << ": devirtualized a call to "
461                       << NV("FunctionName", TargetName));
462   }
463 
464   void replaceAndErase(
465       const StringRef OptName, const StringRef TargetName, bool RemarksEnabled,
466       function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
467       Value *New) {
468     if (RemarksEnabled)
469       emitRemark(OptName, TargetName, OREGetter);
470     CB.replaceAllUsesWith(New);
471     if (auto *II = dyn_cast<InvokeInst>(&CB)) {
472       BranchInst::Create(II->getNormalDest(), CB.getIterator());
473       II->getUnwindDest()->removePredecessor(II->getParent());
474     }
475     CB.eraseFromParent();
476     // This use is no longer unsafe.
477     if (NumUnsafeUses)
478       --*NumUnsafeUses;
479   }
480 };
481 
482 // Call site information collected for a specific VTableSlot and possibly a list
483 // of constant integer arguments. The grouping by arguments is handled by the
484 // VTableSlotInfo class.
485 struct CallSiteInfo {
486   /// The set of call sites for this slot. Used during regular LTO and the
487   /// import phase of ThinLTO (as well as the export phase of ThinLTO for any
488   /// call sites that appear in the merged module itself); in each of these
489   /// cases we are directly operating on the call sites at the IR level.
490   std::vector<VirtualCallSite> CallSites;
491 
492   /// Whether all call sites represented by this CallSiteInfo, including those
493   /// in summaries, have been devirtualized. This starts off as true because a
494   /// default constructed CallSiteInfo represents no call sites.
495   bool AllCallSitesDevirted = true;
496 
497   // These fields are used during the export phase of ThinLTO and reflect
498   // information collected from function summaries.
499 
500   /// Whether any function summary contains an llvm.assume(llvm.type.test) for
501   /// this slot.
502   bool SummaryHasTypeTestAssumeUsers = false;
503 
504   /// CFI-specific: a vector containing the list of function summaries that use
505   /// the llvm.type.checked.load intrinsic and therefore will require
506   /// resolutions for llvm.type.test in order to implement CFI checks if
507   /// devirtualization was unsuccessful. If devirtualization was successful, the
508   /// pass will clear this vector by calling markDevirt(). If at the end of the
509   /// pass the vector is non-empty, we will need to add a use of llvm.type.test
510   /// to each of the function summaries in the vector.
511   std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers;
512   std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers;
513 
514   bool isExported() const {
515     return SummaryHasTypeTestAssumeUsers ||
516            !SummaryTypeCheckedLoadUsers.empty();
517   }
518 
519   void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
520     SummaryTypeCheckedLoadUsers.push_back(FS);
521     AllCallSitesDevirted = false;
522   }
523 
524   void addSummaryTypeTestAssumeUser(FunctionSummary *FS) {
525     SummaryTypeTestAssumeUsers.push_back(FS);
526     SummaryHasTypeTestAssumeUsers = true;
527     AllCallSitesDevirted = false;
528   }
529 
530   void markDevirt() {
531     AllCallSitesDevirted = true;
532 
533     // As explained in the comment for SummaryTypeCheckedLoadUsers.
534     SummaryTypeCheckedLoadUsers.clear();
535   }
536 };
537 
538 // Call site information collected for a specific VTableSlot.
539 struct VTableSlotInfo {
540   // The set of call sites which do not have all constant integer arguments
541   // (excluding "this").
542   CallSiteInfo CSInfo;
543 
544   // The set of call sites with all constant integer arguments (excluding
545   // "this"), grouped by argument list.
546   std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
547 
548   void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses);
549 
550 private:
551   CallSiteInfo &findCallSiteInfo(CallBase &CB);
552 };
553 
554 CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) {
555   std::vector<uint64_t> Args;
556   auto *CBType = dyn_cast<IntegerType>(CB.getType());
557   if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty())
558     return CSInfo;
559   for (auto &&Arg : drop_begin(CB.args())) {
560     auto *CI = dyn_cast<ConstantInt>(Arg);
561     if (!CI || CI->getBitWidth() > 64)
562       return CSInfo;
563     Args.push_back(CI->getZExtValue());
564   }
565   return ConstCSInfo[Args];
566 }
567 
568 void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB,
569                                  unsigned *NumUnsafeUses) {
570   auto &CSI = findCallSiteInfo(CB);
571   CSI.AllCallSitesDevirted = false;
572   CSI.CallSites.push_back({VTable, CB, NumUnsafeUses});
573 }
574 
575 struct DevirtModule {
576   Module &M;
577   function_ref<AAResults &(Function &)> AARGetter;
578   function_ref<DominatorTree &(Function &)> LookupDomTree;
579 
580   ModuleSummaryIndex *ExportSummary;
581   const ModuleSummaryIndex *ImportSummary;
582 
583   IntegerType *Int8Ty;
584   PointerType *Int8PtrTy;
585   IntegerType *Int32Ty;
586   IntegerType *Int64Ty;
587   IntegerType *IntPtrTy;
588   /// Sizeless array type, used for imported vtables. This provides a signal
589   /// to analyzers that these imports may alias, as they do for example
590   /// when multiple unique return values occur in the same vtable.
591   ArrayType *Int8Arr0Ty;
592 
593   bool RemarksEnabled;
594   function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter;
595 
596   MapVector<VTableSlot, VTableSlotInfo> CallSlots;
597 
598   // Calls that have already been optimized. We may add a call to multiple
599   // VTableSlotInfos if vtable loads are coalesced and need to make sure not to
600   // optimize a call more than once.
601   SmallPtrSet<CallBase *, 8> OptimizedCalls;
602 
603   // Store calls that had their ptrauth bundle removed. They are to be deleted
604   // at the end of the optimization.
605   SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved;
606 
607   // This map keeps track of the number of "unsafe" uses of a loaded function
608   // pointer. The key is the associated llvm.type.test intrinsic call generated
609   // by this pass. An unsafe use is one that calls the loaded function pointer
610   // directly. Every time we eliminate an unsafe use (for example, by
611   // devirtualizing it or by applying virtual constant propagation), we
612   // decrement the value stored in this map. If a value reaches zero, we can
613   // eliminate the type check by RAUWing the associated llvm.type.test call with
614   // true.
615   std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
616   PatternList FunctionsToSkip;
617 
618   DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
619                function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
620                function_ref<DominatorTree &(Function &)> LookupDomTree,
621                ModuleSummaryIndex *ExportSummary,
622                const ModuleSummaryIndex *ImportSummary)
623       : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree),
624         ExportSummary(ExportSummary), ImportSummary(ImportSummary),
625         Int8Ty(Type::getInt8Ty(M.getContext())),
626         Int8PtrTy(PointerType::getUnqual(M.getContext())),
627         Int32Ty(Type::getInt32Ty(M.getContext())),
628         Int64Ty(Type::getInt64Ty(M.getContext())),
629         IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)),
630         Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)),
631         RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) {
632     assert(!(ExportSummary && ImportSummary));
633     FunctionsToSkip.init(SkipFunctionNames);
634   }
635 
636   bool areRemarksEnabled();
637 
638   void
639   scanTypeTestUsers(Function *TypeTestFunc,
640                     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
641   void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
642 
643   void buildTypeIdentifierMap(
644       std::vector<VTableBits> &Bits,
645       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
646 
647   bool
648   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
649                             const std::set<TypeMemberInfo> &TypeMemberInfos,
650                             uint64_t ByteOffset,
651                             ModuleSummaryIndex *ExportSummary);
652 
653   void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn,
654                              bool &IsExported);
655   bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary,
656                            MutableArrayRef<VirtualCallTarget> TargetsForSlot,
657                            VTableSlotInfo &SlotInfo,
658                            WholeProgramDevirtResolution *Res);
659 
660   void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
661                               bool &IsExported);
662   void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
663                             VTableSlotInfo &SlotInfo,
664                             WholeProgramDevirtResolution *Res, VTableSlot Slot);
665 
666   bool tryEvaluateFunctionsWithArgs(
667       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
668       ArrayRef<uint64_t> Args);
669 
670   void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
671                              uint64_t TheRetVal);
672   bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
673                            CallSiteInfo &CSInfo,
674                            WholeProgramDevirtResolution::ByArg *Res);
675 
676   // Returns the global symbol name that is used to export information about the
677   // given vtable slot and list of arguments.
678   std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args,
679                             StringRef Name);
680 
681   bool shouldExportConstantsAsAbsoluteSymbols();
682 
683   // This function is called during the export phase to create a symbol
684   // definition containing information about the given vtable slot and list of
685   // arguments.
686   void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
687                     Constant *C);
688   void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name,
689                       uint32_t Const, uint32_t &Storage);
690 
691   // This function is called during the import phase to create a reference to
692   // the symbol definition created during the export phase.
693   Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
694                          StringRef Name);
695   Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
696                            StringRef Name, IntegerType *IntTy,
697                            uint32_t Storage);
698 
699   Constant *getMemberAddr(const TypeMemberInfo *M);
700 
701   void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
702                             Constant *UniqueMemberAddr);
703   bool tryUniqueRetValOpt(unsigned BitWidth,
704                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
705                           CallSiteInfo &CSInfo,
706                           WholeProgramDevirtResolution::ByArg *Res,
707                           VTableSlot Slot, ArrayRef<uint64_t> Args);
708 
709   void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
710                              Constant *Byte, Constant *Bit);
711   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
712                            VTableSlotInfo &SlotInfo,
713                            WholeProgramDevirtResolution *Res, VTableSlot Slot);
714 
715   void rebuildGlobal(VTableBits &B);
716 
717   // Apply the summary resolution for Slot to all virtual calls in SlotInfo.
718   void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo);
719 
720   // If we were able to eliminate all unsafe uses for a type checked load,
721   // eliminate the associated type tests by replacing them with true.
722   void removeRedundantTypeTests();
723 
724   bool run();
725 
726   // Look up the corresponding ValueInfo entry of `TheFn` in `ExportSummary`.
727   //
728   // Caller guarantees that `ExportSummary` is not nullptr.
729   static ValueInfo lookUpFunctionValueInfo(Function *TheFn,
730                                            ModuleSummaryIndex *ExportSummary);
731 
732   // Returns true if the function definition must be unreachable.
733   //
734   // Note if this helper function returns true, `F` is guaranteed
735   // to be unreachable; if it returns false, `F` might still
736   // be unreachable but not covered by this helper function.
737   //
738   // Implementation-wise, if function definition is present, IR is analyzed; if
739   // not, look up function flags from ExportSummary as a fallback.
740   static bool mustBeUnreachableFunction(Function *const F,
741                                         ModuleSummaryIndex *ExportSummary);
742 
743   // Lower the module using the action and summary passed as command line
744   // arguments. For testing purposes only.
745   static bool
746   runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter,
747                 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
748                 function_ref<DominatorTree &(Function &)> LookupDomTree);
749 };
750 
751 struct DevirtIndex {
752   ModuleSummaryIndex &ExportSummary;
753   // The set in which to record GUIDs exported from their module by
754   // devirtualization, used by client to ensure they are not internalized.
755   std::set<GlobalValue::GUID> &ExportedGUIDs;
756   // A map in which to record the information necessary to locate the WPD
757   // resolution for local targets in case they are exported by cross module
758   // importing.
759   std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap;
760 
761   MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots;
762 
763   PatternList FunctionsToSkip;
764 
765   DevirtIndex(
766       ModuleSummaryIndex &ExportSummary,
767       std::set<GlobalValue::GUID> &ExportedGUIDs,
768       std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap)
769       : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs),
770         LocalWPDTargetsMap(LocalWPDTargetsMap) {
771     FunctionsToSkip.init(SkipFunctionNames);
772   }
773 
774   bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot,
775                                  const TypeIdCompatibleVtableInfo TIdInfo,
776                                  uint64_t ByteOffset);
777 
778   bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
779                            VTableSlotSummary &SlotSummary,
780                            VTableSlotInfo &SlotInfo,
781                            WholeProgramDevirtResolution *Res,
782                            std::set<ValueInfo> &DevirtTargets);
783 
784   void run();
785 };
786 } // end anonymous namespace
787 
788 PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
789                                               ModuleAnalysisManager &AM) {
790   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
791   auto AARGetter = [&](Function &F) -> AAResults & {
792     return FAM.getResult<AAManager>(F);
793   };
794   auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
795     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
796   };
797   auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
798     return FAM.getResult<DominatorTreeAnalysis>(F);
799   };
800   if (UseCommandLine) {
801     if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree))
802       return PreservedAnalyses::all();
803     return PreservedAnalyses::none();
804   }
805   if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary,
806                     ImportSummary)
807            .run())
808     return PreservedAnalyses::all();
809   return PreservedAnalyses::none();
810 }
811 
812 // Enable whole program visibility if enabled by client (e.g. linker) or
813 // internal option, and not force disabled.
814 bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) {
815   return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) &&
816          !DisableWholeProgramVisibility;
817 }
818 
819 static bool
820 typeIDVisibleToRegularObj(StringRef TypeID,
821                           function_ref<bool(StringRef)> IsVisibleToRegularObj) {
822   // TypeID for member function pointer type is an internal construct
823   // and won't exist in IsVisibleToRegularObj. The full TypeID
824   // will be present and participate in invalidation.
825   if (TypeID.ends_with(".virtual"))
826     return false;
827 
828   // TypeID that doesn't start with Itanium mangling (_ZTS) will be
829   // non-externally visible types which cannot interact with
830   // external native files. See CodeGenModule::CreateMetadataIdentifierImpl.
831   if (!TypeID.consume_front("_ZTS"))
832     return false;
833 
834   // TypeID is keyed off the type name symbol (_ZTS). However, the native
835   // object may not contain this symbol if it does not contain a key
836   // function for the base type and thus only contains a reference to the
837   // type info (_ZTI). To catch this case we query using the type info
838   // symbol corresponding to the TypeID.
839   std::string typeInfo = ("_ZTI" + TypeID).str();
840   return IsVisibleToRegularObj(typeInfo);
841 }
842 
843 static bool
844 skipUpdateDueToValidation(GlobalVariable &GV,
845                           function_ref<bool(StringRef)> IsVisibleToRegularObj) {
846   SmallVector<MDNode *, 2> Types;
847   GV.getMetadata(LLVMContext::MD_type, Types);
848 
849   for (auto Type : Types)
850     if (auto *TypeID = dyn_cast<MDString>(Type->getOperand(1).get()))
851       return typeIDVisibleToRegularObj(TypeID->getString(),
852                                        IsVisibleToRegularObj);
853 
854   return false;
855 }
856 
857 /// If whole program visibility asserted, then upgrade all public vcall
858 /// visibility metadata on vtable definitions to linkage unit visibility in
859 /// Module IR (for regular or hybrid LTO).
860 void llvm::updateVCallVisibilityInModule(
861     Module &M, bool WholeProgramVisibilityEnabledInLTO,
862     const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
863     bool ValidateAllVtablesHaveTypeInfos,
864     function_ref<bool(StringRef)> IsVisibleToRegularObj) {
865   if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
866     return;
867   for (GlobalVariable &GV : M.globals()) {
868     // Add linkage unit visibility to any variable with type metadata, which are
869     // the vtable definitions. We won't have an existing vcall_visibility
870     // metadata on vtable definitions with public visibility.
871     if (GV.hasMetadata(LLVMContext::MD_type) &&
872         GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic &&
873         // Don't upgrade the visibility for symbols exported to the dynamic
874         // linker, as we have no information on their eventual use.
875         !DynamicExportSymbols.count(GV.getGUID()) &&
876         // With validation enabled, we want to exclude symbols visible to
877         // regular objects. Local symbols will be in this group due to the
878         // current implementation but those with VCallVisibilityTranslationUnit
879         // will have already been marked in clang so are unaffected.
880         !(ValidateAllVtablesHaveTypeInfos &&
881           skipUpdateDueToValidation(GV, IsVisibleToRegularObj)))
882       GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit);
883   }
884 }
885 
886 void llvm::updatePublicTypeTestCalls(Module &M,
887                                      bool WholeProgramVisibilityEnabledInLTO) {
888   Function *PublicTypeTestFunc =
889       Intrinsic::getDeclarationIfExists(&M, Intrinsic::public_type_test);
890   if (!PublicTypeTestFunc)
891     return;
892   if (hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) {
893     Function *TypeTestFunc =
894         Intrinsic::getOrInsertDeclaration(&M, Intrinsic::type_test);
895     for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) {
896       auto *CI = cast<CallInst>(U.getUser());
897       auto *NewCI = CallInst::Create(
898           TypeTestFunc, {CI->getArgOperand(0), CI->getArgOperand(1)}, {}, "",
899           CI->getIterator());
900       CI->replaceAllUsesWith(NewCI);
901       CI->eraseFromParent();
902     }
903   } else {
904     auto *True = ConstantInt::getTrue(M.getContext());
905     for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) {
906       auto *CI = cast<CallInst>(U.getUser());
907       CI->replaceAllUsesWith(True);
908       CI->eraseFromParent();
909     }
910   }
911 }
912 
913 /// Based on typeID string, get all associated vtable GUIDS that are
914 /// visible to regular objects.
915 void llvm::getVisibleToRegularObjVtableGUIDs(
916     ModuleSummaryIndex &Index,
917     DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols,
918     function_ref<bool(StringRef)> IsVisibleToRegularObj) {
919   for (const auto &typeID : Index.typeIdCompatibleVtableMap()) {
920     if (typeIDVisibleToRegularObj(typeID.first, IsVisibleToRegularObj))
921       for (const TypeIdOffsetVtableInfo &P : typeID.second)
922         VisibleToRegularObjSymbols.insert(P.VTableVI.getGUID());
923   }
924 }
925 
926 /// If whole program visibility asserted, then upgrade all public vcall
927 /// visibility metadata on vtable definition summaries to linkage unit
928 /// visibility in Module summary index (for ThinLTO).
929 void llvm::updateVCallVisibilityInIndex(
930     ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO,
931     const DenseSet<GlobalValue::GUID> &DynamicExportSymbols,
932     const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) {
933   if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO))
934     return;
935   for (auto &P : Index) {
936     // Don't upgrade the visibility for symbols exported to the dynamic
937     // linker, as we have no information on their eventual use.
938     if (DynamicExportSymbols.count(P.first))
939       continue;
940     for (auto &S : P.second.SummaryList) {
941       auto *GVar = dyn_cast<GlobalVarSummary>(S.get());
942       if (!GVar ||
943           GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
944         continue;
945       // With validation enabled, we want to exclude symbols visible to regular
946       // objects. Local symbols will be in this group due to the current
947       // implementation but those with VCallVisibilityTranslationUnit will have
948       // already been marked in clang so are unaffected.
949       if (VisibleToRegularObjSymbols.count(P.first))
950         continue;
951       GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
952     }
953   }
954 }
955 
956 void llvm::runWholeProgramDevirtOnIndex(
957     ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs,
958     std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
959   DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run();
960 }
961 
962 void llvm::updateIndexWPDForExports(
963     ModuleSummaryIndex &Summary,
964     function_ref<bool(StringRef, ValueInfo)> isExported,
965     std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) {
966   for (auto &T : LocalWPDTargetsMap) {
967     auto &VI = T.first;
968     // This was enforced earlier during trySingleImplDevirt.
969     assert(VI.getSummaryList().size() == 1 &&
970            "Devirt of local target has more than one copy");
971     auto &S = VI.getSummaryList()[0];
972     if (!isExported(S->modulePath(), VI))
973       continue;
974 
975     // It's been exported by a cross module import.
976     for (auto &SlotSummary : T.second) {
977       auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID);
978       assert(TIdSum);
979       auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset);
980       assert(WPDRes != TIdSum->WPDRes.end());
981       WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
982           WPDRes->second.SingleImplName,
983           Summary.getModuleHash(S->modulePath()));
984     }
985   }
986 }
987 
988 static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) {
989   // Check that summary index contains regular LTO module when performing
990   // export to prevent occasional use of index from pure ThinLTO compilation
991   // (-fno-split-lto-module). This kind of summary index is passed to
992   // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting.
993   const auto &ModPaths = Summary->modulePaths();
994   if (ClSummaryAction != PassSummaryAction::Import &&
995       !ModPaths.contains(ModuleSummaryIndex::getRegularLTOModuleName()))
996     return createStringError(
997         errc::invalid_argument,
998         "combined summary should contain Regular LTO module");
999   return ErrorSuccess();
1000 }
1001 
1002 bool DevirtModule::runForTesting(
1003     Module &M, function_ref<AAResults &(Function &)> AARGetter,
1004     function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
1005     function_ref<DominatorTree &(Function &)> LookupDomTree) {
1006   std::unique_ptr<ModuleSummaryIndex> Summary =
1007       std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false);
1008 
1009   // Handle the command-line summary arguments. This code is for testing
1010   // purposes only, so we handle errors directly.
1011   if (!ClReadSummary.empty()) {
1012     ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary +
1013                           ": ");
1014     auto ReadSummaryFile =
1015         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
1016     if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr =
1017             getModuleSummaryIndex(*ReadSummaryFile)) {
1018       Summary = std::move(*SummaryOrErr);
1019       ExitOnErr(checkCombinedSummaryForTesting(Summary.get()));
1020     } else {
1021       // Try YAML if we've failed with bitcode.
1022       consumeError(SummaryOrErr.takeError());
1023       yaml::Input In(ReadSummaryFile->getBuffer());
1024       In >> *Summary;
1025       ExitOnErr(errorCodeToError(In.error()));
1026     }
1027   }
1028 
1029   bool Changed =
1030       DevirtModule(M, AARGetter, OREGetter, LookupDomTree,
1031                    ClSummaryAction == PassSummaryAction::Export ? Summary.get()
1032                                                                 : nullptr,
1033                    ClSummaryAction == PassSummaryAction::Import ? Summary.get()
1034                                                                 : nullptr)
1035           .run();
1036 
1037   if (!ClWriteSummary.empty()) {
1038     ExitOnError ExitOnErr(
1039         "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": ");
1040     std::error_code EC;
1041     if (StringRef(ClWriteSummary).ends_with(".bc")) {
1042       raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None);
1043       ExitOnErr(errorCodeToError(EC));
1044       writeIndexToFile(*Summary, OS);
1045     } else {
1046       raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF);
1047       ExitOnErr(errorCodeToError(EC));
1048       yaml::Output Out(OS);
1049       Out << *Summary;
1050     }
1051   }
1052 
1053   return Changed;
1054 }
1055 
1056 void DevirtModule::buildTypeIdentifierMap(
1057     std::vector<VTableBits> &Bits,
1058     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
1059   DenseMap<GlobalVariable *, VTableBits *> GVToBits;
1060   Bits.reserve(M.global_size());
1061   SmallVector<MDNode *, 2> Types;
1062   for (GlobalVariable &GV : M.globals()) {
1063     Types.clear();
1064     GV.getMetadata(LLVMContext::MD_type, Types);
1065     if (GV.isDeclaration() || Types.empty())
1066       continue;
1067 
1068     VTableBits *&BitsPtr = GVToBits[&GV];
1069     if (!BitsPtr) {
1070       Bits.emplace_back();
1071       Bits.back().GV = &GV;
1072       Bits.back().ObjectSize =
1073           M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
1074       BitsPtr = &Bits.back();
1075     }
1076 
1077     for (MDNode *Type : Types) {
1078       auto TypeID = Type->getOperand(1).get();
1079 
1080       uint64_t Offset =
1081           cast<ConstantInt>(
1082               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
1083               ->getZExtValue();
1084 
1085       TypeIdMap[TypeID].insert({BitsPtr, Offset});
1086     }
1087   }
1088 }
1089 
1090 bool DevirtModule::tryFindVirtualCallTargets(
1091     std::vector<VirtualCallTarget> &TargetsForSlot,
1092     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset,
1093     ModuleSummaryIndex *ExportSummary) {
1094   for (const TypeMemberInfo &TM : TypeMemberInfos) {
1095     if (!TM.Bits->GV->isConstant())
1096       return false;
1097 
1098     // We cannot perform whole program devirtualization analysis on a vtable
1099     // with public LTO visibility.
1100     if (TM.Bits->GV->getVCallVisibility() ==
1101         GlobalObject::VCallVisibilityPublic)
1102       return false;
1103 
1104     Function *Fn = nullptr;
1105     Constant *C = nullptr;
1106     std::tie(Fn, C) =
1107         getFunctionAtVTableOffset(TM.Bits->GV, TM.Offset + ByteOffset, M);
1108 
1109     if (!Fn)
1110       return false;
1111 
1112     if (FunctionsToSkip.match(Fn->getName()))
1113       return false;
1114 
1115     // We can disregard __cxa_pure_virtual as a possible call target, as
1116     // calls to pure virtuals are UB.
1117     if (Fn->getName() == "__cxa_pure_virtual")
1118       continue;
1119 
1120     // We can disregard unreachable functions as possible call targets, as
1121     // unreachable functions shouldn't be called.
1122     if (mustBeUnreachableFunction(Fn, ExportSummary))
1123       continue;
1124 
1125     // Save the symbol used in the vtable to use as the devirtualization
1126     // target.
1127     auto GV = dyn_cast<GlobalValue>(C);
1128     assert(GV);
1129     TargetsForSlot.push_back({GV, &TM});
1130   }
1131 
1132   // Give up if we couldn't find any targets.
1133   return !TargetsForSlot.empty();
1134 }
1135 
1136 bool DevirtIndex::tryFindVirtualCallTargets(
1137     std::vector<ValueInfo> &TargetsForSlot,
1138     const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) {
1139   for (const TypeIdOffsetVtableInfo &P : TIdInfo) {
1140     // Find a representative copy of the vtable initializer.
1141     // We can have multiple available_externally, linkonce_odr and weak_odr
1142     // vtable initializers. We can also have multiple external vtable
1143     // initializers in the case of comdats, which we cannot check here.
1144     // The linker should give an error in this case.
1145     //
1146     // Also, handle the case of same-named local Vtables with the same path
1147     // and therefore the same GUID. This can happen if there isn't enough
1148     // distinguishing path when compiling the source file. In that case we
1149     // conservatively return false early.
1150     const GlobalVarSummary *VS = nullptr;
1151     bool LocalFound = false;
1152     for (const auto &S : P.VTableVI.getSummaryList()) {
1153       if (GlobalValue::isLocalLinkage(S->linkage())) {
1154         if (LocalFound)
1155           return false;
1156         LocalFound = true;
1157       }
1158       auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject());
1159       if (!CurVS->vTableFuncs().empty() ||
1160           // Previously clang did not attach the necessary type metadata to
1161           // available_externally vtables, in which case there would not
1162           // be any vtable functions listed in the summary and we need
1163           // to treat this case conservatively (in case the bitcode is old).
1164           // However, we will also not have any vtable functions in the
1165           // case of a pure virtual base class. In that case we do want
1166           // to set VS to avoid treating it conservatively.
1167           !GlobalValue::isAvailableExternallyLinkage(S->linkage())) {
1168         VS = CurVS;
1169         // We cannot perform whole program devirtualization analysis on a vtable
1170         // with public LTO visibility.
1171         if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic)
1172           return false;
1173       }
1174     }
1175     // There will be no VS if all copies are available_externally having no
1176     // type metadata. In that case we can't safely perform WPD.
1177     if (!VS)
1178       return false;
1179     if (!VS->isLive())
1180       continue;
1181     for (auto VTP : VS->vTableFuncs()) {
1182       if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset)
1183         continue;
1184 
1185       if (mustBeUnreachableFunction(VTP.FuncVI))
1186         continue;
1187 
1188       TargetsForSlot.push_back(VTP.FuncVI);
1189     }
1190   }
1191 
1192   // Give up if we couldn't find any targets.
1193   return !TargetsForSlot.empty();
1194 }
1195 
1196 void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
1197                                          Constant *TheFn, bool &IsExported) {
1198   // Don't devirtualize function if we're told to skip it
1199   // in -wholeprogramdevirt-skip.
1200   if (FunctionsToSkip.match(TheFn->stripPointerCasts()->getName()))
1201     return;
1202   auto Apply = [&](CallSiteInfo &CSInfo) {
1203     for (auto &&VCallSite : CSInfo.CallSites) {
1204       if (!OptimizedCalls.insert(&VCallSite.CB).second)
1205         continue;
1206 
1207       // Stop when the number of devirted calls reaches the cutoff.
1208       if (WholeProgramDevirtCutoff.getNumOccurrences() > 0 &&
1209           NumDevirtCalls >= WholeProgramDevirtCutoff)
1210         return;
1211 
1212       if (RemarksEnabled)
1213         VCallSite.emitRemark("single-impl",
1214                              TheFn->stripPointerCasts()->getName(), OREGetter);
1215       NumSingleImpl++;
1216       NumDevirtCalls++;
1217       auto &CB = VCallSite.CB;
1218       assert(!CB.getCalledFunction() && "devirtualizing direct call?");
1219       IRBuilder<> Builder(&CB);
1220       Value *Callee =
1221           Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType());
1222 
1223       // If trap checking is enabled, add support to compare the virtual
1224       // function pointer to the devirtualized target. In case of a mismatch,
1225       // perform a debug trap.
1226       if (DevirtCheckMode == WPDCheckMode::Trap) {
1227         auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee);
1228         Instruction *ThenTerm = SplitBlockAndInsertIfThen(
1229             Cond, &CB, /*Unreachable=*/false,
1230             MDBuilder(M.getContext()).createUnlikelyBranchWeights());
1231         Builder.SetInsertPoint(ThenTerm);
1232         Function *TrapFn =
1233             Intrinsic::getOrInsertDeclaration(&M, Intrinsic::debugtrap);
1234         auto *CallTrap = Builder.CreateCall(TrapFn);
1235         CallTrap->setDebugLoc(CB.getDebugLoc());
1236       }
1237 
1238       // If fallback checking is enabled, add support to compare the virtual
1239       // function pointer to the devirtualized target. In case of a mismatch,
1240       // fall back to indirect call.
1241       if (DevirtCheckMode == WPDCheckMode::Fallback) {
1242         MDNode *Weights = MDBuilder(M.getContext()).createLikelyBranchWeights();
1243         // Version the indirect call site. If the called value is equal to the
1244         // given callee, 'NewInst' will be executed, otherwise the original call
1245         // site will be executed.
1246         CallBase &NewInst = versionCallSite(CB, Callee, Weights);
1247         NewInst.setCalledOperand(Callee);
1248         // Since the new call site is direct, we must clear metadata that
1249         // is only appropriate for indirect calls. This includes !prof and
1250         // !callees metadata.
1251         NewInst.setMetadata(LLVMContext::MD_prof, nullptr);
1252         NewInst.setMetadata(LLVMContext::MD_callees, nullptr);
1253         // Additionally, we should remove them from the fallback indirect call,
1254         // so that we don't attempt to perform indirect call promotion later.
1255         CB.setMetadata(LLVMContext::MD_prof, nullptr);
1256         CB.setMetadata(LLVMContext::MD_callees, nullptr);
1257       }
1258 
1259       // In either trapping or non-checking mode, devirtualize original call.
1260       else {
1261         // Devirtualize unconditionally.
1262         CB.setCalledOperand(Callee);
1263         // Since the call site is now direct, we must clear metadata that
1264         // is only appropriate for indirect calls. This includes !prof and
1265         // !callees metadata.
1266         CB.setMetadata(LLVMContext::MD_prof, nullptr);
1267         CB.setMetadata(LLVMContext::MD_callees, nullptr);
1268         if (CB.getCalledOperand() &&
1269             CB.getOperandBundle(LLVMContext::OB_ptrauth)) {
1270           auto *NewCS = CallBase::removeOperandBundle(
1271               &CB, LLVMContext::OB_ptrauth, CB.getIterator());
1272           CB.replaceAllUsesWith(NewCS);
1273           // Schedule for deletion at the end of pass run.
1274           CallsWithPtrAuthBundleRemoved.push_back(&CB);
1275         }
1276       }
1277 
1278       // This use is no longer unsafe.
1279       if (VCallSite.NumUnsafeUses)
1280         --*VCallSite.NumUnsafeUses;
1281     }
1282     if (CSInfo.isExported())
1283       IsExported = true;
1284     CSInfo.markDevirt();
1285   };
1286   Apply(SlotInfo.CSInfo);
1287   for (auto &P : SlotInfo.ConstCSInfo)
1288     Apply(P.second);
1289 }
1290 
1291 static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) {
1292   // We can't add calls if we haven't seen a definition
1293   if (Callee.getSummaryList().empty())
1294     return false;
1295 
1296   // Insert calls into the summary index so that the devirtualized targets
1297   // are eligible for import.
1298   // FIXME: Annotate type tests with hotness. For now, mark these as hot
1299   // to better ensure we have the opportunity to inline them.
1300   bool IsExported = false;
1301   auto &S = Callee.getSummaryList()[0];
1302   CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false,
1303                 /* RelBF = */ 0);
1304   auto AddCalls = [&](CallSiteInfo &CSInfo) {
1305     for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) {
1306       FS->addCall({Callee, CI});
1307       IsExported |= S->modulePath() != FS->modulePath();
1308     }
1309     for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) {
1310       FS->addCall({Callee, CI});
1311       IsExported |= S->modulePath() != FS->modulePath();
1312     }
1313   };
1314   AddCalls(SlotInfo.CSInfo);
1315   for (auto &P : SlotInfo.ConstCSInfo)
1316     AddCalls(P.second);
1317   return IsExported;
1318 }
1319 
1320 bool DevirtModule::trySingleImplDevirt(
1321     ModuleSummaryIndex *ExportSummary,
1322     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1323     WholeProgramDevirtResolution *Res) {
1324   // See if the program contains a single implementation of this virtual
1325   // function.
1326   auto *TheFn = TargetsForSlot[0].Fn;
1327   for (auto &&Target : TargetsForSlot)
1328     if (TheFn != Target.Fn)
1329       return false;
1330 
1331   // If so, update each call site to call that implementation directly.
1332   if (RemarksEnabled || AreStatisticsEnabled())
1333     TargetsForSlot[0].WasDevirt = true;
1334 
1335   bool IsExported = false;
1336   applySingleImplDevirt(SlotInfo, TheFn, IsExported);
1337   if (!IsExported)
1338     return false;
1339 
1340   // If the only implementation has local linkage, we must promote to external
1341   // to make it visible to thin LTO objects. We can only get here during the
1342   // ThinLTO export phase.
1343   if (TheFn->hasLocalLinkage()) {
1344     std::string NewName = (TheFn->getName() + ".llvm.merged").str();
1345 
1346     // Since we are renaming the function, any comdats with the same name must
1347     // also be renamed. This is required when targeting COFF, as the comdat name
1348     // must match one of the names of the symbols in the comdat.
1349     if (Comdat *C = TheFn->getComdat()) {
1350       if (C->getName() == TheFn->getName()) {
1351         Comdat *NewC = M.getOrInsertComdat(NewName);
1352         NewC->setSelectionKind(C->getSelectionKind());
1353         for (GlobalObject &GO : M.global_objects())
1354           if (GO.getComdat() == C)
1355             GO.setComdat(NewC);
1356       }
1357     }
1358 
1359     TheFn->setLinkage(GlobalValue::ExternalLinkage);
1360     TheFn->setVisibility(GlobalValue::HiddenVisibility);
1361     TheFn->setName(NewName);
1362   }
1363   if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID()))
1364     // Any needed promotion of 'TheFn' has already been done during
1365     // LTO unit split, so we can ignore return value of AddCalls.
1366     AddCalls(SlotInfo, TheFnVI);
1367 
1368   Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
1369   Res->SingleImplName = std::string(TheFn->getName());
1370 
1371   return true;
1372 }
1373 
1374 bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
1375                                       VTableSlotSummary &SlotSummary,
1376                                       VTableSlotInfo &SlotInfo,
1377                                       WholeProgramDevirtResolution *Res,
1378                                       std::set<ValueInfo> &DevirtTargets) {
1379   // See if the program contains a single implementation of this virtual
1380   // function.
1381   auto TheFn = TargetsForSlot[0];
1382   for (auto &&Target : TargetsForSlot)
1383     if (TheFn != Target)
1384       return false;
1385 
1386   // Don't devirtualize if we don't have target definition.
1387   auto Size = TheFn.getSummaryList().size();
1388   if (!Size)
1389     return false;
1390 
1391   // Don't devirtualize function if we're told to skip it
1392   // in -wholeprogramdevirt-skip.
1393   if (FunctionsToSkip.match(TheFn.name()))
1394     return false;
1395 
1396   // If the summary list contains multiple summaries where at least one is
1397   // a local, give up, as we won't know which (possibly promoted) name to use.
1398   for (const auto &S : TheFn.getSummaryList())
1399     if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1)
1400       return false;
1401 
1402   // Collect functions devirtualized at least for one call site for stats.
1403   if (PrintSummaryDevirt || AreStatisticsEnabled())
1404     DevirtTargets.insert(TheFn);
1405 
1406   auto &S = TheFn.getSummaryList()[0];
1407   bool IsExported = AddCalls(SlotInfo, TheFn);
1408   if (IsExported)
1409     ExportedGUIDs.insert(TheFn.getGUID());
1410 
1411   // Record in summary for use in devirtualization during the ThinLTO import
1412   // step.
1413   Res->TheKind = WholeProgramDevirtResolution::SingleImpl;
1414   if (GlobalValue::isLocalLinkage(S->linkage())) {
1415     if (IsExported)
1416       // If target is a local function and we are exporting it by
1417       // devirtualizing a call in another module, we need to record the
1418       // promoted name.
1419       Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal(
1420           TheFn.name(), ExportSummary.getModuleHash(S->modulePath()));
1421     else {
1422       LocalWPDTargetsMap[TheFn].push_back(SlotSummary);
1423       Res->SingleImplName = std::string(TheFn.name());
1424     }
1425   } else
1426     Res->SingleImplName = std::string(TheFn.name());
1427 
1428   // Name will be empty if this thin link driven off of serialized combined
1429   // index (e.g. llvm-lto). However, WPD is not supported/invoked for the
1430   // legacy LTO API anyway.
1431   assert(!Res->SingleImplName.empty());
1432 
1433   return true;
1434 }
1435 
1436 void DevirtModule::tryICallBranchFunnel(
1437     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1438     WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1439   Triple T(M.getTargetTriple());
1440   if (T.getArch() != Triple::x86_64)
1441     return;
1442 
1443   if (TargetsForSlot.size() > ClThreshold)
1444     return;
1445 
1446   bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
1447   if (!HasNonDevirt)
1448     for (auto &P : SlotInfo.ConstCSInfo)
1449       if (!P.second.AllCallSitesDevirted) {
1450         HasNonDevirt = true;
1451         break;
1452       }
1453 
1454   if (!HasNonDevirt)
1455     return;
1456 
1457   FunctionType *FT =
1458       FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
1459   Function *JT;
1460   if (isa<MDString>(Slot.TypeID)) {
1461     JT = Function::Create(FT, Function::ExternalLinkage,
1462                           M.getDataLayout().getProgramAddressSpace(),
1463                           getGlobalName(Slot, {}, "branch_funnel"), &M);
1464     JT->setVisibility(GlobalValue::HiddenVisibility);
1465   } else {
1466     JT = Function::Create(FT, Function::InternalLinkage,
1467                           M.getDataLayout().getProgramAddressSpace(),
1468                           "branch_funnel", &M);
1469   }
1470   JT->addParamAttr(0, Attribute::Nest);
1471 
1472   std::vector<Value *> JTArgs;
1473   JTArgs.push_back(JT->arg_begin());
1474   for (auto &T : TargetsForSlot) {
1475     JTArgs.push_back(getMemberAddr(T.TM));
1476     JTArgs.push_back(T.Fn);
1477   }
1478 
1479   BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr);
1480   Function *Intr = Intrinsic::getOrInsertDeclaration(
1481       &M, llvm::Intrinsic::icall_branch_funnel, {});
1482 
1483   auto *CI = CallInst::Create(Intr, JTArgs, "", BB);
1484   CI->setTailCallKind(CallInst::TCK_MustTail);
1485   ReturnInst::Create(M.getContext(), nullptr, BB);
1486 
1487   bool IsExported = false;
1488   applyICallBranchFunnel(SlotInfo, JT, IsExported);
1489   if (IsExported)
1490     Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
1491 }
1492 
1493 void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
1494                                           Constant *JT, bool &IsExported) {
1495   auto Apply = [&](CallSiteInfo &CSInfo) {
1496     if (CSInfo.isExported())
1497       IsExported = true;
1498     if (CSInfo.AllCallSitesDevirted)
1499       return;
1500 
1501     std::map<CallBase *, CallBase *> CallBases;
1502     for (auto &&VCallSite : CSInfo.CallSites) {
1503       CallBase &CB = VCallSite.CB;
1504 
1505       if (CallBases.find(&CB) != CallBases.end()) {
1506         // When finding devirtualizable calls, it's possible to find the same
1507         // vtable passed to multiple llvm.type.test or llvm.type.checked.load
1508         // calls, which can cause duplicate call sites to be recorded in
1509         // [Const]CallSites. If we've already found one of these
1510         // call instances, just ignore it. It will be replaced later.
1511         continue;
1512       }
1513 
1514       // Jump tables are only profitable if the retpoline mitigation is enabled.
1515       Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
1516       if (!FSAttr.isValid() ||
1517           !FSAttr.getValueAsString().contains("+retpoline"))
1518         continue;
1519 
1520       NumBranchFunnel++;
1521       if (RemarksEnabled)
1522         VCallSite.emitRemark("branch-funnel",
1523                              JT->stripPointerCasts()->getName(), OREGetter);
1524 
1525       // Pass the address of the vtable in the nest register, which is r10 on
1526       // x86_64.
1527       std::vector<Type *> NewArgs;
1528       NewArgs.push_back(Int8PtrTy);
1529       append_range(NewArgs, CB.getFunctionType()->params());
1530       FunctionType *NewFT =
1531           FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs,
1532                             CB.getFunctionType()->isVarArg());
1533       IRBuilder<> IRB(&CB);
1534       std::vector<Value *> Args;
1535       Args.push_back(VCallSite.VTable);
1536       llvm::append_range(Args, CB.args());
1537 
1538       CallBase *NewCS = nullptr;
1539       if (isa<CallInst>(CB))
1540         NewCS = IRB.CreateCall(NewFT, JT, Args);
1541       else
1542         NewCS =
1543             IRB.CreateInvoke(NewFT, JT, cast<InvokeInst>(CB).getNormalDest(),
1544                              cast<InvokeInst>(CB).getUnwindDest(), Args);
1545       NewCS->setCallingConv(CB.getCallingConv());
1546 
1547       AttributeList Attrs = CB.getAttributes();
1548       std::vector<AttributeSet> NewArgAttrs;
1549       NewArgAttrs.push_back(AttributeSet::get(
1550           M.getContext(), ArrayRef<Attribute>{Attribute::get(
1551                               M.getContext(), Attribute::Nest)}));
1552       for (unsigned I = 0; I + 2 <  Attrs.getNumAttrSets(); ++I)
1553         NewArgAttrs.push_back(Attrs.getParamAttrs(I));
1554       NewCS->setAttributes(
1555           AttributeList::get(M.getContext(), Attrs.getFnAttrs(),
1556                              Attrs.getRetAttrs(), NewArgAttrs));
1557 
1558       CallBases[&CB] = NewCS;
1559 
1560       // This use is no longer unsafe.
1561       if (VCallSite.NumUnsafeUses)
1562         --*VCallSite.NumUnsafeUses;
1563     }
1564     // Don't mark as devirtualized because there may be callers compiled without
1565     // retpoline mitigation, which would mean that they are lowered to
1566     // llvm.type.test and therefore require an llvm.type.test resolution for the
1567     // type identifier.
1568 
1569     for (auto &[Old, New] : CallBases) {
1570       Old->replaceAllUsesWith(New);
1571       Old->eraseFromParent();
1572     }
1573   };
1574   Apply(SlotInfo.CSInfo);
1575   for (auto &P : SlotInfo.ConstCSInfo)
1576     Apply(P.second);
1577 }
1578 
1579 bool DevirtModule::tryEvaluateFunctionsWithArgs(
1580     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
1581     ArrayRef<uint64_t> Args) {
1582   // Evaluate each function and store the result in each target's RetVal
1583   // field.
1584   for (VirtualCallTarget &Target : TargetsForSlot) {
1585     // TODO: Skip for now if the vtable symbol was an alias to a function,
1586     // need to evaluate whether it would be correct to analyze the aliasee
1587     // function for this optimization.
1588     auto Fn = dyn_cast<Function>(Target.Fn);
1589     if (!Fn)
1590       return false;
1591 
1592     if (Fn->arg_size() != Args.size() + 1)
1593       return false;
1594 
1595     Evaluator Eval(M.getDataLayout(), nullptr);
1596     SmallVector<Constant *, 2> EvalArgs;
1597     EvalArgs.push_back(
1598         Constant::getNullValue(Fn->getFunctionType()->getParamType(0)));
1599     for (unsigned I = 0; I != Args.size(); ++I) {
1600       auto *ArgTy =
1601           dyn_cast<IntegerType>(Fn->getFunctionType()->getParamType(I + 1));
1602       if (!ArgTy)
1603         return false;
1604       EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I]));
1605     }
1606 
1607     Constant *RetVal;
1608     if (!Eval.EvaluateFunction(Fn, RetVal, EvalArgs) ||
1609         !isa<ConstantInt>(RetVal))
1610       return false;
1611     Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
1612   }
1613   return true;
1614 }
1615 
1616 void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
1617                                          uint64_t TheRetVal) {
1618   for (auto Call : CSInfo.CallSites) {
1619     if (!OptimizedCalls.insert(&Call.CB).second)
1620       continue;
1621     NumUniformRetVal++;
1622     Call.replaceAndErase(
1623         "uniform-ret-val", FnName, RemarksEnabled, OREGetter,
1624         ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal));
1625   }
1626   CSInfo.markDevirt();
1627 }
1628 
1629 bool DevirtModule::tryUniformRetValOpt(
1630     MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo,
1631     WholeProgramDevirtResolution::ByArg *Res) {
1632   // Uniform return value optimization. If all functions return the same
1633   // constant, replace all calls with that constant.
1634   uint64_t TheRetVal = TargetsForSlot[0].RetVal;
1635   for (const VirtualCallTarget &Target : TargetsForSlot)
1636     if (Target.RetVal != TheRetVal)
1637       return false;
1638 
1639   if (CSInfo.isExported()) {
1640     Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal;
1641     Res->Info = TheRetVal;
1642   }
1643 
1644   applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal);
1645   if (RemarksEnabled || AreStatisticsEnabled())
1646     for (auto &&Target : TargetsForSlot)
1647       Target.WasDevirt = true;
1648   return true;
1649 }
1650 
1651 std::string DevirtModule::getGlobalName(VTableSlot Slot,
1652                                         ArrayRef<uint64_t> Args,
1653                                         StringRef Name) {
1654   std::string FullName = "__typeid_";
1655   raw_string_ostream OS(FullName);
1656   OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset;
1657   for (uint64_t Arg : Args)
1658     OS << '_' << Arg;
1659   OS << '_' << Name;
1660   return FullName;
1661 }
1662 
1663 bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() {
1664   Triple T(M.getTargetTriple());
1665   return T.isX86() && T.getObjectFormat() == Triple::ELF;
1666 }
1667 
1668 void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1669                                 StringRef Name, Constant *C) {
1670   GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
1671                                         getGlobalName(Slot, Args, Name), C, &M);
1672   GA->setVisibility(GlobalValue::HiddenVisibility);
1673 }
1674 
1675 void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1676                                   StringRef Name, uint32_t Const,
1677                                   uint32_t &Storage) {
1678   if (shouldExportConstantsAsAbsoluteSymbols()) {
1679     exportGlobal(
1680         Slot, Args, Name,
1681         ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy));
1682     return;
1683   }
1684 
1685   Storage = Const;
1686 }
1687 
1688 Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args,
1689                                      StringRef Name) {
1690   Constant *C =
1691       M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty);
1692   auto *GV = dyn_cast<GlobalVariable>(C);
1693   if (GV)
1694     GV->setVisibility(GlobalValue::HiddenVisibility);
1695   return C;
1696 }
1697 
1698 Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args,
1699                                        StringRef Name, IntegerType *IntTy,
1700                                        uint32_t Storage) {
1701   if (!shouldExportConstantsAsAbsoluteSymbols())
1702     return ConstantInt::get(IntTy, Storage);
1703 
1704   Constant *C = importGlobal(Slot, Args, Name);
1705   auto *GV = cast<GlobalVariable>(C->stripPointerCasts());
1706   C = ConstantExpr::getPtrToInt(C, IntTy);
1707 
1708   // We only need to set metadata if the global is newly created, in which
1709   // case it would not have hidden visibility.
1710   if (GV->hasMetadata(LLVMContext::MD_absolute_symbol))
1711     return C;
1712 
1713   auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
1714     auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
1715     auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
1716     GV->setMetadata(LLVMContext::MD_absolute_symbol,
1717                     MDNode::get(M.getContext(), {MinC, MaxC}));
1718   };
1719   unsigned AbsWidth = IntTy->getBitWidth();
1720   if (AbsWidth == IntPtrTy->getBitWidth())
1721     SetAbsRange(~0ull, ~0ull); // Full set.
1722   else
1723     SetAbsRange(0, 1ull << AbsWidth);
1724   return C;
1725 }
1726 
1727 void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
1728                                         bool IsOne,
1729                                         Constant *UniqueMemberAddr) {
1730   for (auto &&Call : CSInfo.CallSites) {
1731     if (!OptimizedCalls.insert(&Call.CB).second)
1732       continue;
1733     IRBuilder<> B(&Call.CB);
1734     Value *Cmp =
1735         B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable,
1736                      B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType()));
1737     Cmp = B.CreateZExt(Cmp, Call.CB.getType());
1738     NumUniqueRetVal++;
1739     Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
1740                          Cmp);
1741   }
1742   CSInfo.markDevirt();
1743 }
1744 
1745 Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
1746   return ConstantExpr::getGetElementPtr(Int8Ty, M->Bits->GV,
1747                                         ConstantInt::get(Int64Ty, M->Offset));
1748 }
1749 
1750 bool DevirtModule::tryUniqueRetValOpt(
1751     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
1752     CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
1753     VTableSlot Slot, ArrayRef<uint64_t> Args) {
1754   // IsOne controls whether we look for a 0 or a 1.
1755   auto tryUniqueRetValOptFor = [&](bool IsOne) {
1756     const TypeMemberInfo *UniqueMember = nullptr;
1757     for (const VirtualCallTarget &Target : TargetsForSlot) {
1758       if (Target.RetVal == (IsOne ? 1 : 0)) {
1759         if (UniqueMember)
1760           return false;
1761         UniqueMember = Target.TM;
1762       }
1763     }
1764 
1765     // We should have found a unique member or bailed out by now. We already
1766     // checked for a uniform return value in tryUniformRetValOpt.
1767     assert(UniqueMember);
1768 
1769     Constant *UniqueMemberAddr = getMemberAddr(UniqueMember);
1770     if (CSInfo.isExported()) {
1771       Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
1772       Res->Info = IsOne;
1773 
1774       exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr);
1775     }
1776 
1777     // Replace each call with the comparison.
1778     applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne,
1779                          UniqueMemberAddr);
1780 
1781     // Update devirtualization statistics for targets.
1782     if (RemarksEnabled || AreStatisticsEnabled())
1783       for (auto &&Target : TargetsForSlot)
1784         Target.WasDevirt = true;
1785 
1786     return true;
1787   };
1788 
1789   if (BitWidth == 1) {
1790     if (tryUniqueRetValOptFor(true))
1791       return true;
1792     if (tryUniqueRetValOptFor(false))
1793       return true;
1794   }
1795   return false;
1796 }
1797 
1798 void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
1799                                          Constant *Byte, Constant *Bit) {
1800   for (auto Call : CSInfo.CallSites) {
1801     if (!OptimizedCalls.insert(&Call.CB).second)
1802       continue;
1803     auto *RetType = cast<IntegerType>(Call.CB.getType());
1804     IRBuilder<> B(&Call.CB);
1805     Value *Addr = B.CreatePtrAdd(Call.VTable, Byte);
1806     if (RetType->getBitWidth() == 1) {
1807       Value *Bits = B.CreateLoad(Int8Ty, Addr);
1808       Value *BitsAndBit = B.CreateAnd(Bits, Bit);
1809       auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
1810       NumVirtConstProp1Bit++;
1811       Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled,
1812                            OREGetter, IsBitSet);
1813     } else {
1814       Value *Val = B.CreateLoad(RetType, Addr);
1815       NumVirtConstProp++;
1816       Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled,
1817                            OREGetter, Val);
1818     }
1819   }
1820   CSInfo.markDevirt();
1821 }
1822 
1823 bool DevirtModule::tryVirtualConstProp(
1824     MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
1825     WholeProgramDevirtResolution *Res, VTableSlot Slot) {
1826   // TODO: Skip for now if the vtable symbol was an alias to a function,
1827   // need to evaluate whether it would be correct to analyze the aliasee
1828   // function for this optimization.
1829   auto Fn = dyn_cast<Function>(TargetsForSlot[0].Fn);
1830   if (!Fn)
1831     return false;
1832   // This only works if the function returns an integer.
1833   auto RetType = dyn_cast<IntegerType>(Fn->getReturnType());
1834   if (!RetType)
1835     return false;
1836   unsigned BitWidth = RetType->getBitWidth();
1837   if (BitWidth > 64)
1838     return false;
1839 
1840   // Make sure that each function is defined, does not access memory, takes at
1841   // least one argument, does not use its first argument (which we assume is
1842   // 'this'), and has the same return type.
1843   //
1844   // Note that we test whether this copy of the function is readnone, rather
1845   // than testing function attributes, which must hold for any copy of the
1846   // function, even a less optimized version substituted at link time. This is
1847   // sound because the virtual constant propagation optimizations effectively
1848   // inline all implementations of the virtual function into each call site,
1849   // rather than using function attributes to perform local optimization.
1850   for (VirtualCallTarget &Target : TargetsForSlot) {
1851     // TODO: Skip for now if the vtable symbol was an alias to a function,
1852     // need to evaluate whether it would be correct to analyze the aliasee
1853     // function for this optimization.
1854     auto Fn = dyn_cast<Function>(Target.Fn);
1855     if (!Fn)
1856       return false;
1857 
1858     if (Fn->isDeclaration() ||
1859         !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn))
1860              .doesNotAccessMemory() ||
1861         Fn->arg_empty() || !Fn->arg_begin()->use_empty() ||
1862         Fn->getReturnType() != RetType)
1863       return false;
1864   }
1865 
1866   for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) {
1867     if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
1868       continue;
1869 
1870     WholeProgramDevirtResolution::ByArg *ResByArg = nullptr;
1871     if (Res)
1872       ResByArg = &Res->ResByArg[CSByConstantArg.first];
1873 
1874     if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg))
1875       continue;
1876 
1877     if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second,
1878                            ResByArg, Slot, CSByConstantArg.first))
1879       continue;
1880 
1881     // Find an allocation offset in bits in all vtables associated with the
1882     // type.
1883     uint64_t AllocBefore =
1884         findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
1885     uint64_t AllocAfter =
1886         findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
1887 
1888     // Calculate the total amount of padding needed to store a value at both
1889     // ends of the object.
1890     uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
1891     for (auto &&Target : TargetsForSlot) {
1892       TotalPaddingBefore += std::max<int64_t>(
1893           (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
1894       TotalPaddingAfter += std::max<int64_t>(
1895           (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
1896     }
1897 
1898     // If the amount of padding is too large, give up.
1899     // FIXME: do something smarter here.
1900     if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
1901       continue;
1902 
1903     // Calculate the offset to the value as a (possibly negative) byte offset
1904     // and (if applicable) a bit offset, and store the values in the targets.
1905     int64_t OffsetByte;
1906     uint64_t OffsetBit;
1907     if (TotalPaddingBefore <= TotalPaddingAfter)
1908       setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
1909                             OffsetBit);
1910     else
1911       setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
1912                            OffsetBit);
1913 
1914     if (RemarksEnabled || AreStatisticsEnabled())
1915       for (auto &&Target : TargetsForSlot)
1916         Target.WasDevirt = true;
1917 
1918 
1919     if (CSByConstantArg.second.isExported()) {
1920       ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp;
1921       exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte,
1922                      ResByArg->Byte);
1923       exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit,
1924                      ResByArg->Bit);
1925     }
1926 
1927     // Rewrite each call to a load from OffsetByte/OffsetBit.
1928     Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte);
1929     Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
1930     applyVirtualConstProp(CSByConstantArg.second,
1931                           TargetsForSlot[0].Fn->getName(), ByteConst, BitConst);
1932   }
1933   return true;
1934 }
1935 
1936 void DevirtModule::rebuildGlobal(VTableBits &B) {
1937   if (B.Before.Bytes.empty() && B.After.Bytes.empty())
1938     return;
1939 
1940   // Align the before byte array to the global's minimum alignment so that we
1941   // don't break any alignment requirements on the global.
1942   Align Alignment = M.getDataLayout().getValueOrABITypeAlignment(
1943       B.GV->getAlign(), B.GV->getValueType());
1944   B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment));
1945 
1946   // Before was stored in reverse order; flip it now.
1947   for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
1948     std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
1949 
1950   // Build an anonymous global containing the before bytes, followed by the
1951   // original initializer, followed by the after bytes.
1952   auto NewInit = ConstantStruct::getAnon(
1953       {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
1954        B.GV->getInitializer(),
1955        ConstantDataArray::get(M.getContext(), B.After.Bytes)});
1956   auto NewGV =
1957       new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
1958                          GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
1959   NewGV->setSection(B.GV->getSection());
1960   NewGV->setComdat(B.GV->getComdat());
1961   NewGV->setAlignment(B.GV->getAlign());
1962 
1963   // Copy the original vtable's metadata to the anonymous global, adjusting
1964   // offsets as required.
1965   NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
1966 
1967   // Build an alias named after the original global, pointing at the second
1968   // element (the original initializer).
1969   auto Alias = GlobalAlias::create(
1970       B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
1971       ConstantExpr::getInBoundsGetElementPtr(
1972           NewInit->getType(), NewGV,
1973           ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
1974                                ConstantInt::get(Int32Ty, 1)}),
1975       &M);
1976   Alias->setVisibility(B.GV->getVisibility());
1977   Alias->takeName(B.GV);
1978 
1979   B.GV->replaceAllUsesWith(Alias);
1980   B.GV->eraseFromParent();
1981 }
1982 
1983 bool DevirtModule::areRemarksEnabled() {
1984   const auto &FL = M.getFunctionList();
1985   for (const Function &Fn : FL) {
1986     if (Fn.empty())
1987       continue;
1988     auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &Fn.front());
1989     return DI.isEnabled();
1990   }
1991   return false;
1992 }
1993 
1994 void DevirtModule::scanTypeTestUsers(
1995     Function *TypeTestFunc,
1996     DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
1997   // Find all virtual calls via a virtual table pointer %p under an assumption
1998   // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
1999   // points to a member of the type identifier %md. Group calls by (type ID,
2000   // offset) pair (effectively the identity of the virtual function) and store
2001   // to CallSlots.
2002   for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) {
2003     auto *CI = dyn_cast<CallInst>(U.getUser());
2004     if (!CI)
2005       continue;
2006 
2007     // Search for virtual calls based on %p and add them to DevirtCalls.
2008     SmallVector<DevirtCallSite, 1> DevirtCalls;
2009     SmallVector<CallInst *, 1> Assumes;
2010     auto &DT = LookupDomTree(*CI->getFunction());
2011     findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
2012 
2013     Metadata *TypeId =
2014         cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
2015     // If we found any, add them to CallSlots.
2016     if (!Assumes.empty()) {
2017       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
2018       for (DevirtCallSite Call : DevirtCalls)
2019         CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr);
2020     }
2021 
2022     auto RemoveTypeTestAssumes = [&]() {
2023       // We no longer need the assumes or the type test.
2024       for (auto *Assume : Assumes)
2025         Assume->eraseFromParent();
2026       // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
2027       // may use the vtable argument later.
2028       if (CI->use_empty())
2029         CI->eraseFromParent();
2030     };
2031 
2032     // At this point we could remove all type test assume sequences, as they
2033     // were originally inserted for WPD. However, we can keep these in the
2034     // code stream for later analysis (e.g. to help drive more efficient ICP
2035     // sequences). They will eventually be removed by a second LowerTypeTests
2036     // invocation that cleans them up. In order to do this correctly, the first
2037     // LowerTypeTests invocation needs to know that they have "Unknown" type
2038     // test resolution, so that they aren't treated as Unsat and lowered to
2039     // False, which will break any uses on assumes. Below we remove any type
2040     // test assumes that will not be treated as Unknown by LTT.
2041 
2042     // The type test assumes will be treated by LTT as Unsat if the type id is
2043     // not used on a global (in which case it has no entry in the TypeIdMap).
2044     if (!TypeIdMap.count(TypeId))
2045       RemoveTypeTestAssumes();
2046 
2047     // For ThinLTO importing, we need to remove the type test assumes if this is
2048     // an MDString type id without a corresponding TypeIdSummary. Any
2049     // non-MDString type ids are ignored and treated as Unknown by LTT, so their
2050     // type test assumes can be kept. If the MDString type id is missing a
2051     // TypeIdSummary (e.g. because there was no use on a vcall, preventing the
2052     // exporting phase of WPD from analyzing it), then it would be treated as
2053     // Unsat by LTT and we need to remove its type test assumes here. If not
2054     // used on a vcall we don't need them for later optimization use in any
2055     // case.
2056     else if (ImportSummary && isa<MDString>(TypeId)) {
2057       const TypeIdSummary *TidSummary =
2058           ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString());
2059       if (!TidSummary)
2060         RemoveTypeTestAssumes();
2061       else
2062         // If one was created it should not be Unsat, because if we reached here
2063         // the type id was used on a global.
2064         assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat);
2065     }
2066   }
2067 }
2068 
2069 void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
2070   Function *TypeTestFunc =
2071       Intrinsic::getOrInsertDeclaration(&M, Intrinsic::type_test);
2072 
2073   for (Use &U : llvm::make_early_inc_range(TypeCheckedLoadFunc->uses())) {
2074     auto *CI = dyn_cast<CallInst>(U.getUser());
2075     if (!CI)
2076       continue;
2077 
2078     Value *Ptr = CI->getArgOperand(0);
2079     Value *Offset = CI->getArgOperand(1);
2080     Value *TypeIdValue = CI->getArgOperand(2);
2081     Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
2082 
2083     SmallVector<DevirtCallSite, 1> DevirtCalls;
2084     SmallVector<Instruction *, 1> LoadedPtrs;
2085     SmallVector<Instruction *, 1> Preds;
2086     bool HasNonCallUses = false;
2087     auto &DT = LookupDomTree(*CI->getFunction());
2088     findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
2089                                                HasNonCallUses, CI, DT);
2090 
2091     // Start by generating "pessimistic" code that explicitly loads the function
2092     // pointer from the vtable and performs the type check. If possible, we will
2093     // eliminate the load and the type check later.
2094 
2095     // If possible, only generate the load at the point where it is used.
2096     // This helps avoid unnecessary spills.
2097     IRBuilder<> LoadB(
2098         (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
2099 
2100     Value *LoadedValue = nullptr;
2101     if (TypeCheckedLoadFunc->getIntrinsicID() ==
2102         Intrinsic::type_checked_load_relative) {
2103       Value *GEP = LoadB.CreatePtrAdd(Ptr, Offset);
2104       LoadedValue = LoadB.CreateLoad(Int32Ty, GEP);
2105       LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy);
2106       GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy);
2107       LoadedValue = LoadB.CreateAdd(GEP, LoadedValue);
2108       LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy);
2109     } else {
2110       Value *GEP = LoadB.CreatePtrAdd(Ptr, Offset);
2111       LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEP);
2112     }
2113 
2114     for (Instruction *LoadedPtr : LoadedPtrs) {
2115       LoadedPtr->replaceAllUsesWith(LoadedValue);
2116       LoadedPtr->eraseFromParent();
2117     }
2118 
2119     // Likewise for the type test.
2120     IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
2121     CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
2122 
2123     for (Instruction *Pred : Preds) {
2124       Pred->replaceAllUsesWith(TypeTestCall);
2125       Pred->eraseFromParent();
2126     }
2127 
2128     // We have already erased any extractvalue instructions that refer to the
2129     // intrinsic call, but the intrinsic may have other non-extractvalue uses
2130     // (although this is unlikely). In that case, explicitly build a pair and
2131     // RAUW it.
2132     if (!CI->use_empty()) {
2133       Value *Pair = PoisonValue::get(CI->getType());
2134       IRBuilder<> B(CI);
2135       Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
2136       Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
2137       CI->replaceAllUsesWith(Pair);
2138     }
2139 
2140     // The number of unsafe uses is initially the number of uses.
2141     auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
2142     NumUnsafeUses = DevirtCalls.size();
2143 
2144     // If the function pointer has a non-call user, we cannot eliminate the type
2145     // check, as one of those users may eventually call the pointer. Increment
2146     // the unsafe use count to make sure it cannot reach zero.
2147     if (HasNonCallUses)
2148       ++NumUnsafeUses;
2149     for (DevirtCallSite Call : DevirtCalls) {
2150       CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB,
2151                                                    &NumUnsafeUses);
2152     }
2153 
2154     CI->eraseFromParent();
2155   }
2156 }
2157 
2158 void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
2159   auto *TypeId = dyn_cast<MDString>(Slot.TypeID);
2160   if (!TypeId)
2161     return;
2162   const TypeIdSummary *TidSummary =
2163       ImportSummary->getTypeIdSummary(TypeId->getString());
2164   if (!TidSummary)
2165     return;
2166   auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset);
2167   if (ResI == TidSummary->WPDRes.end())
2168     return;
2169   const WholeProgramDevirtResolution &Res = ResI->second;
2170 
2171   if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) {
2172     assert(!Res.SingleImplName.empty());
2173     // The type of the function in the declaration is irrelevant because every
2174     // call site will cast it to the correct type.
2175     Constant *SingleImpl =
2176         cast<Constant>(M.getOrInsertFunction(Res.SingleImplName,
2177                                              Type::getVoidTy(M.getContext()))
2178                            .getCallee());
2179 
2180     // This is the import phase so we should not be exporting anything.
2181     bool IsExported = false;
2182     applySingleImplDevirt(SlotInfo, SingleImpl, IsExported);
2183     assert(!IsExported);
2184   }
2185 
2186   for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) {
2187     auto I = Res.ResByArg.find(CSByConstantArg.first);
2188     if (I == Res.ResByArg.end())
2189       continue;
2190     auto &ResByArg = I->second;
2191     // FIXME: We should figure out what to do about the "function name" argument
2192     // to the apply* functions, as the function names are unavailable during the
2193     // importing phase. For now we just pass the empty string. This does not
2194     // impact correctness because the function names are just used for remarks.
2195     switch (ResByArg.TheKind) {
2196     case WholeProgramDevirtResolution::ByArg::UniformRetVal:
2197       applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info);
2198       break;
2199     case WholeProgramDevirtResolution::ByArg::UniqueRetVal: {
2200       Constant *UniqueMemberAddr =
2201           importGlobal(Slot, CSByConstantArg.first, "unique_member");
2202       applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info,
2203                            UniqueMemberAddr);
2204       break;
2205     }
2206     case WholeProgramDevirtResolution::ByArg::VirtualConstProp: {
2207       Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte",
2208                                       Int32Ty, ResByArg.Byte);
2209       Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty,
2210                                      ResByArg.Bit);
2211       applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit);
2212       break;
2213     }
2214     default:
2215       break;
2216     }
2217   }
2218 
2219   if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
2220     // The type of the function is irrelevant, because it's bitcast at calls
2221     // anyhow.
2222     Constant *JT = cast<Constant>(
2223         M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
2224                               Type::getVoidTy(M.getContext()))
2225             .getCallee());
2226     bool IsExported = false;
2227     applyICallBranchFunnel(SlotInfo, JT, IsExported);
2228     assert(!IsExported);
2229   }
2230 }
2231 
2232 void DevirtModule::removeRedundantTypeTests() {
2233   auto True = ConstantInt::getTrue(M.getContext());
2234   for (auto &&U : NumUnsafeUsesForTypeTest) {
2235     if (U.second == 0) {
2236       U.first->replaceAllUsesWith(True);
2237       U.first->eraseFromParent();
2238     }
2239   }
2240 }
2241 
2242 ValueInfo
2243 DevirtModule::lookUpFunctionValueInfo(Function *TheFn,
2244                                       ModuleSummaryIndex *ExportSummary) {
2245   assert((ExportSummary != nullptr) &&
2246          "Caller guarantees ExportSummary is not nullptr");
2247 
2248   const auto TheFnGUID = TheFn->getGUID();
2249   const auto TheFnGUIDWithExportedName = GlobalValue::getGUID(TheFn->getName());
2250   // Look up ValueInfo with the GUID in the current linkage.
2251   ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFnGUID);
2252   // If no entry is found and GUID is different from GUID computed using
2253   // exported name, look up ValueInfo with the exported name unconditionally.
2254   // This is a fallback.
2255   //
2256   // The reason to have a fallback:
2257   // 1. LTO could enable global value internalization via
2258   // `enable-lto-internalization`.
2259   // 2. The GUID in ExportedSummary is computed using exported name.
2260   if ((!TheFnVI) && (TheFnGUID != TheFnGUIDWithExportedName)) {
2261     TheFnVI = ExportSummary->getValueInfo(TheFnGUIDWithExportedName);
2262   }
2263   return TheFnVI;
2264 }
2265 
2266 bool DevirtModule::mustBeUnreachableFunction(
2267     Function *const F, ModuleSummaryIndex *ExportSummary) {
2268   if (WholeProgramDevirtKeepUnreachableFunction)
2269     return false;
2270   // First, learn unreachability by analyzing function IR.
2271   if (!F->isDeclaration()) {
2272     // A function must be unreachable if its entry block ends with an
2273     // 'unreachable'.
2274     return isa<UnreachableInst>(F->getEntryBlock().getTerminator());
2275   }
2276   // Learn unreachability from ExportSummary if ExportSummary is present.
2277   return ExportSummary &&
2278          ::mustBeUnreachableFunction(
2279              DevirtModule::lookUpFunctionValueInfo(F, ExportSummary));
2280 }
2281 
2282 bool DevirtModule::run() {
2283   // If only some of the modules were split, we cannot correctly perform
2284   // this transformation. We already checked for the presense of type tests
2285   // with partially split modules during the thin link, and would have emitted
2286   // an error if any were found, so here we can simply return.
2287   if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) ||
2288       (ImportSummary && ImportSummary->partiallySplitLTOUnits()))
2289     return false;
2290 
2291   Function *TypeTestFunc =
2292       Intrinsic::getDeclarationIfExists(&M, Intrinsic::type_test);
2293   Function *TypeCheckedLoadFunc =
2294       Intrinsic::getDeclarationIfExists(&M, Intrinsic::type_checked_load);
2295   Function *TypeCheckedLoadRelativeFunc = Intrinsic::getDeclarationIfExists(
2296       &M, Intrinsic::type_checked_load_relative);
2297   Function *AssumeFunc =
2298       Intrinsic::getDeclarationIfExists(&M, Intrinsic::assume);
2299 
2300   // Normally if there are no users of the devirtualization intrinsics in the
2301   // module, this pass has nothing to do. But if we are exporting, we also need
2302   // to handle any users that appear only in the function summaries.
2303   if (!ExportSummary &&
2304       (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
2305        AssumeFunc->use_empty()) &&
2306       (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) &&
2307       (!TypeCheckedLoadRelativeFunc ||
2308        TypeCheckedLoadRelativeFunc->use_empty()))
2309     return false;
2310 
2311   // Rebuild type metadata into a map for easy lookup.
2312   std::vector<VTableBits> Bits;
2313   DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
2314   buildTypeIdentifierMap(Bits, TypeIdMap);
2315 
2316   if (TypeTestFunc && AssumeFunc)
2317     scanTypeTestUsers(TypeTestFunc, TypeIdMap);
2318 
2319   if (TypeCheckedLoadFunc)
2320     scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
2321 
2322   if (TypeCheckedLoadRelativeFunc)
2323     scanTypeCheckedLoadUsers(TypeCheckedLoadRelativeFunc);
2324 
2325   if (ImportSummary) {
2326     for (auto &S : CallSlots)
2327       importResolution(S.first, S.second);
2328 
2329     removeRedundantTypeTests();
2330 
2331     // We have lowered or deleted the type intrinsics, so we will no longer have
2332     // enough information to reason about the liveness of virtual function
2333     // pointers in GlobalDCE.
2334     for (GlobalVariable &GV : M.globals())
2335       GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
2336 
2337     // The rest of the code is only necessary when exporting or during regular
2338     // LTO, so we are done.
2339     return true;
2340   }
2341 
2342   if (TypeIdMap.empty())
2343     return true;
2344 
2345   // Collect information from summary about which calls to try to devirtualize.
2346   if (ExportSummary) {
2347     DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
2348     for (auto &P : TypeIdMap) {
2349       if (auto *TypeId = dyn_cast<MDString>(P.first))
2350         MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
2351             TypeId);
2352     }
2353 
2354     for (auto &P : *ExportSummary) {
2355       for (auto &S : P.second.SummaryList) {
2356         auto *FS = dyn_cast<FunctionSummary>(S.get());
2357         if (!FS)
2358           continue;
2359         // FIXME: Only add live functions.
2360         for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
2361           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
2362             CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
2363           }
2364         }
2365         for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
2366           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
2367             CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
2368           }
2369         }
2370         for (const FunctionSummary::ConstVCall &VC :
2371              FS->type_test_assume_const_vcalls()) {
2372           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
2373             CallSlots[{MD, VC.VFunc.Offset}]
2374                 .ConstCSInfo[VC.Args]
2375                 .addSummaryTypeTestAssumeUser(FS);
2376           }
2377         }
2378         for (const FunctionSummary::ConstVCall &VC :
2379              FS->type_checked_load_const_vcalls()) {
2380           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
2381             CallSlots[{MD, VC.VFunc.Offset}]
2382                 .ConstCSInfo[VC.Args]
2383                 .addSummaryTypeCheckedLoadUser(FS);
2384           }
2385         }
2386       }
2387     }
2388   }
2389 
2390   // For each (type, offset) pair:
2391   bool DidVirtualConstProp = false;
2392   std::map<std::string, GlobalValue *> DevirtTargets;
2393   for (auto &S : CallSlots) {
2394     // Search each of the members of the type identifier for the virtual
2395     // function implementation at offset S.first.ByteOffset, and add to
2396     // TargetsForSlot.
2397     std::vector<VirtualCallTarget> TargetsForSlot;
2398     WholeProgramDevirtResolution *Res = nullptr;
2399     const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID];
2400     if (ExportSummary && isa<MDString>(S.first.TypeID) &&
2401         TypeMemberInfos.size())
2402       // For any type id used on a global's type metadata, create the type id
2403       // summary resolution regardless of whether we can devirtualize, so that
2404       // lower type tests knows the type id is not Unsat. If it was not used on
2405       // a global's type metadata, the TypeIdMap entry set will be empty, and
2406       // we don't want to create an entry (with the default Unknown type
2407       // resolution), which can prevent detection of the Unsat.
2408       Res = &ExportSummary
2409                  ->getOrInsertTypeIdSummary(
2410                      cast<MDString>(S.first.TypeID)->getString())
2411                  .WPDRes[S.first.ByteOffset];
2412     if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos,
2413                                   S.first.ByteOffset, ExportSummary)) {
2414 
2415       if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) {
2416         DidVirtualConstProp |=
2417             tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
2418 
2419         tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first);
2420       }
2421 
2422       // Collect functions devirtualized at least for one call site for stats.
2423       if (RemarksEnabled || AreStatisticsEnabled())
2424         for (const auto &T : TargetsForSlot)
2425           if (T.WasDevirt)
2426             DevirtTargets[std::string(T.Fn->getName())] = T.Fn;
2427     }
2428 
2429     // CFI-specific: if we are exporting and any llvm.type.checked.load
2430     // intrinsics were *not* devirtualized, we need to add the resulting
2431     // llvm.type.test intrinsics to the function summaries so that the
2432     // LowerTypeTests pass will export them.
2433     if (ExportSummary && isa<MDString>(S.first.TypeID)) {
2434       auto GUID =
2435           GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString());
2436       for (auto *FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers)
2437         FS->addTypeTest(GUID);
2438       for (auto &CCS : S.second.ConstCSInfo)
2439         for (auto *FS : CCS.second.SummaryTypeCheckedLoadUsers)
2440           FS->addTypeTest(GUID);
2441     }
2442   }
2443 
2444   if (RemarksEnabled) {
2445     // Generate remarks for each devirtualized function.
2446     for (const auto &DT : DevirtTargets) {
2447       GlobalValue *GV = DT.second;
2448       auto F = dyn_cast<Function>(GV);
2449       if (!F) {
2450         auto A = dyn_cast<GlobalAlias>(GV);
2451         assert(A && isa<Function>(A->getAliasee()));
2452         F = dyn_cast<Function>(A->getAliasee());
2453         assert(F);
2454       }
2455 
2456       using namespace ore;
2457       OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F)
2458                         << "devirtualized "
2459                         << NV("FunctionName", DT.first));
2460     }
2461   }
2462 
2463   NumDevirtTargets += DevirtTargets.size();
2464 
2465   removeRedundantTypeTests();
2466 
2467   // Rebuild each global we touched as part of virtual constant propagation to
2468   // include the before and after bytes.
2469   if (DidVirtualConstProp)
2470     for (VTableBits &B : Bits)
2471       rebuildGlobal(B);
2472 
2473   // We have lowered or deleted the type intrinsics, so we will no longer have
2474   // enough information to reason about the liveness of virtual function
2475   // pointers in GlobalDCE.
2476   for (GlobalVariable &GV : M.globals())
2477     GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
2478 
2479   for (auto *CI : CallsWithPtrAuthBundleRemoved)
2480     CI->eraseFromParent();
2481 
2482   return true;
2483 }
2484 
2485 void DevirtIndex::run() {
2486   if (ExportSummary.typeIdCompatibleVtableMap().empty())
2487     return;
2488 
2489   DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
2490   for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) {
2491     NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first);
2492     // Create the type id summary resolution regardlness of whether we can
2493     // devirtualize, so that lower type tests knows the type id is used on
2494     // a global and not Unsat. We do this here rather than in the loop over the
2495     // CallSlots, since that handling will only see type tests that directly
2496     // feed assumes, and we would miss any that aren't currently handled by WPD
2497     // (such as type tests that feed assumes via phis).
2498     ExportSummary.getOrInsertTypeIdSummary(P.first);
2499   }
2500 
2501   // Collect information from summary about which calls to try to devirtualize.
2502   for (auto &P : ExportSummary) {
2503     for (auto &S : P.second.SummaryList) {
2504       auto *FS = dyn_cast<FunctionSummary>(S.get());
2505       if (!FS)
2506         continue;
2507       // FIXME: Only add live functions.
2508       for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
2509         for (StringRef Name : NameByGUID[VF.GUID]) {
2510           CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS);
2511         }
2512       }
2513       for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
2514         for (StringRef Name : NameByGUID[VF.GUID]) {
2515           CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
2516         }
2517       }
2518       for (const FunctionSummary::ConstVCall &VC :
2519            FS->type_test_assume_const_vcalls()) {
2520         for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
2521           CallSlots[{Name, VC.VFunc.Offset}]
2522               .ConstCSInfo[VC.Args]
2523               .addSummaryTypeTestAssumeUser(FS);
2524         }
2525       }
2526       for (const FunctionSummary::ConstVCall &VC :
2527            FS->type_checked_load_const_vcalls()) {
2528         for (StringRef Name : NameByGUID[VC.VFunc.GUID]) {
2529           CallSlots[{Name, VC.VFunc.Offset}]
2530               .ConstCSInfo[VC.Args]
2531               .addSummaryTypeCheckedLoadUser(FS);
2532         }
2533       }
2534     }
2535   }
2536 
2537   std::set<ValueInfo> DevirtTargets;
2538   // For each (type, offset) pair:
2539   for (auto &S : CallSlots) {
2540     // Search each of the members of the type identifier for the virtual
2541     // function implementation at offset S.first.ByteOffset, and add to
2542     // TargetsForSlot.
2543     std::vector<ValueInfo> TargetsForSlot;
2544     auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID);
2545     assert(TidSummary);
2546     // The type id summary would have been created while building the NameByGUID
2547     // map earlier.
2548     WholeProgramDevirtResolution *Res =
2549         &ExportSummary.getTypeIdSummary(S.first.TypeID)
2550              ->WPDRes[S.first.ByteOffset];
2551     if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary,
2552                                   S.first.ByteOffset)) {
2553 
2554       if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res,
2555                                DevirtTargets))
2556         continue;
2557     }
2558   }
2559 
2560   // Optionally have the thin link print message for each devirtualized
2561   // function.
2562   if (PrintSummaryDevirt)
2563     for (const auto &DT : DevirtTargets)
2564       errs() << "Devirtualized call to " << DT << "\n";
2565 
2566   NumDevirtTargets += DevirtTargets.size();
2567 }
2568