xref: /freebsd-src/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp (revision 0eae32dcef82f6f06de6419a0d623d7def0cc8f6)
1 //===- FunctionSpecialization.cpp - Function Specialization ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This specialises functions with constant parameters (e.g. functions,
10 // globals). Constant parameters like function pointers and constant globals
11 // are propagated to the callee by specializing the function.
12 //
13 // Current limitations:
14 // - It does not yet handle integer ranges.
15 // - Only 1 argument per function is specialised,
16 // - The cost-model could be further looked into,
17 // - We are not yet caching analysis results.
18 //
19 // Ideas:
20 // - With a function specialization attribute for arguments, we could have
21 //   a direct way to steer function specialization, avoiding the cost-model,
22 //   and thus control compile-times / code-size.
23 //
24 // Todos:
25 // - Specializing recursive functions relies on running the transformation a
26 //   number of times, which is controlled by option
27 //   `func-specialization-max-iters`. Thus, increasing this value and the
28 //   number of iterations, will linearly increase the number of times recursive
29 //   functions get specialized, see also the discussion in
30 //   https://reviews.llvm.org/D106426 for details. Perhaps there is a
31 //   compile-time friendlier way to control/limit the number of specialisations
32 //   for recursive functions.
33 // - Don't transform the function if there is no function specialization
34 //   happens.
35 //
36 //===----------------------------------------------------------------------===//
37 
38 #include "llvm/ADT/Statistic.h"
39 #include "llvm/Analysis/AssumptionCache.h"
40 #include "llvm/Analysis/CodeMetrics.h"
41 #include "llvm/Analysis/DomTreeUpdater.h"
42 #include "llvm/Analysis/InlineCost.h"
43 #include "llvm/Analysis/LoopInfo.h"
44 #include "llvm/Analysis/TargetLibraryInfo.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/Transforms/Scalar/SCCP.h"
47 #include "llvm/Transforms/Utils/Cloning.h"
48 #include "llvm/Transforms/Utils/SizeOpts.h"
49 #include <cmath>
50 
51 using namespace llvm;
52 
53 #define DEBUG_TYPE "function-specialization"
54 
55 STATISTIC(NumFuncSpecialized, "Number of functions specialized");
56 
57 static cl::opt<bool> ForceFunctionSpecialization(
58     "force-function-specialization", cl::init(false), cl::Hidden,
59     cl::desc("Force function specialization for every call site with a "
60              "constant argument"));
61 
62 static cl::opt<unsigned> FuncSpecializationMaxIters(
63     "func-specialization-max-iters", cl::Hidden,
64     cl::desc("The maximum number of iterations function specialization is run"),
65     cl::init(1));
66 
67 static cl::opt<unsigned> MaxClonesThreshold(
68     "func-specialization-max-clones", cl::Hidden,
69     cl::desc("The maximum number of clones allowed for a single function "
70              "specialization"),
71     cl::init(3));
72 
73 static cl::opt<unsigned> SmallFunctionThreshold(
74     "func-specialization-size-threshold", cl::Hidden,
75     cl::desc("Don't specialize functions that have less than this theshold "
76              "number of instructions"),
77     cl::init(100));
78 
79 static cl::opt<unsigned>
80     AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden,
81                           cl::desc("Average loop iteration count cost"),
82                           cl::init(10));
83 
84 static cl::opt<bool> SpecializeOnAddresses(
85     "func-specialization-on-address", cl::init(false), cl::Hidden,
86     cl::desc("Enable function specialization on the address of global values"));
87 
88 // TODO: This needs checking to see the impact on compile-times, which is why
89 // this is off by default for now.
90 static cl::opt<bool> EnableSpecializationForLiteralConstant(
91     "function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
92     cl::desc("Enable specialization of functions that take a literal constant "
93              "as an argument."));
94 
95 namespace {
96 // Bookkeeping struct to pass data from the analysis and profitability phase
97 // to the actual transform helper functions.
98 struct ArgInfo {
99   Function *Fn;         // The function to perform specialisation on.
100   Argument *Arg;        // The Formal argument being analysed.
101   Constant *Const;      // A corresponding actual constant argument.
102   InstructionCost Gain; // Profitability: Gain = Bonus - Cost.
103 
104   // Flag if this will be a partial specialization, in which case we will need
105   // to keep the original function around in addition to the added
106   // specializations.
107   bool Partial = false;
108 
109   ArgInfo(Function *F, Argument *A, Constant *C, InstructionCost G)
110       : Fn(F), Arg(A), Const(C), Gain(G){};
111 };
112 } // Anonymous namespace
113 
114 using FuncList = SmallVectorImpl<Function *>;
115 using ConstList = SmallVectorImpl<Constant *>;
116 
117 // Helper to check if \p LV is either a constant or a constant
118 // range with a single element. This should cover exactly the same cases as the
119 // old ValueLatticeElement::isConstant() and is intended to be used in the
120 // transition to ValueLatticeElement.
121 static bool isConstant(const ValueLatticeElement &LV) {
122   return LV.isConstant() ||
123          (LV.isConstantRange() && LV.getConstantRange().isSingleElement());
124 }
125 
126 // Helper to check if \p LV is either overdefined or a constant int.
127 static bool isOverdefined(const ValueLatticeElement &LV) {
128   return !LV.isUnknownOrUndef() && !isConstant(LV);
129 }
130 
131 static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) {
132   Value *StoreValue = nullptr;
133   for (auto *User : Alloca->users()) {
134     // We can't use llvm::isAllocaPromotable() as that would fail because of
135     // the usage in the CallInst, which is what we check here.
136     if (User == Call)
137       continue;
138     if (auto *Bitcast = dyn_cast<BitCastInst>(User)) {
139       if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call)
140         return nullptr;
141       continue;
142     }
143 
144     if (auto *Store = dyn_cast<StoreInst>(User)) {
145       // This is a duplicate store, bail out.
146       if (StoreValue || Store->isVolatile())
147         return nullptr;
148       StoreValue = Store->getValueOperand();
149       continue;
150     }
151     // Bail if there is any other unknown usage.
152     return nullptr;
153   }
154   return dyn_cast_or_null<Constant>(StoreValue);
155 }
156 
157 // A constant stack value is an AllocaInst that has a single constant
158 // value stored to it. Return this constant if such an alloca stack value
159 // is a function argument.
160 static Constant *getConstantStackValue(CallInst *Call, Value *Val,
161                                        SCCPSolver &Solver) {
162   if (!Val)
163     return nullptr;
164   Val = Val->stripPointerCasts();
165   if (auto *ConstVal = dyn_cast<ConstantInt>(Val))
166     return ConstVal;
167   auto *Alloca = dyn_cast<AllocaInst>(Val);
168   if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy())
169     return nullptr;
170   return getPromotableAlloca(Alloca, Call);
171 }
172 
173 // To support specializing recursive functions, it is important to propagate
174 // constant arguments because after a first iteration of specialisation, a
175 // reduced example may look like this:
176 //
177 //     define internal void @RecursiveFn(i32* arg1) {
178 //       %temp = alloca i32, align 4
179 //       store i32 2 i32* %temp, align 4
180 //       call void @RecursiveFn.1(i32* nonnull %temp)
181 //       ret void
182 //     }
183 //
184 // Before a next iteration, we need to propagate the constant like so
185 // which allows further specialization in next iterations.
186 //
187 //     @funcspec.arg = internal constant i32 2
188 //
189 //     define internal void @someFunc(i32* arg1) {
190 //       call void @otherFunc(i32* nonnull @funcspec.arg)
191 //       ret void
192 //     }
193 //
194 static void constantArgPropagation(FuncList &WorkList,
195                                    Module &M, SCCPSolver &Solver) {
196   // Iterate over the argument tracked functions see if there
197   // are any new constant values for the call instruction via
198   // stack variables.
199   for (auto *F : WorkList) {
200     // TODO: Generalize for any read only arguments.
201     if (F->arg_size() != 1)
202       continue;
203 
204     auto &Arg = *F->arg_begin();
205     if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy())
206       continue;
207 
208     for (auto *User : F->users()) {
209       auto *Call = dyn_cast<CallInst>(User);
210       if (!Call)
211         break;
212       auto *ArgOp = Call->getArgOperand(0);
213       auto *ArgOpType = ArgOp->getType();
214       auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
215       if (!ConstVal)
216         break;
217 
218       Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
219                                      GlobalValue::InternalLinkage, ConstVal,
220                                      "funcspec.arg");
221 
222       if (ArgOpType != ConstVal->getType())
223         GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType());
224 
225       Call->setArgOperand(0, GV);
226 
227       // Add the changed CallInst to Solver Worklist
228       Solver.visitCall(*Call);
229     }
230   }
231 }
232 
233 // ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics
234 // interfere with the constantArgPropagation optimization.
235 static void removeSSACopy(Function &F) {
236   for (BasicBlock &BB : F) {
237     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
238       auto *II = dyn_cast<IntrinsicInst>(&Inst);
239       if (!II)
240         continue;
241       if (II->getIntrinsicID() != Intrinsic::ssa_copy)
242         continue;
243       Inst.replaceAllUsesWith(II->getOperand(0));
244       Inst.eraseFromParent();
245     }
246   }
247 }
248 
249 static void removeSSACopy(Module &M) {
250   for (Function &F : M)
251     removeSSACopy(F);
252 }
253 
254 namespace {
255 class FunctionSpecializer {
256 
257   /// The IPSCCP Solver.
258   SCCPSolver &Solver;
259 
260   /// Analyses used to help determine if a function should be specialized.
261   std::function<AssumptionCache &(Function &)> GetAC;
262   std::function<TargetTransformInfo &(Function &)> GetTTI;
263   std::function<TargetLibraryInfo &(Function &)> GetTLI;
264 
265   SmallPtrSet<Function *, 2> SpecializedFuncs;
266 
267 public:
268   FunctionSpecializer(SCCPSolver &Solver,
269                       std::function<AssumptionCache &(Function &)> GetAC,
270                       std::function<TargetTransformInfo &(Function &)> GetTTI,
271                       std::function<TargetLibraryInfo &(Function &)> GetTLI)
272       : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {}
273 
274   /// Attempt to specialize functions in the module to enable constant
275   /// propagation across function boundaries.
276   ///
277   /// \returns true if at least one function is specialized.
278   bool
279   specializeFunctions(FuncList &FuncDecls,
280                       FuncList &CurrentSpecializations) {
281     bool Changed = false;
282     for (auto *F : FuncDecls) {
283       if (!isCandidateFunction(F, CurrentSpecializations))
284         continue;
285 
286       auto Cost = getSpecializationCost(F);
287       if (!Cost.isValid()) {
288         LLVM_DEBUG(
289             dbgs() << "FnSpecialization: Invalid specialisation cost.\n");
290         continue;
291       }
292 
293       auto ConstArgs = calculateGains(F, Cost);
294       if (ConstArgs.empty()) {
295         LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n");
296         continue;
297       }
298 
299       for (auto &CA : ConstArgs) {
300         specializeFunction(CA, CurrentSpecializations);
301         Changed = true;
302       }
303     }
304 
305     updateSpecializedFuncs(FuncDecls, CurrentSpecializations);
306     NumFuncSpecialized += NbFunctionsSpecialized;
307     return Changed;
308   }
309 
310   bool tryToReplaceWithConstant(Value *V) {
311     if (!V->getType()->isSingleValueType() || isa<CallBase>(V) ||
312         V->user_empty())
313       return false;
314 
315     const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
316     if (isOverdefined(IV))
317       return false;
318     auto *Const =
319         isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
320     V->replaceAllUsesWith(Const);
321 
322     for (auto *U : Const->users())
323       if (auto *I = dyn_cast<Instruction>(U))
324         if (Solver.isBlockExecutable(I->getParent()))
325           Solver.visit(I);
326 
327     // Remove the instruction from Block and Solver.
328     if (auto *I = dyn_cast<Instruction>(V)) {
329       if (I->isSafeToRemove()) {
330         I->eraseFromParent();
331         Solver.removeLatticeValueFor(I);
332       }
333     }
334     return true;
335   }
336 
337 private:
338   // The number of functions specialised, used for collecting statistics and
339   // also in the cost model.
340   unsigned NbFunctionsSpecialized = 0;
341 
342   /// Clone the function \p F and remove the ssa_copy intrinsics added by
343   /// the SCCPSolver in the cloned version.
344   Function *cloneCandidateFunction(Function *F) {
345     ValueToValueMapTy EmptyMap;
346     Function *Clone = CloneFunction(F, EmptyMap);
347     removeSSACopy(*Clone);
348     return Clone;
349   }
350 
351   /// This function decides whether it's worthwhile to specialize function \p F
352   /// based on the known constant values its arguments can take on, i.e. it
353   /// calculates a gain and returns a list of actual arguments that are deemed
354   /// profitable to specialize. Specialization is performed on the first
355   /// interesting argument. Specializations based on additional arguments will
356   /// be evaluated on following iterations of the main IPSCCP solve loop.
357   SmallVector<ArgInfo> calculateGains(Function *F, InstructionCost Cost) {
358     SmallVector<ArgInfo> Worklist;
359     // Determine if we should specialize the function based on the values the
360     // argument can take on. If specialization is not profitable, we continue
361     // on to the next argument.
362     for (Argument &FormalArg : F->args()) {
363       LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing arg: "
364                         << FormalArg.getName() << "\n");
365       // Determine if this argument is interesting. If we know the argument can
366       // take on any constant values, they are collected in Constants. If the
367       // argument can only ever equal a constant value in Constants, the
368       // function will be completely specialized, and the IsPartial flag will
369       // be set to false by isArgumentInteresting (that function only adds
370       // values to the Constants list that are deemed profitable).
371       bool IsPartial = true;
372       SmallVector<Constant *> ActualConstArg;
373       if (!isArgumentInteresting(&FormalArg, ActualConstArg, IsPartial)) {
374         LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is not interesting\n");
375         continue;
376       }
377 
378       for (auto *ActualArg : ActualConstArg) {
379         InstructionCost Gain =
380             ForceFunctionSpecialization
381                 ? 1
382                 : getSpecializationBonus(&FormalArg, ActualArg) - Cost;
383 
384         if (Gain <= 0)
385           continue;
386         Worklist.push_back({F, &FormalArg, ActualArg, Gain});
387       }
388 
389       if (Worklist.empty())
390         continue;
391 
392       // Sort the candidates in descending order.
393       llvm::stable_sort(Worklist, [](const ArgInfo &L, const ArgInfo &R) {
394         return L.Gain > R.Gain;
395       });
396 
397       // Truncate the worklist to 'MaxClonesThreshold' candidates if
398       // necessary.
399       if (Worklist.size() > MaxClonesThreshold) {
400         LLVM_DEBUG(dbgs() << "FnSpecialization: number of candidates exceed "
401                     << "the maximum number of clones threshold.\n"
402                     << "Truncating worklist to " << MaxClonesThreshold
403                     << " candidates.\n");
404         Worklist.erase(Worklist.begin() + MaxClonesThreshold,
405                        Worklist.end());
406       }
407 
408       if (IsPartial || Worklist.size() < ActualConstArg.size())
409         for (auto &ActualArg : Worklist)
410           ActualArg.Partial = true;
411 
412       LLVM_DEBUG(dbgs() << "Sorted list of candidates by gain:\n";
413                  for (auto &C
414                       : Worklist) {
415                    dbgs() << "- Function = " << C.Fn->getName() << ", ";
416                    dbgs() << "FormalArg = " << C.Arg->getName() << ", ";
417                    dbgs() << "ActualArg = " << C.Const->getName() << ", ";
418                    dbgs() << "Gain = " << C.Gain << "\n";
419                  });
420 
421       // FIXME: Only one argument per function.
422       break;
423     }
424     return Worklist;
425   }
426 
427   bool isCandidateFunction(Function *F, FuncList &Specializations) {
428     // Do not specialize the cloned function again.
429     if (SpecializedFuncs.contains(F))
430       return false;
431 
432     // If we're optimizing the function for size, we shouldn't specialize it.
433     if (F->hasOptSize() ||
434         shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass))
435       return false;
436 
437     // Exit if the function is not executable. There's no point in specializing
438     // a dead function.
439     if (!Solver.isBlockExecutable(&F->getEntryBlock()))
440       return false;
441 
442     // It wastes time to specialize a function which would get inlined finally.
443     if (F->hasFnAttribute(Attribute::AlwaysInline))
444       return false;
445 
446     LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName()
447                       << "\n");
448     return true;
449   }
450 
451   void specializeFunction(ArgInfo &AI, FuncList &Specializations) {
452     Function *Clone = cloneCandidateFunction(AI.Fn);
453     Argument *ClonedArg = Clone->getArg(AI.Arg->getArgNo());
454 
455     // Rewrite calls to the function so that they call the clone instead.
456     rewriteCallSites(AI.Fn, Clone, *ClonedArg, AI.Const);
457 
458     // Initialize the lattice state of the arguments of the function clone,
459     // marking the argument on which we specialized the function constant
460     // with the given value.
461     Solver.markArgInFuncSpecialization(AI.Fn, ClonedArg, AI.Const);
462 
463     // Mark all the specialized functions
464     Specializations.push_back(Clone);
465     NbFunctionsSpecialized++;
466 
467     // If the function has been completely specialized, the original function
468     // is no longer needed. Mark it unreachable.
469     if (!AI.Partial)
470       Solver.markFunctionUnreachable(AI.Fn);
471   }
472 
473   /// Compute and return the cost of specializing function \p F.
474   InstructionCost getSpecializationCost(Function *F) {
475     // Compute the code metrics for the function.
476     SmallPtrSet<const Value *, 32> EphValues;
477     CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
478     CodeMetrics Metrics;
479     for (BasicBlock &BB : *F)
480       Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
481 
482     // If the code metrics reveal that we shouldn't duplicate the function, we
483     // shouldn't specialize it. Set the specialization cost to Invalid.
484     // Or if the lines of codes implies that this function is easy to get
485     // inlined so that we shouldn't specialize it.
486     if (Metrics.notDuplicatable ||
487         (!ForceFunctionSpecialization &&
488          Metrics.NumInsts < SmallFunctionThreshold)) {
489       InstructionCost C{};
490       C.setInvalid();
491       return C;
492     }
493 
494     // Otherwise, set the specialization cost to be the cost of all the
495     // instructions in the function and penalty for specializing more functions.
496     unsigned Penalty = NbFunctionsSpecialized + 1;
497     return Metrics.NumInsts * InlineConstants::InstrCost * Penalty;
498   }
499 
500   InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI,
501                                LoopInfo &LI) {
502     auto *I = dyn_cast_or_null<Instruction>(U);
503     // If not an instruction we do not know how to evaluate.
504     // Keep minimum possible cost for now so that it doesnt affect
505     // specialization.
506     if (!I)
507       return std::numeric_limits<unsigned>::min();
508 
509     auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency);
510 
511     // Traverse recursively if there are more uses.
512     // TODO: Any other instructions to be added here?
513     if (I->mayReadFromMemory() || I->isCast())
514       for (auto *User : I->users())
515         Cost += getUserBonus(User, TTI, LI);
516 
517     // Increase the cost if it is inside the loop.
518     auto LoopDepth = LI.getLoopDepth(I->getParent());
519     Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth);
520     return Cost;
521   }
522 
523   /// Compute a bonus for replacing argument \p A with constant \p C.
524   InstructionCost getSpecializationBonus(Argument *A, Constant *C) {
525     Function *F = A->getParent();
526     DominatorTree DT(*F);
527     LoopInfo LI(DT);
528     auto &TTI = (GetTTI)(*F);
529     LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for: " << *A
530                       << "\n");
531 
532     InstructionCost TotalCost = 0;
533     for (auto *U : A->users()) {
534       TotalCost += getUserBonus(U, TTI, LI);
535       LLVM_DEBUG(dbgs() << "FnSpecialization: User cost ";
536                  TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
537     }
538 
539     // The below heuristic is only concerned with exposing inlining
540     // opportunities via indirect call promotion. If the argument is not a
541     // function pointer, give up.
542     if (!isa<PointerType>(A->getType()) ||
543         !isa<FunctionType>(A->getType()->getPointerElementType()))
544       return TotalCost;
545 
546     // Since the argument is a function pointer, its incoming constant values
547     // should be functions or constant expressions. The code below attempts to
548     // look through cast expressions to find the function that will be called.
549     Value *CalledValue = C;
550     while (isa<ConstantExpr>(CalledValue) &&
551            cast<ConstantExpr>(CalledValue)->isCast())
552       CalledValue = cast<User>(CalledValue)->getOperand(0);
553     Function *CalledFunction = dyn_cast<Function>(CalledValue);
554     if (!CalledFunction)
555       return TotalCost;
556 
557     // Get TTI for the called function (used for the inline cost).
558     auto &CalleeTTI = (GetTTI)(*CalledFunction);
559 
560     // Look at all the call sites whose called value is the argument.
561     // Specializing the function on the argument would allow these indirect
562     // calls to be promoted to direct calls. If the indirect call promotion
563     // would likely enable the called function to be inlined, specializing is a
564     // good idea.
565     int Bonus = 0;
566     for (User *U : A->users()) {
567       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
568         continue;
569       auto *CS = cast<CallBase>(U);
570       if (CS->getCalledOperand() != A)
571         continue;
572 
573       // Get the cost of inlining the called function at this call site. Note
574       // that this is only an estimate. The called function may eventually
575       // change in a way that leads to it not being inlined here, even though
576       // inlining looks profitable now. For example, one of its called
577       // functions may be inlined into it, making the called function too large
578       // to be inlined into this call site.
579       //
580       // We apply a boost for performing indirect call promotion by increasing
581       // the default threshold by the threshold for indirect calls.
582       auto Params = getInlineParams();
583       Params.DefaultThreshold += InlineConstants::IndirectCallThreshold;
584       InlineCost IC =
585           getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI);
586 
587       // We clamp the bonus for this call to be between zero and the default
588       // threshold.
589       if (IC.isAlways())
590         Bonus += Params.DefaultThreshold;
591       else if (IC.isVariable() && IC.getCostDelta() > 0)
592         Bonus += IC.getCostDelta();
593     }
594 
595     return TotalCost + Bonus;
596   }
597 
598   /// Determine if we should specialize a function based on the incoming values
599   /// of the given argument.
600   ///
601   /// This function implements the goal-directed heuristic. It determines if
602   /// specializing the function based on the incoming values of argument \p A
603   /// would result in any significant optimization opportunities. If
604   /// optimization opportunities exist, the constant values of \p A on which to
605   /// specialize the function are collected in \p Constants. If the values in
606   /// \p Constants represent the complete set of values that \p A can take on,
607   /// the function will be completely specialized, and the \p IsPartial flag is
608   /// set to false.
609   ///
610   /// \returns true if the function should be specialized on the given
611   /// argument.
612   bool isArgumentInteresting(Argument *A, ConstList &Constants,
613                              bool &IsPartial) {
614     // For now, don't attempt to specialize functions based on the values of
615     // composite types.
616     if (!A->getType()->isSingleValueType() || A->user_empty())
617       return false;
618 
619     // If the argument isn't overdefined, there's nothing to do. It should
620     // already be constant.
621     if (!Solver.getLatticeValueFor(A).isOverdefined()) {
622       LLVM_DEBUG(dbgs() << "FnSpecialization: nothing to do, arg is already "
623                         << "constant?\n");
624       return false;
625     }
626 
627     // Collect the constant values that the argument can take on. If the
628     // argument can't take on any constant values, we aren't going to
629     // specialize the function. While it's possible to specialize the function
630     // based on non-constant arguments, there's likely not much benefit to
631     // constant propagation in doing so.
632     //
633     // TODO 1: currently it won't specialize if there are over the threshold of
634     // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it
635     // might be beneficial to take the occurrences into account in the cost
636     // model, so we would need to find the unique constants.
637     //
638     // TODO 2: this currently does not support constants, i.e. integer ranges.
639     //
640     IsPartial = !getPossibleConstants(A, Constants);
641     LLVM_DEBUG(dbgs() << "FnSpecialization: interesting arg: " << *A << "\n");
642     return true;
643   }
644 
645   /// Collect in \p Constants all the constant values that argument \p A can
646   /// take on.
647   ///
648   /// \returns true if all of the values the argument can take on are constant
649   /// (e.g., the argument's parent function cannot be called with an
650   /// overdefined value).
651   bool getPossibleConstants(Argument *A, ConstList &Constants) {
652     Function *F = A->getParent();
653     bool AllConstant = true;
654 
655     // Iterate over all the call sites of the argument's parent function.
656     for (User *U : F->users()) {
657       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
658         continue;
659       auto &CS = *cast<CallBase>(U);
660       // If the call site has attribute minsize set, that callsite won't be
661       // specialized.
662       if (CS.hasFnAttr(Attribute::MinSize)) {
663         AllConstant = false;
664         continue;
665       }
666 
667       // If the parent of the call site will never be executed, we don't need
668       // to worry about the passed value.
669       if (!Solver.isBlockExecutable(CS.getParent()))
670         continue;
671 
672       auto *V = CS.getArgOperand(A->getArgNo());
673       if (isa<PoisonValue>(V))
674         return false;
675 
676       // For now, constant expressions are fine but only if they are function
677       // calls.
678       if (auto *CE = dyn_cast<ConstantExpr>(V))
679         if (!isa<Function>(CE->getOperand(0)))
680           return false;
681 
682       // TrackValueOfGlobalVariable only tracks scalar global variables.
683       if (auto *GV = dyn_cast<GlobalVariable>(V)) {
684         // Check if we want to specialize on the address of non-constant
685         // global values.
686         if (!GV->isConstant())
687           if (!SpecializeOnAddresses)
688             return false;
689 
690         if (!GV->getValueType()->isSingleValueType())
691           return false;
692       }
693 
694       if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() ||
695                                EnableSpecializationForLiteralConstant))
696         Constants.push_back(cast<Constant>(V));
697       else
698         AllConstant = false;
699     }
700 
701     // If the argument can only take on constant values, AllConstant will be
702     // true.
703     return AllConstant;
704   }
705 
706   /// Rewrite calls to function \p F to call function \p Clone instead.
707   ///
708   /// This function modifies calls to function \p F whose argument at index \p
709   /// ArgNo is equal to constant \p C. The calls are rewritten to call function
710   /// \p Clone instead.
711   ///
712   /// Callsites that have been marked with the MinSize function attribute won't
713   /// be specialized and rewritten.
714   void rewriteCallSites(Function *F, Function *Clone, Argument &Arg,
715                         Constant *C) {
716     unsigned ArgNo = Arg.getArgNo();
717     SmallVector<CallBase *, 4> CallSitesToRewrite;
718     for (auto *U : F->users()) {
719       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
720         continue;
721       auto &CS = *cast<CallBase>(U);
722       if (!CS.getCalledFunction() || CS.getCalledFunction() != F)
723         continue;
724       CallSitesToRewrite.push_back(&CS);
725     }
726     for (auto *CS : CallSitesToRewrite) {
727       if ((CS->getFunction() == Clone && CS->getArgOperand(ArgNo) == &Arg) ||
728           CS->getArgOperand(ArgNo) == C) {
729         CS->setCalledFunction(Clone);
730         Solver.markOverdefined(CS);
731       }
732     }
733   }
734 
735   void updateSpecializedFuncs(FuncList &FuncDecls,
736                               FuncList &CurrentSpecializations) {
737     for (auto *SpecializedFunc : CurrentSpecializations) {
738       SpecializedFuncs.insert(SpecializedFunc);
739 
740       // Initialize the state of the newly created functions, marking them
741       // argument-tracked and executable.
742       if (SpecializedFunc->hasExactDefinition() &&
743           !SpecializedFunc->hasFnAttribute(Attribute::Naked))
744         Solver.addTrackedFunction(SpecializedFunc);
745 
746       Solver.addArgumentTrackedFunction(SpecializedFunc);
747       FuncDecls.push_back(SpecializedFunc);
748       Solver.markBlockExecutable(&SpecializedFunc->front());
749 
750       // Replace the function arguments for the specialized functions.
751       for (Argument &Arg : SpecializedFunc->args())
752         if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg))
753           LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: "
754                             << Arg.getName() << "\n");
755     }
756   }
757 };
758 } // namespace
759 
760 bool llvm::runFunctionSpecialization(
761     Module &M, const DataLayout &DL,
762     std::function<TargetLibraryInfo &(Function &)> GetTLI,
763     std::function<TargetTransformInfo &(Function &)> GetTTI,
764     std::function<AssumptionCache &(Function &)> GetAC,
765     function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) {
766   SCCPSolver Solver(DL, GetTLI, M.getContext());
767   FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI);
768   bool Changed = false;
769 
770   // Loop over all functions, marking arguments to those with their addresses
771   // taken or that are external as overdefined.
772   for (Function &F : M) {
773     if (F.isDeclaration())
774       continue;
775     if (F.hasFnAttribute(Attribute::NoDuplicate))
776       continue;
777 
778     LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName()
779                       << "\n");
780     Solver.addAnalysis(F, GetAnalysis(F));
781 
782     // Determine if we can track the function's arguments. If so, add the
783     // function to the solver's set of argument-tracked functions.
784     if (canTrackArgumentsInterprocedurally(&F)) {
785       LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n");
786       Solver.addArgumentTrackedFunction(&F);
787       continue;
788     } else {
789       LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n"
790                         << "FnSpecialization: Doesn't have local linkage, or "
791                         << "has its address taken\n");
792     }
793 
794     // Assume the function is called.
795     Solver.markBlockExecutable(&F.front());
796 
797     // Assume nothing about the incoming arguments.
798     for (Argument &AI : F.args())
799       Solver.markOverdefined(&AI);
800   }
801 
802   // Determine if we can track any of the module's global variables. If so, add
803   // the global variables we can track to the solver's set of tracked global
804   // variables.
805   for (GlobalVariable &G : M.globals()) {
806     G.removeDeadConstantUsers();
807     if (canTrackGlobalVariableInterprocedurally(&G))
808       Solver.trackValueOfGlobalVariable(&G);
809   }
810 
811   auto &TrackedFuncs = Solver.getArgumentTrackedFunctions();
812   SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(),
813                                         TrackedFuncs.end());
814 
815   // No tracked functions, so nothing to do: don't run the solver and remove
816   // the ssa_copy intrinsics that may have been introduced.
817   if (TrackedFuncs.empty()) {
818     removeSSACopy(M);
819     return false;
820   }
821 
822   // Solve for constants.
823   auto RunSCCPSolver = [&](auto &WorkList) {
824     bool ResolvedUndefs = true;
825 
826     while (ResolvedUndefs) {
827       // Not running the solver unnecessary is checked in regression test
828       // nothing-to-do.ll, so if this debug message is changed, this regression
829       // test needs updating too.
830       LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n");
831 
832       Solver.solve();
833       LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n");
834       ResolvedUndefs = false;
835       for (Function *F : WorkList)
836         if (Solver.resolvedUndefsIn(*F))
837           ResolvedUndefs = true;
838     }
839 
840     for (auto *F : WorkList) {
841       for (BasicBlock &BB : *F) {
842         if (!Solver.isBlockExecutable(&BB))
843           continue;
844         // FIXME: The solver may make changes to the function here, so set
845         // Changed, even if later function specialization does not trigger.
846         for (auto &I : make_early_inc_range(BB))
847           Changed |= FS.tryToReplaceWithConstant(&I);
848       }
849     }
850   };
851 
852 #ifndef NDEBUG
853   LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n");
854   for (auto *F : FuncDecls)
855     LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n");
856 #endif
857 
858   // Initially resolve the constants in all the argument tracked functions.
859   RunSCCPSolver(FuncDecls);
860 
861   SmallVector<Function *, 2> CurrentSpecializations;
862   unsigned I = 0;
863   while (FuncSpecializationMaxIters != I++ &&
864          FS.specializeFunctions(FuncDecls, CurrentSpecializations)) {
865 
866     // Run the solver for the specialized functions.
867     RunSCCPSolver(CurrentSpecializations);
868 
869     // Replace some unresolved constant arguments.
870     constantArgPropagation(FuncDecls, M, Solver);
871 
872     CurrentSpecializations.clear();
873     Changed = true;
874   }
875 
876   // Clean up the IR by removing ssa_copy intrinsics.
877   removeSSACopy(M);
878   return Changed;
879 }
880