1*73471bf0Spatrick //===- FunctionSpecialization.cpp - Function Specialization ---------------===// 2*73471bf0Spatrick // 3*73471bf0Spatrick // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*73471bf0Spatrick // See https://llvm.org/LICENSE.txt for license information. 5*73471bf0Spatrick // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*73471bf0Spatrick // 7*73471bf0Spatrick //===----------------------------------------------------------------------===// 8*73471bf0Spatrick // 9*73471bf0Spatrick // This specialises functions with constant parameters (e.g. functions, 10*73471bf0Spatrick // globals). Constant parameters like function pointers and constant globals 11*73471bf0Spatrick // are propagated to the callee by specializing the function. 12*73471bf0Spatrick // 13*73471bf0Spatrick // Current limitations: 14*73471bf0Spatrick // - It does not handle specialization of recursive functions, 15*73471bf0Spatrick // - It does not yet handle integer ranges. 16*73471bf0Spatrick // - Only 1 argument per function is specialised, 17*73471bf0Spatrick // - The cost-model could be further looked into, 18*73471bf0Spatrick // - We are not yet caching analysis results. 19*73471bf0Spatrick // 20*73471bf0Spatrick // Ideas: 21*73471bf0Spatrick // - With a function specialization attribute for arguments, we could have 22*73471bf0Spatrick // a direct way to steer function specialization, avoiding the cost-model, 23*73471bf0Spatrick // and thus control compile-times / code-size. 24*73471bf0Spatrick // 25*73471bf0Spatrick //===----------------------------------------------------------------------===// 26*73471bf0Spatrick 27*73471bf0Spatrick #include "llvm/ADT/Statistic.h" 28*73471bf0Spatrick #include "llvm/Analysis/AssumptionCache.h" 29*73471bf0Spatrick #include "llvm/Analysis/CodeMetrics.h" 30*73471bf0Spatrick #include "llvm/Analysis/DomTreeUpdater.h" 31*73471bf0Spatrick #include "llvm/Analysis/InlineCost.h" 32*73471bf0Spatrick #include "llvm/Analysis/LoopInfo.h" 33*73471bf0Spatrick #include "llvm/Analysis/TargetLibraryInfo.h" 34*73471bf0Spatrick #include "llvm/Analysis/TargetTransformInfo.h" 35*73471bf0Spatrick #include "llvm/Transforms/Scalar/SCCP.h" 36*73471bf0Spatrick #include "llvm/Transforms/Utils/Cloning.h" 37*73471bf0Spatrick #include "llvm/Transforms/Utils/SizeOpts.h" 38*73471bf0Spatrick #include <cmath> 39*73471bf0Spatrick 40*73471bf0Spatrick using namespace llvm; 41*73471bf0Spatrick 42*73471bf0Spatrick #define DEBUG_TYPE "function-specialization" 43*73471bf0Spatrick 44*73471bf0Spatrick STATISTIC(NumFuncSpecialized, "Number of functions specialized"); 45*73471bf0Spatrick 46*73471bf0Spatrick static cl::opt<bool> ForceFunctionSpecialization( 47*73471bf0Spatrick "force-function-specialization", cl::init(false), cl::Hidden, 48*73471bf0Spatrick cl::desc("Force function specialization for every call site with a " 49*73471bf0Spatrick "constant argument")); 50*73471bf0Spatrick 51*73471bf0Spatrick static cl::opt<unsigned> FuncSpecializationMaxIters( 52*73471bf0Spatrick "func-specialization-max-iters", cl::Hidden, 53*73471bf0Spatrick cl::desc("The maximum number of iterations function specialization is run"), 54*73471bf0Spatrick cl::init(1)); 55*73471bf0Spatrick 56*73471bf0Spatrick static cl::opt<unsigned> MaxConstantsThreshold( 57*73471bf0Spatrick "func-specialization-max-constants", cl::Hidden, 58*73471bf0Spatrick cl::desc("The maximum number of clones allowed for a single function " 59*73471bf0Spatrick "specialization"), 60*73471bf0Spatrick cl::init(3)); 61*73471bf0Spatrick 62*73471bf0Spatrick static cl::opt<unsigned> 63*73471bf0Spatrick AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden, 64*73471bf0Spatrick cl::desc("Average loop iteration count cost"), 65*73471bf0Spatrick cl::init(10)); 66*73471bf0Spatrick 67*73471bf0Spatrick static cl::opt<bool> EnableSpecializationForLiteralConstant( 68*73471bf0Spatrick "function-specialization-for-literal-constant", cl::init(false), cl::Hidden, 69*73471bf0Spatrick cl::desc("Make function specialization available for literal constant.")); 70*73471bf0Spatrick 71*73471bf0Spatrick // Helper to check if \p LV is either overdefined or a constant int. 72*73471bf0Spatrick static bool isOverdefined(const ValueLatticeElement &LV) { 73*73471bf0Spatrick return !LV.isUnknownOrUndef() && !LV.isConstant(); 74*73471bf0Spatrick } 75*73471bf0Spatrick 76*73471bf0Spatrick class FunctionSpecializer { 77*73471bf0Spatrick 78*73471bf0Spatrick /// The IPSCCP Solver. 79*73471bf0Spatrick SCCPSolver &Solver; 80*73471bf0Spatrick 81*73471bf0Spatrick /// Analyses used to help determine if a function should be specialized. 82*73471bf0Spatrick std::function<AssumptionCache &(Function &)> GetAC; 83*73471bf0Spatrick std::function<TargetTransformInfo &(Function &)> GetTTI; 84*73471bf0Spatrick std::function<TargetLibraryInfo &(Function &)> GetTLI; 85*73471bf0Spatrick 86*73471bf0Spatrick SmallPtrSet<Function *, 2> SpecializedFuncs; 87*73471bf0Spatrick 88*73471bf0Spatrick public: 89*73471bf0Spatrick FunctionSpecializer(SCCPSolver &Solver, 90*73471bf0Spatrick std::function<AssumptionCache &(Function &)> GetAC, 91*73471bf0Spatrick std::function<TargetTransformInfo &(Function &)> GetTTI, 92*73471bf0Spatrick std::function<TargetLibraryInfo &(Function &)> GetTLI) 93*73471bf0Spatrick : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {} 94*73471bf0Spatrick 95*73471bf0Spatrick /// Attempt to specialize functions in the module to enable constant 96*73471bf0Spatrick /// propagation across function boundaries. 97*73471bf0Spatrick /// 98*73471bf0Spatrick /// \returns true if at least one function is specialized. 99*73471bf0Spatrick bool 100*73471bf0Spatrick specializeFunctions(SmallVectorImpl<Function *> &FuncDecls, 101*73471bf0Spatrick SmallVectorImpl<Function *> &CurrentSpecializations) { 102*73471bf0Spatrick 103*73471bf0Spatrick // Attempt to specialize the argument-tracked functions. 104*73471bf0Spatrick bool Changed = false; 105*73471bf0Spatrick for (auto *F : FuncDecls) { 106*73471bf0Spatrick if (specializeFunction(F, CurrentSpecializations)) { 107*73471bf0Spatrick Changed = true; 108*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Can specialize this func.\n"); 109*73471bf0Spatrick } else { 110*73471bf0Spatrick LLVM_DEBUG( 111*73471bf0Spatrick dbgs() << "FnSpecialization: Cannot specialize this func.\n"); 112*73471bf0Spatrick } 113*73471bf0Spatrick } 114*73471bf0Spatrick 115*73471bf0Spatrick for (auto *SpecializedFunc : CurrentSpecializations) { 116*73471bf0Spatrick SpecializedFuncs.insert(SpecializedFunc); 117*73471bf0Spatrick 118*73471bf0Spatrick // TODO: If we want to support specializing specialized functions, 119*73471bf0Spatrick // initialize here the state of the newly created functions, marking 120*73471bf0Spatrick // them argument-tracked and executable. 121*73471bf0Spatrick 122*73471bf0Spatrick // Replace the function arguments for the specialized functions. 123*73471bf0Spatrick for (Argument &Arg : SpecializedFunc->args()) 124*73471bf0Spatrick if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg)) 125*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: " 126*73471bf0Spatrick << Arg.getName() << "\n"); 127*73471bf0Spatrick } 128*73471bf0Spatrick 129*73471bf0Spatrick NumFuncSpecialized += NbFunctionsSpecialized; 130*73471bf0Spatrick return Changed; 131*73471bf0Spatrick } 132*73471bf0Spatrick 133*73471bf0Spatrick bool tryToReplaceWithConstant(Value *V) { 134*73471bf0Spatrick if (!V->getType()->isSingleValueType() || isa<CallBase>(V) || 135*73471bf0Spatrick V->user_empty()) 136*73471bf0Spatrick return false; 137*73471bf0Spatrick 138*73471bf0Spatrick const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); 139*73471bf0Spatrick if (isOverdefined(IV)) 140*73471bf0Spatrick return false; 141*73471bf0Spatrick auto *Const = IV.isConstant() ? Solver.getConstant(IV) 142*73471bf0Spatrick : UndefValue::get(V->getType()); 143*73471bf0Spatrick V->replaceAllUsesWith(Const); 144*73471bf0Spatrick 145*73471bf0Spatrick // TODO: Update the solver here if we want to specialize specialized 146*73471bf0Spatrick // functions. 147*73471bf0Spatrick return true; 148*73471bf0Spatrick } 149*73471bf0Spatrick 150*73471bf0Spatrick private: 151*73471bf0Spatrick // The number of functions specialised, used for collecting statistics and 152*73471bf0Spatrick // also in the cost model. 153*73471bf0Spatrick unsigned NbFunctionsSpecialized = 0; 154*73471bf0Spatrick 155*73471bf0Spatrick /// This function decides whether to specialize function \p F based on the 156*73471bf0Spatrick /// known constant values its arguments can take on. Specialization is 157*73471bf0Spatrick /// performed on the first interesting argument. Specializations based on 158*73471bf0Spatrick /// additional arguments will be evaluated on following iterations of the 159*73471bf0Spatrick /// main IPSCCP solve loop. \returns true if the function is specialized and 160*73471bf0Spatrick /// false otherwise. 161*73471bf0Spatrick bool specializeFunction(Function *F, 162*73471bf0Spatrick SmallVectorImpl<Function *> &Specializations) { 163*73471bf0Spatrick 164*73471bf0Spatrick // Do not specialize the cloned function again. 165*73471bf0Spatrick if (SpecializedFuncs.contains(F)) { 166*73471bf0Spatrick return false; 167*73471bf0Spatrick } 168*73471bf0Spatrick 169*73471bf0Spatrick // If we're optimizing the function for size, we shouldn't specialize it. 170*73471bf0Spatrick if (F->hasOptSize() || 171*73471bf0Spatrick shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass)) 172*73471bf0Spatrick return false; 173*73471bf0Spatrick 174*73471bf0Spatrick // Exit if the function is not executable. There's no point in specializing 175*73471bf0Spatrick // a dead function. 176*73471bf0Spatrick if (!Solver.isBlockExecutable(&F->getEntryBlock())) 177*73471bf0Spatrick return false; 178*73471bf0Spatrick 179*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName() 180*73471bf0Spatrick << "\n"); 181*73471bf0Spatrick // Determine if we should specialize the function based on the values the 182*73471bf0Spatrick // argument can take on. If specialization is not profitable, we continue 183*73471bf0Spatrick // on to the next argument. 184*73471bf0Spatrick for (Argument &A : F->args()) { 185*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing arg: " << A.getName() 186*73471bf0Spatrick << "\n"); 187*73471bf0Spatrick // True if this will be a partial specialization. We will need to keep 188*73471bf0Spatrick // the original function around in addition to the added specializations. 189*73471bf0Spatrick bool IsPartial = true; 190*73471bf0Spatrick 191*73471bf0Spatrick // Determine if this argument is interesting. If we know the argument can 192*73471bf0Spatrick // take on any constant values, they are collected in Constants. If the 193*73471bf0Spatrick // argument can only ever equal a constant value in Constants, the 194*73471bf0Spatrick // function will be completely specialized, and the IsPartial flag will 195*73471bf0Spatrick // be set to false by isArgumentInteresting (that function only adds 196*73471bf0Spatrick // values to the Constants list that are deemed profitable). 197*73471bf0Spatrick SmallVector<Constant *, 4> Constants; 198*73471bf0Spatrick if (!isArgumentInteresting(&A, Constants, IsPartial)) { 199*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is not interesting\n"); 200*73471bf0Spatrick continue; 201*73471bf0Spatrick } 202*73471bf0Spatrick 203*73471bf0Spatrick assert(!Constants.empty() && "No constants on which to specialize"); 204*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is interesting!\n" 205*73471bf0Spatrick << "FnSpecialization: Specializing '" << F->getName() 206*73471bf0Spatrick << "' on argument: " << A << "\n" 207*73471bf0Spatrick << "FnSpecialization: Constants are:\n\n"; 208*73471bf0Spatrick for (unsigned I = 0; I < Constants.size(); ++I) dbgs() 209*73471bf0Spatrick << *Constants[I] << "\n"; 210*73471bf0Spatrick dbgs() << "FnSpecialization: End of constants\n\n"); 211*73471bf0Spatrick 212*73471bf0Spatrick // Create a version of the function in which the argument is marked 213*73471bf0Spatrick // constant with the given value. 214*73471bf0Spatrick for (auto *C : Constants) { 215*73471bf0Spatrick // Clone the function. We leave the ValueToValueMap empty to allow 216*73471bf0Spatrick // IPSCCP to propagate the constant arguments. 217*73471bf0Spatrick ValueToValueMapTy EmptyMap; 218*73471bf0Spatrick Function *Clone = CloneFunction(F, EmptyMap); 219*73471bf0Spatrick Argument *ClonedArg = Clone->arg_begin() + A.getArgNo(); 220*73471bf0Spatrick 221*73471bf0Spatrick // Rewrite calls to the function so that they call the clone instead. 222*73471bf0Spatrick rewriteCallSites(F, Clone, *ClonedArg, C); 223*73471bf0Spatrick 224*73471bf0Spatrick // Initialize the lattice state of the arguments of the function clone, 225*73471bf0Spatrick // marking the argument on which we specialized the function constant 226*73471bf0Spatrick // with the given value. 227*73471bf0Spatrick Solver.markArgInFuncSpecialization(F, ClonedArg, C); 228*73471bf0Spatrick 229*73471bf0Spatrick // Mark all the specialized functions 230*73471bf0Spatrick Specializations.push_back(Clone); 231*73471bf0Spatrick NbFunctionsSpecialized++; 232*73471bf0Spatrick } 233*73471bf0Spatrick 234*73471bf0Spatrick // TODO: if we want to support specialize specialized functions, and if 235*73471bf0Spatrick // the function has been completely specialized, the original function is 236*73471bf0Spatrick // no longer needed, so we would need to mark it unreachable here. 237*73471bf0Spatrick 238*73471bf0Spatrick // FIXME: Only one argument per function. 239*73471bf0Spatrick return true; 240*73471bf0Spatrick } 241*73471bf0Spatrick 242*73471bf0Spatrick return false; 243*73471bf0Spatrick } 244*73471bf0Spatrick 245*73471bf0Spatrick /// Compute the cost of specializing function \p F. 246*73471bf0Spatrick InstructionCost getSpecializationCost(Function *F) { 247*73471bf0Spatrick // Compute the code metrics for the function. 248*73471bf0Spatrick SmallPtrSet<const Value *, 32> EphValues; 249*73471bf0Spatrick CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); 250*73471bf0Spatrick CodeMetrics Metrics; 251*73471bf0Spatrick for (BasicBlock &BB : *F) 252*73471bf0Spatrick Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); 253*73471bf0Spatrick 254*73471bf0Spatrick // If the code metrics reveal that we shouldn't duplicate the function, we 255*73471bf0Spatrick // shouldn't specialize it. Set the specialization cost to Invalid. 256*73471bf0Spatrick if (Metrics.notDuplicatable) { 257*73471bf0Spatrick InstructionCost C{}; 258*73471bf0Spatrick C.setInvalid(); 259*73471bf0Spatrick return C; 260*73471bf0Spatrick } 261*73471bf0Spatrick 262*73471bf0Spatrick // Otherwise, set the specialization cost to be the cost of all the 263*73471bf0Spatrick // instructions in the function and penalty for specializing more functions. 264*73471bf0Spatrick unsigned Penalty = NbFunctionsSpecialized + 1; 265*73471bf0Spatrick return Metrics.NumInsts * InlineConstants::InstrCost * Penalty; 266*73471bf0Spatrick } 267*73471bf0Spatrick 268*73471bf0Spatrick InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI, 269*73471bf0Spatrick LoopInfo &LI) { 270*73471bf0Spatrick auto *I = dyn_cast_or_null<Instruction>(U); 271*73471bf0Spatrick // If not an instruction we do not know how to evaluate. 272*73471bf0Spatrick // Keep minimum possible cost for now so that it doesnt affect 273*73471bf0Spatrick // specialization. 274*73471bf0Spatrick if (!I) 275*73471bf0Spatrick return std::numeric_limits<unsigned>::min(); 276*73471bf0Spatrick 277*73471bf0Spatrick auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency); 278*73471bf0Spatrick 279*73471bf0Spatrick // Traverse recursively if there are more uses. 280*73471bf0Spatrick // TODO: Any other instructions to be added here? 281*73471bf0Spatrick if (I->mayReadFromMemory() || I->isCast()) 282*73471bf0Spatrick for (auto *User : I->users()) 283*73471bf0Spatrick Cost += getUserBonus(User, TTI, LI); 284*73471bf0Spatrick 285*73471bf0Spatrick // Increase the cost if it is inside the loop. 286*73471bf0Spatrick auto LoopDepth = LI.getLoopDepth(I->getParent()); 287*73471bf0Spatrick Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth); 288*73471bf0Spatrick return Cost; 289*73471bf0Spatrick } 290*73471bf0Spatrick 291*73471bf0Spatrick /// Compute a bonus for replacing argument \p A with constant \p C. 292*73471bf0Spatrick InstructionCost getSpecializationBonus(Argument *A, Constant *C) { 293*73471bf0Spatrick Function *F = A->getParent(); 294*73471bf0Spatrick DominatorTree DT(*F); 295*73471bf0Spatrick LoopInfo LI(DT); 296*73471bf0Spatrick auto &TTI = (GetTTI)(*F); 297*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for: " << *A 298*73471bf0Spatrick << "\n"); 299*73471bf0Spatrick 300*73471bf0Spatrick InstructionCost TotalCost = 0; 301*73471bf0Spatrick for (auto *U : A->users()) { 302*73471bf0Spatrick TotalCost += getUserBonus(U, TTI, LI); 303*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; 304*73471bf0Spatrick TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); 305*73471bf0Spatrick } 306*73471bf0Spatrick 307*73471bf0Spatrick // The below heuristic is only concerned with exposing inlining 308*73471bf0Spatrick // opportunities via indirect call promotion. If the argument is not a 309*73471bf0Spatrick // function pointer, give up. 310*73471bf0Spatrick if (!isa<PointerType>(A->getType()) || 311*73471bf0Spatrick !isa<FunctionType>(A->getType()->getPointerElementType())) 312*73471bf0Spatrick return TotalCost; 313*73471bf0Spatrick 314*73471bf0Spatrick // Since the argument is a function pointer, its incoming constant values 315*73471bf0Spatrick // should be functions or constant expressions. The code below attempts to 316*73471bf0Spatrick // look through cast expressions to find the function that will be called. 317*73471bf0Spatrick Value *CalledValue = C; 318*73471bf0Spatrick while (isa<ConstantExpr>(CalledValue) && 319*73471bf0Spatrick cast<ConstantExpr>(CalledValue)->isCast()) 320*73471bf0Spatrick CalledValue = cast<User>(CalledValue)->getOperand(0); 321*73471bf0Spatrick Function *CalledFunction = dyn_cast<Function>(CalledValue); 322*73471bf0Spatrick if (!CalledFunction) 323*73471bf0Spatrick return TotalCost; 324*73471bf0Spatrick 325*73471bf0Spatrick // Get TTI for the called function (used for the inline cost). 326*73471bf0Spatrick auto &CalleeTTI = (GetTTI)(*CalledFunction); 327*73471bf0Spatrick 328*73471bf0Spatrick // Look at all the call sites whose called value is the argument. 329*73471bf0Spatrick // Specializing the function on the argument would allow these indirect 330*73471bf0Spatrick // calls to be promoted to direct calls. If the indirect call promotion 331*73471bf0Spatrick // would likely enable the called function to be inlined, specializing is a 332*73471bf0Spatrick // good idea. 333*73471bf0Spatrick int Bonus = 0; 334*73471bf0Spatrick for (User *U : A->users()) { 335*73471bf0Spatrick if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) 336*73471bf0Spatrick continue; 337*73471bf0Spatrick auto *CS = cast<CallBase>(U); 338*73471bf0Spatrick if (CS->getCalledOperand() != A) 339*73471bf0Spatrick continue; 340*73471bf0Spatrick 341*73471bf0Spatrick // Get the cost of inlining the called function at this call site. Note 342*73471bf0Spatrick // that this is only an estimate. The called function may eventually 343*73471bf0Spatrick // change in a way that leads to it not being inlined here, even though 344*73471bf0Spatrick // inlining looks profitable now. For example, one of its called 345*73471bf0Spatrick // functions may be inlined into it, making the called function too large 346*73471bf0Spatrick // to be inlined into this call site. 347*73471bf0Spatrick // 348*73471bf0Spatrick // We apply a boost for performing indirect call promotion by increasing 349*73471bf0Spatrick // the default threshold by the threshold for indirect calls. 350*73471bf0Spatrick auto Params = getInlineParams(); 351*73471bf0Spatrick Params.DefaultThreshold += InlineConstants::IndirectCallThreshold; 352*73471bf0Spatrick InlineCost IC = 353*73471bf0Spatrick getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI); 354*73471bf0Spatrick 355*73471bf0Spatrick // We clamp the bonus for this call to be between zero and the default 356*73471bf0Spatrick // threshold. 357*73471bf0Spatrick if (IC.isAlways()) 358*73471bf0Spatrick Bonus += Params.DefaultThreshold; 359*73471bf0Spatrick else if (IC.isVariable() && IC.getCostDelta() > 0) 360*73471bf0Spatrick Bonus += IC.getCostDelta(); 361*73471bf0Spatrick } 362*73471bf0Spatrick 363*73471bf0Spatrick return TotalCost + Bonus; 364*73471bf0Spatrick } 365*73471bf0Spatrick 366*73471bf0Spatrick /// Determine if we should specialize a function based on the incoming values 367*73471bf0Spatrick /// of the given argument. 368*73471bf0Spatrick /// 369*73471bf0Spatrick /// This function implements the goal-directed heuristic. It determines if 370*73471bf0Spatrick /// specializing the function based on the incoming values of argument \p A 371*73471bf0Spatrick /// would result in any significant optimization opportunities. If 372*73471bf0Spatrick /// optimization opportunities exist, the constant values of \p A on which to 373*73471bf0Spatrick /// specialize the function are collected in \p Constants. If the values in 374*73471bf0Spatrick /// \p Constants represent the complete set of values that \p A can take on, 375*73471bf0Spatrick /// the function will be completely specialized, and the \p IsPartial flag is 376*73471bf0Spatrick /// set to false. 377*73471bf0Spatrick /// 378*73471bf0Spatrick /// \returns true if the function should be specialized on the given 379*73471bf0Spatrick /// argument. 380*73471bf0Spatrick bool isArgumentInteresting(Argument *A, 381*73471bf0Spatrick SmallVectorImpl<Constant *> &Constants, 382*73471bf0Spatrick bool &IsPartial) { 383*73471bf0Spatrick Function *F = A->getParent(); 384*73471bf0Spatrick 385*73471bf0Spatrick // For now, don't attempt to specialize functions based on the values of 386*73471bf0Spatrick // composite types. 387*73471bf0Spatrick if (!A->getType()->isSingleValueType() || A->user_empty()) 388*73471bf0Spatrick return false; 389*73471bf0Spatrick 390*73471bf0Spatrick // If the argument isn't overdefined, there's nothing to do. It should 391*73471bf0Spatrick // already be constant. 392*73471bf0Spatrick if (!Solver.getLatticeValueFor(A).isOverdefined()) { 393*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: nothing to do, arg is already " 394*73471bf0Spatrick << "constant?\n"); 395*73471bf0Spatrick return false; 396*73471bf0Spatrick } 397*73471bf0Spatrick 398*73471bf0Spatrick // Collect the constant values that the argument can take on. If the 399*73471bf0Spatrick // argument can't take on any constant values, we aren't going to 400*73471bf0Spatrick // specialize the function. While it's possible to specialize the function 401*73471bf0Spatrick // based on non-constant arguments, there's likely not much benefit to 402*73471bf0Spatrick // constant propagation in doing so. 403*73471bf0Spatrick // 404*73471bf0Spatrick // TODO 1: currently it won't specialize if there are over the threshold of 405*73471bf0Spatrick // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it 406*73471bf0Spatrick // might be beneficial to take the occurrences into account in the cost 407*73471bf0Spatrick // model, so we would need to find the unique constants. 408*73471bf0Spatrick // 409*73471bf0Spatrick // TODO 2: this currently does not support constants, i.e. integer ranges. 410*73471bf0Spatrick // 411*73471bf0Spatrick SmallVector<Constant *, 4> PossibleConstants; 412*73471bf0Spatrick bool AllConstant = getPossibleConstants(A, PossibleConstants); 413*73471bf0Spatrick if (PossibleConstants.empty()) { 414*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n"); 415*73471bf0Spatrick return false; 416*73471bf0Spatrick } 417*73471bf0Spatrick if (PossibleConstants.size() > MaxConstantsThreshold) { 418*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: number of constants found exceed " 419*73471bf0Spatrick << "the maximum number of constants threshold.\n"); 420*73471bf0Spatrick return false; 421*73471bf0Spatrick } 422*73471bf0Spatrick 423*73471bf0Spatrick // Determine if it would be profitable to create a specialization of the 424*73471bf0Spatrick // function where the argument takes on the given constant value. If so, 425*73471bf0Spatrick // add the constant to Constants. 426*73471bf0Spatrick auto FnSpecCost = getSpecializationCost(F); 427*73471bf0Spatrick if (!FnSpecCost.isValid()) { 428*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Invalid specialisation cost.\n"); 429*73471bf0Spatrick return false; 430*73471bf0Spatrick } 431*73471bf0Spatrick 432*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: func specialisation cost: "; 433*73471bf0Spatrick FnSpecCost.print(dbgs()); dbgs() << "\n"); 434*73471bf0Spatrick 435*73471bf0Spatrick for (auto *C : PossibleConstants) { 436*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Constant: " << *C << "\n"); 437*73471bf0Spatrick if (ForceFunctionSpecialization) { 438*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Forced!\n"); 439*73471bf0Spatrick Constants.push_back(C); 440*73471bf0Spatrick continue; 441*73471bf0Spatrick } 442*73471bf0Spatrick if (getSpecializationBonus(A, C) > FnSpecCost) { 443*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: profitable!\n"); 444*73471bf0Spatrick Constants.push_back(C); 445*73471bf0Spatrick } else { 446*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: not profitable\n"); 447*73471bf0Spatrick } 448*73471bf0Spatrick } 449*73471bf0Spatrick 450*73471bf0Spatrick // None of the constant values the argument can take on were deemed good 451*73471bf0Spatrick // candidates on which to specialize the function. 452*73471bf0Spatrick if (Constants.empty()) 453*73471bf0Spatrick return false; 454*73471bf0Spatrick 455*73471bf0Spatrick // This will be a partial specialization if some of the constants were 456*73471bf0Spatrick // rejected due to their profitability. 457*73471bf0Spatrick IsPartial = !AllConstant || PossibleConstants.size() != Constants.size(); 458*73471bf0Spatrick 459*73471bf0Spatrick return true; 460*73471bf0Spatrick } 461*73471bf0Spatrick 462*73471bf0Spatrick /// Collect in \p Constants all the constant values that argument \p A can 463*73471bf0Spatrick /// take on. 464*73471bf0Spatrick /// 465*73471bf0Spatrick /// \returns true if all of the values the argument can take on are constant 466*73471bf0Spatrick /// (e.g., the argument's parent function cannot be called with an 467*73471bf0Spatrick /// overdefined value). 468*73471bf0Spatrick bool getPossibleConstants(Argument *A, 469*73471bf0Spatrick SmallVectorImpl<Constant *> &Constants) { 470*73471bf0Spatrick Function *F = A->getParent(); 471*73471bf0Spatrick bool AllConstant = true; 472*73471bf0Spatrick 473*73471bf0Spatrick // Iterate over all the call sites of the argument's parent function. 474*73471bf0Spatrick for (User *U : F->users()) { 475*73471bf0Spatrick if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) 476*73471bf0Spatrick continue; 477*73471bf0Spatrick auto &CS = *cast<CallBase>(U); 478*73471bf0Spatrick 479*73471bf0Spatrick // If the parent of the call site will never be executed, we don't need 480*73471bf0Spatrick // to worry about the passed value. 481*73471bf0Spatrick if (!Solver.isBlockExecutable(CS.getParent())) 482*73471bf0Spatrick continue; 483*73471bf0Spatrick 484*73471bf0Spatrick auto *V = CS.getArgOperand(A->getArgNo()); 485*73471bf0Spatrick // TrackValueOfGlobalVariable only tracks scalar global variables. 486*73471bf0Spatrick if (auto *GV = dyn_cast<GlobalVariable>(V)) { 487*73471bf0Spatrick if (!GV->getValueType()->isSingleValueType()) { 488*73471bf0Spatrick return false; 489*73471bf0Spatrick } 490*73471bf0Spatrick } 491*73471bf0Spatrick 492*73471bf0Spatrick if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() || 493*73471bf0Spatrick EnableSpecializationForLiteralConstant)) 494*73471bf0Spatrick Constants.push_back(cast<Constant>(V)); 495*73471bf0Spatrick else 496*73471bf0Spatrick AllConstant = false; 497*73471bf0Spatrick } 498*73471bf0Spatrick 499*73471bf0Spatrick // If the argument can only take on constant values, AllConstant will be 500*73471bf0Spatrick // true. 501*73471bf0Spatrick return AllConstant; 502*73471bf0Spatrick } 503*73471bf0Spatrick 504*73471bf0Spatrick /// Rewrite calls to function \p F to call function \p Clone instead. 505*73471bf0Spatrick /// 506*73471bf0Spatrick /// This function modifies calls to function \p F whose argument at index \p 507*73471bf0Spatrick /// ArgNo is equal to constant \p C. The calls are rewritten to call function 508*73471bf0Spatrick /// \p Clone instead. 509*73471bf0Spatrick void rewriteCallSites(Function *F, Function *Clone, Argument &Arg, 510*73471bf0Spatrick Constant *C) { 511*73471bf0Spatrick unsigned ArgNo = Arg.getArgNo(); 512*73471bf0Spatrick SmallVector<CallBase *, 4> CallSitesToRewrite; 513*73471bf0Spatrick for (auto *U : F->users()) { 514*73471bf0Spatrick if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) 515*73471bf0Spatrick continue; 516*73471bf0Spatrick auto &CS = *cast<CallBase>(U); 517*73471bf0Spatrick if (!CS.getCalledFunction() || CS.getCalledFunction() != F) 518*73471bf0Spatrick continue; 519*73471bf0Spatrick CallSitesToRewrite.push_back(&CS); 520*73471bf0Spatrick } 521*73471bf0Spatrick for (auto *CS : CallSitesToRewrite) { 522*73471bf0Spatrick if ((CS->getFunction() == Clone && CS->getArgOperand(ArgNo) == &Arg) || 523*73471bf0Spatrick CS->getArgOperand(ArgNo) == C) { 524*73471bf0Spatrick CS->setCalledFunction(Clone); 525*73471bf0Spatrick Solver.markOverdefined(CS); 526*73471bf0Spatrick } 527*73471bf0Spatrick } 528*73471bf0Spatrick } 529*73471bf0Spatrick }; 530*73471bf0Spatrick 531*73471bf0Spatrick /// Function to clean up the left over intrinsics from SCCP util. 532*73471bf0Spatrick static void cleanup(Module &M) { 533*73471bf0Spatrick for (Function &F : M) { 534*73471bf0Spatrick for (BasicBlock &BB : F) { 535*73471bf0Spatrick for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { 536*73471bf0Spatrick Instruction *Inst = &*BI++; 537*73471bf0Spatrick if (auto *II = dyn_cast<IntrinsicInst>(Inst)) { 538*73471bf0Spatrick if (II->getIntrinsicID() == Intrinsic::ssa_copy) { 539*73471bf0Spatrick Value *Op = II->getOperand(0); 540*73471bf0Spatrick Inst->replaceAllUsesWith(Op); 541*73471bf0Spatrick Inst->eraseFromParent(); 542*73471bf0Spatrick } 543*73471bf0Spatrick } 544*73471bf0Spatrick } 545*73471bf0Spatrick } 546*73471bf0Spatrick } 547*73471bf0Spatrick } 548*73471bf0Spatrick 549*73471bf0Spatrick bool llvm::runFunctionSpecialization( 550*73471bf0Spatrick Module &M, const DataLayout &DL, 551*73471bf0Spatrick std::function<TargetLibraryInfo &(Function &)> GetTLI, 552*73471bf0Spatrick std::function<TargetTransformInfo &(Function &)> GetTTI, 553*73471bf0Spatrick std::function<AssumptionCache &(Function &)> GetAC, 554*73471bf0Spatrick function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) { 555*73471bf0Spatrick SCCPSolver Solver(DL, GetTLI, M.getContext()); 556*73471bf0Spatrick FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI); 557*73471bf0Spatrick bool Changed = false; 558*73471bf0Spatrick 559*73471bf0Spatrick // Loop over all functions, marking arguments to those with their addresses 560*73471bf0Spatrick // taken or that are external as overdefined. 561*73471bf0Spatrick for (Function &F : M) { 562*73471bf0Spatrick if (F.isDeclaration()) 563*73471bf0Spatrick continue; 564*73471bf0Spatrick if (F.hasFnAttribute(Attribute::NoDuplicate)) 565*73471bf0Spatrick continue; 566*73471bf0Spatrick 567*73471bf0Spatrick LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName() 568*73471bf0Spatrick << "\n"); 569*73471bf0Spatrick Solver.addAnalysis(F, GetAnalysis(F)); 570*73471bf0Spatrick 571*73471bf0Spatrick // Determine if we can track the function's arguments. If so, add the 572*73471bf0Spatrick // function to the solver's set of argument-tracked functions. 573*73471bf0Spatrick if (canTrackArgumentsInterprocedurally(&F)) { 574*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n"); 575*73471bf0Spatrick Solver.addArgumentTrackedFunction(&F); 576*73471bf0Spatrick continue; 577*73471bf0Spatrick } else { 578*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n" 579*73471bf0Spatrick << "FnSpecialization: Doesn't have local linkage, or " 580*73471bf0Spatrick << "has its address taken\n"); 581*73471bf0Spatrick } 582*73471bf0Spatrick 583*73471bf0Spatrick // Assume the function is called. 584*73471bf0Spatrick Solver.markBlockExecutable(&F.front()); 585*73471bf0Spatrick 586*73471bf0Spatrick // Assume nothing about the incoming arguments. 587*73471bf0Spatrick for (Argument &AI : F.args()) 588*73471bf0Spatrick Solver.markOverdefined(&AI); 589*73471bf0Spatrick } 590*73471bf0Spatrick 591*73471bf0Spatrick // Determine if we can track any of the module's global variables. If so, add 592*73471bf0Spatrick // the global variables we can track to the solver's set of tracked global 593*73471bf0Spatrick // variables. 594*73471bf0Spatrick for (GlobalVariable &G : M.globals()) { 595*73471bf0Spatrick G.removeDeadConstantUsers(); 596*73471bf0Spatrick if (canTrackGlobalVariableInterprocedurally(&G)) 597*73471bf0Spatrick Solver.trackValueOfGlobalVariable(&G); 598*73471bf0Spatrick } 599*73471bf0Spatrick 600*73471bf0Spatrick // Solve for constants. 601*73471bf0Spatrick auto RunSCCPSolver = [&](auto &WorkList) { 602*73471bf0Spatrick bool ResolvedUndefs = true; 603*73471bf0Spatrick 604*73471bf0Spatrick while (ResolvedUndefs) { 605*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n"); 606*73471bf0Spatrick Solver.solve(); 607*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n"); 608*73471bf0Spatrick ResolvedUndefs = false; 609*73471bf0Spatrick for (Function *F : WorkList) 610*73471bf0Spatrick if (Solver.resolvedUndefsIn(*F)) 611*73471bf0Spatrick ResolvedUndefs = true; 612*73471bf0Spatrick } 613*73471bf0Spatrick 614*73471bf0Spatrick for (auto *F : WorkList) { 615*73471bf0Spatrick for (BasicBlock &BB : *F) { 616*73471bf0Spatrick if (!Solver.isBlockExecutable(&BB)) 617*73471bf0Spatrick continue; 618*73471bf0Spatrick for (auto &I : make_early_inc_range(BB)) 619*73471bf0Spatrick FS.tryToReplaceWithConstant(&I); 620*73471bf0Spatrick } 621*73471bf0Spatrick } 622*73471bf0Spatrick }; 623*73471bf0Spatrick 624*73471bf0Spatrick auto &TrackedFuncs = Solver.getArgumentTrackedFunctions(); 625*73471bf0Spatrick SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(), 626*73471bf0Spatrick TrackedFuncs.end()); 627*73471bf0Spatrick #ifndef NDEBUG 628*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n"); 629*73471bf0Spatrick for (auto *F : FuncDecls) 630*73471bf0Spatrick LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n"); 631*73471bf0Spatrick #endif 632*73471bf0Spatrick 633*73471bf0Spatrick // Initially resolve the constants in all the argument tracked functions. 634*73471bf0Spatrick RunSCCPSolver(FuncDecls); 635*73471bf0Spatrick 636*73471bf0Spatrick SmallVector<Function *, 2> CurrentSpecializations; 637*73471bf0Spatrick unsigned I = 0; 638*73471bf0Spatrick while (FuncSpecializationMaxIters != I++ && 639*73471bf0Spatrick FS.specializeFunctions(FuncDecls, CurrentSpecializations)) { 640*73471bf0Spatrick // TODO: run the solver here for the specialized functions only if we want 641*73471bf0Spatrick // to specialize recursively. 642*73471bf0Spatrick 643*73471bf0Spatrick CurrentSpecializations.clear(); 644*73471bf0Spatrick Changed = true; 645*73471bf0Spatrick } 646*73471bf0Spatrick 647*73471bf0Spatrick // Clean up the IR by removing ssa_copy intrinsics. 648*73471bf0Spatrick cleanup(M); 649*73471bf0Spatrick return Changed; 650*73471bf0Spatrick } 651