xref: /llvm-project/llvm/lib/Object/OffloadBinary.cpp (revision e9c8106a90d49e75bac87341ade57c6049357a97)
1 //===- Offloading.cpp - Utilities for handling offloading code  -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Object/OffloadBinary.h"
10 
11 #include "llvm/ADT/StringSwitch.h"
12 #include "llvm/BinaryFormat/Magic.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/IRReader/IRReader.h"
16 #include "llvm/MC/StringTableBuilder.h"
17 #include "llvm/Object/Archive.h"
18 #include "llvm/Object/Binary.h"
19 #include "llvm/Object/ELFObjectFile.h"
20 #include "llvm/Object/Error.h"
21 #include "llvm/Object/IRObjectFile.h"
22 #include "llvm/Object/ObjectFile.h"
23 #include "llvm/Support/Alignment.h"
24 #include "llvm/Support/SourceMgr.h"
25 
26 using namespace llvm;
27 using namespace llvm::object;
28 
29 namespace {
30 
31 /// Attempts to extract all the embedded device images contained inside the
32 /// buffer \p Contents. The buffer is expected to contain a valid offloading
33 /// binary format.
34 Error extractOffloadFiles(MemoryBufferRef Contents,
35                           SmallVectorImpl<OffloadFile> &Binaries) {
36   uint64_t Offset = 0;
37   // There could be multiple offloading binaries stored at this section.
38   while (Offset < Contents.getBuffer().size()) {
39     std::unique_ptr<MemoryBuffer> Buffer =
40         MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",
41                                    /*RequiresNullTerminator*/ false);
42     if (!isAddrAligned(Align(OffloadBinary::getAlignment()),
43                        Buffer->getBufferStart()))
44       Buffer = MemoryBuffer::getMemBufferCopy(Buffer->getBuffer(),
45                                               Buffer->getBufferIdentifier());
46     auto BinaryOrErr = OffloadBinary::create(*Buffer);
47     if (!BinaryOrErr)
48       return BinaryOrErr.takeError();
49     OffloadBinary &Binary = **BinaryOrErr;
50 
51     // Create a new owned binary with a copy of the original memory.
52     std::unique_ptr<MemoryBuffer> BufferCopy = MemoryBuffer::getMemBufferCopy(
53         Binary.getData().take_front(Binary.getSize()),
54         Contents.getBufferIdentifier());
55     auto NewBinaryOrErr = OffloadBinary::create(*BufferCopy);
56     if (!NewBinaryOrErr)
57       return NewBinaryOrErr.takeError();
58     Binaries.emplace_back(std::move(*NewBinaryOrErr), std::move(BufferCopy));
59 
60     Offset += Binary.getSize();
61   }
62 
63   return Error::success();
64 }
65 
66 // Extract offloading binaries from an Object file \p Obj.
67 Error extractFromObject(const ObjectFile &Obj,
68                         SmallVectorImpl<OffloadFile> &Binaries) {
69   assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type");
70 
71   for (SectionRef Sec : Obj.sections()) {
72     // ELF files contain a section with the LLVM_OFFLOADING type.
73     if (Obj.isELF() &&
74         static_cast<ELFSectionRef>(Sec).getType() != ELF::SHT_LLVM_OFFLOADING)
75       continue;
76 
77     // COFF has no section types so we rely on the name of the section.
78     if (Obj.isCOFF()) {
79       Expected<StringRef> NameOrErr = Sec.getName();
80       if (!NameOrErr)
81         return NameOrErr.takeError();
82 
83       if (!NameOrErr->starts_with(".llvm.offloading"))
84         continue;
85     }
86 
87     Expected<StringRef> Buffer = Sec.getContents();
88     if (!Buffer)
89       return Buffer.takeError();
90 
91     MemoryBufferRef Contents(*Buffer, Obj.getFileName());
92     if (Error Err = extractOffloadFiles(Contents, Binaries))
93       return Err;
94   }
95 
96   return Error::success();
97 }
98 
99 Error extractFromBitcode(MemoryBufferRef Buffer,
100                          SmallVectorImpl<OffloadFile> &Binaries) {
101   LLVMContext Context;
102   SMDiagnostic Err;
103   std::unique_ptr<Module> M = getLazyIRModule(
104       MemoryBuffer::getMemBuffer(Buffer, /*RequiresNullTerminator=*/false), Err,
105       Context);
106   if (!M)
107     return createStringError(inconvertibleErrorCode(),
108                              "Failed to create module");
109 
110   // Extract offloading data from globals referenced by the
111   // `llvm.embedded.object` metadata with the `.llvm.offloading` section.
112   auto *MD = M->getNamedMetadata("llvm.embedded.objects");
113   if (!MD)
114     return Error::success();
115 
116   for (const MDNode *Op : MD->operands()) {
117     if (Op->getNumOperands() < 2)
118       continue;
119 
120     MDString *SectionID = dyn_cast<MDString>(Op->getOperand(1));
121     if (!SectionID || SectionID->getString() != ".llvm.offloading")
122       continue;
123 
124     GlobalVariable *GV =
125         mdconst::dyn_extract_or_null<GlobalVariable>(Op->getOperand(0));
126     if (!GV)
127       continue;
128 
129     auto *CDS = dyn_cast<ConstantDataSequential>(GV->getInitializer());
130     if (!CDS)
131       continue;
132 
133     MemoryBufferRef Contents(CDS->getAsString(), M->getName());
134     if (Error Err = extractOffloadFiles(Contents, Binaries))
135       return Err;
136   }
137 
138   return Error::success();
139 }
140 
141 Error extractFromArchive(const Archive &Library,
142                          SmallVectorImpl<OffloadFile> &Binaries) {
143   // Try to extract device code from each file stored in the static archive.
144   Error Err = Error::success();
145   for (auto Child : Library.children(Err)) {
146     auto ChildBufferOrErr = Child.getMemoryBufferRef();
147     if (!ChildBufferOrErr)
148       return ChildBufferOrErr.takeError();
149     std::unique_ptr<MemoryBuffer> ChildBuffer =
150         MemoryBuffer::getMemBuffer(*ChildBufferOrErr, false);
151 
152     // Check if the buffer has the required alignment.
153     if (!isAddrAligned(Align(OffloadBinary::getAlignment()),
154                        ChildBuffer->getBufferStart()))
155       ChildBuffer = MemoryBuffer::getMemBufferCopy(
156           ChildBufferOrErr->getBuffer(),
157           ChildBufferOrErr->getBufferIdentifier());
158 
159     if (Error Err = extractOffloadBinaries(*ChildBuffer, Binaries))
160       return Err;
161   }
162 
163   if (Err)
164     return Err;
165   return Error::success();
166 }
167 
168 } // namespace
169 
170 Expected<std::unique_ptr<OffloadBinary>>
171 OffloadBinary::create(MemoryBufferRef Buf) {
172   if (Buf.getBufferSize() < sizeof(Header) + sizeof(Entry))
173     return errorCodeToError(object_error::parse_failed);
174 
175   // Check for 0x10FF1OAD magic bytes.
176   if (identify_magic(Buf.getBuffer()) != file_magic::offload_binary)
177     return errorCodeToError(object_error::parse_failed);
178 
179   // Make sure that the data has sufficient alignment.
180   if (!isAddrAligned(Align(getAlignment()), Buf.getBufferStart()))
181     return errorCodeToError(object_error::parse_failed);
182 
183   const char *Start = Buf.getBufferStart();
184   const Header *TheHeader = reinterpret_cast<const Header *>(Start);
185   if (TheHeader->Version != OffloadBinary::Version)
186     return errorCodeToError(object_error::parse_failed);
187 
188   if (TheHeader->Size > Buf.getBufferSize() ||
189       TheHeader->Size < sizeof(Entry) || TheHeader->Size < sizeof(Header))
190     return errorCodeToError(object_error::unexpected_eof);
191 
192   if (TheHeader->EntryOffset > TheHeader->Size - sizeof(Entry) ||
193       TheHeader->EntrySize > TheHeader->Size - sizeof(Header))
194     return errorCodeToError(object_error::unexpected_eof);
195 
196   const Entry *TheEntry =
197       reinterpret_cast<const Entry *>(&Start[TheHeader->EntryOffset]);
198 
199   if (TheEntry->ImageOffset > Buf.getBufferSize() ||
200       TheEntry->StringOffset > Buf.getBufferSize())
201     return errorCodeToError(object_error::unexpected_eof);
202 
203   return std::unique_ptr<OffloadBinary>(
204       new OffloadBinary(Buf, TheHeader, TheEntry));
205 }
206 
207 SmallString<0> OffloadBinary::write(const OffloadingImage &OffloadingData) {
208   // Create a null-terminated string table with all the used strings.
209   StringTableBuilder StrTab(StringTableBuilder::ELF);
210   for (auto &KeyAndValue : OffloadingData.StringData) {
211     StrTab.add(KeyAndValue.first);
212     StrTab.add(KeyAndValue.second);
213   }
214   StrTab.finalize();
215 
216   uint64_t StringEntrySize =
217       sizeof(StringEntry) * OffloadingData.StringData.size();
218 
219   // Make sure the image we're wrapping around is aligned as well.
220   uint64_t BinaryDataSize = alignTo(sizeof(Header) + sizeof(Entry) +
221                                         StringEntrySize + StrTab.getSize(),
222                                     getAlignment());
223 
224   // Create the header and fill in the offsets. The entry will be directly
225   // placed after the header in memory. Align the size to the alignment of the
226   // header so this can be placed contiguously in a single section.
227   Header TheHeader;
228   TheHeader.Size = alignTo(
229       BinaryDataSize + OffloadingData.Image->getBufferSize(), getAlignment());
230   TheHeader.EntryOffset = sizeof(Header);
231   TheHeader.EntrySize = sizeof(Entry);
232 
233   // Create the entry using the string table offsets. The string table will be
234   // placed directly after the entry in memory, and the image after that.
235   Entry TheEntry;
236   TheEntry.TheImageKind = OffloadingData.TheImageKind;
237   TheEntry.TheOffloadKind = OffloadingData.TheOffloadKind;
238   TheEntry.Flags = OffloadingData.Flags;
239   TheEntry.StringOffset = sizeof(Header) + sizeof(Entry);
240   TheEntry.NumStrings = OffloadingData.StringData.size();
241 
242   TheEntry.ImageOffset = BinaryDataSize;
243   TheEntry.ImageSize = OffloadingData.Image->getBufferSize();
244 
245   SmallString<0> Data;
246   Data.reserve(TheHeader.Size);
247   raw_svector_ostream OS(Data);
248   OS << StringRef(reinterpret_cast<char *>(&TheHeader), sizeof(Header));
249   OS << StringRef(reinterpret_cast<char *>(&TheEntry), sizeof(Entry));
250   for (auto &KeyAndValue : OffloadingData.StringData) {
251     uint64_t Offset = sizeof(Header) + sizeof(Entry) + StringEntrySize;
252     StringEntry Map{Offset + StrTab.getOffset(KeyAndValue.first),
253                     Offset + StrTab.getOffset(KeyAndValue.second)};
254     OS << StringRef(reinterpret_cast<char *>(&Map), sizeof(StringEntry));
255   }
256   StrTab.write(OS);
257   // Add padding to required image alignment.
258   OS.write_zeros(TheEntry.ImageOffset - OS.tell());
259   OS << OffloadingData.Image->getBuffer();
260 
261   // Add final padding to required alignment.
262   assert(TheHeader.Size >= OS.tell() && "Too much data written?");
263   OS.write_zeros(TheHeader.Size - OS.tell());
264   assert(TheHeader.Size == OS.tell() && "Size mismatch");
265 
266   return Data;
267 }
268 
269 Error object::extractOffloadBinaries(MemoryBufferRef Buffer,
270                                      SmallVectorImpl<OffloadFile> &Binaries) {
271   file_magic Type = identify_magic(Buffer.getBuffer());
272   switch (Type) {
273   case file_magic::bitcode:
274     return extractFromBitcode(Buffer, Binaries);
275   case file_magic::elf_relocatable:
276   case file_magic::elf_executable:
277   case file_magic::elf_shared_object:
278   case file_magic::coff_object: {
279     Expected<std::unique_ptr<ObjectFile>> ObjFile =
280         ObjectFile::createObjectFile(Buffer, Type);
281     if (!ObjFile)
282       return ObjFile.takeError();
283     return extractFromObject(*ObjFile->get(), Binaries);
284   }
285   case file_magic::archive: {
286     Expected<std::unique_ptr<llvm::object::Archive>> LibFile =
287         object::Archive::create(Buffer);
288     if (!LibFile)
289       return LibFile.takeError();
290     return extractFromArchive(*LibFile->get(), Binaries);
291   }
292   case file_magic::offload_binary:
293     return extractOffloadFiles(Buffer, Binaries);
294   default:
295     return Error::success();
296   }
297 }
298 
299 OffloadKind object::getOffloadKind(StringRef Name) {
300   return llvm::StringSwitch<OffloadKind>(Name)
301       .Case("openmp", OFK_OpenMP)
302       .Case("cuda", OFK_Cuda)
303       .Case("hip", OFK_HIP)
304       .Default(OFK_None);
305 }
306 
307 StringRef object::getOffloadKindName(OffloadKind Kind) {
308   switch (Kind) {
309   case OFK_OpenMP:
310     return "openmp";
311   case OFK_Cuda:
312     return "cuda";
313   case OFK_HIP:
314     return "hip";
315   default:
316     return "none";
317   }
318 }
319 
320 ImageKind object::getImageKind(StringRef Name) {
321   return llvm::StringSwitch<ImageKind>(Name)
322       .Case("o", IMG_Object)
323       .Case("bc", IMG_Bitcode)
324       .Case("cubin", IMG_Cubin)
325       .Case("fatbin", IMG_Fatbinary)
326       .Case("s", IMG_PTX)
327       .Default(IMG_None);
328 }
329 
330 StringRef object::getImageKindName(ImageKind Kind) {
331   switch (Kind) {
332   case IMG_Object:
333     return "o";
334   case IMG_Bitcode:
335     return "bc";
336   case IMG_Cubin:
337     return "cubin";
338   case IMG_Fatbinary:
339     return "fatbin";
340   case IMG_PTX:
341     return "s";
342   default:
343     return "";
344   }
345 }
346 
347 bool object::areTargetsCompatible(const OffloadFile::TargetID &LHS,
348                                   const OffloadFile::TargetID &RHS) {
349   // Exact matches are not considered compatible because they are the same
350   // target. We are interested in different targets that are compatible.
351   if (LHS == RHS)
352     return false;
353 
354   // The triples must match at all times.
355   if (LHS.first != RHS.first)
356     return false;
357 
358   // If the architecture is "all" we assume it is always compatible.
359   if (LHS.second == "generic" || RHS.second == "generic")
360     return true;
361 
362   // Only The AMDGPU target requires additional checks.
363   llvm::Triple T(LHS.first);
364   if (!T.isAMDGPU())
365     return false;
366 
367   // The base processor must always match.
368   if (LHS.second.split(":").first != RHS.second.split(":").first)
369     return false;
370 
371   // Check combintions of on / off features that must match.
372   if (LHS.second.contains("xnack+") && RHS.second.contains("xnack-"))
373     return false;
374   if (LHS.second.contains("xnack-") && RHS.second.contains("xnack+"))
375     return false;
376   if (LHS.second.contains("sramecc-") && RHS.second.contains("sramecc+"))
377     return false;
378   if (LHS.second.contains("sramecc+") && RHS.second.contains("sramecc-"))
379     return false;
380   return true;
381 }
382