xref: /llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp (revision 8c75ecb373059f2eed020ad0218313bba9f90b3d)
1 //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass promotes "by reference" arguments to be "by value" arguments.  In
10 // practice, this means looking for internal functions that have pointer
11 // arguments.  If it can prove, through the use of alias analysis, that an
12 // argument is *only* loaded, then it can pass the value into the function
13 // instead of the address of the value.  This can cause recursive simplification
14 // of code and lead to the elimination of allocas (especially in C++ template
15 // code like the STL).
16 //
17 // This pass also handles aggregate arguments that are passed into a function,
18 // scalarizing them if the elements of the aggregate are only loaded.  Note that
19 // by default it refuses to scalarize aggregates which would require passing in
20 // more than three operands to the function, because passing thousands of
21 // operands for a large array or structure is unprofitable! This limit can be
22 // configured or disabled, however.
23 //
24 // Note that this transformation could also be done for arguments that are only
25 // stored to (returning the value instead), but does not currently.  This case
26 // would be best handled when and if LLVM begins supporting multiple return
27 // values from functions.
28 //
29 //===----------------------------------------------------------------------===//
30 
31 #include "llvm/Transforms/IPO/ArgumentPromotion.h"
32 
33 #include "llvm/ADT/DepthFirstIterator.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/ScopeExit.h"
36 #include "llvm/ADT/SmallPtrSet.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/Statistic.h"
39 #include "llvm/ADT/Twine.h"
40 #include "llvm/Analysis/AssumptionCache.h"
41 #include "llvm/Analysis/BasicAliasAnalysis.h"
42 #include "llvm/Analysis/CallGraph.h"
43 #include "llvm/Analysis/Loads.h"
44 #include "llvm/Analysis/MemoryLocation.h"
45 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
46 #include "llvm/Analysis/TargetTransformInfo.h"
47 #include "llvm/Analysis/ValueTracking.h"
48 #include "llvm/IR/Argument.h"
49 #include "llvm/IR/Attributes.h"
50 #include "llvm/IR/BasicBlock.h"
51 #include "llvm/IR/CFG.h"
52 #include "llvm/IR/Constants.h"
53 #include "llvm/IR/DataLayout.h"
54 #include "llvm/IR/DerivedTypes.h"
55 #include "llvm/IR/Dominators.h"
56 #include "llvm/IR/Function.h"
57 #include "llvm/IR/IRBuilder.h"
58 #include "llvm/IR/InstrTypes.h"
59 #include "llvm/IR/Instruction.h"
60 #include "llvm/IR/Instructions.h"
61 #include "llvm/IR/Metadata.h"
62 #include "llvm/IR/Module.h"
63 #include "llvm/IR/NoFolder.h"
64 #include "llvm/IR/PassManager.h"
65 #include "llvm/IR/Type.h"
66 #include "llvm/IR/Use.h"
67 #include "llvm/IR/User.h"
68 #include "llvm/IR/Value.h"
69 #include "llvm/Support/Casting.h"
70 #include "llvm/Support/Debug.h"
71 #include "llvm/Support/raw_ostream.h"
72 #include "llvm/Transforms/Utils/Local.h"
73 #include "llvm/Transforms/Utils/PromoteMemToReg.h"
74 #include <algorithm>
75 #include <cassert>
76 #include <cstdint>
77 #include <utility>
78 #include <vector>
79 
80 using namespace llvm;
81 
82 #define DEBUG_TYPE "argpromotion"
83 
84 STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
85 STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
86 
87 namespace {
88 
89 struct ArgPart {
90   Type *Ty;
91   Align Alignment;
92   /// A representative guaranteed-executed load or store instruction for use by
93   /// metadata transfer.
94   Instruction *MustExecInstr;
95 };
96 
97 using OffsetAndArgPart = std::pair<int64_t, ArgPart>;
98 
99 } // end anonymous namespace
100 
101 static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL,
102                             Value *Ptr, Type *ResElemTy, int64_t Offset) {
103   if (Offset != 0) {
104     APInt APOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset,
105                    /*isSigned=*/true);
106     Ptr = IRB.CreatePtrAdd(Ptr, IRB.getInt(APOffset));
107   }
108   return Ptr;
109 }
110 
111 /// DoPromotion - This method actually performs the promotion of the specified
112 /// arguments, and returns the new function.  At this point, we know that it's
113 /// safe to do so.
114 static Function *
115 doPromotion(Function *F, FunctionAnalysisManager &FAM,
116             const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>>
117                 &ArgsToPromote) {
118   // Start by computing a new prototype for the function, which is the same as
119   // the old function, but has modified arguments.
120   FunctionType *FTy = F->getFunctionType();
121   std::vector<Type *> Params;
122 
123   // Attribute - Keep track of the parameter attributes for the arguments
124   // that we are *not* promoting. For the ones that we do promote, the parameter
125   // attributes are lost
126   SmallVector<AttributeSet, 8> ArgAttrVec;
127   // Mapping from old to new argument indices. -1 for promoted or removed
128   // arguments.
129   SmallVector<unsigned> NewArgIndices;
130   AttributeList PAL = F->getAttributes();
131   OptimizationRemarkEmitter ORE(F);
132 
133   // First, determine the new argument list
134   unsigned ArgNo = 0, NewArgNo = 0;
135   for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
136        ++I, ++ArgNo) {
137     if (!ArgsToPromote.count(&*I)) {
138       // Unchanged argument
139       Params.push_back(I->getType());
140       ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo));
141       NewArgIndices.push_back(NewArgNo++);
142     } else if (I->use_empty()) {
143       // Dead argument (which are always marked as promotable)
144       ++NumArgumentsDead;
145       ORE.emit([&]() {
146         return OptimizationRemark(DEBUG_TYPE, "ArgumentRemoved", F)
147                << "eliminating argument " << ore::NV("ArgName", I->getName())
148                << "(" << ore::NV("ArgIndex", ArgNo) << ")";
149       });
150 
151       NewArgIndices.push_back((unsigned)-1);
152     } else {
153       const auto &ArgParts = ArgsToPromote.find(&*I)->second;
154       for (const auto &Pair : ArgParts) {
155         Params.push_back(Pair.second.Ty);
156         ArgAttrVec.push_back(AttributeSet());
157       }
158       ++NumArgumentsPromoted;
159       ORE.emit([&]() {
160         return OptimizationRemark(DEBUG_TYPE, "ArgumentPromoted", F)
161                << "promoting argument " << ore::NV("ArgName", I->getName())
162                << "(" << ore::NV("ArgIndex", ArgNo) << ")"
163                << " to pass by value";
164       });
165 
166       NewArgIndices.push_back((unsigned)-1);
167       NewArgNo += ArgParts.size();
168     }
169   }
170 
171   Type *RetTy = FTy->getReturnType();
172 
173   // Construct the new function type using the new arguments.
174   FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
175 
176   // Create the new function body and insert it into the module.
177   Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
178                                   F->getName());
179   NF->copyAttributesFrom(F);
180   NF->copyMetadata(F, 0);
181   NF->setIsNewDbgInfoFormat(F->IsNewDbgInfoFormat);
182 
183   // The new function will have the !dbg metadata copied from the original
184   // function. The original function may not be deleted, and dbg metadata need
185   // to be unique, so we need to drop it.
186   F->setSubprogram(nullptr);
187 
188   LLVM_DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n"
189                     << "From: " << *F);
190 
191   uint64_t LargestVectorWidth = 0;
192   for (auto *I : Params)
193     if (auto *VT = dyn_cast<llvm::VectorType>(I))
194       LargestVectorWidth = std::max(
195           LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinValue());
196 
197   // Recompute the parameter attributes list based on the new arguments for
198   // the function.
199   NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(),
200                                        PAL.getRetAttrs(), ArgAttrVec));
201 
202   // Remap argument indices in allocsize attribute.
203   if (auto AllocSize = NF->getAttributes().getFnAttrs().getAllocSizeArgs()) {
204     unsigned Arg1 = NewArgIndices[AllocSize->first];
205     assert(Arg1 != (unsigned)-1 && "allocsize cannot be promoted argument");
206     std::optional<unsigned> Arg2;
207     if (AllocSize->second) {
208       Arg2 = NewArgIndices[*AllocSize->second];
209       assert(Arg2 != (unsigned)-1 && "allocsize cannot be promoted argument");
210     }
211     NF->addFnAttr(Attribute::getWithAllocSizeArgs(F->getContext(), Arg1, Arg2));
212   }
213 
214   AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth);
215   ArgAttrVec.clear();
216 
217   F->getParent()->getFunctionList().insert(F->getIterator(), NF);
218   NF->takeName(F);
219 
220   // Loop over all the callers of the function, transforming the call sites to
221   // pass in the loaded pointers.
222   SmallVector<Value *, 16> Args;
223   const DataLayout &DL = F->getDataLayout();
224   SmallVector<WeakTrackingVH, 16> DeadArgs;
225 
226   while (!F->use_empty()) {
227     CallBase &CB = cast<CallBase>(*F->user_back());
228     assert(CB.getCalledFunction() == F);
229     const AttributeList &CallPAL = CB.getAttributes();
230     IRBuilder<NoFolder> IRB(&CB);
231 
232     // Loop over the operands, inserting GEP and loads in the caller as
233     // appropriate.
234     auto *AI = CB.arg_begin();
235     ArgNo = 0;
236     for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
237          ++I, ++AI, ++ArgNo) {
238       if (!ArgsToPromote.count(&*I)) {
239         Args.push_back(*AI); // Unmodified argument
240         ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
241       } else if (!I->use_empty()) {
242         Value *V = *AI;
243         const auto &ArgParts = ArgsToPromote.find(&*I)->second;
244         for (const auto &Pair : ArgParts) {
245           LoadInst *LI = IRB.CreateAlignedLoad(
246               Pair.second.Ty,
247               createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first),
248               Pair.second.Alignment, V->getName() + ".val");
249           if (Pair.second.MustExecInstr) {
250             LI->setAAMetadata(Pair.second.MustExecInstr->getAAMetadata());
251             LI->copyMetadata(*Pair.second.MustExecInstr,
252                              {LLVMContext::MD_dereferenceable,
253                               LLVMContext::MD_dereferenceable_or_null,
254                               LLVMContext::MD_noundef,
255                               LLVMContext::MD_nontemporal});
256             // Only transfer poison-generating metadata if we also have
257             // !noundef.
258             // TODO: Without !noundef, we could merge this metadata across
259             // all promoted loads.
260             if (LI->hasMetadata(LLVMContext::MD_noundef))
261               LI->copyMetadata(*Pair.second.MustExecInstr,
262                                Metadata::PoisonGeneratingIDs);
263           }
264           Args.push_back(LI);
265           ArgAttrVec.push_back(AttributeSet());
266         }
267       } else {
268         assert(ArgsToPromote.count(&*I) && I->use_empty());
269         DeadArgs.emplace_back(AI->get());
270       }
271     }
272 
273     // Push any varargs arguments on the list.
274     for (; AI != CB.arg_end(); ++AI, ++ArgNo) {
275       Args.push_back(*AI);
276       ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo));
277     }
278 
279     SmallVector<OperandBundleDef, 1> OpBundles;
280     CB.getOperandBundlesAsDefs(OpBundles);
281 
282     CallBase *NewCS = nullptr;
283     if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) {
284       NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
285                                  Args, OpBundles, "", CB.getIterator());
286     } else {
287       auto *NewCall =
288           CallInst::Create(NF, Args, OpBundles, "", CB.getIterator());
289       NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind());
290       NewCS = NewCall;
291     }
292     NewCS->setCallingConv(CB.getCallingConv());
293     NewCS->setAttributes(AttributeList::get(F->getContext(),
294                                             CallPAL.getFnAttrs(),
295                                             CallPAL.getRetAttrs(), ArgAttrVec));
296     NewCS->copyMetadata(CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg});
297     Args.clear();
298     ArgAttrVec.clear();
299 
300     AttributeFuncs::updateMinLegalVectorWidthAttr(*CB.getCaller(),
301                                                   LargestVectorWidth);
302 
303     if (!CB.use_empty()) {
304       CB.replaceAllUsesWith(NewCS);
305       NewCS->takeName(&CB);
306     }
307 
308     // Finally, remove the old call from the program, reducing the use-count of
309     // F.
310     CB.eraseFromParent();
311   }
312 
313   RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadArgs);
314 
315   // Since we have now created the new function, splice the body of the old
316   // function right into the new function, leaving the old rotting hulk of the
317   // function empty.
318   NF->splice(NF->begin(), F);
319 
320   // We will collect all the new created allocas to promote them into registers
321   // after the following loop
322   SmallVector<AllocaInst *, 4> Allocas;
323 
324   // Loop over the argument list, transferring uses of the old arguments over to
325   // the new arguments, also transferring over the names as well.
326   Function::arg_iterator I2 = NF->arg_begin();
327   for (Argument &Arg : F->args()) {
328     if (!ArgsToPromote.count(&Arg)) {
329       // If this is an unmodified argument, move the name and users over to the
330       // new version.
331       Arg.replaceAllUsesWith(&*I2);
332       I2->takeName(&Arg);
333       ++I2;
334       continue;
335     }
336 
337     // There potentially are metadata uses for things like llvm.dbg.value.
338     // Replace them with poison, after handling the other regular uses.
339     auto RauwPoisonMetadata = make_scope_exit(
340         [&]() { Arg.replaceAllUsesWith(PoisonValue::get(Arg.getType())); });
341 
342     if (Arg.use_empty())
343       continue;
344 
345     // Otherwise, if we promoted this argument, we have to create an alloca in
346     // the callee for every promotable part and store each of the new incoming
347     // arguments into the corresponding alloca, what lets the old code (the
348     // store instructions if they are allowed especially) a chance to work as
349     // before.
350     assert(Arg.getType()->isPointerTy() &&
351            "Only arguments with a pointer type are promotable");
352 
353     IRBuilder<NoFolder> IRB(&NF->begin()->front());
354 
355     // Add only the promoted elements, so parts from ArgsToPromote
356     SmallDenseMap<int64_t, AllocaInst *> OffsetToAlloca;
357     for (const auto &Pair : ArgsToPromote.find(&Arg)->second) {
358       int64_t Offset = Pair.first;
359       const ArgPart &Part = Pair.second;
360 
361       Argument *NewArg = I2++;
362       NewArg->setName(Arg.getName() + "." + Twine(Offset) + ".val");
363 
364       AllocaInst *NewAlloca = IRB.CreateAlloca(
365           Part.Ty, nullptr, Arg.getName() + "." + Twine(Offset) + ".allc");
366       NewAlloca->setAlignment(Pair.second.Alignment);
367       IRB.CreateAlignedStore(NewArg, NewAlloca, Pair.second.Alignment);
368 
369       // Collect the alloca to retarget the users to
370       OffsetToAlloca.insert({Offset, NewAlloca});
371     }
372 
373     auto GetAlloca = [&](Value *Ptr) {
374       APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
375       Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
376                                                    /* AllowNonInbounds */ true);
377       assert(Ptr == &Arg && "Not constant offset from arg?");
378       return OffsetToAlloca.lookup(Offset.getSExtValue());
379     };
380 
381     // Cleanup the code from the dead instructions: GEPs and BitCasts in between
382     // the original argument and its users: loads and stores. Retarget every
383     // user to the new created alloca.
384     SmallVector<Value *, 16> Worklist;
385     SmallVector<Instruction *, 16> DeadInsts;
386     append_range(Worklist, Arg.users());
387     while (!Worklist.empty()) {
388       Value *V = Worklist.pop_back_val();
389       if (isa<GetElementPtrInst>(V)) {
390         DeadInsts.push_back(cast<Instruction>(V));
391         append_range(Worklist, V->users());
392         continue;
393       }
394 
395       if (auto *LI = dyn_cast<LoadInst>(V)) {
396         Value *Ptr = LI->getPointerOperand();
397         LI->setOperand(LoadInst::getPointerOperandIndex(), GetAlloca(Ptr));
398         continue;
399       }
400 
401       if (auto *SI = dyn_cast<StoreInst>(V)) {
402         assert(!SI->isVolatile() && "Volatile operations can't be promoted.");
403         Value *Ptr = SI->getPointerOperand();
404         SI->setOperand(StoreInst::getPointerOperandIndex(), GetAlloca(Ptr));
405         continue;
406       }
407 
408       llvm_unreachable("Unexpected user");
409     }
410 
411     for (Instruction *I : DeadInsts) {
412       I->replaceAllUsesWith(PoisonValue::get(I->getType()));
413       I->eraseFromParent();
414     }
415 
416     // Collect the allocas for promotion
417     for (const auto &Pair : OffsetToAlloca) {
418       assert(isAllocaPromotable(Pair.second) &&
419              "By design, only promotable allocas should be produced.");
420       Allocas.push_back(Pair.second);
421     }
422   }
423 
424   LLVM_DEBUG(dbgs() << "ARG PROMOTION: " << Allocas.size()
425                     << " alloca(s) are promotable by Mem2Reg\n");
426 
427   if (!Allocas.empty()) {
428     // And we are able to call the `promoteMemoryToRegister()` function.
429     // Our earlier checks have ensured that PromoteMemToReg() will
430     // succeed.
431     auto &DT = FAM.getResult<DominatorTreeAnalysis>(*NF);
432     auto &AC = FAM.getResult<AssumptionAnalysis>(*NF);
433     PromoteMemToReg(Allocas, DT, &AC);
434   }
435 
436   return NF;
437 }
438 
439 /// Return true if we can prove that all callees pass in a valid pointer for the
440 /// specified function argument.
441 static bool allCallersPassValidPointerForArgument(
442     Argument *Arg, SmallPtrSetImpl<CallBase *> &RecursiveCalls,
443     Align NeededAlign, uint64_t NeededDerefBytes) {
444   Function *Callee = Arg->getParent();
445   const DataLayout &DL = Callee->getDataLayout();
446   APInt Bytes(64, NeededDerefBytes);
447 
448   // Check if the argument itself is marked dereferenceable and aligned.
449   if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL))
450     return true;
451 
452   // Look at all call sites of the function.  At this point we know we only have
453   // direct callees.
454   return all_of(Callee->users(), [&](User *U) {
455     CallBase &CB = cast<CallBase>(*U);
456     // In case of functions with recursive calls, this check
457     // (isDereferenceableAndAlignedPointer) will fail when it tries to look at
458     // the first caller of this function. The caller may or may not have a load,
459     // incase it doesn't load the pointer being passed, this check will fail.
460     // So, it's safe to skip the check incase we know that we are dealing with a
461     // recursive call. For example we have a IR given below.
462     //
463     // def fun(ptr %a) {
464     //   ...
465     //   %loadres = load i32, ptr %a, align 4
466     //   %res = call i32 @fun(ptr %a)
467     //   ...
468     // }
469     //
470     // def bar(ptr %x) {
471     //   ...
472     //   %resbar = call i32 @fun(ptr %x)
473     //   ...
474     // }
475     //
476     // Since we record processed recursive calls, we check if the current
477     // CallBase has been processed before. If yes it means that it is a
478     // recursive call and we can skip the check just for this call. So, just
479     // return true.
480     if (RecursiveCalls.contains(&CB))
481       return true;
482 
483     return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()),
484                                               NeededAlign, Bytes, DL);
485   });
486 }
487 
488 // Try to prove that all Calls to F do not modify the memory pointed to by Arg,
489 // using alias analysis local to each caller of F.
490 static bool isArgUnmodifiedByAllCalls(Argument *Arg,
491                                       FunctionAnalysisManager &FAM) {
492   for (User *U : Arg->getParent()->users()) {
493 
494     auto *Call = cast<CallBase>(U);
495 
496     MemoryLocation Loc =
497         MemoryLocation::getForArgument(Call, Arg->getArgNo(), nullptr);
498 
499     AAResults &AAR = FAM.getResult<AAManager>(*Call->getFunction());
500     // Bail as soon as we find a Call where Arg may be modified.
501     if (isModSet(AAR.getModRefInfo(Call, Loc)))
502       return false;
503   }
504 
505   // All Users are Calls which do not modify the Arg.
506   return true;
507 }
508 
509 /// Determine that this argument is safe to promote, and find the argument
510 /// parts it can be promoted into.
511 static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
512                          unsigned MaxElements, bool IsRecursive,
513                          SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec,
514                          FunctionAnalysisManager &FAM) {
515   // Quick exit for unused arguments
516   if (Arg->use_empty())
517     return true;
518 
519   // We can only promote this argument if all the uses are loads at known
520   // offsets.
521   //
522   // Promoting the argument causes it to be loaded in the caller
523   // unconditionally. This is only safe if we can prove that either the load
524   // would have happened in the callee anyway (ie, there is a load in the entry
525   // block) or the pointer passed in at every call site is guaranteed to be
526   // valid.
527   // In the former case, invalid loads can happen, but would have happened
528   // anyway, in the latter case, invalid loads won't happen. This prevents us
529   // from introducing an invalid load that wouldn't have happened in the
530   // original code.
531 
532   SmallDenseMap<int64_t, ArgPart, 4> ArgParts;
533   Align NeededAlign(1);
534   uint64_t NeededDerefBytes = 0;
535 
536   // And if this is a byval argument we also allow to have store instructions.
537   // Only handle in such way arguments with specified alignment;
538   // if it's unspecified, the actual alignment of the argument is
539   // target-specific.
540   bool AreStoresAllowed = Arg->getParamByValType() && Arg->getParamAlign();
541 
542   // An end user of a pointer argument is a load or store instruction.
543   // Returns std::nullopt if this load or store is not based on the argument.
544   // Return true if we can promote the instruction, false otherwise.
545   auto HandleEndUser = [&](auto *I, Type *Ty,
546                            bool GuaranteedToExecute) -> std::optional<bool> {
547     // Don't promote volatile or atomic instructions.
548     if (!I->isSimple())
549       return false;
550 
551     Value *Ptr = I->getPointerOperand();
552     APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0);
553     Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset,
554                                                  /* AllowNonInbounds */ true);
555     if (Ptr != Arg)
556       return std::nullopt;
557 
558     if (Offset.getSignificantBits() >= 64)
559       return false;
560 
561     TypeSize Size = DL.getTypeStoreSize(Ty);
562     // Don't try to promote scalable types.
563     if (Size.isScalable())
564       return false;
565 
566     // If this is a recursive function and one of the types is a pointer,
567     // then promoting it might lead to recursive promotion.
568     if (IsRecursive && Ty->isPointerTy())
569       return false;
570 
571     int64_t Off = Offset.getSExtValue();
572     auto Pair = ArgParts.try_emplace(
573         Off, ArgPart{Ty, I->getAlign(), GuaranteedToExecute ? I : nullptr});
574     ArgPart &Part = Pair.first->second;
575     bool OffsetNotSeenBefore = Pair.second;
576 
577     // We limit promotion to only promoting up to a fixed number of elements of
578     // the aggregate.
579     if (MaxElements > 0 && ArgParts.size() > MaxElements) {
580       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
581                         << "more than " << MaxElements << " parts\n");
582       return false;
583     }
584 
585     // For now, we only support loading/storing one specific type at a given
586     // offset.
587     if (Part.Ty != Ty) {
588       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
589                         << "accessed as both " << *Part.Ty << " and " << *Ty
590                         << " at offset " << Off << "\n");
591       return false;
592     }
593 
594     // If this instruction is not guaranteed to execute, and we haven't seen a
595     // load or store at this offset before (or it had lower alignment), then we
596     // need to remember that requirement.
597     // Note that skipping instructions of previously seen offsets is only
598     // correct because we only allow a single type for a given offset, which
599     // also means that the number of accessed bytes will be the same.
600     if (!GuaranteedToExecute &&
601         (OffsetNotSeenBefore || Part.Alignment < I->getAlign())) {
602       // We won't be able to prove dereferenceability for negative offsets.
603       if (Off < 0)
604         return false;
605 
606       // If the offset is not aligned, an aligned base pointer won't help.
607       if (!isAligned(I->getAlign(), Off))
608         return false;
609 
610       NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue());
611       NeededAlign = std::max(NeededAlign, I->getAlign());
612     }
613 
614     Part.Alignment = std::max(Part.Alignment, I->getAlign());
615     return true;
616   };
617 
618   // Look for loads and stores that are guaranteed to execute on entry.
619   for (Instruction &I : Arg->getParent()->getEntryBlock()) {
620     std::optional<bool> Res{};
621     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
622       Res = HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ true);
623     else if (StoreInst *SI = dyn_cast<StoreInst>(&I))
624       Res = HandleEndUser(SI, SI->getValueOperand()->getType(),
625                           /* GuaranteedToExecute */ true);
626     if (Res && !*Res)
627       return false;
628 
629     if (!isGuaranteedToTransferExecutionToSuccessor(&I))
630       break;
631   }
632 
633   // Now look at all loads of the argument. Remember the load instructions
634   // for the aliasing check below.
635   SmallVector<const Use *, 16> Worklist;
636   SmallPtrSet<const Use *, 16> Visited;
637   SmallVector<LoadInst *, 16> Loads;
638   SmallPtrSet<CallBase *, 4> RecursiveCalls;
639   auto AppendUses = [&](const Value *V) {
640     for (const Use &U : V->uses())
641       if (Visited.insert(&U).second)
642         Worklist.push_back(&U);
643   };
644   AppendUses(Arg);
645   while (!Worklist.empty()) {
646     const Use *U = Worklist.pop_back_val();
647     Value *V = U->getUser();
648 
649     if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
650       if (!GEP->hasAllConstantIndices())
651         return false;
652       AppendUses(V);
653       continue;
654     }
655 
656     if (auto *LI = dyn_cast<LoadInst>(V)) {
657       if (!*HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ false))
658         return false;
659       Loads.push_back(LI);
660       continue;
661     }
662 
663     // Stores are allowed for byval arguments
664     auto *SI = dyn_cast<StoreInst>(V);
665     if (AreStoresAllowed && SI &&
666         U->getOperandNo() == StoreInst::getPointerOperandIndex()) {
667       if (!*HandleEndUser(SI, SI->getValueOperand()->getType(),
668                           /* GuaranteedToExecute */ false))
669         return false;
670       continue;
671       // Only stores TO the argument is allowed, all the other stores are
672       // unknown users
673     }
674 
675     auto *CB = dyn_cast<CallBase>(V);
676     Value *PtrArg = U->get();
677     if (CB && CB->getCalledFunction() == CB->getFunction()) {
678       if (PtrArg != Arg) {
679         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
680                           << "pointer offset is not equal to zero\n");
681         return false;
682       }
683 
684       unsigned int ArgNo = Arg->getArgNo();
685       if (U->getOperandNo() != ArgNo) {
686         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
687                           << "arg position is different in callee\n");
688         return false;
689       }
690 
691       // We limit promotion to only promoting up to a fixed number of elements
692       // of the aggregate.
693       if (MaxElements > 0 && ArgParts.size() > MaxElements) {
694         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
695                           << "more than " << MaxElements << " parts\n");
696         return false;
697       }
698 
699       RecursiveCalls.insert(CB);
700       continue;
701     }
702     // Unknown user.
703     LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
704                       << "unknown user " << *V << "\n");
705     return false;
706   }
707 
708   if (NeededDerefBytes || NeededAlign > 1) {
709     // Try to prove a required deref / aligned requirement.
710     if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign,
711                                                NeededDerefBytes)) {
712       LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
713                         << "not dereferenceable or aligned\n");
714       return false;
715     }
716   }
717 
718   if (ArgParts.empty())
719     return true; // No users, this is a dead argument.
720 
721   // Sort parts by offset.
722   append_range(ArgPartsVec, ArgParts);
723   sort(ArgPartsVec, llvm::less_first());
724 
725   // Make sure the parts are non-overlapping.
726   int64_t Offset = ArgPartsVec[0].first;
727   for (const auto &Pair : ArgPartsVec) {
728     if (Pair.first < Offset)
729       return false; // Overlap with previous part.
730 
731     Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty);
732   }
733 
734   // If store instructions are allowed, the path from the entry of the function
735   // to each load may be not free of instructions that potentially invalidate
736   // the load, and this is an admissible situation.
737   if (AreStoresAllowed)
738     return true;
739 
740   // Okay, now we know that the argument is only used by load instructions, and
741   // it is safe to unconditionally perform all of them.
742 
743   // If we can determine that no call to the Function modifies the memory region
744   // accessed through Arg, through alias analysis using actual arguments in the
745   // callers, we know that it is guaranteed to be safe to promote the argument.
746   if (isArgUnmodifiedByAllCalls(Arg, FAM))
747     return true;
748 
749   // Otherwise, use alias analysis to check if the pointer is guaranteed to not
750   // be modified from entry of the function to each of the load instructions.
751   for (LoadInst *Load : Loads) {
752     // Check to see if the load is invalidated from the start of the block to
753     // the load itself.
754     BasicBlock *BB = Load->getParent();
755 
756     MemoryLocation Loc = MemoryLocation::get(Load);
757     if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
758       return false; // Pointer is invalidated!
759 
760     // Now check every path from the entry block to the load for transparency.
761     // To do this, we perform a depth first search on the inverse CFG from the
762     // loading block.
763     for (BasicBlock *P : predecessors(BB)) {
764       for (BasicBlock *TranspBB : inverse_depth_first(P))
765         if (AAR.canBasicBlockModify(*TranspBB, Loc))
766           return false;
767     }
768   }
769 
770   // If the path from the entry of the function to each load is free of
771   // instructions that potentially invalidate the load, we can make the
772   // transformation!
773   return true;
774 }
775 
776 /// Check if callers and callee agree on how promoted arguments would be
777 /// passed.
778 static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
779                                   const TargetTransformInfo &TTI) {
780   return all_of(F.uses(), [&](const Use &U) {
781     CallBase *CB = dyn_cast<CallBase>(U.getUser());
782     if (!CB)
783       return false;
784 
785     const Function *Caller = CB->getCaller();
786     const Function *Callee = CB->getCalledFunction();
787     return TTI.areTypesABICompatible(Caller, Callee, Types);
788   });
789 }
790 
791 /// PromoteArguments - This method checks the specified function to see if there
792 /// are any promotable arguments and if it is safe to promote the function (for
793 /// example, all callers are direct).  If safe to promote some arguments, it
794 /// calls the DoPromotion method.
795 static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
796                                   unsigned MaxElements, bool IsRecursive) {
797   // Don't perform argument promotion for naked functions; otherwise we can end
798   // up removing parameters that are seemingly 'not used' as they are referred
799   // to in the assembly.
800   if (F->hasFnAttribute(Attribute::Naked))
801     return nullptr;
802 
803   // Make sure that it is local to this module.
804   if (!F->hasLocalLinkage())
805     return nullptr;
806 
807   // Don't promote arguments for variadic functions. Adding, removing, or
808   // changing non-pack parameters can change the classification of pack
809   // parameters. Frontends encode that classification at the call site in the
810   // IR, while in the callee the classification is determined dynamically based
811   // on the number of registers consumed so far.
812   if (F->isVarArg())
813     return nullptr;
814 
815   // Don't transform functions that receive inallocas, as the transformation may
816   // not be safe depending on calling convention.
817   if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
818     return nullptr;
819 
820   // First check: see if there are any pointer arguments!  If not, quick exit.
821   SmallVector<Argument *, 16> PointerArgs;
822   for (Argument &I : F->args())
823     if (I.getType()->isPointerTy())
824       PointerArgs.push_back(&I);
825   if (PointerArgs.empty())
826     return nullptr;
827 
828   // Second check: make sure that all callers are direct callers.  We can't
829   // transform functions that have indirect callers.  Also see if the function
830   // is self-recursive.
831   for (Use &U : F->uses()) {
832     CallBase *CB = dyn_cast<CallBase>(U.getUser());
833     // Must be a direct call.
834     if (CB == nullptr || !CB->isCallee(&U) ||
835         CB->getFunctionType() != F->getFunctionType())
836       return nullptr;
837 
838     // Can't change signature of musttail callee
839     if (CB->isMustTailCall())
840       return nullptr;
841 
842     if (CB->getFunction() == F)
843       IsRecursive = true;
844   }
845 
846   // Can't change signature of musttail caller
847   // FIXME: Support promoting whole chain of musttail functions
848   for (BasicBlock &BB : *F)
849     if (BB.getTerminatingMustTailCall())
850       return nullptr;
851 
852   const DataLayout &DL = F->getDataLayout();
853   auto &AAR = FAM.getResult<AAManager>(*F);
854   const auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
855 
856   // Check to see which arguments are promotable.  If an argument is promotable,
857   // add it to ArgsToPromote.
858   DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote;
859   unsigned NumArgsAfterPromote = F->getFunctionType()->getNumParams();
860   for (Argument *PtrArg : PointerArgs) {
861     // Replace sret attribute with noalias. This reduces register pressure by
862     // avoiding a register copy.
863     if (PtrArg->hasStructRetAttr()) {
864       unsigned ArgNo = PtrArg->getArgNo();
865       F->removeParamAttr(ArgNo, Attribute::StructRet);
866       F->addParamAttr(ArgNo, Attribute::NoAlias);
867       for (Use &U : F->uses()) {
868         CallBase &CB = cast<CallBase>(*U.getUser());
869         CB.removeParamAttr(ArgNo, Attribute::StructRet);
870         CB.addParamAttr(ArgNo, Attribute::NoAlias);
871       }
872     }
873 
874     // If we can promote the pointer to its value.
875     SmallVector<OffsetAndArgPart, 4> ArgParts;
876 
877     if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts,
878                      FAM)) {
879       SmallVector<Type *, 4> Types;
880       for (const auto &Pair : ArgParts)
881         Types.push_back(Pair.second.Ty);
882 
883       if (areTypesABICompatible(Types, *F, TTI)) {
884         NumArgsAfterPromote += ArgParts.size() - 1;
885         ArgsToPromote.insert({PtrArg, std::move(ArgParts)});
886       }
887     }
888   }
889 
890   // No promotable pointer arguments.
891   if (ArgsToPromote.empty())
892     return nullptr;
893 
894   if (NumArgsAfterPromote > TTI.getMaxNumArgs())
895     return nullptr;
896 
897   return doPromotion(F, FAM, ArgsToPromote);
898 }
899 
900 PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
901                                              CGSCCAnalysisManager &AM,
902                                              LazyCallGraph &CG,
903                                              CGSCCUpdateResult &UR) {
904   bool Changed = false, LocalChange;
905 
906   // Iterate until we stop promoting from this SCC.
907   do {
908     LocalChange = false;
909 
910     FunctionAnalysisManager &FAM =
911         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
912 
913     bool IsRecursive = C.size() > 1;
914     for (LazyCallGraph::Node &N : C) {
915       Function &OldF = N.getFunction();
916       Function *NewF = promoteArguments(&OldF, FAM, MaxElements, IsRecursive);
917       if (!NewF)
918         continue;
919       LocalChange = true;
920 
921       // Directly substitute the functions in the call graph. Note that this
922       // requires the old function to be completely dead and completely
923       // replaced by the new function. It does no call graph updates, it merely
924       // swaps out the particular function mapped to a particular node in the
925       // graph.
926       C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
927       FAM.clear(OldF, OldF.getName());
928       OldF.eraseFromParent();
929 
930       PreservedAnalyses FuncPA;
931       FuncPA.preserveSet<CFGAnalyses>();
932       for (auto *U : NewF->users()) {
933         auto *UserF = cast<CallBase>(U)->getFunction();
934         FAM.invalidate(*UserF, FuncPA);
935       }
936     }
937 
938     Changed |= LocalChange;
939   } while (LocalChange);
940 
941   if (!Changed)
942     return PreservedAnalyses::all();
943 
944   PreservedAnalyses PA;
945   // We've cleared out analyses for deleted functions.
946   PA.preserve<FunctionAnalysisManagerCGSCCProxy>();
947   // We've manually invalidated analyses for functions we've modified.
948   PA.preserveSet<AllAnalysesOn<Function>>();
949   return PA;
950 }
951