xref: /llvm-project/bolt/tools/merge-fdata/merge-fdata.cpp (revision 86526084044167b3c753d32ef8dbf79d57cba0c4)
1 //===- bolt/tools/merge-fdata/merge-fdata.cpp -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Tool for merging profile in fdata format:
10 //
11 //   $ merge-fdata 1.fdata 2.fdata 3.fdata > merged.fdata
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "bolt/Profile/ProfileYAMLMapping.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/FileSystem.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/PrettyStackTrace.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/ThreadPool.h"
24 #include <algorithm>
25 #include <fstream>
26 #include <mutex>
27 #include <unordered_map>
28 
29 using namespace llvm;
30 using namespace llvm::yaml::bolt;
31 
32 namespace opts {
33 
34 cl::OptionCategory MergeFdataCategory("merge-fdata options");
35 
36 enum SortType : char {
37   ST_NONE,
38   ST_EXEC_COUNT,      /// Sort based on function execution count.
39   ST_TOTAL_BRANCHES,  /// Sort based on all branches in the function.
40 };
41 
42 static cl::list<std::string>
43 InputDataFilenames(
44   cl::Positional,
45   cl::CommaSeparated,
46   cl::desc("<fdata1> [<fdata2>]..."),
47   cl::OneOrMore,
48   cl::cat(MergeFdataCategory));
49 
50 static cl::opt<SortType>
51 PrintFunctionList("print",
52   cl::desc("print the list of objects with count to stderr"),
53   cl::init(ST_NONE),
54   cl::values(clEnumValN(ST_NONE,
55       "none",
56       "do not print objects/functions"),
57     clEnumValN(ST_EXEC_COUNT,
58       "exec",
59       "print functions sorted by execution count"),
60     clEnumValN(ST_TOTAL_BRANCHES,
61       "branches",
62       "print functions sorted by total branch count")),
63   cl::cat(MergeFdataCategory));
64 
65 static cl::opt<bool>
66 SuppressMergedDataOutput("q",
67   cl::desc("do not print merged data to stdout"),
68   cl::init(false),
69   cl::Optional,
70   cl::cat(MergeFdataCategory));
71 
72 static cl::opt<std::string>
73 OutputFilePath("o",
74   cl::value_desc("file"),
75   cl::desc("Write output to <file>"),
76   cl::cat(MergeFdataCategory));
77 
78 } // namespace opts
79 
80 namespace {
81 
82 static StringRef ToolName;
83 
84 static void report_error(StringRef Message, std::error_code EC) {
85   assert(EC);
86   errs() << ToolName << ": '" << Message << "': " << EC.message() << ".\n";
87   exit(1);
88 }
89 
90 static void report_error(Twine Message, StringRef CustomError) {
91   errs() << ToolName << ": '" << Message << "': " << CustomError << ".\n";
92   exit(1);
93 }
94 
95 static raw_fd_ostream &output() {
96   if (opts::OutputFilePath.empty() || opts::OutputFilePath == "-")
97     return outs();
98   else {
99     std::error_code EC;
100     static raw_fd_ostream Output(opts::OutputFilePath, EC);
101     if (EC)
102       report_error(opts::OutputFilePath, EC);
103     return Output;
104   }
105 }
106 
107 void mergeProfileHeaders(BinaryProfileHeader &MergedHeader,
108                          const BinaryProfileHeader &Header) {
109   if (MergedHeader.FileName.empty())
110     MergedHeader.FileName = Header.FileName;
111 
112   if (!MergedHeader.FileName.empty() &&
113       MergedHeader.FileName != Header.FileName)
114     errs() << "WARNING: merging profile from a binary for " << Header.FileName
115            << " into a profile for binary " << MergedHeader.FileName << '\n';
116 
117   if (MergedHeader.Id.empty())
118     MergedHeader.Id = Header.Id;
119 
120   if (!MergedHeader.Id.empty() && (MergedHeader.Id != Header.Id))
121     errs() << "WARNING: build-ids in merged profiles do not match\n";
122 
123   // Cannot merge samples profile with LBR profile.
124   if (!MergedHeader.Flags)
125     MergedHeader.Flags = Header.Flags;
126 
127   constexpr auto Mask = llvm::bolt::BinaryFunction::PF_LBR |
128                         llvm::bolt::BinaryFunction::PF_SAMPLE;
129   if ((MergedHeader.Flags & Mask) != (Header.Flags & Mask)) {
130     errs() << "ERROR: cannot merge LBR profile with non-LBR profile\n";
131     exit(1);
132   }
133   MergedHeader.Flags = MergedHeader.Flags | Header.Flags;
134 
135   if (!Header.Origin.empty()) {
136     if (MergedHeader.Origin.empty())
137       MergedHeader.Origin = Header.Origin;
138     else if (MergedHeader.Origin != Header.Origin)
139       MergedHeader.Origin += "; " + Header.Origin;
140   }
141 
142   if (MergedHeader.EventNames.empty())
143     MergedHeader.EventNames = Header.EventNames;
144 
145   if (MergedHeader.EventNames != Header.EventNames) {
146     errs() << "WARNING: merging profiles with different sampling events\n";
147     MergedHeader.EventNames += "," + Header.EventNames;
148   }
149 
150   if (MergedHeader.HashFunction != Header.HashFunction)
151     report_error("merge conflict",
152                  "cannot merge profiles with different hash functions");
153 }
154 
155 void mergeBasicBlockProfile(BinaryBasicBlockProfile &MergedBB,
156                             BinaryBasicBlockProfile &&BB,
157                             const BinaryFunctionProfile &BF) {
158   // Verify that the blocks match.
159   if (BB.NumInstructions != MergedBB.NumInstructions)
160     report_error(BF.Name + " : BB #" + Twine(BB.Index),
161                  "number of instructions in block mismatch");
162   if (BB.Hash != MergedBB.Hash)
163     report_error(BF.Name + " : BB #" + Twine(BB.Index),
164                  "basic block hash mismatch");
165 
166   // Update the execution count.
167   MergedBB.ExecCount += BB.ExecCount;
168 
169   // Update the event count.
170   MergedBB.EventCount += BB.EventCount;
171 
172   // Merge calls sites.
173   std::unordered_map<uint32_t, CallSiteInfo *> CSByOffset;
174   for (CallSiteInfo &CS : BB.CallSites)
175     CSByOffset.emplace(std::make_pair(CS.Offset, &CS));
176 
177   for (CallSiteInfo &MergedCS : MergedBB.CallSites) {
178     auto CSI = CSByOffset.find(MergedCS.Offset);
179     if (CSI == CSByOffset.end())
180       continue;
181     yaml::bolt::CallSiteInfo &CS = *CSI->second;
182     if (CS != MergedCS)
183       continue;
184 
185     MergedCS.Count += CS.Count;
186     MergedCS.Mispreds += CS.Mispreds;
187 
188     CSByOffset.erase(CSI);
189   }
190 
191   // Append the rest of call sites.
192   for (std::pair<const uint32_t, CallSiteInfo *> CSI : CSByOffset)
193     MergedBB.CallSites.emplace_back(std::move(*CSI.second));
194 
195   // Merge successor info.
196   std::vector<SuccessorInfo *> SIByIndex(BF.NumBasicBlocks);
197   for (SuccessorInfo &SI : BB.Successors) {
198     if (SI.Index >= BF.NumBasicBlocks)
199       report_error(BF.Name, "bad successor index");
200     SIByIndex[SI.Index] = &SI;
201   }
202   for (SuccessorInfo &MergedSI : MergedBB.Successors) {
203     if (!SIByIndex[MergedSI.Index])
204       continue;
205     SuccessorInfo &SI = *SIByIndex[MergedSI.Index];
206 
207     MergedSI.Count += SI.Count;
208     MergedSI.Mispreds += SI.Mispreds;
209 
210     SIByIndex[MergedSI.Index] = nullptr;
211   }
212   for (SuccessorInfo *SI : SIByIndex)
213     if (SI)
214       MergedBB.Successors.emplace_back(std::move(*SI));
215 }
216 
217 void mergeFunctionProfile(BinaryFunctionProfile &MergedBF,
218                           BinaryFunctionProfile &&BF) {
219   // Validate that we are merging the correct function.
220   if (BF.NumBasicBlocks != MergedBF.NumBasicBlocks)
221     report_error(BF.Name, "number of basic blocks mismatch");
222   if (BF.Id != MergedBF.Id)
223     report_error(BF.Name, "ID mismatch");
224   if (BF.Hash != MergedBF.Hash)
225     report_error(BF.Name, "hash mismatch");
226 
227   // Update the execution count.
228   MergedBF.ExecCount += BF.ExecCount;
229 
230   // Merge basic blocks profile.
231   std::vector<BinaryBasicBlockProfile *> BlockByIndex(BF.NumBasicBlocks);
232   for (BinaryBasicBlockProfile &BB : BF.Blocks) {
233     if (BB.Index >= BF.NumBasicBlocks)
234       report_error(BF.Name + " : BB #" + Twine(BB.Index),
235                    "bad basic block index");
236     BlockByIndex[BB.Index] = &BB;
237   }
238   for (BinaryBasicBlockProfile &MergedBB : MergedBF.Blocks) {
239     if (!BlockByIndex[MergedBB.Index])
240       continue;
241     BinaryBasicBlockProfile &BB = *BlockByIndex[MergedBB.Index];
242 
243     mergeBasicBlockProfile(MergedBB, std::move(BB), MergedBF);
244 
245     // Ignore this block in the future.
246     BlockByIndex[MergedBB.Index] = nullptr;
247   }
248 
249   // Append blocks unique to BF (i.e. those that are not in MergedBF).
250   for (BinaryBasicBlockProfile *BB : BlockByIndex)
251     if (BB)
252       MergedBF.Blocks.emplace_back(std::move(*BB));
253 }
254 
255 bool isYAML(const StringRef Filename) {
256   ErrorOr<std::unique_ptr<MemoryBuffer>> MB =
257       MemoryBuffer::getFileOrSTDIN(Filename);
258   if (std::error_code EC = MB.getError())
259     report_error(Filename, EC);
260   StringRef Buffer = MB.get()->getBuffer();
261   if (Buffer.starts_with("---\n"))
262     return true;
263   return false;
264 }
265 
266 void mergeLegacyProfiles(const SmallVectorImpl<std::string> &Filenames) {
267   errs() << "Using legacy profile format.\n";
268   std::optional<bool> BoltedCollection;
269   std::optional<bool> NoLBRCollection;
270   std::mutex BoltedCollectionMutex;
271   struct CounterTy {
272     uint64_t Exec{0};
273     uint64_t Mispred{0};
274     CounterTy &operator+=(const CounterTy &O) {
275       Exec += O.Exec;
276       Mispred += O.Mispred;
277       return *this;
278     }
279     CounterTy operator+(const CounterTy &O) { return *this += O; }
280   };
281   typedef StringMap<CounterTy> ProfileTy;
282 
283   auto ParseProfile = [&](const std::string &Filename, auto &Profiles) {
284     const llvm::thread::id tid = llvm::this_thread::get_id();
285 
286     if (isYAML(Filename))
287       report_error(Filename, "cannot mix YAML and legacy formats");
288 
289     std::ifstream FdataFile(Filename, std::ios::in);
290     std::string FdataLine;
291     std::getline(FdataFile, FdataLine);
292 
293     auto checkMode = [&](const std::string &Key, std::optional<bool> &Flag) {
294       const bool KeyIsSet = FdataLine.rfind(Key, 0) == 0;
295 
296       if (!Flag.has_value())
297         Flag = KeyIsSet;
298       else if (*Flag != KeyIsSet)
299         report_error(Filename, "cannot mix profile with and without " + Key);
300       if (KeyIsSet)
301         // Advance line
302         std::getline(FdataFile, FdataLine);
303     };
304 
305     ProfileTy *Profile;
306     {
307       std::lock_guard<std::mutex> Lock(BoltedCollectionMutex);
308       // Check if the string "boltedcollection" is in the first line
309       checkMode("boltedcollection", BoltedCollection);
310       // Check if the string "no_lbr" is in the first line
311       // (or second line if BoltedCollection is true)
312       checkMode("no_lbr", NoLBRCollection);
313       Profile = &Profiles[tid];
314     }
315 
316     do {
317       StringRef Line(FdataLine);
318       CounterTy Count;
319       auto [Signature, ExecCount] = Line.rsplit(' ');
320       if (ExecCount.getAsInteger(10, Count.Exec))
321         report_error(Filename, "Malformed / corrupted execution count");
322       // Only LBR profile has misprediction field
323       if (!NoLBRCollection.value_or(false)) {
324         auto [SignatureLBR, MispredCount] = Signature.rsplit(' ');
325         Signature = SignatureLBR;
326         if (MispredCount.getAsInteger(10, Count.Mispred))
327           report_error(Filename, "Malformed / corrupted misprediction count");
328       }
329 
330       Count += Profile->lookup(Signature);
331       Profile->insert_or_assign(Signature, Count);
332     } while (std::getline(FdataFile, FdataLine));
333   };
334 
335   // The final reduction has non-trivial cost, make sure each thread has at
336   // least 4 tasks.
337   ThreadPoolStrategy S = optimal_concurrency(
338       std::max(Filenames.size() / 4, static_cast<size_t>(1)));
339   DefaultThreadPool Pool(S);
340   DenseMap<llvm::thread::id, ProfileTy> ParsedProfiles(
341       Pool.getMaxConcurrency());
342   for (const auto &Filename : Filenames)
343     Pool.async(ParseProfile, std::cref(Filename), std::ref(ParsedProfiles));
344   Pool.wait();
345 
346   ProfileTy MergedProfile;
347   for (const auto &[Thread, Profile] : ParsedProfiles)
348     for (const auto &[Key, Value] : Profile) {
349       CounterTy Count = MergedProfile.lookup(Key) + Value;
350       MergedProfile.insert_or_assign(Key, Count);
351     }
352 
353   if (BoltedCollection.value_or(false))
354     output() << "boltedcollection\n";
355   if (NoLBRCollection.value_or(false))
356     output() << "no_lbr\n";
357   for (const auto &[Key, Value] : MergedProfile) {
358     output() << Key << " ";
359     if (!NoLBRCollection.value_or(false))
360       output() << Value.Mispred << " ";
361     output() << Value.Exec << "\n";
362   }
363 
364   errs() << "Profile from " << Filenames.size() << " files merged.\n";
365 }
366 
367 } // anonymous namespace
368 
369 int main(int argc, char **argv) {
370   // Print a stack trace if we signal out.
371   sys::PrintStackTraceOnErrorSignal(argv[0]);
372   PrettyStackTraceProgram X(argc, argv);
373 
374   llvm_shutdown_obj Y; // Call llvm_shutdown() on exit.
375 
376   cl::HideUnrelatedOptions(opts::MergeFdataCategory);
377 
378   cl::ParseCommandLineOptions(argc, argv,
379                               "merge multiple fdata into a single file");
380 
381   ToolName = argv[0];
382 
383   // Recursively expand input directories into input file lists.
384   SmallVector<std::string> Inputs;
385   for (std::string &InputDataFilename : opts::InputDataFilenames) {
386     if (!llvm::sys::fs::exists(InputDataFilename))
387       report_error(InputDataFilename,
388                    std::make_error_code(std::errc::no_such_file_or_directory));
389     if (llvm::sys::fs::is_regular_file(InputDataFilename))
390       Inputs.emplace_back(InputDataFilename);
391     else if (llvm::sys::fs::is_directory(InputDataFilename)) {
392       std::error_code EC;
393       for (llvm::sys::fs::recursive_directory_iterator F(InputDataFilename, EC),
394            E;
395            F != E && !EC; F.increment(EC))
396         if (llvm::sys::fs::is_regular_file(F->path()))
397           Inputs.emplace_back(F->path());
398       if (EC)
399         report_error(InputDataFilename, EC);
400     }
401   }
402 
403   if (!isYAML(Inputs.front())) {
404     mergeLegacyProfiles(Inputs);
405     return 0;
406   }
407 
408   // Merged header.
409   BinaryProfileHeader MergedHeader;
410   MergedHeader.Version = 1;
411 
412   // Merged information for all functions.
413   StringMap<BinaryFunctionProfile> MergedBFs;
414 
415   bool FirstHeader = true;
416   for (std::string &InputDataFilename : Inputs) {
417     ErrorOr<std::unique_ptr<MemoryBuffer>> MB =
418         MemoryBuffer::getFileOrSTDIN(InputDataFilename);
419     if (std::error_code EC = MB.getError())
420       report_error(InputDataFilename, EC);
421     yaml::Input YamlInput(MB.get()->getBuffer());
422     YamlInput.setAllowUnknownKeys(true);
423 
424     errs() << "Merging data from " << InputDataFilename << "...\n";
425 
426     BinaryProfile BP;
427     YamlInput >> BP;
428     if (YamlInput.error())
429       report_error(InputDataFilename, YamlInput.error());
430 
431     // Sanity check.
432     if (BP.Header.Version != 1) {
433       errs() << "Unable to merge data from profile using version "
434              << BP.Header.Version << '\n';
435       exit(1);
436     }
437 
438     // Merge the header.
439     if (FirstHeader) {
440       MergedHeader = BP.Header;
441       FirstHeader = false;
442     } else {
443       mergeProfileHeaders(MergedHeader, BP.Header);
444     }
445 
446     // Do the function merge.
447     for (BinaryFunctionProfile &BF : BP.Functions) {
448       if (!MergedBFs.count(BF.Name)) {
449         MergedBFs.insert(std::make_pair(BF.Name, BF));
450         continue;
451       }
452 
453       BinaryFunctionProfile &MergedBF = MergedBFs.find(BF.Name)->second;
454       mergeFunctionProfile(MergedBF, std::move(BF));
455     }
456   }
457 
458   if (!opts::SuppressMergedDataOutput) {
459     yaml::Output YamlOut(output());
460 
461     BinaryProfile MergedProfile;
462     MergedProfile.Header = MergedHeader;
463     MergedProfile.Functions.resize(MergedBFs.size());
464     llvm::copy(llvm::make_second_range(MergedBFs),
465                MergedProfile.Functions.begin());
466 
467     // For consistency, sort functions by their IDs.
468     llvm::sort(MergedProfile.Functions,
469                [](const BinaryFunctionProfile &A,
470                   const BinaryFunctionProfile &B) { return A.Id < B.Id; });
471 
472     YamlOut << MergedProfile;
473   }
474 
475   errs() << "Data for " << MergedBFs.size()
476          << " unique objects successfully merged.\n";
477 
478   if (opts::PrintFunctionList != opts::ST_NONE) {
479     // List of function names with execution count.
480     std::vector<std::pair<uint64_t, StringRef>> FunctionList(MergedBFs.size());
481     using CountFuncType = std::function<std::pair<uint64_t, StringRef>(
482         const StringMapEntry<BinaryFunctionProfile> &)>;
483     CountFuncType ExecCountFunc =
484         [](const StringMapEntry<BinaryFunctionProfile> &V) {
485           return std::make_pair(V.second.ExecCount, StringRef(V.second.Name));
486         };
487     CountFuncType BranchCountFunc =
488         [](const StringMapEntry<BinaryFunctionProfile> &V) {
489           // Return total branch count.
490           uint64_t BranchCount = 0;
491           for (const BinaryBasicBlockProfile &BI : V.second.Blocks)
492             for (const SuccessorInfo &SI : BI.Successors)
493               BranchCount += SI.Count;
494           return std::make_pair(BranchCount, StringRef(V.second.Name));
495         };
496 
497     CountFuncType CountFunc = (opts::PrintFunctionList == opts::ST_EXEC_COUNT)
498                                   ? ExecCountFunc
499                                   : BranchCountFunc;
500     llvm::transform(MergedBFs, FunctionList.begin(), CountFunc);
501     llvm::stable_sort(reverse(FunctionList));
502     errs() << "Functions sorted by "
503            << (opts::PrintFunctionList == opts::ST_EXEC_COUNT ? "execution"
504                                                               : "total branch")
505            << " count:\n";
506     for (std::pair<uint64_t, StringRef> &FI : FunctionList)
507       errs() << FI.second << " : " << FI.first << '\n';
508   }
509 
510   return 0;
511 }
512