xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp (revision 1c025fb02d0fa15b76ca816d8414d532a687ebeb)
1 //===- AMDGPUSplitModule.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Implements a module splitting algorithm designed to support the
10 /// FullLTO --lto-partitions option for parallel codegen. This is completely
11 /// different from the common SplitModule pass, as this system is designed with
12 /// AMDGPU in mind.
13 ///
14 /// The basic idea of this module splitting implementation is the same as
15 /// SplitModule: load-balance the module's functions across a set of N
16 /// partitions to allow parallel codegen. However, it does it very
17 /// differently than the target-agnostic variant:
18 ///   - The module has "split roots", which are kernels in the vast
19 //      majority of cases.
20 ///   - Each root has a set of dependencies, and when a root and its
21 ///     dependencies is considered "big", we try to put it in a partition where
22 ///     most dependencies are already imported, to avoid duplicating large
23 ///     amounts of code.
24 ///   - There's special care for indirect calls in order to ensure
25 ///     AMDGPUResourceUsageAnalysis can work correctly.
26 ///
27 /// This file also includes a more elaborate logging system to enable
28 /// users to easily generate logs that (if desired) do not include any value
29 /// names, in order to not leak information about the source file.
30 /// Such logs are very helpful to understand and fix potential issues with
31 /// module splitting.
32 
33 #include "AMDGPUSplitModule.h"
34 #include "AMDGPUTargetMachine.h"
35 #include "Utils/AMDGPUBaseInfo.h"
36 #include "llvm/ADT/DenseMap.h"
37 #include "llvm/ADT/SmallVector.h"
38 #include "llvm/ADT/StringExtras.h"
39 #include "llvm/ADT/StringRef.h"
40 #include "llvm/Analysis/CallGraph.h"
41 #include "llvm/Analysis/TargetTransformInfo.h"
42 #include "llvm/IR/Function.h"
43 #include "llvm/IR/Instruction.h"
44 #include "llvm/IR/Module.h"
45 #include "llvm/IR/User.h"
46 #include "llvm/IR/Value.h"
47 #include "llvm/Support/Casting.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/FileSystem.h"
50 #include "llvm/Support/Path.h"
51 #include "llvm/Support/Process.h"
52 #include "llvm/Support/SHA256.h"
53 #include "llvm/Support/Threading.h"
54 #include "llvm/Support/raw_ostream.h"
55 #include "llvm/Transforms/Utils/Cloning.h"
56 #include <algorithm>
57 #include <cassert>
58 #include <iterator>
59 #include <memory>
60 #include <utility>
61 #include <vector>
62 
63 using namespace llvm;
64 
65 #define DEBUG_TYPE "amdgpu-split-module"
66 
67 namespace {
68 
69 static cl::opt<float> LargeFnFactor(
70     "amdgpu-module-splitting-large-function-threshold", cl::init(2.0f),
71     cl::Hidden,
72     cl::desc(
73         "consider a function as large and needing special treatment when the "
74         "cost of importing it into a partition"
75         "exceeds the average cost of a partition by this factor; e;g. 2.0 "
76         "means if the function and its dependencies is 2 times bigger than "
77         "an average partition; 0 disables large functions handling entirely"));
78 
79 static cl::opt<float> LargeFnOverlapForMerge(
80     "amdgpu-module-splitting-large-function-merge-overlap", cl::init(0.8f),
81     cl::Hidden,
82     cl::desc(
83         "defines how much overlap between two large function's dependencies "
84         "is needed to put them in the same partition"));
85 
86 static cl::opt<bool> NoExternalizeGlobals(
87     "amdgpu-module-splitting-no-externalize-globals", cl::Hidden,
88     cl::desc("disables externalization of global variable with local linkage; "
89              "may cause globals to be duplicated which increases binary size"));
90 
91 static cl::opt<std::string>
92     LogDirOpt("amdgpu-module-splitting-log-dir", cl::Hidden,
93               cl::desc("output directory for AMDGPU module splitting logs"));
94 
95 static cl::opt<bool>
96     LogPrivate("amdgpu-module-splitting-log-private", cl::Hidden,
97                cl::desc("hash value names before printing them in the AMDGPU "
98                         "module splitting logs"));
99 
100 using CostType = InstructionCost::CostType;
101 using PartitionID = unsigned;
102 using GetTTIFn = function_ref<const TargetTransformInfo &(Function &)>;
103 
104 static bool isEntryPoint(const Function *F) {
105   return AMDGPU::isEntryFunctionCC(F->getCallingConv());
106 }
107 
108 static std::string getName(const Value &V) {
109   static bool HideNames;
110 
111   static llvm::once_flag HideNameInitFlag;
112   llvm::call_once(HideNameInitFlag, [&]() {
113     if (LogPrivate.getNumOccurrences())
114       HideNames = LogPrivate;
115     else {
116       const auto EV = sys::Process::GetEnv("AMD_SPLIT_MODULE_LOG_PRIVATE");
117       HideNames = (EV.value_or("0") != "0");
118     }
119   });
120 
121   if (!HideNames)
122     return V.getName().str();
123   return toHex(SHA256::hash(arrayRefFromStringRef(V.getName())),
124                /*LowerCase=*/true);
125 }
126 
127 /// Main logging helper.
128 ///
129 /// Logging can be configured by the following environment variable.
130 ///   AMD_SPLIT_MODULE_LOG_DIR=<filepath>
131 ///     If set, uses <filepath> as the directory to write logfiles to
132 ///     each time module splitting is used.
133 ///   AMD_SPLIT_MODULE_LOG_PRIVATE
134 ///     If set to anything other than zero, all names are hidden.
135 ///
136 /// Both environment variables have corresponding CL options which
137 /// takes priority over them.
138 ///
139 /// Any output printed to the log files is also printed to dbgs() when -debug is
140 /// used and LLVM_DEBUG is defined.
141 ///
142 /// This approach has a small disadvantage over LLVM_DEBUG though: logging logic
143 /// cannot be removed from the code (by building without debug). This probably
144 /// has a small performance cost because if some computation/formatting is
145 /// needed for logging purpose, it may be done everytime only to be ignored
146 /// by the logger.
147 ///
148 /// As this pass only runs once and is not doing anything computationally
149 /// expensive, this is likely a reasonable trade-off.
150 ///
151 /// If some computation should really be avoided when unused, users of the class
152 /// can check whether any logging will occur by using the bool operator.
153 ///
154 /// \code
155 ///   if (SML) {
156 ///     // Executes only if logging to a file or if -debug is available and
157 ///     used.
158 ///   }
159 /// \endcode
160 class SplitModuleLogger {
161 public:
162   SplitModuleLogger(const Module &M) {
163     std::string LogDir = LogDirOpt;
164     if (LogDir.empty())
165       LogDir = sys::Process::GetEnv("AMD_SPLIT_MODULE_LOG_DIR").value_or("");
166 
167     // No log dir specified means we don't need to log to a file.
168     // We may still log to dbgs(), though.
169     if (LogDir.empty())
170       return;
171 
172     // If a log directory is specified, create a new file with a unique name in
173     // that directory.
174     int Fd;
175     SmallString<0> PathTemplate;
176     SmallString<0> RealPath;
177     sys::path::append(PathTemplate, LogDir, "Module-%%-%%-%%-%%-%%-%%-%%.txt");
178     if (auto Err =
179             sys::fs::createUniqueFile(PathTemplate.str(), Fd, RealPath)) {
180       report_fatal_error("Failed to create log file at '" + Twine(LogDir) +
181                              "': " + Err.message(),
182                          /*CrashDiag=*/false);
183     }
184 
185     FileOS = std::make_unique<raw_fd_ostream>(Fd, /*shouldClose=*/true);
186   }
187 
188   bool hasLogFile() const { return FileOS != nullptr; }
189 
190   raw_ostream &logfile() {
191     assert(FileOS && "no logfile!");
192     return *FileOS;
193   }
194 
195   /// \returns true if this SML will log anything either to a file or dbgs().
196   /// Can be used to avoid expensive computations that are ignored when logging
197   /// is disabled.
198   operator bool() const {
199     return hasLogFile() || (DebugFlag && isCurrentDebugType(DEBUG_TYPE));
200   }
201 
202 private:
203   std::unique_ptr<raw_fd_ostream> FileOS;
204 };
205 
206 template <typename Ty>
207 static SplitModuleLogger &operator<<(SplitModuleLogger &SML, const Ty &Val) {
208   static_assert(
209       !std::is_same_v<Ty, Value>,
210       "do not print values to logs directly, use handleName instead!");
211   LLVM_DEBUG(dbgs() << Val);
212   if (SML.hasLogFile())
213     SML.logfile() << Val;
214   return SML;
215 }
216 
217 /// Calculate the cost of each function in \p M
218 /// \param SML Log Helper
219 /// \param GetTTI Abstract getter for TargetTransformInfo.
220 /// \param M Module to analyze.
221 /// \param CostMap[out] Resulting Function -> Cost map.
222 /// \return The module's total cost.
223 static CostType
224 calculateFunctionCosts(SplitModuleLogger &SML, GetTTIFn GetTTI, Module &M,
225                        DenseMap<const Function *, CostType> &CostMap) {
226   CostType ModuleCost = 0;
227   CostType KernelCost = 0;
228 
229   for (auto &Fn : M) {
230     if (Fn.isDeclaration())
231       continue;
232 
233     CostType FnCost = 0;
234     const auto &TTI = GetTTI(Fn);
235     for (const auto &BB : Fn) {
236       for (const auto &I : BB) {
237         auto Cost =
238             TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
239         assert(Cost != InstructionCost::getMax());
240         // Assume expensive if we can't tell the cost of an instruction.
241         CostType CostVal =
242             Cost.getValue().value_or(TargetTransformInfo::TCC_Expensive);
243         assert((FnCost + CostVal) >= FnCost && "Overflow!");
244         FnCost += CostVal;
245       }
246     }
247 
248     assert(FnCost != 0);
249 
250     CostMap[&Fn] = FnCost;
251     assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!");
252     ModuleCost += FnCost;
253 
254     if (isEntryPoint(&Fn))
255       KernelCost += FnCost;
256   }
257 
258   CostType FnCost = (ModuleCost - KernelCost);
259   SML << "=> Total Module Cost: " << ModuleCost << '\n'
260       << "  => KernelCost: " << KernelCost << " ("
261       << format("%0.2f", (float(KernelCost) / ModuleCost) * 100) << "%)\n"
262       << "  => FnsCost: " << FnCost << " ("
263       << format("%0.2f", (float(FnCost) / ModuleCost) * 100) << "%)\n";
264 
265   return ModuleCost;
266 }
267 
268 static bool canBeIndirectlyCalled(const Function &F) {
269   if (F.isDeclaration() || isEntryPoint(&F))
270     return false;
271   return !F.hasLocalLinkage() ||
272          F.hasAddressTaken(/*PutOffender=*/nullptr,
273                            /*IgnoreCallbackUses=*/false,
274                            /*IgnoreAssumeLikeCalls=*/true,
275                            /*IgnoreLLVMUsed=*/true,
276                            /*IgnoreARCAttachedCall=*/false,
277                            /*IgnoreCastedDirectCall=*/true);
278 }
279 
280 /// When a function or any of its callees performs an indirect call, this
281 /// takes over \ref addAllDependencies and adds all potentially callable
282 /// functions to \p Fns so they can be counted as dependencies of the function.
283 ///
284 /// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the
285 /// presence of an indirect call, the function's resource usage is the same as
286 /// the most expensive function in the module.
287 /// \param M    The module.
288 /// \param Fns[out] Resulting list of functions.
289 static void addAllIndirectCallDependencies(const Module &M,
290                                            DenseSet<const Function *> &Fns) {
291   for (const auto &Fn : M) {
292     if (canBeIndirectlyCalled(Fn))
293       Fns.insert(&Fn);
294   }
295 }
296 
297 /// Adds the functions that \p Fn may call to \p Fns, then recurses into each
298 /// callee until all reachable functions have been gathered.
299 ///
300 /// \param SML Log Helper
301 /// \param CG Call graph for \p Fn's module.
302 /// \param Fn Current function to look at.
303 /// \param Fns[out] Resulting list of functions.
304 /// \param OnlyDirect Whether to only consider direct callees.
305 /// \param HadIndirectCall[out] Set to true if an indirect call was seen at some
306 /// point, either in \p Fn or in one of the function it calls. When that
307 /// happens, we fall back to adding all callable functions inside \p Fn's module
308 /// to \p Fns.
309 static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG,
310                                const Function &Fn,
311                                DenseSet<const Function *> &Fns, bool OnlyDirect,
312                                bool &HadIndirectCall) {
313   assert(!Fn.isDeclaration());
314 
315   const Module &M = *Fn.getParent();
316   SmallVector<const Function *> WorkList({&Fn});
317   while (!WorkList.empty()) {
318     const auto &CurFn = *WorkList.pop_back_val();
319     assert(!CurFn.isDeclaration());
320 
321     // Scan for an indirect call. If such a call is found, we have to
322     // conservatively assume this can call all non-entrypoint functions in the
323     // module.
324 
325     for (auto &CGEntry : *CG[&CurFn]) {
326       auto *CGNode = CGEntry.second;
327       auto *Callee = CGNode->getFunction();
328       if (!Callee) {
329         if (OnlyDirect)
330           continue;
331 
332         // Functions have an edge towards CallsExternalNode if they're external
333         // declarations, or if they do an indirect call. As we only process
334         // definitions here, we know this means the function has an indirect
335         // call. We then have to conservatively assume this can call all
336         // non-entrypoint functions in the module.
337         if (CGNode != CG.getCallsExternalNode())
338           continue; // this is another function-less node we don't care about.
339 
340         SML << "Indirect call detected in " << getName(CurFn)
341             << " - treating all non-entrypoint functions as "
342                "potential dependencies\n";
343 
344         // TODO: Print an ORE as well ?
345         addAllIndirectCallDependencies(M, Fns);
346         HadIndirectCall = true;
347         continue;
348       }
349 
350       if (Callee->isDeclaration())
351         continue;
352 
353       auto [It, Inserted] = Fns.insert(Callee);
354       if (Inserted)
355         WorkList.push_back(Callee);
356     }
357   }
358 }
359 
360 /// Contains information about a function and its dependencies.
361 /// This is a splitting root. The splitting algorithm works by
362 /// assigning these to partitions.
363 struct FunctionWithDependencies {
364   FunctionWithDependencies(SplitModuleLogger &SML, CallGraph &CG,
365                            const DenseMap<const Function *, CostType> &FnCosts,
366                            const Function *Fn)
367       : Fn(Fn) {
368     // When Fn is not a kernel, we don't need to collect indirect callees.
369     // Resource usage analysis is only performed on kernels, and we collect
370     // indirect callees for resource usage analysis.
371     addAllDependencies(SML, CG, *Fn, Dependencies,
372                        /*OnlyDirect*/ !isEntryPoint(Fn), HasIndirectCall);
373     TotalCost = FnCosts.at(Fn);
374     for (const auto *Dep : Dependencies) {
375       TotalCost += FnCosts.at(Dep);
376 
377       // We cannot duplicate functions with external linkage, or functions that
378       // may be overriden at runtime.
379       HasNonDuplicatableDependecy |=
380           (Dep->hasExternalLinkage() || !Dep->isDefinitionExact());
381     }
382   }
383 
384   const Function *Fn = nullptr;
385   DenseSet<const Function *> Dependencies;
386   /// Whether \p Fn or any of its \ref Dependencies contains an indirect call.
387   bool HasIndirectCall = false;
388   /// Whether any of \p Fn's dependencies cannot be duplicated.
389   bool HasNonDuplicatableDependecy = false;
390 
391   CostType TotalCost = 0;
392 
393   /// \returns true if this function and its dependencies can be considered
394   /// large according to \p Threshold.
395   bool isLarge(CostType Threshold) const {
396     return TotalCost > Threshold && !Dependencies.empty();
397   }
398 };
399 
400 /// Calculates how much overlap there is between \p A and \p B.
401 /// \return A number between 0.0 and 1.0, where 1.0 means A == B and 0.0 means A
402 /// and B have no shared elements. Kernels do not count in overlap calculation.
403 static float calculateOverlap(const DenseSet<const Function *> &A,
404                               const DenseSet<const Function *> &B) {
405   DenseSet<const Function *> Total;
406   for (const auto *F : A) {
407     if (!isEntryPoint(F))
408       Total.insert(F);
409   }
410 
411   if (Total.empty())
412     return 0.0f;
413 
414   unsigned NumCommon = 0;
415   for (const auto *F : B) {
416     if (isEntryPoint(F))
417       continue;
418 
419     auto [It, Inserted] = Total.insert(F);
420     if (!Inserted)
421       ++NumCommon;
422   }
423 
424   return static_cast<float>(NumCommon) / Total.size();
425 }
426 
427 /// Performs all of the partitioning work on \p M.
428 /// \param SML Log Helper
429 /// \param M Module to partition.
430 /// \param NumParts Number of partitions to create.
431 /// \param ModuleCost Total cost of all functions in \p M.
432 /// \param FnCosts Map of Function -> Cost
433 /// \param WorkList Functions and their dependencies to process in order.
434 /// \returns The created partitions (a vector of size \p NumParts )
435 static std::vector<DenseSet<const Function *>>
436 doPartitioning(SplitModuleLogger &SML, Module &M, unsigned NumParts,
437                CostType ModuleCost,
438                const DenseMap<const Function *, CostType> &FnCosts,
439                const SmallVector<FunctionWithDependencies> &WorkList) {
440 
441   SML << "\n--Partitioning Starts--\n";
442 
443   // Calculate a "large function threshold". When more than one function's total
444   // import cost exceeds this value, we will try to assign it to an existing
445   // partition to reduce the amount of duplication needed.
446   //
447   // e.g. let two functions X and Y have a import cost of ~10% of the module, we
448   // assign X to a partition as usual, but when we get to Y, we check if it's
449   // worth also putting it in Y's partition.
450   const CostType LargeFnThreshold =
451       LargeFnFactor ? CostType(((ModuleCost / NumParts) * LargeFnFactor))
452                     : std::numeric_limits<CostType>::max();
453 
454   std::vector<DenseSet<const Function *>> Partitions;
455   Partitions.resize(NumParts);
456 
457   // Assign functions to partitions, and try to keep the partitions more or
458   // less balanced. We do that through a priority queue sorted in reverse, so we
459   // can always look at the partition with the least content.
460   //
461   // There are some cases where we will be deliberately unbalanced though.
462   //  - Large functions: we try to merge with existing partitions to reduce code
463   //  duplication.
464   //  - Functions with indirect or external calls always go in the first
465   //  partition (P0).
466   auto ComparePartitions = [](const std::pair<PartitionID, CostType> &a,
467                               const std::pair<PartitionID, CostType> &b) {
468     // When two partitions have the same cost, assign to the one with the
469     // biggest ID first. This allows us to put things in P0 last, because P0 may
470     // have other stuff added later.
471     if (a.second == b.second)
472       return a.first < b.first;
473     return a.second > b.second;
474   };
475 
476   // We can't use priority_queue here because we need to be able to access any
477   // element. This makes this a bit inefficient as we need to sort it again
478   // everytime we change it, but it's a very small array anyway (likely under 64
479   // partitions) so it's a cheap operation.
480   std::vector<std::pair<PartitionID, CostType>> BalancingQueue;
481   for (unsigned I = 0; I < NumParts; ++I)
482     BalancingQueue.push_back(std::make_pair(I, 0));
483 
484   // Helper function to handle assigning a function to a partition. This takes
485   // care of updating the balancing queue.
486   const auto AssignToPartition = [&](PartitionID PID,
487                                      const FunctionWithDependencies &FWD) {
488     auto &FnsInPart = Partitions[PID];
489     FnsInPart.insert(FWD.Fn);
490     FnsInPart.insert(FWD.Dependencies.begin(), FWD.Dependencies.end());
491 
492     SML << "assign " << getName(*FWD.Fn) << " to P" << PID << "\n  ->  ";
493     if (!FWD.Dependencies.empty()) {
494       SML << FWD.Dependencies.size() << " dependencies added\n";
495     };
496 
497     // Update the balancing queue. we scan backwards because in the common case
498     // the partition is at the end.
499     for (auto &[QueuePID, Cost] : reverse(BalancingQueue)) {
500       if (QueuePID == PID) {
501         CostType NewCost = 0;
502         for (auto *Fn : Partitions[PID])
503           NewCost += FnCosts.at(Fn);
504 
505         SML << "[Updating P" << PID << " Cost]:" << Cost << " -> " << NewCost;
506         if (Cost) {
507           SML << " (" << unsigned(((float(NewCost) / Cost) - 1) * 100)
508               << "% increase)";
509         }
510         SML << '\n';
511 
512         Cost = NewCost;
513       }
514     }
515 
516     sort(BalancingQueue, ComparePartitions);
517   };
518 
519   for (auto &CurFn : WorkList) {
520     // When a function has indirect calls, it must stay in the first partition
521     // alongside every reachable non-entry function. This is a nightmare case
522     // for splitting as it severely limits what we can do.
523     if (CurFn.HasIndirectCall) {
524       SML << "Function with indirect call(s): " << getName(*CurFn.Fn)
525           << " defaulting to P0\n";
526       AssignToPartition(0, CurFn);
527       continue;
528     }
529 
530     // When a function has non duplicatable dependencies, we have to keep it in
531     // the first partition as well. This is a conservative approach, a
532     // finer-grained approach could keep track of which dependencies are
533     // non-duplicatable exactly and just make sure they're grouped together.
534     if (CurFn.HasNonDuplicatableDependecy) {
535       SML << "Function with externally visible dependency "
536           << getName(*CurFn.Fn) << " defaulting to P0\n";
537       AssignToPartition(0, CurFn);
538       continue;
539     }
540 
541     // Be smart with large functions to avoid duplicating their dependencies.
542     if (CurFn.isLarge(LargeFnThreshold)) {
543       assert(LargeFnOverlapForMerge >= 0.0f && LargeFnOverlapForMerge <= 1.0f);
544       SML << "Large Function: " << getName(*CurFn.Fn)
545           << " - looking for partition with at least "
546           << format("%0.2f", LargeFnOverlapForMerge * 100) << "% overlap\n";
547 
548       bool Assigned = false;
549       for (const auto &[PID, Fns] : enumerate(Partitions)) {
550         float Overlap = calculateOverlap(CurFn.Dependencies, Fns);
551         SML << "  => " << format("%0.2f", Overlap * 100) << "% overlap with P"
552             << PID << '\n';
553         if (Overlap > LargeFnOverlapForMerge) {
554           SML << "  selecting P" << PID << '\n';
555           AssignToPartition(PID, CurFn);
556           Assigned = true;
557         }
558       }
559 
560       if (Assigned)
561         continue;
562     }
563 
564     // Normal "load-balancing", assign to partition with least pressure.
565     auto [PID, CurCost] = BalancingQueue.back();
566     AssignToPartition(PID, CurFn);
567   }
568 
569   if (SML) {
570     for (const auto &[Idx, Part] : enumerate(Partitions)) {
571       CostType Cost = 0;
572       for (auto *Fn : Part)
573         Cost += FnCosts.at(Fn);
574       SML << "P" << Idx << " has a total cost of " << Cost << " ("
575           << format("%0.2f", (float(Cost) / ModuleCost) * 100)
576           << "% of source module)\n";
577     }
578 
579     SML << "--Partitioning Done--\n\n";
580   }
581 
582   // Check no functions were missed.
583 #ifndef NDEBUG
584   DenseSet<const Function *> AllFunctions;
585   for (const auto &Part : Partitions)
586     AllFunctions.insert(Part.begin(), Part.end());
587 
588   for (auto &Fn : M) {
589     if (!Fn.isDeclaration() && !AllFunctions.contains(&Fn)) {
590       assert(AllFunctions.contains(&Fn) && "Missed a function?!");
591     }
592   }
593 #endif
594 
595   return Partitions;
596 }
597 
598 static void externalize(GlobalValue &GV) {
599   if (GV.hasLocalLinkage()) {
600     GV.setLinkage(GlobalValue::ExternalLinkage);
601     GV.setVisibility(GlobalValue::HiddenVisibility);
602   }
603 
604   // Unnamed entities must be named consistently between modules. setName will
605   // give a distinct name to each such entity.
606   if (!GV.hasName())
607     GV.setName("__llvmsplit_unnamed");
608 }
609 
610 static bool hasDirectCaller(const Function &Fn) {
611   for (auto &U : Fn.uses()) {
612     if (auto *CB = dyn_cast<CallBase>(U.getUser()); CB && CB->isCallee(&U))
613       return true;
614   }
615   return false;
616 }
617 
618 static void splitAMDGPUModule(
619     GetTTIFn GetTTI, Module &M, unsigned N,
620     function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
621 
622   SplitModuleLogger SML(M);
623 
624   CallGraph CG(M);
625 
626   // Externalize functions whose address are taken.
627   //
628   // This is needed because partitioning is purely based on calls, but sometimes
629   // a kernel/function may just look at the address of another local function
630   // and not do anything (no calls). After partitioning, that local function may
631   // end up in a different module (so it's just a declaration in the module
632   // where its address is taken), which emits a "undefined hidden symbol" linker
633   // error.
634   //
635   // Additionally, it guides partitioning to not duplicate this function if it's
636   // called directly at some point.
637   for (auto &Fn : M) {
638     if (Fn.hasAddressTaken()) {
639       if (Fn.hasLocalLinkage()) {
640         SML << "[externalize] " << Fn.getName()
641             << " because its address is taken\n";
642       }
643       externalize(Fn);
644     }
645   }
646 
647   // Externalize local GVs, which avoids duplicating their initializers, which
648   // in turns helps keep code size in check.
649   if (!NoExternalizeGlobals) {
650     for (auto &GV : M.globals()) {
651       if (GV.hasLocalLinkage())
652         SML << "[externalize] GV " << GV.getName() << '\n';
653       externalize(GV);
654     }
655   }
656 
657   // Start by calculating the cost of every function in the module, as well as
658   // the module's overall cost.
659   DenseMap<const Function *, CostType> FnCosts;
660   const CostType ModuleCost = calculateFunctionCosts(SML, GetTTI, M, FnCosts);
661 
662   // First, gather ever kernel into the worklist.
663   SmallVector<FunctionWithDependencies> WorkList;
664   for (auto &Fn : M) {
665     if (isEntryPoint(&Fn) && !Fn.isDeclaration())
666       WorkList.emplace_back(SML, CG, FnCosts, &Fn);
667   }
668 
669   // Then, find missing functions that need to be considered as additional
670   // roots. These can't be called in theory, but in practice we still have to
671   // handle them to avoid linker errors.
672   {
673     DenseSet<const Function *> SeenFunctions;
674     for (const auto &FWD : WorkList) {
675       SeenFunctions.insert(FWD.Fn);
676       SeenFunctions.insert(FWD.Dependencies.begin(), FWD.Dependencies.end());
677     }
678 
679     for (auto &Fn : M) {
680       // If this function is not part of any kernel's dependencies and isn't
681       // directly called, consider it as a root.
682       if (!Fn.isDeclaration() && !isEntryPoint(&Fn) &&
683           !SeenFunctions.count(&Fn) && !hasDirectCaller(Fn)) {
684         WorkList.emplace_back(SML, CG, FnCosts, &Fn);
685       }
686     }
687   }
688 
689   // Sort the worklist so the most expensive roots are seen first.
690   sort(WorkList, [&](auto &A, auto &B) {
691     // Sort by total cost, and if the total cost is identical, sort
692     // alphabetically.
693     if (A.TotalCost == B.TotalCost)
694       return A.Fn->getName() < B.Fn->getName();
695     return A.TotalCost > B.TotalCost;
696   });
697 
698   if (SML) {
699     SML << "Worklist\n";
700     for (const auto &FWD : WorkList) {
701       SML << "[root] " << getName(*FWD.Fn) << " (totalCost:" << FWD.TotalCost
702           << " indirect:" << FWD.HasIndirectCall
703           << " hasNonDuplicatableDep:" << FWD.HasNonDuplicatableDependecy
704           << ")\n";
705       // Sort function names before printing to ensure determinism.
706       SmallVector<std::string> SortedDepNames;
707       SortedDepNames.reserve(FWD.Dependencies.size());
708       for (const auto *Dep : FWD.Dependencies)
709         SortedDepNames.push_back(getName(*Dep));
710       sort(SortedDepNames);
711 
712       for (const auto &Name : SortedDepNames)
713         SML << "  [dependency] " << Name << '\n';
714     }
715   }
716 
717   // This performs all of the partitioning work.
718   auto Partitions = doPartitioning(SML, M, N, ModuleCost, FnCosts, WorkList);
719   assert(Partitions.size() == N);
720 
721   // If we didn't externalize GVs, then local GVs need to be conservatively
722   // imported into every module (including their initializers), and then cleaned
723   // up afterwards.
724   const auto NeedsConservativeImport = [&](const GlobalValue *GV) {
725     // We conservatively import private/internal GVs into every module and clean
726     // them up afterwards.
727     const auto *Var = dyn_cast<GlobalVariable>(GV);
728     return Var && Var->hasLocalLinkage();
729   };
730 
731   SML << "Creating " << N << " modules...\n";
732   unsigned TotalFnImpls = 0;
733   for (unsigned I = 0; I < N; ++I) {
734     const auto &FnsInPart = Partitions[I];
735 
736     ValueToValueMapTy VMap;
737     std::unique_ptr<Module> MPart(
738         CloneModule(M, VMap, [&](const GlobalValue *GV) {
739           // Functions go in their assigned partition.
740           if (const auto *Fn = dyn_cast<Function>(GV))
741             return FnsInPart.contains(Fn);
742 
743           if (NeedsConservativeImport(GV))
744             return true;
745 
746           // Everything else goes in the first partition.
747           return I == 0;
748         }));
749 
750     // Clean-up conservatively imported GVs without any users.
751     for (auto &GV : make_early_inc_range(MPart->globals())) {
752       if (NeedsConservativeImport(&GV) && GV.use_empty())
753         GV.eraseFromParent();
754     }
755 
756     unsigned NumAllFns = 0, NumKernels = 0;
757     for (auto &Cur : *MPart) {
758       if (!Cur.isDeclaration()) {
759         ++NumAllFns;
760         if (isEntryPoint(&Cur))
761           ++NumKernels;
762       }
763     }
764     TotalFnImpls += NumAllFns;
765     SML << "  - Module " << I << " with " << NumAllFns << " functions ("
766         << NumKernels << " kernels)\n";
767     ModuleCallback(std::move(MPart));
768   }
769 
770   SML << TotalFnImpls << " function definitions across all modules ("
771       << format("%0.2f", (float(TotalFnImpls) / FnCosts.size()) * 100)
772       << "% of original module)\n";
773 }
774 } // namespace
775 
776 PreservedAnalyses AMDGPUSplitModulePass::run(Module &M,
777                                              ModuleAnalysisManager &MAM) {
778   FunctionAnalysisManager &FAM =
779       MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
780   const auto TTIGetter = [&FAM](Function &F) -> const TargetTransformInfo & {
781     return FAM.getResult<TargetIRAnalysis>(F);
782   };
783   splitAMDGPUModule(TTIGetter, M, N, ModuleCallback);
784   // We don't change the original module.
785   return PreservedAnalyses::all();
786 }
787