xref: /llvm-project/clang/tools/clang-nvlink-wrapper/ClangNVLinkWrapper.cpp (revision dd647e3e608ed0b2bac7c588d5859b80ef4a5976)
1 //===-- clang-nvlink-wrapper/ClangNVLinkWrapper.cpp - NVIDIA linker util --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This tool wraps around the NVIDIA linker called 'nvlink'. The NVIDIA linker
10 // is required to create NVPTX applications, but does not support common
11 // features like LTO or archives. This utility wraps around the tool to cover
12 // its deficiencies. This tool can be removed once NVIDIA improves their linker
13 // or ports it to `ld.lld`.
14 //
15 //===---------------------------------------------------------------------===//
16 
17 #include "clang/Basic/Version.h"
18 
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/BinaryFormat/Magic.h"
21 #include "llvm/Bitcode/BitcodeWriter.h"
22 #include "llvm/CodeGen/CommandFlags.h"
23 #include "llvm/IR/DiagnosticPrinter.h"
24 #include "llvm/LTO/LTO.h"
25 #include "llvm/Object/Archive.h"
26 #include "llvm/Object/ArchiveWriter.h"
27 #include "llvm/Object/Binary.h"
28 #include "llvm/Object/ELFObjectFile.h"
29 #include "llvm/Object/IRObjectFile.h"
30 #include "llvm/Object/ObjectFile.h"
31 #include "llvm/Object/OffloadBinary.h"
32 #include "llvm/Option/ArgList.h"
33 #include "llvm/Option/OptTable.h"
34 #include "llvm/Option/Option.h"
35 #include "llvm/Remarks/HotnessThresholdParser.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/FileOutputBuffer.h"
38 #include "llvm/Support/FileSystem.h"
39 #include "llvm/Support/InitLLVM.h"
40 #include "llvm/Support/MemoryBuffer.h"
41 #include "llvm/Support/Path.h"
42 #include "llvm/Support/Program.h"
43 #include "llvm/Support/Signals.h"
44 #include "llvm/Support/StringSaver.h"
45 #include "llvm/Support/TargetSelect.h"
46 #include "llvm/Support/WithColor.h"
47 
48 using namespace llvm;
49 using namespace llvm::opt;
50 using namespace llvm::object;
51 
52 // Various tools (e.g., llc and opt) duplicate this series of declarations for
53 // options related to passes and remarks.
54 static cl::opt<bool> RemarksWithHotness(
55     "pass-remarks-with-hotness",
56     cl::desc("With PGO, include profile count in optimization remarks"),
57     cl::Hidden);
58 
59 static cl::opt<std::optional<uint64_t>, false, remarks::HotnessThresholdParser>
60     RemarksHotnessThreshold(
61         "pass-remarks-hotness-threshold",
62         cl::desc("Minimum profile count required for "
63                  "an optimization remark to be output. "
64                  "Use 'auto' to apply the threshold from profile summary."),
65         cl::value_desc("N or 'auto'"), cl::init(0), cl::Hidden);
66 
67 static cl::opt<std::string>
68     RemarksFilename("pass-remarks-output",
69                     cl::desc("Output filename for pass remarks"),
70                     cl::value_desc("filename"));
71 
72 static cl::opt<std::string>
73     RemarksPasses("pass-remarks-filter",
74                   cl::desc("Only record optimization remarks from passes whose "
75                            "names match the given regular expression"),
76                   cl::value_desc("regex"));
77 
78 static cl::opt<std::string> RemarksFormat(
79     "pass-remarks-format",
80     cl::desc("The format used for serializing remarks (default: YAML)"),
81     cl::value_desc("format"), cl::init("yaml"));
82 
83 static cl::list<std::string>
84     PassPlugins("load-pass-plugin",
85                 cl::desc("Load passes from plugin library"));
86 
87 static void printVersion(raw_ostream &OS) {
88   OS << clang::getClangToolFullVersion("clang-nvlink-wrapper") << '\n';
89 }
90 
91 /// The value of `argv[0]` when run.
92 static const char *Executable;
93 
94 /// Temporary files to be cleaned up.
95 static SmallVector<SmallString<128>> TempFiles;
96 
97 /// Codegen flags for LTO backend.
98 static codegen::RegisterCodeGenFlags CodeGenFlags;
99 
100 namespace {
101 // Must not overlap with llvm::opt::DriverFlag.
102 enum WrapperFlags { WrapperOnlyOption = (1 << 4) };
103 
104 enum ID {
105   OPT_INVALID = 0, // This is not an option ID.
106 #define OPTION(...) LLVM_MAKE_OPT_ID(__VA_ARGS__),
107 #include "NVLinkOpts.inc"
108   LastOption
109 #undef OPTION
110 };
111 
112 #define OPTTABLE_STR_TABLE_CODE
113 #include "NVLinkOpts.inc"
114 #undef OPTTABLE_STR_TABLE_CODE
115 
116 #define OPTTABLE_PREFIXES_TABLE_CODE
117 #include "NVLinkOpts.inc"
118 #undef OPTTABLE_PREFIXES_TABLE_CODE
119 
120 static constexpr OptTable::Info InfoTable[] = {
121 #define OPTION(...) LLVM_CONSTRUCT_OPT_INFO(__VA_ARGS__),
122 #include "NVLinkOpts.inc"
123 #undef OPTION
124 };
125 
126 class WrapperOptTable : public opt::GenericOptTable {
127 public:
128   WrapperOptTable()
129       : opt::GenericOptTable(OptionStrTable, OptionPrefixesTable, InfoTable) {}
130 };
131 
132 const OptTable &getOptTable() {
133   static const WrapperOptTable *Table = []() {
134     auto Result = std::make_unique<WrapperOptTable>();
135     return Result.release();
136   }();
137   return *Table;
138 }
139 
140 [[noreturn]] void reportError(Error E) {
141   outs().flush();
142   logAllUnhandledErrors(std::move(E), WithColor::error(errs(), Executable));
143   exit(EXIT_FAILURE);
144 }
145 
146 void diagnosticHandler(const DiagnosticInfo &DI) {
147   std::string ErrStorage;
148   raw_string_ostream OS(ErrStorage);
149   DiagnosticPrinterRawOStream DP(OS);
150   DI.print(DP);
151 
152   switch (DI.getSeverity()) {
153   case DS_Error:
154     WithColor::error(errs(), Executable) << ErrStorage << "\n";
155     break;
156   case DS_Warning:
157     WithColor::warning(errs(), Executable) << ErrStorage << "\n";
158     break;
159   case DS_Note:
160     WithColor::note(errs(), Executable) << ErrStorage << "\n";
161     break;
162   case DS_Remark:
163     WithColor::remark(errs()) << ErrStorage << "\n";
164     break;
165   }
166 }
167 
168 Expected<StringRef> createTempFile(const ArgList &Args, const Twine &Prefix,
169                                    StringRef Extension) {
170   SmallString<128> OutputFile;
171   if (Args.hasArg(OPT_save_temps)) {
172     (Prefix + "." + Extension).toNullTerminatedStringRef(OutputFile);
173   } else {
174     if (std::error_code EC =
175             sys::fs::createTemporaryFile(Prefix, Extension, OutputFile))
176       return createFileError(OutputFile, EC);
177   }
178 
179   TempFiles.emplace_back(std::move(OutputFile));
180   return TempFiles.back();
181 }
182 
183 Expected<std::string> findProgram(const ArgList &Args, StringRef Name,
184                                   ArrayRef<StringRef> Paths) {
185   if (Args.hasArg(OPT_dry_run))
186     return Name.str();
187   ErrorOr<std::string> Path = sys::findProgramByName(Name, Paths);
188   if (!Path)
189     Path = sys::findProgramByName(Name);
190   if (!Path)
191     return createStringError(Path.getError(),
192                              "Unable to find '" + Name + "' in path");
193   return *Path;
194 }
195 
196 std::optional<std::string> findFile(StringRef Dir, StringRef Root,
197                                     const Twine &Name) {
198   SmallString<128> Path;
199   if (Dir.starts_with("="))
200     sys::path::append(Path, Root, Dir.substr(1), Name);
201   else
202     sys::path::append(Path, Dir, Name);
203 
204   if (sys::fs::exists(Path))
205     return static_cast<std::string>(Path);
206   return std::nullopt;
207 }
208 
209 std::optional<std::string>
210 findFromSearchPaths(StringRef Name, StringRef Root,
211                     ArrayRef<StringRef> SearchPaths) {
212   for (StringRef Dir : SearchPaths)
213     if (std::optional<std::string> File = findFile(Dir, Root, Name))
214       return File;
215   return std::nullopt;
216 }
217 
218 std::optional<std::string>
219 searchLibraryBaseName(StringRef Name, StringRef Root,
220                       ArrayRef<StringRef> SearchPaths) {
221   for (StringRef Dir : SearchPaths)
222     if (std::optional<std::string> File =
223             findFile(Dir, Root, "lib" + Name + ".a"))
224       return File;
225   return std::nullopt;
226 }
227 
228 /// Search for static libraries in the linker's library path given input like
229 /// `-lfoo` or `-l:libfoo.a`.
230 std::optional<std::string> searchLibrary(StringRef Input, StringRef Root,
231                                          ArrayRef<StringRef> SearchPaths) {
232   if (Input.starts_with(":"))
233     return findFromSearchPaths(Input.drop_front(), Root, SearchPaths);
234   return searchLibraryBaseName(Input, Root, SearchPaths);
235 }
236 
237 void printCommands(ArrayRef<StringRef> CmdArgs) {
238   if (CmdArgs.empty())
239     return;
240 
241   errs() << " \"" << CmdArgs.front() << "\" ";
242   errs() << join(std::next(CmdArgs.begin()), CmdArgs.end(), " ") << "\n";
243 }
244 
245 /// A minimum symbol interface that provides the necessary information to
246 /// extract archive members and resolve LTO symbols.
247 struct Symbol {
248   enum Flags {
249     None = 0,
250     Undefined = 1 << 0,
251     Weak = 1 << 1,
252   };
253 
254   Symbol() : File(), Flags(None), UsedInRegularObj(false) {}
255   Symbol(Symbol::Flags Flags) : File(), Flags(Flags), UsedInRegularObj(true) {}
256 
257   Symbol(MemoryBufferRef File, const irsymtab::Reader::SymbolRef Sym)
258       : File(File), Flags(0), UsedInRegularObj(false) {
259     if (Sym.isUndefined())
260       Flags |= Undefined;
261     if (Sym.isWeak())
262       Flags |= Weak;
263   }
264 
265   Symbol(MemoryBufferRef File, const SymbolRef Sym)
266       : File(File), Flags(0), UsedInRegularObj(false) {
267     auto FlagsOrErr = Sym.getFlags();
268     if (!FlagsOrErr)
269       reportError(FlagsOrErr.takeError());
270     if (*FlagsOrErr & SymbolRef::SF_Undefined)
271       Flags |= Undefined;
272     if (*FlagsOrErr & SymbolRef::SF_Weak)
273       Flags |= Weak;
274 
275     auto NameOrErr = Sym.getName();
276     if (!NameOrErr)
277       reportError(NameOrErr.takeError());
278   }
279 
280   bool isWeak() const { return Flags & Weak; }
281   bool isUndefined() const { return Flags & Undefined; }
282 
283   MemoryBufferRef File;
284   uint32_t Flags;
285   bool UsedInRegularObj;
286 };
287 
288 Expected<StringRef> runPTXAs(StringRef File, const ArgList &Args) {
289   std::string CudaPath = Args.getLastArgValue(OPT_cuda_path_EQ).str();
290   std::string GivenPath = Args.getLastArgValue(OPT_ptxas_path_EQ).str();
291   Expected<std::string> PTXAsPath =
292       findProgram(Args, "ptxas", {CudaPath + "/bin", GivenPath});
293   if (!PTXAsPath)
294     return PTXAsPath.takeError();
295   if (!Args.hasArg(OPT_arch))
296     return createStringError(
297         "must pass in an explicit nvptx64 gpu architecture to 'ptxas'");
298 
299   auto TempFileOrErr = createTempFile(
300       Args, sys::path::stem(Args.getLastArgValue(OPT_o, "a.out")), "cubin");
301   if (!TempFileOrErr)
302     return TempFileOrErr.takeError();
303 
304   SmallVector<StringRef> AssemblerArgs({*PTXAsPath, "-m64", "-c", File});
305   if (Args.hasArg(OPT_verbose))
306     AssemblerArgs.push_back("-v");
307   if (Args.hasArg(OPT_g)) {
308     if (Args.hasArg(OPT_O))
309       WithColor::warning(errs(), Executable)
310           << "Optimized debugging not supported, overriding to '-O0'\n";
311     AssemblerArgs.push_back("-O0");
312   } else
313     AssemblerArgs.push_back(
314         Args.MakeArgString("-O" + Args.getLastArgValue(OPT_O, "3")));
315   AssemblerArgs.append({"-arch", Args.getLastArgValue(OPT_arch)});
316   AssemblerArgs.append({"-o", *TempFileOrErr});
317 
318   if (Args.hasArg(OPT_dry_run) || Args.hasArg(OPT_verbose))
319     printCommands(AssemblerArgs);
320   if (Args.hasArg(OPT_dry_run))
321     return Args.MakeArgString(*TempFileOrErr);
322   if (sys::ExecuteAndWait(*PTXAsPath, AssemblerArgs))
323     return createStringError("'" + sys::path::filename(*PTXAsPath) + "'" +
324                              " failed");
325   return Args.MakeArgString(*TempFileOrErr);
326 }
327 
328 Expected<std::unique_ptr<lto::LTO>> createLTO(const ArgList &Args) {
329   const llvm::Triple Triple("nvptx64-nvidia-cuda");
330   lto::Config Conf;
331   lto::ThinBackend Backend;
332   unsigned Jobs = 0;
333   if (auto *Arg = Args.getLastArg(OPT_jobs))
334     if (!to_integer(Arg->getValue(), Jobs) || Jobs == 0)
335       reportError(createStringError("%s: expected a positive integer, got '%s'",
336                                     Arg->getSpelling().data(),
337                                     Arg->getValue()));
338   Backend =
339       lto::createInProcessThinBackend(heavyweight_hardware_concurrency(Jobs));
340 
341   Conf.CPU = Args.getLastArgValue(OPT_arch);
342   Conf.Options = codegen::InitTargetOptionsFromCodeGenFlags(Triple);
343 
344   Conf.RemarksFilename = RemarksFilename;
345   Conf.RemarksPasses = RemarksPasses;
346   Conf.RemarksWithHotness = RemarksWithHotness;
347   Conf.RemarksHotnessThreshold = RemarksHotnessThreshold;
348   Conf.RemarksFormat = RemarksFormat;
349 
350   Conf.MAttrs = llvm::codegen::getMAttrs();
351   std::optional<CodeGenOptLevel> CGOptLevelOrNone =
352       CodeGenOpt::parseLevel(Args.getLastArgValue(OPT_O, "2")[0]);
353   assert(CGOptLevelOrNone && "Invalid optimization level");
354   Conf.CGOptLevel = *CGOptLevelOrNone;
355   Conf.OptLevel = Args.getLastArgValue(OPT_O, "2")[0] - '0';
356   Conf.DefaultTriple = Triple.getTriple();
357 
358   Conf.OptPipeline = Args.getLastArgValue(OPT_lto_newpm_passes, "");
359   Conf.PassPlugins = PassPlugins;
360   Conf.DebugPassManager = Args.hasArg(OPT_lto_debug_pass_manager);
361 
362   Conf.DiagHandler = diagnosticHandler;
363   Conf.CGFileType = CodeGenFileType::AssemblyFile;
364 
365   if (Args.hasArg(OPT_lto_emit_llvm)) {
366     Conf.PreCodeGenModuleHook = [&](size_t, const Module &M) {
367       std::error_code EC;
368       raw_fd_ostream LinkedBitcode(Args.getLastArgValue(OPT_o, "a.out"), EC);
369       if (EC)
370         reportError(errorCodeToError(EC));
371       WriteBitcodeToFile(M, LinkedBitcode);
372       return false;
373     };
374   }
375 
376   if (Args.hasArg(OPT_save_temps))
377     if (Error Err = Conf.addSaveTemps(
378             (Args.getLastArgValue(OPT_o, "a.out") + ".").str()))
379       return Err;
380 
381   unsigned Partitions = 1;
382   if (auto *Arg = Args.getLastArg(OPT_lto_partitions))
383     if (!to_integer(Arg->getValue(), Partitions) || Partitions == 0)
384       reportError(createStringError("%s: expected a positive integer, got '%s'",
385                                     Arg->getSpelling().data(),
386                                     Arg->getValue()));
387   lto::LTO::LTOKind Kind = Args.hasArg(OPT_thinlto) ? lto::LTO::LTOK_UnifiedThin
388                                                     : lto::LTO::LTOK_Default;
389   return std::make_unique<lto::LTO>(std::move(Conf), Backend, Partitions, Kind);
390 }
391 
392 Expected<bool> getSymbolsFromBitcode(MemoryBufferRef Buffer,
393                                      StringMap<Symbol> &SymTab, bool IsLazy) {
394   Expected<IRSymtabFile> IRSymtabOrErr = readIRSymtab(Buffer);
395   if (!IRSymtabOrErr)
396     return IRSymtabOrErr.takeError();
397   bool Extracted = !IsLazy;
398   StringMap<Symbol> PendingSymbols;
399   for (unsigned I = 0; I != IRSymtabOrErr->Mods.size(); ++I) {
400     for (const auto &IRSym : IRSymtabOrErr->TheReader.module_symbols(I)) {
401       if (IRSym.isFormatSpecific() || !IRSym.isGlobal())
402         continue;
403 
404       Symbol &OldSym = !SymTab.count(IRSym.getName()) && IsLazy
405                            ? PendingSymbols[IRSym.getName()]
406                            : SymTab[IRSym.getName()];
407       Symbol Sym = Symbol(Buffer, IRSym);
408       if (OldSym.File.getBuffer().empty())
409         OldSym = Sym;
410 
411       bool ResolvesReference =
412           !Sym.isUndefined() &&
413           (OldSym.isUndefined() || (OldSym.isWeak() && !Sym.isWeak())) &&
414           !(OldSym.isWeak() && OldSym.isUndefined() && IsLazy);
415       Extracted |= ResolvesReference;
416 
417       Sym.UsedInRegularObj = OldSym.UsedInRegularObj;
418       if (ResolvesReference)
419         OldSym = Sym;
420     }
421   }
422   if (Extracted)
423     for (const auto &[Name, Symbol] : PendingSymbols)
424       SymTab[Name] = Symbol;
425   return Extracted;
426 }
427 
428 Expected<bool> getSymbolsFromObject(ObjectFile &ObjFile,
429                                     StringMap<Symbol> &SymTab, bool IsLazy) {
430   bool Extracted = !IsLazy;
431   StringMap<Symbol> PendingSymbols;
432   for (SymbolRef ObjSym : ObjFile.symbols()) {
433     auto NameOrErr = ObjSym.getName();
434     if (!NameOrErr)
435       return NameOrErr.takeError();
436 
437     Symbol &OldSym = !SymTab.count(*NameOrErr) && IsLazy
438                          ? PendingSymbols[*NameOrErr]
439                          : SymTab[*NameOrErr];
440     Symbol Sym = Symbol(ObjFile.getMemoryBufferRef(), ObjSym);
441     if (OldSym.File.getBuffer().empty())
442       OldSym = Sym;
443 
444     bool ResolvesReference = OldSym.isUndefined() && !Sym.isUndefined() &&
445                              (!OldSym.isWeak() || !IsLazy);
446     Extracted |= ResolvesReference;
447 
448     if (ResolvesReference)
449       OldSym = Sym;
450     OldSym.UsedInRegularObj = true;
451   }
452   if (Extracted)
453     for (const auto &[Name, Symbol] : PendingSymbols)
454       SymTab[Name] = Symbol;
455   return Extracted;
456 }
457 
458 Expected<bool> getSymbols(MemoryBufferRef Buffer, StringMap<Symbol> &SymTab,
459                           bool IsLazy) {
460   switch (identify_magic(Buffer.getBuffer())) {
461   case file_magic::bitcode: {
462     return getSymbolsFromBitcode(Buffer, SymTab, IsLazy);
463   }
464   case file_magic::elf_relocatable: {
465     Expected<std::unique_ptr<ObjectFile>> ObjFile =
466         ObjectFile::createObjectFile(Buffer);
467     if (!ObjFile)
468       return ObjFile.takeError();
469     return getSymbolsFromObject(**ObjFile, SymTab, IsLazy);
470   }
471   default:
472     return createStringError("Unsupported file type");
473   }
474 }
475 
476 Expected<SmallVector<StringRef>> getInput(const ArgList &Args) {
477   SmallVector<StringRef> LibraryPaths;
478   for (const opt::Arg *Arg : Args.filtered(OPT_library_path))
479     LibraryPaths.push_back(Arg->getValue());
480 
481   bool WholeArchive = false;
482   SmallVector<std::pair<std::unique_ptr<MemoryBuffer>, bool>> InputFiles;
483   for (const opt::Arg *Arg : Args.filtered(
484            OPT_INPUT, OPT_library, OPT_whole_archive, OPT_no_whole_archive)) {
485     if (Arg->getOption().matches(OPT_whole_archive) ||
486         Arg->getOption().matches(OPT_no_whole_archive)) {
487       WholeArchive = Arg->getOption().matches(OPT_whole_archive);
488       continue;
489     }
490 
491     std::optional<std::string> Filename =
492         Arg->getOption().matches(OPT_library)
493             ? searchLibrary(Arg->getValue(), /*Root=*/"", LibraryPaths)
494             : std::string(Arg->getValue());
495 
496     if (!Filename && Arg->getOption().matches(OPT_library))
497       return createStringError("unable to find library -l%s", Arg->getValue());
498 
499     if (!Filename || !sys::fs::exists(*Filename) ||
500         sys::fs::is_directory(*Filename))
501       continue;
502 
503     ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr =
504         MemoryBuffer::getFileOrSTDIN(*Filename);
505     if (std::error_code EC = BufferOrErr.getError())
506       return createFileError(*Filename, EC);
507 
508     MemoryBufferRef Buffer = **BufferOrErr;
509     switch (identify_magic(Buffer.getBuffer())) {
510     case file_magic::bitcode:
511     case file_magic::elf_relocatable:
512       InputFiles.emplace_back(std::move(*BufferOrErr), /*IsLazy=*/false);
513       break;
514     case file_magic::archive: {
515       Expected<std::unique_ptr<object::Archive>> LibFile =
516           object::Archive::create(Buffer);
517       if (!LibFile)
518         return LibFile.takeError();
519       Error Err = Error::success();
520       for (auto Child : (*LibFile)->children(Err)) {
521         auto ChildBufferOrErr = Child.getMemoryBufferRef();
522         if (!ChildBufferOrErr)
523           return ChildBufferOrErr.takeError();
524         std::unique_ptr<MemoryBuffer> ChildBuffer =
525             MemoryBuffer::getMemBufferCopy(
526                 ChildBufferOrErr->getBuffer(),
527                 ChildBufferOrErr->getBufferIdentifier());
528         InputFiles.emplace_back(std::move(ChildBuffer), !WholeArchive);
529       }
530       if (Err)
531         return Err;
532       break;
533     }
534     default:
535       return createStringError("Unsupported file type");
536     }
537   }
538 
539   bool Extracted = true;
540   StringMap<Symbol> SymTab;
541   for (auto &Sym : Args.getAllArgValues(OPT_u))
542     SymTab[Sym] = Symbol(Symbol::Undefined);
543   SmallVector<std::unique_ptr<MemoryBuffer>> LinkerInput;
544   while (Extracted) {
545     Extracted = false;
546     for (auto &[Input, IsLazy] : InputFiles) {
547       if (!Input)
548         continue;
549 
550       // Archive members only extract if they define needed symbols. We will
551       // re-scan all the inputs if any files were extracted for the link job.
552       Expected<bool> ExtractOrErr = getSymbols(*Input, SymTab, IsLazy);
553       if (!ExtractOrErr)
554         return ExtractOrErr.takeError();
555 
556       Extracted |= *ExtractOrErr;
557       if (!*ExtractOrErr)
558         continue;
559 
560       LinkerInput.emplace_back(std::move(Input));
561     }
562   }
563   InputFiles.clear();
564 
565   // Extract any bitcode files to be passed to the LTO pipeline.
566   SmallVector<std::unique_ptr<MemoryBuffer>> BitcodeFiles;
567   for (auto &Input : LinkerInput)
568     if (identify_magic(Input->getBuffer()) == file_magic::bitcode)
569       BitcodeFiles.emplace_back(std::move(Input));
570   erase_if(LinkerInput, [](const auto &F) { return !F; });
571 
572   // Run the LTO pipeline on the extracted inputs.
573   SmallVector<StringRef> Files;
574   if (!BitcodeFiles.empty()) {
575     auto LTOBackendOrErr = createLTO(Args);
576     if (!LTOBackendOrErr)
577       return LTOBackendOrErr.takeError();
578     lto::LTO &LTOBackend = **LTOBackendOrErr;
579     for (auto &BitcodeFile : BitcodeFiles) {
580       Expected<std::unique_ptr<lto::InputFile>> BitcodeFileOrErr =
581           lto::InputFile::create(*BitcodeFile);
582       if (!BitcodeFileOrErr)
583         return BitcodeFileOrErr.takeError();
584 
585       const auto Symbols = (*BitcodeFileOrErr)->symbols();
586       SmallVector<lto::SymbolResolution, 16> Resolutions(Symbols.size());
587       size_t Idx = 0;
588       for (auto &Sym : Symbols) {
589         lto::SymbolResolution &Res = Resolutions[Idx++];
590         Symbol ObjSym = SymTab[Sym.getName()];
591         // We will use this as the prevailing symbol in LTO if it is not
592         // undefined and it is from the file that contained the canonical
593         // definition.
594         Res.Prevailing = !Sym.isUndefined() && ObjSym.File == *BitcodeFile;
595 
596         // We need LTO to preseve the following global symbols:
597         // 1) All symbols during a relocatable link.
598         // 2) Symbols used in regular objects.
599         // 3) Prevailing symbols that are needed visible to the gpu runtime.
600         Res.VisibleToRegularObj =
601             Args.hasArg(OPT_relocatable) || ObjSym.UsedInRegularObj ||
602             (Res.Prevailing &&
603              (Sym.getVisibility() != GlobalValue::HiddenVisibility &&
604               !Sym.canBeOmittedFromSymbolTable()));
605 
606         // Identify symbols that must be exported dynamically and can be
607         // referenced by other files, (i.e. the runtime).
608         Res.ExportDynamic =
609             Sym.getVisibility() != GlobalValue::HiddenVisibility &&
610             !Sym.canBeOmittedFromSymbolTable();
611 
612         // The NVIDIA platform does not support any symbol preemption.
613         Res.FinalDefinitionInLinkageUnit = true;
614 
615         // We do not support linker redefined symbols (e.g. --wrap) for device
616         // image linking, so the symbols will not be changed after LTO.
617         Res.LinkerRedefined = false;
618       }
619 
620       // Add the bitcode file with its resolved symbols to the LTO job.
621       if (Error Err = LTOBackend.add(std::move(*BitcodeFileOrErr), Resolutions))
622         return Err;
623     }
624 
625     // Run the LTO job to compile the bitcode.
626     size_t MaxTasks = LTOBackend.getMaxTasks();
627     SmallVector<StringRef> LTOFiles(MaxTasks);
628     auto AddStream =
629         [&](size_t Task,
630             const Twine &ModuleName) -> std::unique_ptr<CachedFileStream> {
631       int FD = -1;
632       auto &TempFile = LTOFiles[Task];
633       if (Args.hasArg(OPT_lto_emit_asm))
634         TempFile = Args.getLastArgValue(OPT_o, "a.out");
635       else {
636         auto TempFileOrErr = createTempFile(
637             Args, sys::path::stem(Args.getLastArgValue(OPT_o, "a.out")), "s");
638         if (!TempFileOrErr)
639           reportError(TempFileOrErr.takeError());
640         TempFile = Args.MakeArgString(*TempFileOrErr);
641       }
642       if (std::error_code EC = sys::fs::openFileForWrite(TempFile, FD))
643         reportError(errorCodeToError(EC));
644       return std::make_unique<CachedFileStream>(
645           std::make_unique<raw_fd_ostream>(FD, true));
646     };
647 
648     if (Error Err = LTOBackend.run(AddStream))
649       return Err;
650 
651     if (Args.hasArg(OPT_lto_emit_llvm) || Args.hasArg(OPT_lto_emit_asm))
652       return Files;
653 
654     for (StringRef LTOFile : LTOFiles) {
655       auto FileOrErr = runPTXAs(LTOFile, Args);
656       if (!FileOrErr)
657         return FileOrErr.takeError();
658       Files.emplace_back(*FileOrErr);
659     }
660   }
661 
662   // Create a copy for each file to a new file ending in `.cubin`. The 'nvlink'
663   // linker requires all NVPTX inputs to have this extension for some reason.
664   // We don't use a symbolic link because it's not supported on Windows and some
665   // of this input files could be extracted from an archive.
666   for (auto &Input : LinkerInput) {
667     auto TempFileOrErr = createTempFile(
668         Args, sys::path::stem(Input->getBufferIdentifier()), "cubin");
669     if (!TempFileOrErr)
670       return TempFileOrErr.takeError();
671     Expected<std::unique_ptr<FileOutputBuffer>> OutputOrErr =
672         FileOutputBuffer::create(*TempFileOrErr, Input->getBuffer().size());
673     if (!OutputOrErr)
674       return OutputOrErr.takeError();
675     std::unique_ptr<FileOutputBuffer> Output = std::move(*OutputOrErr);
676     copy(Input->getBuffer(), Output->getBufferStart());
677     if (Error E = Output->commit())
678       return E;
679     Files.emplace_back(Args.MakeArgString(*TempFileOrErr));
680   }
681 
682   return Files;
683 }
684 
685 Error runNVLink(ArrayRef<StringRef> Files, const ArgList &Args) {
686   if (Args.hasArg(OPT_lto_emit_asm) || Args.hasArg(OPT_lto_emit_llvm))
687     return Error::success();
688 
689   std::string CudaPath = Args.getLastArgValue(OPT_cuda_path_EQ).str();
690   Expected<std::string> NVLinkPath =
691       findProgram(Args, "nvlink", {CudaPath + "/bin"});
692   if (!NVLinkPath)
693     return NVLinkPath.takeError();
694 
695   if (!Args.hasArg(OPT_arch))
696     return createStringError(
697         "must pass in an explicit nvptx64 gpu architecture to 'nvlink'");
698 
699   ArgStringList NewLinkerArgs;
700   for (const opt::Arg *Arg : Args) {
701     // Do not forward arguments only intended for the linker wrapper.
702     if (Arg->getOption().hasFlag(WrapperOnlyOption))
703       continue;
704 
705     // Do not forward any inputs that we have processed.
706     if (Arg->getOption().matches(OPT_INPUT) ||
707         Arg->getOption().matches(OPT_library))
708       continue;
709 
710     Arg->render(Args, NewLinkerArgs);
711   }
712 
713   transform(Files, std::back_inserter(NewLinkerArgs),
714             [&](StringRef Arg) { return Args.MakeArgString(Arg); });
715 
716   SmallVector<StringRef> LinkerArgs({*NVLinkPath});
717   if (!Args.hasArg(OPT_o))
718     LinkerArgs.append({"-o", "a.out"});
719   for (StringRef Arg : NewLinkerArgs)
720     LinkerArgs.push_back(Arg);
721 
722   if (Args.hasArg(OPT_dry_run) || Args.hasArg(OPT_verbose))
723     printCommands(LinkerArgs);
724   if (Args.hasArg(OPT_dry_run))
725     return Error::success();
726   if (sys::ExecuteAndWait(*NVLinkPath, LinkerArgs))
727     return createStringError("'" + sys::path::filename(*NVLinkPath) + "'" +
728                              " failed");
729   return Error::success();
730 }
731 
732 } // namespace
733 
734 int main(int argc, char **argv) {
735   InitLLVM X(argc, argv);
736   InitializeAllTargetInfos();
737   InitializeAllTargets();
738   InitializeAllTargetMCs();
739   InitializeAllAsmParsers();
740   InitializeAllAsmPrinters();
741 
742   Executable = argv[0];
743   sys::PrintStackTraceOnErrorSignal(argv[0]);
744 
745   const OptTable &Tbl = getOptTable();
746   BumpPtrAllocator Alloc;
747   StringSaver Saver(Alloc);
748   auto Args = Tbl.parseArgs(argc, argv, OPT_INVALID, Saver, [&](StringRef Err) {
749     reportError(createStringError(inconvertibleErrorCode(), Err));
750   });
751 
752   if (Args.hasArg(OPT_help) || Args.hasArg(OPT_help_hidden)) {
753     Tbl.printHelp(
754         outs(), "clang-nvlink-wrapper [options] <options to passed to nvlink>",
755         "A utility that wraps around the NVIDIA 'nvlink' linker.\n"
756         "This enables static linking and LTO handling for NVPTX targets.",
757         Args.hasArg(OPT_help_hidden), Args.hasArg(OPT_help_hidden));
758     return EXIT_SUCCESS;
759   }
760 
761   if (Args.hasArg(OPT_version))
762     printVersion(outs());
763 
764   // This forwards '-mllvm' arguments to LLVM if present.
765   SmallVector<const char *> NewArgv = {argv[0]};
766   for (const opt::Arg *Arg : Args.filtered(OPT_mllvm))
767     NewArgv.push_back(Arg->getValue());
768   for (const opt::Arg *Arg : Args.filtered(OPT_plugin_opt))
769     NewArgv.push_back(Arg->getValue());
770   cl::ParseCommandLineOptions(NewArgv.size(), &NewArgv[0]);
771 
772   // Get the input files to pass to 'nvlink'.
773   auto FilesOrErr = getInput(Args);
774   if (!FilesOrErr)
775     reportError(FilesOrErr.takeError());
776 
777   // Run 'nvlink' on the generated inputs.
778   if (Error Err = runNVLink(*FilesOrErr, Args))
779     reportError(std::move(Err));
780 
781   // Remove the temporary files created.
782   if (!Args.hasArg(OPT_save_temps))
783     for (const auto &TempFile : TempFiles)
784       if (std::error_code EC = sys::fs::remove(TempFile))
785         reportError(createFileError(TempFile, EC));
786 
787   return EXIT_SUCCESS;
788 }
789