1 //===-- clang-offload-wrapper/ClangOffloadWrapper.cpp -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Implementation of the offload wrapper tool. It takes offload target binaries
11 /// as input and creates wrapper bitcode file containing target binaries
12 /// packaged as data. Wrapper bitcode also includes initialization code which
13 /// registers target binaries in offloading runtime at program startup.
14 ///
15 //===----------------------------------------------------------------------===//
16
17 #include "clang/Basic/Version.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/Bitcode/BitcodeWriter.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/GlobalVariable.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/LLVMContext.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Errc.h"
28 #include "llvm/Support/Error.h"
29 #include "llvm/Support/ErrorOr.h"
30 #include "llvm/Support/FileSystem.h"
31 #include "llvm/Support/MemoryBuffer.h"
32 #include "llvm/Support/Signals.h"
33 #include "llvm/Support/ToolOutputFile.h"
34 #include "llvm/Support/WithColor.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include "llvm/Transforms/Utils/ModuleUtils.h"
37 #include <cassert>
38 #include <cstdint>
39
40 using namespace llvm;
41
42 static cl::opt<bool> Help("h", cl::desc("Alias for -help"), cl::Hidden);
43
44 // Mark all our options with this category, everything else (except for -version
45 // and -help) will be hidden.
46 static cl::OptionCategory
47 ClangOffloadWrapperCategory("clang-offload-wrapper options");
48
49 static cl::opt<std::string> Output("o", cl::Required,
50 cl::desc("Output filename"),
51 cl::value_desc("filename"),
52 cl::cat(ClangOffloadWrapperCategory));
53
54 static cl::list<std::string> Inputs(cl::Positional, cl::OneOrMore,
55 cl::desc("<input files>"),
56 cl::cat(ClangOffloadWrapperCategory));
57
58 static cl::opt<std::string>
59 Target("target", cl::Required,
60 cl::desc("Target triple for the output module"),
61 cl::value_desc("triple"), cl::cat(ClangOffloadWrapperCategory));
62
63 namespace {
64
65 class BinaryWrapper {
66 LLVMContext C;
67 Module M;
68
69 StructType *EntryTy = nullptr;
70 StructType *ImageTy = nullptr;
71 StructType *DescTy = nullptr;
72
73 private:
getSizeTTy()74 IntegerType *getSizeTTy() {
75 switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) {
76 case 4u:
77 return Type::getInt32Ty(C);
78 case 8u:
79 return Type::getInt64Ty(C);
80 }
81 llvm_unreachable("unsupported pointer type size");
82 }
83
84 // struct __tgt_offload_entry {
85 // void *addr;
86 // char *name;
87 // size_t size;
88 // int32_t flags;
89 // int32_t reserved;
90 // };
getEntryTy()91 StructType *getEntryTy() {
92 if (!EntryTy)
93 EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C),
94 Type::getInt8PtrTy(C), getSizeTTy(),
95 Type::getInt32Ty(C), Type::getInt32Ty(C));
96 return EntryTy;
97 }
98
getEntryPtrTy()99 PointerType *getEntryPtrTy() { return PointerType::getUnqual(getEntryTy()); }
100
101 // struct __tgt_device_image {
102 // void *ImageStart;
103 // void *ImageEnd;
104 // __tgt_offload_entry *EntriesBegin;
105 // __tgt_offload_entry *EntriesEnd;
106 // };
getDeviceImageTy()107 StructType *getDeviceImageTy() {
108 if (!ImageTy)
109 ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C),
110 Type::getInt8PtrTy(C), getEntryPtrTy(),
111 getEntryPtrTy());
112 return ImageTy;
113 }
114
getDeviceImagePtrTy()115 PointerType *getDeviceImagePtrTy() {
116 return PointerType::getUnqual(getDeviceImageTy());
117 }
118
119 // struct __tgt_bin_desc {
120 // int32_t NumDeviceImages;
121 // __tgt_device_image *DeviceImages;
122 // __tgt_offload_entry *HostEntriesBegin;
123 // __tgt_offload_entry *HostEntriesEnd;
124 // };
getBinDescTy()125 StructType *getBinDescTy() {
126 if (!DescTy)
127 DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C),
128 getDeviceImagePtrTy(), getEntryPtrTy(),
129 getEntryPtrTy());
130 return DescTy;
131 }
132
getBinDescPtrTy()133 PointerType *getBinDescPtrTy() {
134 return PointerType::getUnqual(getBinDescTy());
135 }
136
137 /// Creates binary descriptor for the given device images. Binary descriptor
138 /// is an object that is passed to the offloading runtime at program startup
139 /// and it describes all device images available in the executable or shared
140 /// library. It is defined as follows
141 ///
142 /// __attribute__((visibility("hidden")))
143 /// extern __tgt_offload_entry *__start_omp_offloading_entries;
144 /// __attribute__((visibility("hidden")))
145 /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
146 ///
147 /// static const char Image0[] = { <Bufs.front() contents> };
148 /// ...
149 /// static const char ImageN[] = { <Bufs.back() contents> };
150 ///
151 /// static const __tgt_device_image Images[] = {
152 /// {
153 /// Image0, /*ImageStart*/
154 /// Image0 + sizeof(Image0), /*ImageEnd*/
155 /// __start_omp_offloading_entries, /*EntriesBegin*/
156 /// __stop_omp_offloading_entries /*EntriesEnd*/
157 /// },
158 /// ...
159 /// {
160 /// ImageN, /*ImageStart*/
161 /// ImageN + sizeof(ImageN), /*ImageEnd*/
162 /// __start_omp_offloading_entries, /*EntriesBegin*/
163 /// __stop_omp_offloading_entries /*EntriesEnd*/
164 /// }
165 /// };
166 ///
167 /// static const __tgt_bin_desc BinDesc = {
168 /// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/
169 /// Images, /*DeviceImages*/
170 /// __start_omp_offloading_entries, /*HostEntriesBegin*/
171 /// __stop_omp_offloading_entries /*HostEntriesEnd*/
172 /// };
173 ///
174 /// Global variable that represents BinDesc is returned.
createBinDesc(ArrayRef<ArrayRef<char>> Bufs)175 GlobalVariable *createBinDesc(ArrayRef<ArrayRef<char>> Bufs) {
176 // Create external begin/end symbols for the offload entries table.
177 auto *EntriesB = new GlobalVariable(
178 M, getEntryTy(), /*isConstant*/ true, GlobalValue::ExternalLinkage,
179 /*Initializer*/ nullptr, "__start_omp_offloading_entries");
180 EntriesB->setVisibility(GlobalValue::HiddenVisibility);
181 auto *EntriesE = new GlobalVariable(
182 M, getEntryTy(), /*isConstant*/ true, GlobalValue::ExternalLinkage,
183 /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
184 EntriesE->setVisibility(GlobalValue::HiddenVisibility);
185
186 // We assume that external begin/end symbols that we have created above will
187 // be defined by the linker. But linker will do that only if linker inputs
188 // have section with "omp_offloading_entries" name which is not guaranteed.
189 // So, we just create dummy zero sized object in the offload entries section
190 // to force linker to define those symbols.
191 auto *DummyInit =
192 ConstantAggregateZero::get(ArrayType::get(getEntryTy(), 0u));
193 auto *DummyEntry = new GlobalVariable(
194 M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage,
195 DummyInit, "__dummy.omp_offloading.entry");
196 DummyEntry->setSection("omp_offloading_entries");
197 DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
198
199 auto *Zero = ConstantInt::get(getSizeTTy(), 0u);
200 Constant *ZeroZero[] = {Zero, Zero};
201
202 // Create initializer for the images array.
203 SmallVector<Constant *, 4u> ImagesInits;
204 ImagesInits.reserve(Bufs.size());
205 for (ArrayRef<char> Buf : Bufs) {
206 auto *Data = ConstantDataArray::get(C, Buf);
207 auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
208 GlobalVariable::InternalLinkage, Data,
209 ".omp_offloading.device_image");
210 Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
211
212 auto *Size = ConstantInt::get(getSizeTTy(), Buf.size());
213 Constant *ZeroSize[] = {Zero, Size};
214
215 auto *ImageB = ConstantExpr::getGetElementPtr(Image->getValueType(),
216 Image, ZeroZero);
217 auto *ImageE = ConstantExpr::getGetElementPtr(Image->getValueType(),
218 Image, ZeroSize);
219
220 ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(), ImageB,
221 ImageE, EntriesB, EntriesE));
222 }
223
224 // Then create images array.
225 auto *ImagesData = ConstantArray::get(
226 ArrayType::get(getDeviceImageTy(), ImagesInits.size()), ImagesInits);
227
228 auto *Images =
229 new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
230 GlobalValue::InternalLinkage, ImagesData,
231 ".omp_offloading.device_images");
232 Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
233
234 auto *ImagesB = ConstantExpr::getGetElementPtr(Images->getValueType(),
235 Images, ZeroZero);
236
237 // And finally create the binary descriptor object.
238 auto *DescInit = ConstantStruct::get(
239 getBinDescTy(),
240 ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
241 EntriesB, EntriesE);
242
243 return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
244 GlobalValue::InternalLinkage, DescInit,
245 ".omp_offloading.descriptor");
246 }
247
createRegisterFunction(GlobalVariable * BinDesc)248 void createRegisterFunction(GlobalVariable *BinDesc) {
249 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
250 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
251 ".omp_offloading.descriptor_reg", &M);
252 Func->setSection(".text.startup");
253
254 // Get __tgt_register_lib function declaration.
255 auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(),
256 /*isVarArg*/ false);
257 FunctionCallee RegFuncC =
258 M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
259
260 // Construct function body
261 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
262 Builder.CreateCall(RegFuncC, BinDesc);
263 Builder.CreateRetVoid();
264
265 // Add this function to constructors.
266 // Set priority to 1 so that __tgt_register_lib is executed AFTER
267 // __tgt_register_requires (we want to know what requirements have been
268 // asked for before we load a libomptarget plugin so that by the time the
269 // plugin is loaded it can report how many devices there are which can
270 // satisfy these requirements).
271 appendToGlobalCtors(M, Func, /*Priority*/ 1);
272 }
273
createUnregisterFunction(GlobalVariable * BinDesc)274 void createUnregisterFunction(GlobalVariable *BinDesc) {
275 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
276 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
277 ".omp_offloading.descriptor_unreg", &M);
278 Func->setSection(".text.startup");
279
280 // Get __tgt_unregister_lib function declaration.
281 auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(),
282 /*isVarArg*/ false);
283 FunctionCallee UnRegFuncC =
284 M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
285
286 // Construct function body
287 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
288 Builder.CreateCall(UnRegFuncC, BinDesc);
289 Builder.CreateRetVoid();
290
291 // Add this function to global destructors.
292 // Match priority of __tgt_register_lib
293 appendToGlobalDtors(M, Func, /*Priority*/ 1);
294 }
295
296 public:
BinaryWrapper(StringRef Target)297 BinaryWrapper(StringRef Target) : M("offload.wrapper.object", C) {
298 M.setTargetTriple(Target);
299 }
300
wrapBinaries(ArrayRef<ArrayRef<char>> Binaries)301 const Module &wrapBinaries(ArrayRef<ArrayRef<char>> Binaries) {
302 GlobalVariable *Desc = createBinDesc(Binaries);
303 assert(Desc && "no binary descriptor");
304 createRegisterFunction(Desc);
305 createUnregisterFunction(Desc);
306 return M;
307 }
308 };
309
310 } // anonymous namespace
311
main(int argc,const char ** argv)312 int main(int argc, const char **argv) {
313 sys::PrintStackTraceOnErrorSignal(argv[0]);
314
315 cl::HideUnrelatedOptions(ClangOffloadWrapperCategory);
316 cl::SetVersionPrinter([](raw_ostream &OS) {
317 OS << clang::getClangToolFullVersion("clang-offload-wrapper") << '\n';
318 });
319 cl::ParseCommandLineOptions(
320 argc, argv,
321 "A tool to create a wrapper bitcode for offload target binaries. Takes "
322 "offload\ntarget binaries as input and produces bitcode file containing "
323 "target binaries packaged\nas data and initialization code which "
324 "registers target binaries in offload runtime.\n");
325
326 if (Help) {
327 cl::PrintHelpMessage();
328 return 0;
329 }
330
331 auto reportError = [argv](Error E) {
332 logAllUnhandledErrors(std::move(E), WithColor::error(errs(), argv[0]));
333 };
334
335 if (Triple(Target).getArch() == Triple::UnknownArch) {
336 reportError(createStringError(
337 errc::invalid_argument, "'" + Target + "': unsupported target triple"));
338 return 1;
339 }
340
341 // Read device binaries.
342 SmallVector<std::unique_ptr<MemoryBuffer>, 4u> Buffers;
343 SmallVector<ArrayRef<char>, 4u> Images;
344 Buffers.reserve(Inputs.size());
345 Images.reserve(Inputs.size());
346 for (const std::string &File : Inputs) {
347 ErrorOr<std::unique_ptr<MemoryBuffer>> BufOrErr =
348 MemoryBuffer::getFileOrSTDIN(File);
349 if (!BufOrErr) {
350 reportError(createFileError(File, BufOrErr.getError()));
351 return 1;
352 }
353 const std::unique_ptr<MemoryBuffer> &Buf =
354 Buffers.emplace_back(std::move(*BufOrErr));
355 Images.emplace_back(Buf->getBufferStart(), Buf->getBufferSize());
356 }
357
358 // Create the output file to write the resulting bitcode to.
359 std::error_code EC;
360 ToolOutputFile Out(Output, EC, sys::fs::OF_None);
361 if (EC) {
362 reportError(createFileError(Output, EC));
363 return 1;
364 }
365
366 // Create a wrapper for device binaries and write its bitcode to the file.
367 WriteBitcodeToFile(BinaryWrapper(Target).wrapBinaries(
368 makeArrayRef(Images.data(), Images.size())),
369 Out.os());
370 if (Out.os().has_error()) {
371 reportError(createFileError(Output, Out.os().error()));
372 return 1;
373 }
374
375 // Success.
376 Out.keep();
377 return 0;
378 }
379