1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements whole program optimization of virtual calls in cases 10 // where we know (via !type metadata) that the list of callees is fixed. This 11 // includes the following: 12 // - Single implementation devirtualization: if a virtual call has a single 13 // possible callee, replace all calls with a direct call to that callee. 14 // - Virtual constant propagation: if the virtual function's return type is an 15 // integer <=64 bits and all possible callees are readnone, for each class and 16 // each list of constant arguments: evaluate the function, store the return 17 // value alongside the virtual table, and rewrite each virtual call as a load 18 // from the virtual table. 19 // - Uniform return value optimization: if the conditions for virtual constant 20 // propagation hold and each function returns the same constant value, replace 21 // each virtual call with that constant. 22 // - Unique return value optimization for i1 return values: if the conditions 23 // for virtual constant propagation hold and a single vtable's function 24 // returns 0, or a single vtable's function returns 1, replace each virtual 25 // call with a comparison of the vptr against that vtable's address. 26 // 27 // This pass is intended to be used during the regular and thin LTO pipelines: 28 // 29 // During regular LTO, the pass determines the best optimization for each 30 // virtual call and applies the resolutions directly to virtual calls that are 31 // eligible for virtual call optimization (i.e. calls that use either of the 32 // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). 33 // 34 // During hybrid Regular/ThinLTO, the pass operates in two phases: 35 // - Export phase: this is run during the thin link over a single merged module 36 // that contains all vtables with !type metadata that participate in the link. 37 // The pass computes a resolution for each virtual call and stores it in the 38 // type identifier summary. 39 // - Import phase: this is run during the thin backends over the individual 40 // modules. The pass applies the resolutions previously computed during the 41 // import phase to each eligible virtual call. 42 // 43 // During ThinLTO, the pass operates in two phases: 44 // - Export phase: this is run during the thin link over the index which 45 // contains a summary of all vtables with !type metadata that participate in 46 // the link. It computes a resolution for each virtual call and stores it in 47 // the type identifier summary. Only single implementation devirtualization 48 // is supported. 49 // - Import phase: (same as with hybrid case above). 50 // 51 //===----------------------------------------------------------------------===// 52 53 #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 54 #include "llvm/ADT/ArrayRef.h" 55 #include "llvm/ADT/DenseMap.h" 56 #include "llvm/ADT/DenseMapInfo.h" 57 #include "llvm/ADT/DenseSet.h" 58 #include "llvm/ADT/MapVector.h" 59 #include "llvm/ADT/SmallVector.h" 60 #include "llvm/ADT/Statistic.h" 61 #include "llvm/Analysis/AssumptionCache.h" 62 #include "llvm/Analysis/BasicAliasAnalysis.h" 63 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 64 #include "llvm/Analysis/TypeMetadataUtils.h" 65 #include "llvm/Bitcode/BitcodeReader.h" 66 #include "llvm/Bitcode/BitcodeWriter.h" 67 #include "llvm/IR/Constants.h" 68 #include "llvm/IR/DataLayout.h" 69 #include "llvm/IR/DebugLoc.h" 70 #include "llvm/IR/DerivedTypes.h" 71 #include "llvm/IR/Dominators.h" 72 #include "llvm/IR/Function.h" 73 #include "llvm/IR/GlobalAlias.h" 74 #include "llvm/IR/GlobalVariable.h" 75 #include "llvm/IR/IRBuilder.h" 76 #include "llvm/IR/InstrTypes.h" 77 #include "llvm/IR/Instruction.h" 78 #include "llvm/IR/Instructions.h" 79 #include "llvm/IR/Intrinsics.h" 80 #include "llvm/IR/LLVMContext.h" 81 #include "llvm/IR/MDBuilder.h" 82 #include "llvm/IR/Metadata.h" 83 #include "llvm/IR/Module.h" 84 #include "llvm/IR/ModuleSummaryIndexYAML.h" 85 #include "llvm/Support/Casting.h" 86 #include "llvm/Support/CommandLine.h" 87 #include "llvm/Support/Errc.h" 88 #include "llvm/Support/Error.h" 89 #include "llvm/Support/FileSystem.h" 90 #include "llvm/Support/GlobPattern.h" 91 #include "llvm/TargetParser/Triple.h" 92 #include "llvm/Transforms/IPO.h" 93 #include "llvm/Transforms/IPO/FunctionAttrs.h" 94 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 95 #include "llvm/Transforms/Utils/CallPromotionUtils.h" 96 #include "llvm/Transforms/Utils/Evaluator.h" 97 #include <algorithm> 98 #include <cstddef> 99 #include <map> 100 #include <set> 101 #include <string> 102 103 using namespace llvm; 104 using namespace wholeprogramdevirt; 105 106 #define DEBUG_TYPE "wholeprogramdevirt" 107 108 STATISTIC(NumDevirtTargets, "Number of whole program devirtualization targets"); 109 STATISTIC(NumSingleImpl, "Number of single implementation devirtualizations"); 110 STATISTIC(NumBranchFunnel, "Number of branch funnels"); 111 STATISTIC(NumUniformRetVal, "Number of uniform return value optimizations"); 112 STATISTIC(NumUniqueRetVal, "Number of unique return value optimizations"); 113 STATISTIC(NumVirtConstProp1Bit, 114 "Number of 1 bit virtual constant propagations"); 115 STATISTIC(NumVirtConstProp, "Number of virtual constant propagations"); 116 117 static cl::opt<PassSummaryAction> ClSummaryAction( 118 "wholeprogramdevirt-summary-action", 119 cl::desc("What to do with the summary when running this pass"), 120 cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), 121 clEnumValN(PassSummaryAction::Import, "import", 122 "Import typeid resolutions from summary and globals"), 123 clEnumValN(PassSummaryAction::Export, "export", 124 "Export typeid resolutions to summary and globals")), 125 cl::Hidden); 126 127 static cl::opt<std::string> ClReadSummary( 128 "wholeprogramdevirt-read-summary", 129 cl::desc( 130 "Read summary from given bitcode or YAML file before running pass"), 131 cl::Hidden); 132 133 static cl::opt<std::string> ClWriteSummary( 134 "wholeprogramdevirt-write-summary", 135 cl::desc("Write summary to given bitcode or YAML file after running pass. " 136 "Output file format is deduced from extension: *.bc means writing " 137 "bitcode, otherwise YAML"), 138 cl::Hidden); 139 140 static cl::opt<unsigned> 141 ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, 142 cl::init(10), 143 cl::desc("Maximum number of call targets per " 144 "call site to enable branch funnels")); 145 146 static cl::opt<bool> 147 PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden, 148 cl::desc("Print index-based devirtualization messages")); 149 150 /// Provide a way to force enable whole program visibility in tests. 151 /// This is needed to support legacy tests that don't contain 152 /// !vcall_visibility metadata (the mere presense of type tests 153 /// previously implied hidden visibility). 154 static cl::opt<bool> 155 WholeProgramVisibility("whole-program-visibility", cl::Hidden, 156 cl::desc("Enable whole program visibility")); 157 158 /// Provide a way to force disable whole program for debugging or workarounds, 159 /// when enabled via the linker. 160 static cl::opt<bool> DisableWholeProgramVisibility( 161 "disable-whole-program-visibility", cl::Hidden, 162 cl::desc("Disable whole program visibility (overrides enabling options)")); 163 164 /// Provide way to prevent certain function from being devirtualized 165 static cl::list<std::string> 166 SkipFunctionNames("wholeprogramdevirt-skip", 167 cl::desc("Prevent function(s) from being devirtualized"), 168 cl::Hidden, cl::CommaSeparated); 169 170 /// With Clang, a pure virtual class's deleting destructor is emitted as a 171 /// `llvm.trap` intrinsic followed by an unreachable IR instruction. In the 172 /// context of whole program devirtualization, the deleting destructor of a pure 173 /// virtual class won't be invoked by the source code so safe to skip as a 174 /// devirtualize target. 175 /// 176 /// However, not all unreachable functions are safe to skip. In some cases, the 177 /// program intends to run such functions and terminate, for instance, a unit 178 /// test may run a death test. A non-test program might (or allowed to) invoke 179 /// such functions to report failures (whether/when it's a good practice or not 180 /// is a different topic). 181 /// 182 /// This option is enabled to keep an unreachable function as a possible 183 /// devirtualize target to conservatively keep the program behavior. 184 /// 185 /// TODO: Make a pure virtual class's deleting destructor precisely identifiable 186 /// in Clang's codegen for more devirtualization in LLVM. 187 static cl::opt<bool> WholeProgramDevirtKeepUnreachableFunction( 188 "wholeprogramdevirt-keep-unreachable-function", 189 cl::desc("Regard unreachable functions as possible devirtualize targets."), 190 cl::Hidden, cl::init(true)); 191 192 /// If explicitly specified, the devirt module pass will stop transformation 193 /// once the total number of devirtualizations reach the cutoff value. Setting 194 /// this option to 0 explicitly will do 0 devirtualization. 195 static cl::opt<unsigned> WholeProgramDevirtCutoff( 196 "wholeprogramdevirt-cutoff", 197 cl::desc("Max number of devirtualizations for devirt module pass"), 198 cl::init(0)); 199 200 /// Mechanism to add runtime checking of devirtualization decisions, optionally 201 /// trapping or falling back to indirect call on any that are not correct. 202 /// Trapping mode is useful for debugging undefined behavior leading to failures 203 /// with WPD. Fallback mode is useful for ensuring safety when whole program 204 /// visibility may be compromised. 205 enum WPDCheckMode { None, Trap, Fallback }; 206 static cl::opt<WPDCheckMode> DevirtCheckMode( 207 "wholeprogramdevirt-check", cl::Hidden, 208 cl::desc("Type of checking for incorrect devirtualizations"), 209 cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"), 210 clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"), 211 clEnumValN(WPDCheckMode::Fallback, "fallback", 212 "Fallback to indirect when incorrect"))); 213 214 namespace { 215 struct PatternList { 216 std::vector<GlobPattern> Patterns; 217 template <class T> void init(const T &StringList) { 218 for (const auto &S : StringList) 219 if (Expected<GlobPattern> Pat = GlobPattern::create(S)) 220 Patterns.push_back(std::move(*Pat)); 221 } 222 bool match(StringRef S) { 223 for (const GlobPattern &P : Patterns) 224 if (P.match(S)) 225 return true; 226 return false; 227 } 228 }; 229 } // namespace 230 231 // Find the minimum offset that we may store a value of size Size bits at. If 232 // IsAfter is set, look for an offset before the object, otherwise look for an 233 // offset after the object. 234 uint64_t 235 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 236 bool IsAfter, uint64_t Size) { 237 // Find a minimum offset taking into account only vtable sizes. 238 uint64_t MinByte = 0; 239 for (const VirtualCallTarget &Target : Targets) { 240 if (IsAfter) 241 MinByte = std::max(MinByte, Target.minAfterBytes()); 242 else 243 MinByte = std::max(MinByte, Target.minBeforeBytes()); 244 } 245 246 // Build a vector of arrays of bytes covering, for each target, a slice of the 247 // used region (see AccumBitVector::BytesUsed in 248 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 249 // this aligns the used regions to start at MinByte. 250 // 251 // In this example, A, B and C are vtables, # is a byte already allocated for 252 // a virtual function pointer, AAAA... (etc.) are the used regions for the 253 // vtables and Offset(X) is the value computed for the Offset variable below 254 // for X. 255 // 256 // Offset(A) 257 // | | 258 // |MinByte 259 // A: ################AAAAAAAA|AAAAAAAA 260 // B: ########BBBBBBBBBBBBBBBB|BBBB 261 // C: ########################|CCCCCCCCCCCCCCCC 262 // | Offset(B) | 263 // 264 // This code produces the slices of A, B and C that appear after the divider 265 // at MinByte. 266 std::vector<ArrayRef<uint8_t>> Used; 267 for (const VirtualCallTarget &Target : Targets) { 268 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed 269 : Target.TM->Bits->Before.BytesUsed; 270 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 271 : MinByte - Target.minBeforeBytes(); 272 273 // Disregard used regions that are smaller than Offset. These are 274 // effectively all-free regions that do not need to be checked. 275 if (VTUsed.size() > Offset) 276 Used.push_back(VTUsed.slice(Offset)); 277 } 278 279 if (Size == 1) { 280 // Find a free bit in each member of Used. 281 for (unsigned I = 0;; ++I) { 282 uint8_t BitsUsed = 0; 283 for (auto &&B : Used) 284 if (I < B.size()) 285 BitsUsed |= B[I]; 286 if (BitsUsed != 0xff) 287 return (MinByte + I) * 8 + llvm::countr_zero(uint8_t(~BitsUsed)); 288 } 289 } else { 290 // Find a free (Size/8) byte region in each member of Used. 291 // FIXME: see if alignment helps. 292 for (unsigned I = 0;; ++I) { 293 for (auto &&B : Used) { 294 unsigned Byte = 0; 295 while ((I + Byte) < B.size() && Byte < (Size / 8)) { 296 if (B[I + Byte]) 297 goto NextI; 298 ++Byte; 299 } 300 } 301 return (MinByte + I) * 8; 302 NextI:; 303 } 304 } 305 } 306 307 void wholeprogramdevirt::setBeforeReturnValues( 308 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 309 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 310 if (BitWidth == 1) 311 OffsetByte = -(AllocBefore / 8 + 1); 312 else 313 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 314 OffsetBit = AllocBefore % 8; 315 316 for (VirtualCallTarget &Target : Targets) { 317 if (BitWidth == 1) 318 Target.setBeforeBit(AllocBefore); 319 else 320 Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 321 } 322 } 323 324 void wholeprogramdevirt::setAfterReturnValues( 325 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 326 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 327 if (BitWidth == 1) 328 OffsetByte = AllocAfter / 8; 329 else 330 OffsetByte = (AllocAfter + 7) / 8; 331 OffsetBit = AllocAfter % 8; 332 333 for (VirtualCallTarget &Target : Targets) { 334 if (BitWidth == 1) 335 Target.setAfterBit(AllocAfter); 336 else 337 Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 338 } 339 } 340 341 VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM) 342 : Fn(Fn), TM(TM), 343 IsBigEndian(Fn->getDataLayout().isBigEndian()), 344 WasDevirt(false) {} 345 346 namespace { 347 348 // Tracks the number of devirted calls in the IR transformation. 349 static unsigned NumDevirtCalls = 0; 350 351 // A slot in a set of virtual tables. The TypeID identifies the set of virtual 352 // tables, and the ByteOffset is the offset in bytes from the address point to 353 // the virtual function pointer. 354 struct VTableSlot { 355 Metadata *TypeID; 356 uint64_t ByteOffset; 357 }; 358 359 } // end anonymous namespace 360 361 namespace llvm { 362 363 template <> struct DenseMapInfo<VTableSlot> { 364 static VTableSlot getEmptyKey() { 365 return {DenseMapInfo<Metadata *>::getEmptyKey(), 366 DenseMapInfo<uint64_t>::getEmptyKey()}; 367 } 368 static VTableSlot getTombstoneKey() { 369 return {DenseMapInfo<Metadata *>::getTombstoneKey(), 370 DenseMapInfo<uint64_t>::getTombstoneKey()}; 371 } 372 static unsigned getHashValue(const VTableSlot &I) { 373 return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ 374 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 375 } 376 static bool isEqual(const VTableSlot &LHS, 377 const VTableSlot &RHS) { 378 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 379 } 380 }; 381 382 template <> struct DenseMapInfo<VTableSlotSummary> { 383 static VTableSlotSummary getEmptyKey() { 384 return {DenseMapInfo<StringRef>::getEmptyKey(), 385 DenseMapInfo<uint64_t>::getEmptyKey()}; 386 } 387 static VTableSlotSummary getTombstoneKey() { 388 return {DenseMapInfo<StringRef>::getTombstoneKey(), 389 DenseMapInfo<uint64_t>::getTombstoneKey()}; 390 } 391 static unsigned getHashValue(const VTableSlotSummary &I) { 392 return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^ 393 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 394 } 395 static bool isEqual(const VTableSlotSummary &LHS, 396 const VTableSlotSummary &RHS) { 397 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 398 } 399 }; 400 401 } // end namespace llvm 402 403 // Returns true if the function must be unreachable based on ValueInfo. 404 // 405 // In particular, identifies a function as unreachable in the following 406 // conditions 407 // 1) All summaries are live. 408 // 2) All function summaries indicate it's unreachable 409 // 3) There is no non-function with the same GUID (which is rare) 410 static bool mustBeUnreachableFunction(ValueInfo TheFnVI) { 411 if (WholeProgramDevirtKeepUnreachableFunction) 412 return false; 413 414 if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { 415 // Returns false if ValueInfo is absent, or the summary list is empty 416 // (e.g., function declarations). 417 return false; 418 } 419 420 for (const auto &Summary : TheFnVI.getSummaryList()) { 421 // Conservatively returns false if any non-live functions are seen. 422 // In general either all summaries should be live or all should be dead. 423 if (!Summary->isLive()) 424 return false; 425 if (auto *FS = dyn_cast<FunctionSummary>(Summary->getBaseObject())) { 426 if (!FS->fflags().MustBeUnreachable) 427 return false; 428 } 429 // Be conservative if a non-function has the same GUID (which is rare). 430 else 431 return false; 432 } 433 // All function summaries are live and all of them agree that the function is 434 // unreachble. 435 return true; 436 } 437 438 namespace { 439 // A virtual call site. VTable is the loaded virtual table pointer, and CS is 440 // the indirect virtual call. 441 struct VirtualCallSite { 442 Value *VTable = nullptr; 443 CallBase &CB; 444 445 // If non-null, this field points to the associated unsafe use count stored in 446 // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description 447 // of that field for details. 448 unsigned *NumUnsafeUses = nullptr; 449 450 void 451 emitRemark(const StringRef OptName, const StringRef TargetName, 452 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { 453 Function *F = CB.getCaller(); 454 DebugLoc DLoc = CB.getDebugLoc(); 455 BasicBlock *Block = CB.getParent(); 456 457 using namespace ore; 458 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) 459 << NV("Optimization", OptName) 460 << ": devirtualized a call to " 461 << NV("FunctionName", TargetName)); 462 } 463 464 void replaceAndErase( 465 const StringRef OptName, const StringRef TargetName, bool RemarksEnabled, 466 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 467 Value *New) { 468 if (RemarksEnabled) 469 emitRemark(OptName, TargetName, OREGetter); 470 CB.replaceAllUsesWith(New); 471 if (auto *II = dyn_cast<InvokeInst>(&CB)) { 472 BranchInst::Create(II->getNormalDest(), CB.getIterator()); 473 II->getUnwindDest()->removePredecessor(II->getParent()); 474 } 475 CB.eraseFromParent(); 476 // This use is no longer unsafe. 477 if (NumUnsafeUses) 478 --*NumUnsafeUses; 479 } 480 }; 481 482 // Call site information collected for a specific VTableSlot and possibly a list 483 // of constant integer arguments. The grouping by arguments is handled by the 484 // VTableSlotInfo class. 485 struct CallSiteInfo { 486 /// The set of call sites for this slot. Used during regular LTO and the 487 /// import phase of ThinLTO (as well as the export phase of ThinLTO for any 488 /// call sites that appear in the merged module itself); in each of these 489 /// cases we are directly operating on the call sites at the IR level. 490 std::vector<VirtualCallSite> CallSites; 491 492 /// Whether all call sites represented by this CallSiteInfo, including those 493 /// in summaries, have been devirtualized. This starts off as true because a 494 /// default constructed CallSiteInfo represents no call sites. 495 bool AllCallSitesDevirted = true; 496 497 // These fields are used during the export phase of ThinLTO and reflect 498 // information collected from function summaries. 499 500 /// Whether any function summary contains an llvm.assume(llvm.type.test) for 501 /// this slot. 502 bool SummaryHasTypeTestAssumeUsers = false; 503 504 /// CFI-specific: a vector containing the list of function summaries that use 505 /// the llvm.type.checked.load intrinsic and therefore will require 506 /// resolutions for llvm.type.test in order to implement CFI checks if 507 /// devirtualization was unsuccessful. If devirtualization was successful, the 508 /// pass will clear this vector by calling markDevirt(). If at the end of the 509 /// pass the vector is non-empty, we will need to add a use of llvm.type.test 510 /// to each of the function summaries in the vector. 511 std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; 512 std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers; 513 514 bool isExported() const { 515 return SummaryHasTypeTestAssumeUsers || 516 !SummaryTypeCheckedLoadUsers.empty(); 517 } 518 519 void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { 520 SummaryTypeCheckedLoadUsers.push_back(FS); 521 AllCallSitesDevirted = false; 522 } 523 524 void addSummaryTypeTestAssumeUser(FunctionSummary *FS) { 525 SummaryTypeTestAssumeUsers.push_back(FS); 526 SummaryHasTypeTestAssumeUsers = true; 527 AllCallSitesDevirted = false; 528 } 529 530 void markDevirt() { 531 AllCallSitesDevirted = true; 532 533 // As explained in the comment for SummaryTypeCheckedLoadUsers. 534 SummaryTypeCheckedLoadUsers.clear(); 535 } 536 }; 537 538 // Call site information collected for a specific VTableSlot. 539 struct VTableSlotInfo { 540 // The set of call sites which do not have all constant integer arguments 541 // (excluding "this"). 542 CallSiteInfo CSInfo; 543 544 // The set of call sites with all constant integer arguments (excluding 545 // "this"), grouped by argument list. 546 std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; 547 548 void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses); 549 550 private: 551 CallSiteInfo &findCallSiteInfo(CallBase &CB); 552 }; 553 554 CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) { 555 std::vector<uint64_t> Args; 556 auto *CBType = dyn_cast<IntegerType>(CB.getType()); 557 if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty()) 558 return CSInfo; 559 for (auto &&Arg : drop_begin(CB.args())) { 560 auto *CI = dyn_cast<ConstantInt>(Arg); 561 if (!CI || CI->getBitWidth() > 64) 562 return CSInfo; 563 Args.push_back(CI->getZExtValue()); 564 } 565 return ConstCSInfo[Args]; 566 } 567 568 void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB, 569 unsigned *NumUnsafeUses) { 570 auto &CSI = findCallSiteInfo(CB); 571 CSI.AllCallSitesDevirted = false; 572 CSI.CallSites.push_back({VTable, CB, NumUnsafeUses}); 573 } 574 575 struct DevirtModule { 576 Module &M; 577 function_ref<AAResults &(Function &)> AARGetter; 578 function_ref<DominatorTree &(Function &)> LookupDomTree; 579 580 ModuleSummaryIndex *ExportSummary; 581 const ModuleSummaryIndex *ImportSummary; 582 583 IntegerType *Int8Ty; 584 PointerType *Int8PtrTy; 585 IntegerType *Int32Ty; 586 IntegerType *Int64Ty; 587 IntegerType *IntPtrTy; 588 /// Sizeless array type, used for imported vtables. This provides a signal 589 /// to analyzers that these imports may alias, as they do for example 590 /// when multiple unique return values occur in the same vtable. 591 ArrayType *Int8Arr0Ty; 592 593 bool RemarksEnabled; 594 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; 595 596 MapVector<VTableSlot, VTableSlotInfo> CallSlots; 597 598 // Calls that have already been optimized. We may add a call to multiple 599 // VTableSlotInfos if vtable loads are coalesced and need to make sure not to 600 // optimize a call more than once. 601 SmallPtrSet<CallBase *, 8> OptimizedCalls; 602 603 // Store calls that had their ptrauth bundle removed. They are to be deleted 604 // at the end of the optimization. 605 SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved; 606 607 // This map keeps track of the number of "unsafe" uses of a loaded function 608 // pointer. The key is the associated llvm.type.test intrinsic call generated 609 // by this pass. An unsafe use is one that calls the loaded function pointer 610 // directly. Every time we eliminate an unsafe use (for example, by 611 // devirtualizing it or by applying virtual constant propagation), we 612 // decrement the value stored in this map. If a value reaches zero, we can 613 // eliminate the type check by RAUWing the associated llvm.type.test call with 614 // true. 615 std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; 616 PatternList FunctionsToSkip; 617 618 DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, 619 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 620 function_ref<DominatorTree &(Function &)> LookupDomTree, 621 ModuleSummaryIndex *ExportSummary, 622 const ModuleSummaryIndex *ImportSummary) 623 : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), 624 ExportSummary(ExportSummary), ImportSummary(ImportSummary), 625 Int8Ty(Type::getInt8Ty(M.getContext())), 626 Int8PtrTy(PointerType::getUnqual(M.getContext())), 627 Int32Ty(Type::getInt32Ty(M.getContext())), 628 Int64Ty(Type::getInt64Ty(M.getContext())), 629 IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), 630 Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)), 631 RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { 632 assert(!(ExportSummary && ImportSummary)); 633 FunctionsToSkip.init(SkipFunctionNames); 634 } 635 636 bool areRemarksEnabled(); 637 638 void 639 scanTypeTestUsers(Function *TypeTestFunc, 640 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 641 void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); 642 643 void buildTypeIdentifierMap( 644 std::vector<VTableBits> &Bits, 645 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 646 647 bool 648 tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 649 const std::set<TypeMemberInfo> &TypeMemberInfos, 650 uint64_t ByteOffset, 651 ModuleSummaryIndex *ExportSummary); 652 653 void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, 654 bool &IsExported); 655 bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary, 656 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 657 VTableSlotInfo &SlotInfo, 658 WholeProgramDevirtResolution *Res); 659 660 void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT, 661 bool &IsExported); 662 void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 663 VTableSlotInfo &SlotInfo, 664 WholeProgramDevirtResolution *Res, VTableSlot Slot); 665 666 bool tryEvaluateFunctionsWithArgs( 667 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 668 ArrayRef<uint64_t> Args); 669 670 void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 671 uint64_t TheRetVal); 672 bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 673 CallSiteInfo &CSInfo, 674 WholeProgramDevirtResolution::ByArg *Res); 675 676 // Returns the global symbol name that is used to export information about the 677 // given vtable slot and list of arguments. 678 std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, 679 StringRef Name); 680 681 bool shouldExportConstantsAsAbsoluteSymbols(); 682 683 // This function is called during the export phase to create a symbol 684 // definition containing information about the given vtable slot and list of 685 // arguments. 686 void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 687 Constant *C); 688 void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 689 uint32_t Const, uint32_t &Storage); 690 691 // This function is called during the import phase to create a reference to 692 // the symbol definition created during the export phase. 693 Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 694 StringRef Name); 695 Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 696 StringRef Name, IntegerType *IntTy, 697 uint32_t Storage); 698 699 Constant *getMemberAddr(const TypeMemberInfo *M); 700 701 void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, 702 Constant *UniqueMemberAddr); 703 bool tryUniqueRetValOpt(unsigned BitWidth, 704 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 705 CallSiteInfo &CSInfo, 706 WholeProgramDevirtResolution::ByArg *Res, 707 VTableSlot Slot, ArrayRef<uint64_t> Args); 708 709 void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 710 Constant *Byte, Constant *Bit); 711 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 712 VTableSlotInfo &SlotInfo, 713 WholeProgramDevirtResolution *Res, VTableSlot Slot); 714 715 void rebuildGlobal(VTableBits &B); 716 717 // Apply the summary resolution for Slot to all virtual calls in SlotInfo. 718 void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); 719 720 // If we were able to eliminate all unsafe uses for a type checked load, 721 // eliminate the associated type tests by replacing them with true. 722 void removeRedundantTypeTests(); 723 724 bool run(); 725 726 // Look up the corresponding ValueInfo entry of `TheFn` in `ExportSummary`. 727 // 728 // Caller guarantees that `ExportSummary` is not nullptr. 729 static ValueInfo lookUpFunctionValueInfo(Function *TheFn, 730 ModuleSummaryIndex *ExportSummary); 731 732 // Returns true if the function definition must be unreachable. 733 // 734 // Note if this helper function returns true, `F` is guaranteed 735 // to be unreachable; if it returns false, `F` might still 736 // be unreachable but not covered by this helper function. 737 // 738 // Implementation-wise, if function definition is present, IR is analyzed; if 739 // not, look up function flags from ExportSummary as a fallback. 740 static bool mustBeUnreachableFunction(Function *const F, 741 ModuleSummaryIndex *ExportSummary); 742 743 // Lower the module using the action and summary passed as command line 744 // arguments. For testing purposes only. 745 static bool 746 runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter, 747 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 748 function_ref<DominatorTree &(Function &)> LookupDomTree); 749 }; 750 751 struct DevirtIndex { 752 ModuleSummaryIndex &ExportSummary; 753 // The set in which to record GUIDs exported from their module by 754 // devirtualization, used by client to ensure they are not internalized. 755 std::set<GlobalValue::GUID> &ExportedGUIDs; 756 // A map in which to record the information necessary to locate the WPD 757 // resolution for local targets in case they are exported by cross module 758 // importing. 759 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap; 760 761 MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots; 762 763 PatternList FunctionsToSkip; 764 765 DevirtIndex( 766 ModuleSummaryIndex &ExportSummary, 767 std::set<GlobalValue::GUID> &ExportedGUIDs, 768 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) 769 : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs), 770 LocalWPDTargetsMap(LocalWPDTargetsMap) { 771 FunctionsToSkip.init(SkipFunctionNames); 772 } 773 774 bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot, 775 const TypeIdCompatibleVtableInfo TIdInfo, 776 uint64_t ByteOffset); 777 778 bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 779 VTableSlotSummary &SlotSummary, 780 VTableSlotInfo &SlotInfo, 781 WholeProgramDevirtResolution *Res, 782 std::set<ValueInfo> &DevirtTargets); 783 784 void run(); 785 }; 786 } // end anonymous namespace 787 788 PreservedAnalyses WholeProgramDevirtPass::run(Module &M, 789 ModuleAnalysisManager &AM) { 790 auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 791 auto AARGetter = [&](Function &F) -> AAResults & { 792 return FAM.getResult<AAManager>(F); 793 }; 794 auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 795 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 796 }; 797 auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { 798 return FAM.getResult<DominatorTreeAnalysis>(F); 799 }; 800 if (UseCommandLine) { 801 if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) 802 return PreservedAnalyses::all(); 803 return PreservedAnalyses::none(); 804 } 805 if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary, 806 ImportSummary) 807 .run()) 808 return PreservedAnalyses::all(); 809 return PreservedAnalyses::none(); 810 } 811 812 // Enable whole program visibility if enabled by client (e.g. linker) or 813 // internal option, and not force disabled. 814 bool llvm::hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { 815 return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && 816 !DisableWholeProgramVisibility; 817 } 818 819 static bool 820 typeIDVisibleToRegularObj(StringRef TypeID, 821 function_ref<bool(StringRef)> IsVisibleToRegularObj) { 822 // TypeID for member function pointer type is an internal construct 823 // and won't exist in IsVisibleToRegularObj. The full TypeID 824 // will be present and participate in invalidation. 825 if (TypeID.ends_with(".virtual")) 826 return false; 827 828 // TypeID that doesn't start with Itanium mangling (_ZTS) will be 829 // non-externally visible types which cannot interact with 830 // external native files. See CodeGenModule::CreateMetadataIdentifierImpl. 831 if (!TypeID.consume_front("_ZTS")) 832 return false; 833 834 // TypeID is keyed off the type name symbol (_ZTS). However, the native 835 // object may not contain this symbol if it does not contain a key 836 // function for the base type and thus only contains a reference to the 837 // type info (_ZTI). To catch this case we query using the type info 838 // symbol corresponding to the TypeID. 839 std::string typeInfo = ("_ZTI" + TypeID).str(); 840 return IsVisibleToRegularObj(typeInfo); 841 } 842 843 static bool 844 skipUpdateDueToValidation(GlobalVariable &GV, 845 function_ref<bool(StringRef)> IsVisibleToRegularObj) { 846 SmallVector<MDNode *, 2> Types; 847 GV.getMetadata(LLVMContext::MD_type, Types); 848 849 for (auto Type : Types) 850 if (auto *TypeID = dyn_cast<MDString>(Type->getOperand(1).get())) 851 return typeIDVisibleToRegularObj(TypeID->getString(), 852 IsVisibleToRegularObj); 853 854 return false; 855 } 856 857 /// If whole program visibility asserted, then upgrade all public vcall 858 /// visibility metadata on vtable definitions to linkage unit visibility in 859 /// Module IR (for regular or hybrid LTO). 860 void llvm::updateVCallVisibilityInModule( 861 Module &M, bool WholeProgramVisibilityEnabledInLTO, 862 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, 863 bool ValidateAllVtablesHaveTypeInfos, 864 function_ref<bool(StringRef)> IsVisibleToRegularObj) { 865 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 866 return; 867 for (GlobalVariable &GV : M.globals()) { 868 // Add linkage unit visibility to any variable with type metadata, which are 869 // the vtable definitions. We won't have an existing vcall_visibility 870 // metadata on vtable definitions with public visibility. 871 if (GV.hasMetadata(LLVMContext::MD_type) && 872 GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic && 873 // Don't upgrade the visibility for symbols exported to the dynamic 874 // linker, as we have no information on their eventual use. 875 !DynamicExportSymbols.count(GV.getGUID()) && 876 // With validation enabled, we want to exclude symbols visible to 877 // regular objects. Local symbols will be in this group due to the 878 // current implementation but those with VCallVisibilityTranslationUnit 879 // will have already been marked in clang so are unaffected. 880 !(ValidateAllVtablesHaveTypeInfos && 881 skipUpdateDueToValidation(GV, IsVisibleToRegularObj))) 882 GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); 883 } 884 } 885 886 void llvm::updatePublicTypeTestCalls(Module &M, 887 bool WholeProgramVisibilityEnabledInLTO) { 888 Function *PublicTypeTestFunc = 889 Intrinsic::getDeclarationIfExists(&M, Intrinsic::public_type_test); 890 if (!PublicTypeTestFunc) 891 return; 892 if (hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) { 893 Function *TypeTestFunc = 894 Intrinsic::getOrInsertDeclaration(&M, Intrinsic::type_test); 895 for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) { 896 auto *CI = cast<CallInst>(U.getUser()); 897 auto *NewCI = CallInst::Create( 898 TypeTestFunc, {CI->getArgOperand(0), CI->getArgOperand(1)}, {}, "", 899 CI->getIterator()); 900 CI->replaceAllUsesWith(NewCI); 901 CI->eraseFromParent(); 902 } 903 } else { 904 auto *True = ConstantInt::getTrue(M.getContext()); 905 for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) { 906 auto *CI = cast<CallInst>(U.getUser()); 907 CI->replaceAllUsesWith(True); 908 CI->eraseFromParent(); 909 } 910 } 911 } 912 913 /// Based on typeID string, get all associated vtable GUIDS that are 914 /// visible to regular objects. 915 void llvm::getVisibleToRegularObjVtableGUIDs( 916 ModuleSummaryIndex &Index, 917 DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols, 918 function_ref<bool(StringRef)> IsVisibleToRegularObj) { 919 for (const auto &typeID : Index.typeIdCompatibleVtableMap()) { 920 if (typeIDVisibleToRegularObj(typeID.first, IsVisibleToRegularObj)) 921 for (const TypeIdOffsetVtableInfo &P : typeID.second) 922 VisibleToRegularObjSymbols.insert(P.VTableVI.getGUID()); 923 } 924 } 925 926 /// If whole program visibility asserted, then upgrade all public vcall 927 /// visibility metadata on vtable definition summaries to linkage unit 928 /// visibility in Module summary index (for ThinLTO). 929 void llvm::updateVCallVisibilityInIndex( 930 ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO, 931 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols, 932 const DenseSet<GlobalValue::GUID> &VisibleToRegularObjSymbols) { 933 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 934 return; 935 for (auto &P : Index) { 936 // Don't upgrade the visibility for symbols exported to the dynamic 937 // linker, as we have no information on their eventual use. 938 if (DynamicExportSymbols.count(P.first)) 939 continue; 940 for (auto &S : P.second.SummaryList) { 941 auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); 942 if (!GVar || 943 GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) 944 continue; 945 // With validation enabled, we want to exclude symbols visible to regular 946 // objects. Local symbols will be in this group due to the current 947 // implementation but those with VCallVisibilityTranslationUnit will have 948 // already been marked in clang so are unaffected. 949 if (VisibleToRegularObjSymbols.count(P.first)) 950 continue; 951 GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); 952 } 953 } 954 } 955 956 void llvm::runWholeProgramDevirtOnIndex( 957 ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, 958 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 959 DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); 960 } 961 962 void llvm::updateIndexWPDForExports( 963 ModuleSummaryIndex &Summary, 964 function_ref<bool(StringRef, ValueInfo)> isExported, 965 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 966 for (auto &T : LocalWPDTargetsMap) { 967 auto &VI = T.first; 968 // This was enforced earlier during trySingleImplDevirt. 969 assert(VI.getSummaryList().size() == 1 && 970 "Devirt of local target has more than one copy"); 971 auto &S = VI.getSummaryList()[0]; 972 if (!isExported(S->modulePath(), VI)) 973 continue; 974 975 // It's been exported by a cross module import. 976 for (auto &SlotSummary : T.second) { 977 auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID); 978 assert(TIdSum); 979 auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset); 980 assert(WPDRes != TIdSum->WPDRes.end()); 981 WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 982 WPDRes->second.SingleImplName, 983 Summary.getModuleHash(S->modulePath())); 984 } 985 } 986 } 987 988 static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { 989 // Check that summary index contains regular LTO module when performing 990 // export to prevent occasional use of index from pure ThinLTO compilation 991 // (-fno-split-lto-module). This kind of summary index is passed to 992 // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting. 993 const auto &ModPaths = Summary->modulePaths(); 994 if (ClSummaryAction != PassSummaryAction::Import && 995 !ModPaths.contains(ModuleSummaryIndex::getRegularLTOModuleName())) 996 return createStringError( 997 errc::invalid_argument, 998 "combined summary should contain Regular LTO module"); 999 return ErrorSuccess(); 1000 } 1001 1002 bool DevirtModule::runForTesting( 1003 Module &M, function_ref<AAResults &(Function &)> AARGetter, 1004 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 1005 function_ref<DominatorTree &(Function &)> LookupDomTree) { 1006 std::unique_ptr<ModuleSummaryIndex> Summary = 1007 std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false); 1008 1009 // Handle the command-line summary arguments. This code is for testing 1010 // purposes only, so we handle errors directly. 1011 if (!ClReadSummary.empty()) { 1012 ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + 1013 ": "); 1014 auto ReadSummaryFile = 1015 ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); 1016 if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr = 1017 getModuleSummaryIndex(*ReadSummaryFile)) { 1018 Summary = std::move(*SummaryOrErr); 1019 ExitOnErr(checkCombinedSummaryForTesting(Summary.get())); 1020 } else { 1021 // Try YAML if we've failed with bitcode. 1022 consumeError(SummaryOrErr.takeError()); 1023 yaml::Input In(ReadSummaryFile->getBuffer()); 1024 In >> *Summary; 1025 ExitOnErr(errorCodeToError(In.error())); 1026 } 1027 } 1028 1029 bool Changed = 1030 DevirtModule(M, AARGetter, OREGetter, LookupDomTree, 1031 ClSummaryAction == PassSummaryAction::Export ? Summary.get() 1032 : nullptr, 1033 ClSummaryAction == PassSummaryAction::Import ? Summary.get() 1034 : nullptr) 1035 .run(); 1036 1037 if (!ClWriteSummary.empty()) { 1038 ExitOnError ExitOnErr( 1039 "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); 1040 std::error_code EC; 1041 if (StringRef(ClWriteSummary).ends_with(".bc")) { 1042 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); 1043 ExitOnErr(errorCodeToError(EC)); 1044 writeIndexToFile(*Summary, OS); 1045 } else { 1046 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF); 1047 ExitOnErr(errorCodeToError(EC)); 1048 yaml::Output Out(OS); 1049 Out << *Summary; 1050 } 1051 } 1052 1053 return Changed; 1054 } 1055 1056 void DevirtModule::buildTypeIdentifierMap( 1057 std::vector<VTableBits> &Bits, 1058 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 1059 DenseMap<GlobalVariable *, VTableBits *> GVToBits; 1060 Bits.reserve(M.global_size()); 1061 SmallVector<MDNode *, 2> Types; 1062 for (GlobalVariable &GV : M.globals()) { 1063 Types.clear(); 1064 GV.getMetadata(LLVMContext::MD_type, Types); 1065 if (GV.isDeclaration() || Types.empty()) 1066 continue; 1067 1068 VTableBits *&BitsPtr = GVToBits[&GV]; 1069 if (!BitsPtr) { 1070 Bits.emplace_back(); 1071 Bits.back().GV = &GV; 1072 Bits.back().ObjectSize = 1073 M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); 1074 BitsPtr = &Bits.back(); 1075 } 1076 1077 for (MDNode *Type : Types) { 1078 auto TypeID = Type->getOperand(1).get(); 1079 1080 uint64_t Offset = 1081 cast<ConstantInt>( 1082 cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) 1083 ->getZExtValue(); 1084 1085 TypeIdMap[TypeID].insert({BitsPtr, Offset}); 1086 } 1087 } 1088 } 1089 1090 bool DevirtModule::tryFindVirtualCallTargets( 1091 std::vector<VirtualCallTarget> &TargetsForSlot, 1092 const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset, 1093 ModuleSummaryIndex *ExportSummary) { 1094 for (const TypeMemberInfo &TM : TypeMemberInfos) { 1095 if (!TM.Bits->GV->isConstant()) 1096 return false; 1097 1098 // We cannot perform whole program devirtualization analysis on a vtable 1099 // with public LTO visibility. 1100 if (TM.Bits->GV->getVCallVisibility() == 1101 GlobalObject::VCallVisibilityPublic) 1102 return false; 1103 1104 Function *Fn = nullptr; 1105 Constant *C = nullptr; 1106 std::tie(Fn, C) = 1107 getFunctionAtVTableOffset(TM.Bits->GV, TM.Offset + ByteOffset, M); 1108 1109 if (!Fn) 1110 return false; 1111 1112 if (FunctionsToSkip.match(Fn->getName())) 1113 return false; 1114 1115 // We can disregard __cxa_pure_virtual as a possible call target, as 1116 // calls to pure virtuals are UB. 1117 if (Fn->getName() == "__cxa_pure_virtual") 1118 continue; 1119 1120 // We can disregard unreachable functions as possible call targets, as 1121 // unreachable functions shouldn't be called. 1122 if (mustBeUnreachableFunction(Fn, ExportSummary)) 1123 continue; 1124 1125 // Save the symbol used in the vtable to use as the devirtualization 1126 // target. 1127 auto GV = dyn_cast<GlobalValue>(C); 1128 assert(GV); 1129 TargetsForSlot.push_back({GV, &TM}); 1130 } 1131 1132 // Give up if we couldn't find any targets. 1133 return !TargetsForSlot.empty(); 1134 } 1135 1136 bool DevirtIndex::tryFindVirtualCallTargets( 1137 std::vector<ValueInfo> &TargetsForSlot, 1138 const TypeIdCompatibleVtableInfo TIdInfo, uint64_t ByteOffset) { 1139 for (const TypeIdOffsetVtableInfo &P : TIdInfo) { 1140 // Find a representative copy of the vtable initializer. 1141 // We can have multiple available_externally, linkonce_odr and weak_odr 1142 // vtable initializers. We can also have multiple external vtable 1143 // initializers in the case of comdats, which we cannot check here. 1144 // The linker should give an error in this case. 1145 // 1146 // Also, handle the case of same-named local Vtables with the same path 1147 // and therefore the same GUID. This can happen if there isn't enough 1148 // distinguishing path when compiling the source file. In that case we 1149 // conservatively return false early. 1150 const GlobalVarSummary *VS = nullptr; 1151 bool LocalFound = false; 1152 for (const auto &S : P.VTableVI.getSummaryList()) { 1153 if (GlobalValue::isLocalLinkage(S->linkage())) { 1154 if (LocalFound) 1155 return false; 1156 LocalFound = true; 1157 } 1158 auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject()); 1159 if (!CurVS->vTableFuncs().empty() || 1160 // Previously clang did not attach the necessary type metadata to 1161 // available_externally vtables, in which case there would not 1162 // be any vtable functions listed in the summary and we need 1163 // to treat this case conservatively (in case the bitcode is old). 1164 // However, we will also not have any vtable functions in the 1165 // case of a pure virtual base class. In that case we do want 1166 // to set VS to avoid treating it conservatively. 1167 !GlobalValue::isAvailableExternallyLinkage(S->linkage())) { 1168 VS = CurVS; 1169 // We cannot perform whole program devirtualization analysis on a vtable 1170 // with public LTO visibility. 1171 if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) 1172 return false; 1173 } 1174 } 1175 // There will be no VS if all copies are available_externally having no 1176 // type metadata. In that case we can't safely perform WPD. 1177 if (!VS) 1178 return false; 1179 if (!VS->isLive()) 1180 continue; 1181 for (auto VTP : VS->vTableFuncs()) { 1182 if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset) 1183 continue; 1184 1185 if (mustBeUnreachableFunction(VTP.FuncVI)) 1186 continue; 1187 1188 TargetsForSlot.push_back(VTP.FuncVI); 1189 } 1190 } 1191 1192 // Give up if we couldn't find any targets. 1193 return !TargetsForSlot.empty(); 1194 } 1195 1196 void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, 1197 Constant *TheFn, bool &IsExported) { 1198 // Don't devirtualize function if we're told to skip it 1199 // in -wholeprogramdevirt-skip. 1200 if (FunctionsToSkip.match(TheFn->stripPointerCasts()->getName())) 1201 return; 1202 auto Apply = [&](CallSiteInfo &CSInfo) { 1203 for (auto &&VCallSite : CSInfo.CallSites) { 1204 if (!OptimizedCalls.insert(&VCallSite.CB).second) 1205 continue; 1206 1207 // Stop when the number of devirted calls reaches the cutoff. 1208 if (WholeProgramDevirtCutoff.getNumOccurrences() > 0 && 1209 NumDevirtCalls >= WholeProgramDevirtCutoff) 1210 return; 1211 1212 if (RemarksEnabled) 1213 VCallSite.emitRemark("single-impl", 1214 TheFn->stripPointerCasts()->getName(), OREGetter); 1215 NumSingleImpl++; 1216 NumDevirtCalls++; 1217 auto &CB = VCallSite.CB; 1218 assert(!CB.getCalledFunction() && "devirtualizing direct call?"); 1219 IRBuilder<> Builder(&CB); 1220 Value *Callee = 1221 Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType()); 1222 1223 // If trap checking is enabled, add support to compare the virtual 1224 // function pointer to the devirtualized target. In case of a mismatch, 1225 // perform a debug trap. 1226 if (DevirtCheckMode == WPDCheckMode::Trap) { 1227 auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee); 1228 Instruction *ThenTerm = SplitBlockAndInsertIfThen( 1229 Cond, &CB, /*Unreachable=*/false, 1230 MDBuilder(M.getContext()).createUnlikelyBranchWeights()); 1231 Builder.SetInsertPoint(ThenTerm); 1232 Function *TrapFn = 1233 Intrinsic::getOrInsertDeclaration(&M, Intrinsic::debugtrap); 1234 auto *CallTrap = Builder.CreateCall(TrapFn); 1235 CallTrap->setDebugLoc(CB.getDebugLoc()); 1236 } 1237 1238 // If fallback checking is enabled, add support to compare the virtual 1239 // function pointer to the devirtualized target. In case of a mismatch, 1240 // fall back to indirect call. 1241 if (DevirtCheckMode == WPDCheckMode::Fallback) { 1242 MDNode *Weights = MDBuilder(M.getContext()).createLikelyBranchWeights(); 1243 // Version the indirect call site. If the called value is equal to the 1244 // given callee, 'NewInst' will be executed, otherwise the original call 1245 // site will be executed. 1246 CallBase &NewInst = versionCallSite(CB, Callee, Weights); 1247 NewInst.setCalledOperand(Callee); 1248 // Since the new call site is direct, we must clear metadata that 1249 // is only appropriate for indirect calls. This includes !prof and 1250 // !callees metadata. 1251 NewInst.setMetadata(LLVMContext::MD_prof, nullptr); 1252 NewInst.setMetadata(LLVMContext::MD_callees, nullptr); 1253 // Additionally, we should remove them from the fallback indirect call, 1254 // so that we don't attempt to perform indirect call promotion later. 1255 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1256 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1257 } 1258 1259 // In either trapping or non-checking mode, devirtualize original call. 1260 else { 1261 // Devirtualize unconditionally. 1262 CB.setCalledOperand(Callee); 1263 // Since the call site is now direct, we must clear metadata that 1264 // is only appropriate for indirect calls. This includes !prof and 1265 // !callees metadata. 1266 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1267 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1268 if (CB.getCalledOperand() && 1269 CB.getOperandBundle(LLVMContext::OB_ptrauth)) { 1270 auto *NewCS = CallBase::removeOperandBundle( 1271 &CB, LLVMContext::OB_ptrauth, CB.getIterator()); 1272 CB.replaceAllUsesWith(NewCS); 1273 // Schedule for deletion at the end of pass run. 1274 CallsWithPtrAuthBundleRemoved.push_back(&CB); 1275 } 1276 } 1277 1278 // This use is no longer unsafe. 1279 if (VCallSite.NumUnsafeUses) 1280 --*VCallSite.NumUnsafeUses; 1281 } 1282 if (CSInfo.isExported()) 1283 IsExported = true; 1284 CSInfo.markDevirt(); 1285 }; 1286 Apply(SlotInfo.CSInfo); 1287 for (auto &P : SlotInfo.ConstCSInfo) 1288 Apply(P.second); 1289 } 1290 1291 static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { 1292 // We can't add calls if we haven't seen a definition 1293 if (Callee.getSummaryList().empty()) 1294 return false; 1295 1296 // Insert calls into the summary index so that the devirtualized targets 1297 // are eligible for import. 1298 // FIXME: Annotate type tests with hotness. For now, mark these as hot 1299 // to better ensure we have the opportunity to inline them. 1300 bool IsExported = false; 1301 auto &S = Callee.getSummaryList()[0]; 1302 CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* HasTailCall = */ false, 1303 /* RelBF = */ 0); 1304 auto AddCalls = [&](CallSiteInfo &CSInfo) { 1305 for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { 1306 FS->addCall({Callee, CI}); 1307 IsExported |= S->modulePath() != FS->modulePath(); 1308 } 1309 for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) { 1310 FS->addCall({Callee, CI}); 1311 IsExported |= S->modulePath() != FS->modulePath(); 1312 } 1313 }; 1314 AddCalls(SlotInfo.CSInfo); 1315 for (auto &P : SlotInfo.ConstCSInfo) 1316 AddCalls(P.second); 1317 return IsExported; 1318 } 1319 1320 bool DevirtModule::trySingleImplDevirt( 1321 ModuleSummaryIndex *ExportSummary, 1322 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1323 WholeProgramDevirtResolution *Res) { 1324 // See if the program contains a single implementation of this virtual 1325 // function. 1326 auto *TheFn = TargetsForSlot[0].Fn; 1327 for (auto &&Target : TargetsForSlot) 1328 if (TheFn != Target.Fn) 1329 return false; 1330 1331 // If so, update each call site to call that implementation directly. 1332 if (RemarksEnabled || AreStatisticsEnabled()) 1333 TargetsForSlot[0].WasDevirt = true; 1334 1335 bool IsExported = false; 1336 applySingleImplDevirt(SlotInfo, TheFn, IsExported); 1337 if (!IsExported) 1338 return false; 1339 1340 // If the only implementation has local linkage, we must promote to external 1341 // to make it visible to thin LTO objects. We can only get here during the 1342 // ThinLTO export phase. 1343 if (TheFn->hasLocalLinkage()) { 1344 std::string NewName = (TheFn->getName() + ".llvm.merged").str(); 1345 1346 // Since we are renaming the function, any comdats with the same name must 1347 // also be renamed. This is required when targeting COFF, as the comdat name 1348 // must match one of the names of the symbols in the comdat. 1349 if (Comdat *C = TheFn->getComdat()) { 1350 if (C->getName() == TheFn->getName()) { 1351 Comdat *NewC = M.getOrInsertComdat(NewName); 1352 NewC->setSelectionKind(C->getSelectionKind()); 1353 for (GlobalObject &GO : M.global_objects()) 1354 if (GO.getComdat() == C) 1355 GO.setComdat(NewC); 1356 } 1357 } 1358 1359 TheFn->setLinkage(GlobalValue::ExternalLinkage); 1360 TheFn->setVisibility(GlobalValue::HiddenVisibility); 1361 TheFn->setName(NewName); 1362 } 1363 if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID())) 1364 // Any needed promotion of 'TheFn' has already been done during 1365 // LTO unit split, so we can ignore return value of AddCalls. 1366 AddCalls(SlotInfo, TheFnVI); 1367 1368 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1369 Res->SingleImplName = std::string(TheFn->getName()); 1370 1371 return true; 1372 } 1373 1374 bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 1375 VTableSlotSummary &SlotSummary, 1376 VTableSlotInfo &SlotInfo, 1377 WholeProgramDevirtResolution *Res, 1378 std::set<ValueInfo> &DevirtTargets) { 1379 // See if the program contains a single implementation of this virtual 1380 // function. 1381 auto TheFn = TargetsForSlot[0]; 1382 for (auto &&Target : TargetsForSlot) 1383 if (TheFn != Target) 1384 return false; 1385 1386 // Don't devirtualize if we don't have target definition. 1387 auto Size = TheFn.getSummaryList().size(); 1388 if (!Size) 1389 return false; 1390 1391 // Don't devirtualize function if we're told to skip it 1392 // in -wholeprogramdevirt-skip. 1393 if (FunctionsToSkip.match(TheFn.name())) 1394 return false; 1395 1396 // If the summary list contains multiple summaries where at least one is 1397 // a local, give up, as we won't know which (possibly promoted) name to use. 1398 for (const auto &S : TheFn.getSummaryList()) 1399 if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) 1400 return false; 1401 1402 // Collect functions devirtualized at least for one call site for stats. 1403 if (PrintSummaryDevirt || AreStatisticsEnabled()) 1404 DevirtTargets.insert(TheFn); 1405 1406 auto &S = TheFn.getSummaryList()[0]; 1407 bool IsExported = AddCalls(SlotInfo, TheFn); 1408 if (IsExported) 1409 ExportedGUIDs.insert(TheFn.getGUID()); 1410 1411 // Record in summary for use in devirtualization during the ThinLTO import 1412 // step. 1413 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1414 if (GlobalValue::isLocalLinkage(S->linkage())) { 1415 if (IsExported) 1416 // If target is a local function and we are exporting it by 1417 // devirtualizing a call in another module, we need to record the 1418 // promoted name. 1419 Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 1420 TheFn.name(), ExportSummary.getModuleHash(S->modulePath())); 1421 else { 1422 LocalWPDTargetsMap[TheFn].push_back(SlotSummary); 1423 Res->SingleImplName = std::string(TheFn.name()); 1424 } 1425 } else 1426 Res->SingleImplName = std::string(TheFn.name()); 1427 1428 // Name will be empty if this thin link driven off of serialized combined 1429 // index (e.g. llvm-lto). However, WPD is not supported/invoked for the 1430 // legacy LTO API anyway. 1431 assert(!Res->SingleImplName.empty()); 1432 1433 return true; 1434 } 1435 1436 void DevirtModule::tryICallBranchFunnel( 1437 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1438 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1439 Triple T(M.getTargetTriple()); 1440 if (T.getArch() != Triple::x86_64) 1441 return; 1442 1443 if (TargetsForSlot.size() > ClThreshold) 1444 return; 1445 1446 bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted; 1447 if (!HasNonDevirt) 1448 for (auto &P : SlotInfo.ConstCSInfo) 1449 if (!P.second.AllCallSitesDevirted) { 1450 HasNonDevirt = true; 1451 break; 1452 } 1453 1454 if (!HasNonDevirt) 1455 return; 1456 1457 FunctionType *FT = 1458 FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true); 1459 Function *JT; 1460 if (isa<MDString>(Slot.TypeID)) { 1461 JT = Function::Create(FT, Function::ExternalLinkage, 1462 M.getDataLayout().getProgramAddressSpace(), 1463 getGlobalName(Slot, {}, "branch_funnel"), &M); 1464 JT->setVisibility(GlobalValue::HiddenVisibility); 1465 } else { 1466 JT = Function::Create(FT, Function::InternalLinkage, 1467 M.getDataLayout().getProgramAddressSpace(), 1468 "branch_funnel", &M); 1469 } 1470 JT->addParamAttr(0, Attribute::Nest); 1471 1472 std::vector<Value *> JTArgs; 1473 JTArgs.push_back(JT->arg_begin()); 1474 for (auto &T : TargetsForSlot) { 1475 JTArgs.push_back(getMemberAddr(T.TM)); 1476 JTArgs.push_back(T.Fn); 1477 } 1478 1479 BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); 1480 Function *Intr = Intrinsic::getOrInsertDeclaration( 1481 &M, llvm::Intrinsic::icall_branch_funnel, {}); 1482 1483 auto *CI = CallInst::Create(Intr, JTArgs, "", BB); 1484 CI->setTailCallKind(CallInst::TCK_MustTail); 1485 ReturnInst::Create(M.getContext(), nullptr, BB); 1486 1487 bool IsExported = false; 1488 applyICallBranchFunnel(SlotInfo, JT, IsExported); 1489 if (IsExported) 1490 Res->TheKind = WholeProgramDevirtResolution::BranchFunnel; 1491 } 1492 1493 void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, 1494 Constant *JT, bool &IsExported) { 1495 auto Apply = [&](CallSiteInfo &CSInfo) { 1496 if (CSInfo.isExported()) 1497 IsExported = true; 1498 if (CSInfo.AllCallSitesDevirted) 1499 return; 1500 1501 std::map<CallBase *, CallBase *> CallBases; 1502 for (auto &&VCallSite : CSInfo.CallSites) { 1503 CallBase &CB = VCallSite.CB; 1504 1505 if (CallBases.find(&CB) != CallBases.end()) { 1506 // When finding devirtualizable calls, it's possible to find the same 1507 // vtable passed to multiple llvm.type.test or llvm.type.checked.load 1508 // calls, which can cause duplicate call sites to be recorded in 1509 // [Const]CallSites. If we've already found one of these 1510 // call instances, just ignore it. It will be replaced later. 1511 continue; 1512 } 1513 1514 // Jump tables are only profitable if the retpoline mitigation is enabled. 1515 Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); 1516 if (!FSAttr.isValid() || 1517 !FSAttr.getValueAsString().contains("+retpoline")) 1518 continue; 1519 1520 NumBranchFunnel++; 1521 if (RemarksEnabled) 1522 VCallSite.emitRemark("branch-funnel", 1523 JT->stripPointerCasts()->getName(), OREGetter); 1524 1525 // Pass the address of the vtable in the nest register, which is r10 on 1526 // x86_64. 1527 std::vector<Type *> NewArgs; 1528 NewArgs.push_back(Int8PtrTy); 1529 append_range(NewArgs, CB.getFunctionType()->params()); 1530 FunctionType *NewFT = 1531 FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs, 1532 CB.getFunctionType()->isVarArg()); 1533 IRBuilder<> IRB(&CB); 1534 std::vector<Value *> Args; 1535 Args.push_back(VCallSite.VTable); 1536 llvm::append_range(Args, CB.args()); 1537 1538 CallBase *NewCS = nullptr; 1539 if (isa<CallInst>(CB)) 1540 NewCS = IRB.CreateCall(NewFT, JT, Args); 1541 else 1542 NewCS = 1543 IRB.CreateInvoke(NewFT, JT, cast<InvokeInst>(CB).getNormalDest(), 1544 cast<InvokeInst>(CB).getUnwindDest(), Args); 1545 NewCS->setCallingConv(CB.getCallingConv()); 1546 1547 AttributeList Attrs = CB.getAttributes(); 1548 std::vector<AttributeSet> NewArgAttrs; 1549 NewArgAttrs.push_back(AttributeSet::get( 1550 M.getContext(), ArrayRef<Attribute>{Attribute::get( 1551 M.getContext(), Attribute::Nest)})); 1552 for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) 1553 NewArgAttrs.push_back(Attrs.getParamAttrs(I)); 1554 NewCS->setAttributes( 1555 AttributeList::get(M.getContext(), Attrs.getFnAttrs(), 1556 Attrs.getRetAttrs(), NewArgAttrs)); 1557 1558 CallBases[&CB] = NewCS; 1559 1560 // This use is no longer unsafe. 1561 if (VCallSite.NumUnsafeUses) 1562 --*VCallSite.NumUnsafeUses; 1563 } 1564 // Don't mark as devirtualized because there may be callers compiled without 1565 // retpoline mitigation, which would mean that they are lowered to 1566 // llvm.type.test and therefore require an llvm.type.test resolution for the 1567 // type identifier. 1568 1569 for (auto &[Old, New] : CallBases) { 1570 Old->replaceAllUsesWith(New); 1571 Old->eraseFromParent(); 1572 } 1573 }; 1574 Apply(SlotInfo.CSInfo); 1575 for (auto &P : SlotInfo.ConstCSInfo) 1576 Apply(P.second); 1577 } 1578 1579 bool DevirtModule::tryEvaluateFunctionsWithArgs( 1580 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1581 ArrayRef<uint64_t> Args) { 1582 // Evaluate each function and store the result in each target's RetVal 1583 // field. 1584 for (VirtualCallTarget &Target : TargetsForSlot) { 1585 // TODO: Skip for now if the vtable symbol was an alias to a function, 1586 // need to evaluate whether it would be correct to analyze the aliasee 1587 // function for this optimization. 1588 auto Fn = dyn_cast<Function>(Target.Fn); 1589 if (!Fn) 1590 return false; 1591 1592 if (Fn->arg_size() != Args.size() + 1) 1593 return false; 1594 1595 Evaluator Eval(M.getDataLayout(), nullptr); 1596 SmallVector<Constant *, 2> EvalArgs; 1597 EvalArgs.push_back( 1598 Constant::getNullValue(Fn->getFunctionType()->getParamType(0))); 1599 for (unsigned I = 0; I != Args.size(); ++I) { 1600 auto *ArgTy = 1601 dyn_cast<IntegerType>(Fn->getFunctionType()->getParamType(I + 1)); 1602 if (!ArgTy) 1603 return false; 1604 EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); 1605 } 1606 1607 Constant *RetVal; 1608 if (!Eval.EvaluateFunction(Fn, RetVal, EvalArgs) || 1609 !isa<ConstantInt>(RetVal)) 1610 return false; 1611 Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 1612 } 1613 return true; 1614 } 1615 1616 void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1617 uint64_t TheRetVal) { 1618 for (auto Call : CSInfo.CallSites) { 1619 if (!OptimizedCalls.insert(&Call.CB).second) 1620 continue; 1621 NumUniformRetVal++; 1622 Call.replaceAndErase( 1623 "uniform-ret-val", FnName, RemarksEnabled, OREGetter, 1624 ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal)); 1625 } 1626 CSInfo.markDevirt(); 1627 } 1628 1629 bool DevirtModule::tryUniformRetValOpt( 1630 MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, 1631 WholeProgramDevirtResolution::ByArg *Res) { 1632 // Uniform return value optimization. If all functions return the same 1633 // constant, replace all calls with that constant. 1634 uint64_t TheRetVal = TargetsForSlot[0].RetVal; 1635 for (const VirtualCallTarget &Target : TargetsForSlot) 1636 if (Target.RetVal != TheRetVal) 1637 return false; 1638 1639 if (CSInfo.isExported()) { 1640 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; 1641 Res->Info = TheRetVal; 1642 } 1643 1644 applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); 1645 if (RemarksEnabled || AreStatisticsEnabled()) 1646 for (auto &&Target : TargetsForSlot) 1647 Target.WasDevirt = true; 1648 return true; 1649 } 1650 1651 std::string DevirtModule::getGlobalName(VTableSlot Slot, 1652 ArrayRef<uint64_t> Args, 1653 StringRef Name) { 1654 std::string FullName = "__typeid_"; 1655 raw_string_ostream OS(FullName); 1656 OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; 1657 for (uint64_t Arg : Args) 1658 OS << '_' << Arg; 1659 OS << '_' << Name; 1660 return FullName; 1661 } 1662 1663 bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() { 1664 Triple T(M.getTargetTriple()); 1665 return T.isX86() && T.getObjectFormat() == Triple::ELF; 1666 } 1667 1668 void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1669 StringRef Name, Constant *C) { 1670 GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, 1671 getGlobalName(Slot, Args, Name), C, &M); 1672 GA->setVisibility(GlobalValue::HiddenVisibility); 1673 } 1674 1675 void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1676 StringRef Name, uint32_t Const, 1677 uint32_t &Storage) { 1678 if (shouldExportConstantsAsAbsoluteSymbols()) { 1679 exportGlobal( 1680 Slot, Args, Name, 1681 ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy)); 1682 return; 1683 } 1684 1685 Storage = Const; 1686 } 1687 1688 Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1689 StringRef Name) { 1690 Constant *C = 1691 M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty); 1692 auto *GV = dyn_cast<GlobalVariable>(C); 1693 if (GV) 1694 GV->setVisibility(GlobalValue::HiddenVisibility); 1695 return C; 1696 } 1697 1698 Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1699 StringRef Name, IntegerType *IntTy, 1700 uint32_t Storage) { 1701 if (!shouldExportConstantsAsAbsoluteSymbols()) 1702 return ConstantInt::get(IntTy, Storage); 1703 1704 Constant *C = importGlobal(Slot, Args, Name); 1705 auto *GV = cast<GlobalVariable>(C->stripPointerCasts()); 1706 C = ConstantExpr::getPtrToInt(C, IntTy); 1707 1708 // We only need to set metadata if the global is newly created, in which 1709 // case it would not have hidden visibility. 1710 if (GV->hasMetadata(LLVMContext::MD_absolute_symbol)) 1711 return C; 1712 1713 auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { 1714 auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); 1715 auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); 1716 GV->setMetadata(LLVMContext::MD_absolute_symbol, 1717 MDNode::get(M.getContext(), {MinC, MaxC})); 1718 }; 1719 unsigned AbsWidth = IntTy->getBitWidth(); 1720 if (AbsWidth == IntPtrTy->getBitWidth()) 1721 SetAbsRange(~0ull, ~0ull); // Full set. 1722 else 1723 SetAbsRange(0, 1ull << AbsWidth); 1724 return C; 1725 } 1726 1727 void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1728 bool IsOne, 1729 Constant *UniqueMemberAddr) { 1730 for (auto &&Call : CSInfo.CallSites) { 1731 if (!OptimizedCalls.insert(&Call.CB).second) 1732 continue; 1733 IRBuilder<> B(&Call.CB); 1734 Value *Cmp = 1735 B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, 1736 B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType())); 1737 Cmp = B.CreateZExt(Cmp, Call.CB.getType()); 1738 NumUniqueRetVal++; 1739 Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, 1740 Cmp); 1741 } 1742 CSInfo.markDevirt(); 1743 } 1744 1745 Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { 1746 return ConstantExpr::getGetElementPtr(Int8Ty, M->Bits->GV, 1747 ConstantInt::get(Int64Ty, M->Offset)); 1748 } 1749 1750 bool DevirtModule::tryUniqueRetValOpt( 1751 unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1752 CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, 1753 VTableSlot Slot, ArrayRef<uint64_t> Args) { 1754 // IsOne controls whether we look for a 0 or a 1. 1755 auto tryUniqueRetValOptFor = [&](bool IsOne) { 1756 const TypeMemberInfo *UniqueMember = nullptr; 1757 for (const VirtualCallTarget &Target : TargetsForSlot) { 1758 if (Target.RetVal == (IsOne ? 1 : 0)) { 1759 if (UniqueMember) 1760 return false; 1761 UniqueMember = Target.TM; 1762 } 1763 } 1764 1765 // We should have found a unique member or bailed out by now. We already 1766 // checked for a uniform return value in tryUniformRetValOpt. 1767 assert(UniqueMember); 1768 1769 Constant *UniqueMemberAddr = getMemberAddr(UniqueMember); 1770 if (CSInfo.isExported()) { 1771 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; 1772 Res->Info = IsOne; 1773 1774 exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); 1775 } 1776 1777 // Replace each call with the comparison. 1778 applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, 1779 UniqueMemberAddr); 1780 1781 // Update devirtualization statistics for targets. 1782 if (RemarksEnabled || AreStatisticsEnabled()) 1783 for (auto &&Target : TargetsForSlot) 1784 Target.WasDevirt = true; 1785 1786 return true; 1787 }; 1788 1789 if (BitWidth == 1) { 1790 if (tryUniqueRetValOptFor(true)) 1791 return true; 1792 if (tryUniqueRetValOptFor(false)) 1793 return true; 1794 } 1795 return false; 1796 } 1797 1798 void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 1799 Constant *Byte, Constant *Bit) { 1800 for (auto Call : CSInfo.CallSites) { 1801 if (!OptimizedCalls.insert(&Call.CB).second) 1802 continue; 1803 auto *RetType = cast<IntegerType>(Call.CB.getType()); 1804 IRBuilder<> B(&Call.CB); 1805 Value *Addr = B.CreatePtrAdd(Call.VTable, Byte); 1806 if (RetType->getBitWidth() == 1) { 1807 Value *Bits = B.CreateLoad(Int8Ty, Addr); 1808 Value *BitsAndBit = B.CreateAnd(Bits, Bit); 1809 auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 1810 NumVirtConstProp1Bit++; 1811 Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, 1812 OREGetter, IsBitSet); 1813 } else { 1814 Value *Val = B.CreateLoad(RetType, Addr); 1815 NumVirtConstProp++; 1816 Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, 1817 OREGetter, Val); 1818 } 1819 } 1820 CSInfo.markDevirt(); 1821 } 1822 1823 bool DevirtModule::tryVirtualConstProp( 1824 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1825 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1826 // TODO: Skip for now if the vtable symbol was an alias to a function, 1827 // need to evaluate whether it would be correct to analyze the aliasee 1828 // function for this optimization. 1829 auto Fn = dyn_cast<Function>(TargetsForSlot[0].Fn); 1830 if (!Fn) 1831 return false; 1832 // This only works if the function returns an integer. 1833 auto RetType = dyn_cast<IntegerType>(Fn->getReturnType()); 1834 if (!RetType) 1835 return false; 1836 unsigned BitWidth = RetType->getBitWidth(); 1837 if (BitWidth > 64) 1838 return false; 1839 1840 // Make sure that each function is defined, does not access memory, takes at 1841 // least one argument, does not use its first argument (which we assume is 1842 // 'this'), and has the same return type. 1843 // 1844 // Note that we test whether this copy of the function is readnone, rather 1845 // than testing function attributes, which must hold for any copy of the 1846 // function, even a less optimized version substituted at link time. This is 1847 // sound because the virtual constant propagation optimizations effectively 1848 // inline all implementations of the virtual function into each call site, 1849 // rather than using function attributes to perform local optimization. 1850 for (VirtualCallTarget &Target : TargetsForSlot) { 1851 // TODO: Skip for now if the vtable symbol was an alias to a function, 1852 // need to evaluate whether it would be correct to analyze the aliasee 1853 // function for this optimization. 1854 auto Fn = dyn_cast<Function>(Target.Fn); 1855 if (!Fn) 1856 return false; 1857 1858 if (Fn->isDeclaration() || 1859 !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn)) 1860 .doesNotAccessMemory() || 1861 Fn->arg_empty() || !Fn->arg_begin()->use_empty() || 1862 Fn->getReturnType() != RetType) 1863 return false; 1864 } 1865 1866 for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { 1867 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 1868 continue; 1869 1870 WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; 1871 if (Res) 1872 ResByArg = &Res->ResByArg[CSByConstantArg.first]; 1873 1874 if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) 1875 continue; 1876 1877 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, 1878 ResByArg, Slot, CSByConstantArg.first)) 1879 continue; 1880 1881 // Find an allocation offset in bits in all vtables associated with the 1882 // type. 1883 uint64_t AllocBefore = 1884 findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 1885 uint64_t AllocAfter = 1886 findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 1887 1888 // Calculate the total amount of padding needed to store a value at both 1889 // ends of the object. 1890 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 1891 for (auto &&Target : TargetsForSlot) { 1892 TotalPaddingBefore += std::max<int64_t>( 1893 (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 1894 TotalPaddingAfter += std::max<int64_t>( 1895 (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 1896 } 1897 1898 // If the amount of padding is too large, give up. 1899 // FIXME: do something smarter here. 1900 if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 1901 continue; 1902 1903 // Calculate the offset to the value as a (possibly negative) byte offset 1904 // and (if applicable) a bit offset, and store the values in the targets. 1905 int64_t OffsetByte; 1906 uint64_t OffsetBit; 1907 if (TotalPaddingBefore <= TotalPaddingAfter) 1908 setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 1909 OffsetBit); 1910 else 1911 setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 1912 OffsetBit); 1913 1914 if (RemarksEnabled || AreStatisticsEnabled()) 1915 for (auto &&Target : TargetsForSlot) 1916 Target.WasDevirt = true; 1917 1918 1919 if (CSByConstantArg.second.isExported()) { 1920 ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; 1921 exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte, 1922 ResByArg->Byte); 1923 exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit, 1924 ResByArg->Bit); 1925 } 1926 1927 // Rewrite each call to a load from OffsetByte/OffsetBit. 1928 Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); 1929 Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 1930 applyVirtualConstProp(CSByConstantArg.second, 1931 TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); 1932 } 1933 return true; 1934 } 1935 1936 void DevirtModule::rebuildGlobal(VTableBits &B) { 1937 if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 1938 return; 1939 1940 // Align the before byte array to the global's minimum alignment so that we 1941 // don't break any alignment requirements on the global. 1942 Align Alignment = M.getDataLayout().getValueOrABITypeAlignment( 1943 B.GV->getAlign(), B.GV->getValueType()); 1944 B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment)); 1945 1946 // Before was stored in reverse order; flip it now. 1947 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 1948 std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 1949 1950 // Build an anonymous global containing the before bytes, followed by the 1951 // original initializer, followed by the after bytes. 1952 auto NewInit = ConstantStruct::getAnon( 1953 {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 1954 B.GV->getInitializer(), 1955 ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 1956 auto NewGV = 1957 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 1958 GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 1959 NewGV->setSection(B.GV->getSection()); 1960 NewGV->setComdat(B.GV->getComdat()); 1961 NewGV->setAlignment(B.GV->getAlign()); 1962 1963 // Copy the original vtable's metadata to the anonymous global, adjusting 1964 // offsets as required. 1965 NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); 1966 1967 // Build an alias named after the original global, pointing at the second 1968 // element (the original initializer). 1969 auto Alias = GlobalAlias::create( 1970 B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 1971 ConstantExpr::getInBoundsGetElementPtr( 1972 NewInit->getType(), NewGV, 1973 ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 1974 ConstantInt::get(Int32Ty, 1)}), 1975 &M); 1976 Alias->setVisibility(B.GV->getVisibility()); 1977 Alias->takeName(B.GV); 1978 1979 B.GV->replaceAllUsesWith(Alias); 1980 B.GV->eraseFromParent(); 1981 } 1982 1983 bool DevirtModule::areRemarksEnabled() { 1984 const auto &FL = M.getFunctionList(); 1985 for (const Function &Fn : FL) { 1986 if (Fn.empty()) 1987 continue; 1988 auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &Fn.front()); 1989 return DI.isEnabled(); 1990 } 1991 return false; 1992 } 1993 1994 void DevirtModule::scanTypeTestUsers( 1995 Function *TypeTestFunc, 1996 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 1997 // Find all virtual calls via a virtual table pointer %p under an assumption 1998 // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p 1999 // points to a member of the type identifier %md. Group calls by (type ID, 2000 // offset) pair (effectively the identity of the virtual function) and store 2001 // to CallSlots. 2002 for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { 2003 auto *CI = dyn_cast<CallInst>(U.getUser()); 2004 if (!CI) 2005 continue; 2006 2007 // Search for virtual calls based on %p and add them to DevirtCalls. 2008 SmallVector<DevirtCallSite, 1> DevirtCalls; 2009 SmallVector<CallInst *, 1> Assumes; 2010 auto &DT = LookupDomTree(*CI->getFunction()); 2011 findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); 2012 2013 Metadata *TypeId = 2014 cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 2015 // If we found any, add them to CallSlots. 2016 if (!Assumes.empty()) { 2017 Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 2018 for (DevirtCallSite Call : DevirtCalls) 2019 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr); 2020 } 2021 2022 auto RemoveTypeTestAssumes = [&]() { 2023 // We no longer need the assumes or the type test. 2024 for (auto *Assume : Assumes) 2025 Assume->eraseFromParent(); 2026 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 2027 // may use the vtable argument later. 2028 if (CI->use_empty()) 2029 CI->eraseFromParent(); 2030 }; 2031 2032 // At this point we could remove all type test assume sequences, as they 2033 // were originally inserted for WPD. However, we can keep these in the 2034 // code stream for later analysis (e.g. to help drive more efficient ICP 2035 // sequences). They will eventually be removed by a second LowerTypeTests 2036 // invocation that cleans them up. In order to do this correctly, the first 2037 // LowerTypeTests invocation needs to know that they have "Unknown" type 2038 // test resolution, so that they aren't treated as Unsat and lowered to 2039 // False, which will break any uses on assumes. Below we remove any type 2040 // test assumes that will not be treated as Unknown by LTT. 2041 2042 // The type test assumes will be treated by LTT as Unsat if the type id is 2043 // not used on a global (in which case it has no entry in the TypeIdMap). 2044 if (!TypeIdMap.count(TypeId)) 2045 RemoveTypeTestAssumes(); 2046 2047 // For ThinLTO importing, we need to remove the type test assumes if this is 2048 // an MDString type id without a corresponding TypeIdSummary. Any 2049 // non-MDString type ids are ignored and treated as Unknown by LTT, so their 2050 // type test assumes can be kept. If the MDString type id is missing a 2051 // TypeIdSummary (e.g. because there was no use on a vcall, preventing the 2052 // exporting phase of WPD from analyzing it), then it would be treated as 2053 // Unsat by LTT and we need to remove its type test assumes here. If not 2054 // used on a vcall we don't need them for later optimization use in any 2055 // case. 2056 else if (ImportSummary && isa<MDString>(TypeId)) { 2057 const TypeIdSummary *TidSummary = 2058 ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString()); 2059 if (!TidSummary) 2060 RemoveTypeTestAssumes(); 2061 else 2062 // If one was created it should not be Unsat, because if we reached here 2063 // the type id was used on a global. 2064 assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat); 2065 } 2066 } 2067 } 2068 2069 void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { 2070 Function *TypeTestFunc = 2071 Intrinsic::getOrInsertDeclaration(&M, Intrinsic::type_test); 2072 2073 for (Use &U : llvm::make_early_inc_range(TypeCheckedLoadFunc->uses())) { 2074 auto *CI = dyn_cast<CallInst>(U.getUser()); 2075 if (!CI) 2076 continue; 2077 2078 Value *Ptr = CI->getArgOperand(0); 2079 Value *Offset = CI->getArgOperand(1); 2080 Value *TypeIdValue = CI->getArgOperand(2); 2081 Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); 2082 2083 SmallVector<DevirtCallSite, 1> DevirtCalls; 2084 SmallVector<Instruction *, 1> LoadedPtrs; 2085 SmallVector<Instruction *, 1> Preds; 2086 bool HasNonCallUses = false; 2087 auto &DT = LookupDomTree(*CI->getFunction()); 2088 findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, 2089 HasNonCallUses, CI, DT); 2090 2091 // Start by generating "pessimistic" code that explicitly loads the function 2092 // pointer from the vtable and performs the type check. If possible, we will 2093 // eliminate the load and the type check later. 2094 2095 // If possible, only generate the load at the point where it is used. 2096 // This helps avoid unnecessary spills. 2097 IRBuilder<> LoadB( 2098 (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); 2099 2100 Value *LoadedValue = nullptr; 2101 if (TypeCheckedLoadFunc->getIntrinsicID() == 2102 Intrinsic::type_checked_load_relative) { 2103 Value *GEP = LoadB.CreatePtrAdd(Ptr, Offset); 2104 LoadedValue = LoadB.CreateLoad(Int32Ty, GEP); 2105 LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy); 2106 GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy); 2107 LoadedValue = LoadB.CreateAdd(GEP, LoadedValue); 2108 LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy); 2109 } else { 2110 Value *GEP = LoadB.CreatePtrAdd(Ptr, Offset); 2111 LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEP); 2112 } 2113 2114 for (Instruction *LoadedPtr : LoadedPtrs) { 2115 LoadedPtr->replaceAllUsesWith(LoadedValue); 2116 LoadedPtr->eraseFromParent(); 2117 } 2118 2119 // Likewise for the type test. 2120 IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); 2121 CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); 2122 2123 for (Instruction *Pred : Preds) { 2124 Pred->replaceAllUsesWith(TypeTestCall); 2125 Pred->eraseFromParent(); 2126 } 2127 2128 // We have already erased any extractvalue instructions that refer to the 2129 // intrinsic call, but the intrinsic may have other non-extractvalue uses 2130 // (although this is unlikely). In that case, explicitly build a pair and 2131 // RAUW it. 2132 if (!CI->use_empty()) { 2133 Value *Pair = PoisonValue::get(CI->getType()); 2134 IRBuilder<> B(CI); 2135 Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); 2136 Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); 2137 CI->replaceAllUsesWith(Pair); 2138 } 2139 2140 // The number of unsafe uses is initially the number of uses. 2141 auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; 2142 NumUnsafeUses = DevirtCalls.size(); 2143 2144 // If the function pointer has a non-call user, we cannot eliminate the type 2145 // check, as one of those users may eventually call the pointer. Increment 2146 // the unsafe use count to make sure it cannot reach zero. 2147 if (HasNonCallUses) 2148 ++NumUnsafeUses; 2149 for (DevirtCallSite Call : DevirtCalls) { 2150 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, 2151 &NumUnsafeUses); 2152 } 2153 2154 CI->eraseFromParent(); 2155 } 2156 } 2157 2158 void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { 2159 auto *TypeId = dyn_cast<MDString>(Slot.TypeID); 2160 if (!TypeId) 2161 return; 2162 const TypeIdSummary *TidSummary = 2163 ImportSummary->getTypeIdSummary(TypeId->getString()); 2164 if (!TidSummary) 2165 return; 2166 auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); 2167 if (ResI == TidSummary->WPDRes.end()) 2168 return; 2169 const WholeProgramDevirtResolution &Res = ResI->second; 2170 2171 if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { 2172 assert(!Res.SingleImplName.empty()); 2173 // The type of the function in the declaration is irrelevant because every 2174 // call site will cast it to the correct type. 2175 Constant *SingleImpl = 2176 cast<Constant>(M.getOrInsertFunction(Res.SingleImplName, 2177 Type::getVoidTy(M.getContext())) 2178 .getCallee()); 2179 2180 // This is the import phase so we should not be exporting anything. 2181 bool IsExported = false; 2182 applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); 2183 assert(!IsExported); 2184 } 2185 2186 for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { 2187 auto I = Res.ResByArg.find(CSByConstantArg.first); 2188 if (I == Res.ResByArg.end()) 2189 continue; 2190 auto &ResByArg = I->second; 2191 // FIXME: We should figure out what to do about the "function name" argument 2192 // to the apply* functions, as the function names are unavailable during the 2193 // importing phase. For now we just pass the empty string. This does not 2194 // impact correctness because the function names are just used for remarks. 2195 switch (ResByArg.TheKind) { 2196 case WholeProgramDevirtResolution::ByArg::UniformRetVal: 2197 applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); 2198 break; 2199 case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { 2200 Constant *UniqueMemberAddr = 2201 importGlobal(Slot, CSByConstantArg.first, "unique_member"); 2202 applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, 2203 UniqueMemberAddr); 2204 break; 2205 } 2206 case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { 2207 Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte", 2208 Int32Ty, ResByArg.Byte); 2209 Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty, 2210 ResByArg.Bit); 2211 applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); 2212 break; 2213 } 2214 default: 2215 break; 2216 } 2217 } 2218 2219 if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) { 2220 // The type of the function is irrelevant, because it's bitcast at calls 2221 // anyhow. 2222 Constant *JT = cast<Constant>( 2223 M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), 2224 Type::getVoidTy(M.getContext())) 2225 .getCallee()); 2226 bool IsExported = false; 2227 applyICallBranchFunnel(SlotInfo, JT, IsExported); 2228 assert(!IsExported); 2229 } 2230 } 2231 2232 void DevirtModule::removeRedundantTypeTests() { 2233 auto True = ConstantInt::getTrue(M.getContext()); 2234 for (auto &&U : NumUnsafeUsesForTypeTest) { 2235 if (U.second == 0) { 2236 U.first->replaceAllUsesWith(True); 2237 U.first->eraseFromParent(); 2238 } 2239 } 2240 } 2241 2242 ValueInfo 2243 DevirtModule::lookUpFunctionValueInfo(Function *TheFn, 2244 ModuleSummaryIndex *ExportSummary) { 2245 assert((ExportSummary != nullptr) && 2246 "Caller guarantees ExportSummary is not nullptr"); 2247 2248 const auto TheFnGUID = TheFn->getGUID(); 2249 const auto TheFnGUIDWithExportedName = GlobalValue::getGUID(TheFn->getName()); 2250 // Look up ValueInfo with the GUID in the current linkage. 2251 ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFnGUID); 2252 // If no entry is found and GUID is different from GUID computed using 2253 // exported name, look up ValueInfo with the exported name unconditionally. 2254 // This is a fallback. 2255 // 2256 // The reason to have a fallback: 2257 // 1. LTO could enable global value internalization via 2258 // `enable-lto-internalization`. 2259 // 2. The GUID in ExportedSummary is computed using exported name. 2260 if ((!TheFnVI) && (TheFnGUID != TheFnGUIDWithExportedName)) { 2261 TheFnVI = ExportSummary->getValueInfo(TheFnGUIDWithExportedName); 2262 } 2263 return TheFnVI; 2264 } 2265 2266 bool DevirtModule::mustBeUnreachableFunction( 2267 Function *const F, ModuleSummaryIndex *ExportSummary) { 2268 if (WholeProgramDevirtKeepUnreachableFunction) 2269 return false; 2270 // First, learn unreachability by analyzing function IR. 2271 if (!F->isDeclaration()) { 2272 // A function must be unreachable if its entry block ends with an 2273 // 'unreachable'. 2274 return isa<UnreachableInst>(F->getEntryBlock().getTerminator()); 2275 } 2276 // Learn unreachability from ExportSummary if ExportSummary is present. 2277 return ExportSummary && 2278 ::mustBeUnreachableFunction( 2279 DevirtModule::lookUpFunctionValueInfo(F, ExportSummary)); 2280 } 2281 2282 bool DevirtModule::run() { 2283 // If only some of the modules were split, we cannot correctly perform 2284 // this transformation. We already checked for the presense of type tests 2285 // with partially split modules during the thin link, and would have emitted 2286 // an error if any were found, so here we can simply return. 2287 if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || 2288 (ImportSummary && ImportSummary->partiallySplitLTOUnits())) 2289 return false; 2290 2291 Function *TypeTestFunc = 2292 Intrinsic::getDeclarationIfExists(&M, Intrinsic::type_test); 2293 Function *TypeCheckedLoadFunc = 2294 Intrinsic::getDeclarationIfExists(&M, Intrinsic::type_checked_load); 2295 Function *TypeCheckedLoadRelativeFunc = Intrinsic::getDeclarationIfExists( 2296 &M, Intrinsic::type_checked_load_relative); 2297 Function *AssumeFunc = 2298 Intrinsic::getDeclarationIfExists(&M, Intrinsic::assume); 2299 2300 // Normally if there are no users of the devirtualization intrinsics in the 2301 // module, this pass has nothing to do. But if we are exporting, we also need 2302 // to handle any users that appear only in the function summaries. 2303 if (!ExportSummary && 2304 (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || 2305 AssumeFunc->use_empty()) && 2306 (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) && 2307 (!TypeCheckedLoadRelativeFunc || 2308 TypeCheckedLoadRelativeFunc->use_empty())) 2309 return false; 2310 2311 // Rebuild type metadata into a map for easy lookup. 2312 std::vector<VTableBits> Bits; 2313 DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; 2314 buildTypeIdentifierMap(Bits, TypeIdMap); 2315 2316 if (TypeTestFunc && AssumeFunc) 2317 scanTypeTestUsers(TypeTestFunc, TypeIdMap); 2318 2319 if (TypeCheckedLoadFunc) 2320 scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); 2321 2322 if (TypeCheckedLoadRelativeFunc) 2323 scanTypeCheckedLoadUsers(TypeCheckedLoadRelativeFunc); 2324 2325 if (ImportSummary) { 2326 for (auto &S : CallSlots) 2327 importResolution(S.first, S.second); 2328 2329 removeRedundantTypeTests(); 2330 2331 // We have lowered or deleted the type intrinsics, so we will no longer have 2332 // enough information to reason about the liveness of virtual function 2333 // pointers in GlobalDCE. 2334 for (GlobalVariable &GV : M.globals()) 2335 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2336 2337 // The rest of the code is only necessary when exporting or during regular 2338 // LTO, so we are done. 2339 return true; 2340 } 2341 2342 if (TypeIdMap.empty()) 2343 return true; 2344 2345 // Collect information from summary about which calls to try to devirtualize. 2346 if (ExportSummary) { 2347 DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; 2348 for (auto &P : TypeIdMap) { 2349 if (auto *TypeId = dyn_cast<MDString>(P.first)) 2350 MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( 2351 TypeId); 2352 } 2353 2354 for (auto &P : *ExportSummary) { 2355 for (auto &S : P.second.SummaryList) { 2356 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2357 if (!FS) 2358 continue; 2359 // FIXME: Only add live functions. 2360 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2361 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2362 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2363 } 2364 } 2365 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2366 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2367 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2368 } 2369 } 2370 for (const FunctionSummary::ConstVCall &VC : 2371 FS->type_test_assume_const_vcalls()) { 2372 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2373 CallSlots[{MD, VC.VFunc.Offset}] 2374 .ConstCSInfo[VC.Args] 2375 .addSummaryTypeTestAssumeUser(FS); 2376 } 2377 } 2378 for (const FunctionSummary::ConstVCall &VC : 2379 FS->type_checked_load_const_vcalls()) { 2380 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2381 CallSlots[{MD, VC.VFunc.Offset}] 2382 .ConstCSInfo[VC.Args] 2383 .addSummaryTypeCheckedLoadUser(FS); 2384 } 2385 } 2386 } 2387 } 2388 } 2389 2390 // For each (type, offset) pair: 2391 bool DidVirtualConstProp = false; 2392 std::map<std::string, GlobalValue *> DevirtTargets; 2393 for (auto &S : CallSlots) { 2394 // Search each of the members of the type identifier for the virtual 2395 // function implementation at offset S.first.ByteOffset, and add to 2396 // TargetsForSlot. 2397 std::vector<VirtualCallTarget> TargetsForSlot; 2398 WholeProgramDevirtResolution *Res = nullptr; 2399 const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID]; 2400 if (ExportSummary && isa<MDString>(S.first.TypeID) && 2401 TypeMemberInfos.size()) 2402 // For any type id used on a global's type metadata, create the type id 2403 // summary resolution regardless of whether we can devirtualize, so that 2404 // lower type tests knows the type id is not Unsat. If it was not used on 2405 // a global's type metadata, the TypeIdMap entry set will be empty, and 2406 // we don't want to create an entry (with the default Unknown type 2407 // resolution), which can prevent detection of the Unsat. 2408 Res = &ExportSummary 2409 ->getOrInsertTypeIdSummary( 2410 cast<MDString>(S.first.TypeID)->getString()) 2411 .WPDRes[S.first.ByteOffset]; 2412 if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos, 2413 S.first.ByteOffset, ExportSummary)) { 2414 2415 if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { 2416 DidVirtualConstProp |= 2417 tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); 2418 2419 tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first); 2420 } 2421 2422 // Collect functions devirtualized at least for one call site for stats. 2423 if (RemarksEnabled || AreStatisticsEnabled()) 2424 for (const auto &T : TargetsForSlot) 2425 if (T.WasDevirt) 2426 DevirtTargets[std::string(T.Fn->getName())] = T.Fn; 2427 } 2428 2429 // CFI-specific: if we are exporting and any llvm.type.checked.load 2430 // intrinsics were *not* devirtualized, we need to add the resulting 2431 // llvm.type.test intrinsics to the function summaries so that the 2432 // LowerTypeTests pass will export them. 2433 if (ExportSummary && isa<MDString>(S.first.TypeID)) { 2434 auto GUID = 2435 GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); 2436 for (auto *FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) 2437 FS->addTypeTest(GUID); 2438 for (auto &CCS : S.second.ConstCSInfo) 2439 for (auto *FS : CCS.second.SummaryTypeCheckedLoadUsers) 2440 FS->addTypeTest(GUID); 2441 } 2442 } 2443 2444 if (RemarksEnabled) { 2445 // Generate remarks for each devirtualized function. 2446 for (const auto &DT : DevirtTargets) { 2447 GlobalValue *GV = DT.second; 2448 auto F = dyn_cast<Function>(GV); 2449 if (!F) { 2450 auto A = dyn_cast<GlobalAlias>(GV); 2451 assert(A && isa<Function>(A->getAliasee())); 2452 F = dyn_cast<Function>(A->getAliasee()); 2453 assert(F); 2454 } 2455 2456 using namespace ore; 2457 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) 2458 << "devirtualized " 2459 << NV("FunctionName", DT.first)); 2460 } 2461 } 2462 2463 NumDevirtTargets += DevirtTargets.size(); 2464 2465 removeRedundantTypeTests(); 2466 2467 // Rebuild each global we touched as part of virtual constant propagation to 2468 // include the before and after bytes. 2469 if (DidVirtualConstProp) 2470 for (VTableBits &B : Bits) 2471 rebuildGlobal(B); 2472 2473 // We have lowered or deleted the type intrinsics, so we will no longer have 2474 // enough information to reason about the liveness of virtual function 2475 // pointers in GlobalDCE. 2476 for (GlobalVariable &GV : M.globals()) 2477 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2478 2479 for (auto *CI : CallsWithPtrAuthBundleRemoved) 2480 CI->eraseFromParent(); 2481 2482 return true; 2483 } 2484 2485 void DevirtIndex::run() { 2486 if (ExportSummary.typeIdCompatibleVtableMap().empty()) 2487 return; 2488 2489 DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; 2490 for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) { 2491 NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first); 2492 // Create the type id summary resolution regardlness of whether we can 2493 // devirtualize, so that lower type tests knows the type id is used on 2494 // a global and not Unsat. We do this here rather than in the loop over the 2495 // CallSlots, since that handling will only see type tests that directly 2496 // feed assumes, and we would miss any that aren't currently handled by WPD 2497 // (such as type tests that feed assumes via phis). 2498 ExportSummary.getOrInsertTypeIdSummary(P.first); 2499 } 2500 2501 // Collect information from summary about which calls to try to devirtualize. 2502 for (auto &P : ExportSummary) { 2503 for (auto &S : P.second.SummaryList) { 2504 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2505 if (!FS) 2506 continue; 2507 // FIXME: Only add live functions. 2508 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2509 for (StringRef Name : NameByGUID[VF.GUID]) { 2510 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2511 } 2512 } 2513 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2514 for (StringRef Name : NameByGUID[VF.GUID]) { 2515 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2516 } 2517 } 2518 for (const FunctionSummary::ConstVCall &VC : 2519 FS->type_test_assume_const_vcalls()) { 2520 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2521 CallSlots[{Name, VC.VFunc.Offset}] 2522 .ConstCSInfo[VC.Args] 2523 .addSummaryTypeTestAssumeUser(FS); 2524 } 2525 } 2526 for (const FunctionSummary::ConstVCall &VC : 2527 FS->type_checked_load_const_vcalls()) { 2528 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2529 CallSlots[{Name, VC.VFunc.Offset}] 2530 .ConstCSInfo[VC.Args] 2531 .addSummaryTypeCheckedLoadUser(FS); 2532 } 2533 } 2534 } 2535 } 2536 2537 std::set<ValueInfo> DevirtTargets; 2538 // For each (type, offset) pair: 2539 for (auto &S : CallSlots) { 2540 // Search each of the members of the type identifier for the virtual 2541 // function implementation at offset S.first.ByteOffset, and add to 2542 // TargetsForSlot. 2543 std::vector<ValueInfo> TargetsForSlot; 2544 auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID); 2545 assert(TidSummary); 2546 // The type id summary would have been created while building the NameByGUID 2547 // map earlier. 2548 WholeProgramDevirtResolution *Res = 2549 &ExportSummary.getTypeIdSummary(S.first.TypeID) 2550 ->WPDRes[S.first.ByteOffset]; 2551 if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary, 2552 S.first.ByteOffset)) { 2553 2554 if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res, 2555 DevirtTargets)) 2556 continue; 2557 } 2558 } 2559 2560 // Optionally have the thin link print message for each devirtualized 2561 // function. 2562 if (PrintSummaryDevirt) 2563 for (const auto &DT : DevirtTargets) 2564 errs() << "Devirtualized call to " << DT << "\n"; 2565 2566 NumDevirtTargets += DevirtTargets.size(); 2567 } 2568