xref: /llvm-project/libc/utils/gpu/loader/Loader.h (revision e85a9f5540f5399b20a32c8d87474e6fc906ad33)
1 //===-- Generic device loader interface -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_GPU_LOADER_LOADER_H
10 #define LLVM_LIBC_UTILS_GPU_LOADER_LOADER_H
11 
12 #include "include/llvm-libc-types/test_rpc_opcodes_t.h"
13 
14 #include "shared/rpc.h"
15 #include "shared/rpc_opcodes.h"
16 
17 #include <cstddef>
18 #include <cstdint>
19 #include <cstdio>
20 #include <cstdlib>
21 #include <cstring>
22 
23 /// Generic launch parameters for configuration the number of blocks / threads.
24 struct LaunchParameters {
25   uint32_t num_threads_x;
26   uint32_t num_threads_y;
27   uint32_t num_threads_z;
28   uint32_t num_blocks_x;
29   uint32_t num_blocks_y;
30   uint32_t num_blocks_z;
31 };
32 
33 /// The arguments to the '_begin' kernel.
34 struct begin_args_t {
35   int argc;
36   void *argv;
37   void *envp;
38 };
39 
40 /// The arguments to the '_start' kernel.
41 struct start_args_t {
42   int argc;
43   void *argv;
44   void *envp;
45   void *ret;
46 };
47 
48 /// The arguments to the '_end' kernel.
49 struct end_args_t {
50   int argc;
51 };
52 
53 /// Generic interface to load the \p image and launch execution of the _start
54 /// kernel on the target device. Copies \p argc and \p argv to the device.
55 /// Returns the final value of the `main` function on the device.
56 int load(int argc, const char **argv, const char **evnp, void *image,
57          size_t size, const LaunchParameters &params,
58          bool print_resource_usage);
59 
60 /// Return \p V aligned "upwards" according to \p Align.
61 template <typename V, typename A> inline V align_up(V val, A align) {
62   return ((val + V(align) - 1) / V(align)) * V(align);
63 }
64 
65 /// Copy the system's argument vector to GPU memory allocated using \p alloc.
66 template <typename Allocator>
67 void *copy_argument_vector(int argc, const char **argv, Allocator alloc) {
68   size_t argv_size = sizeof(char *) * (argc + 1);
69   size_t str_size = 0;
70   for (int i = 0; i < argc; ++i)
71     str_size += strlen(argv[i]) + 1;
72 
73   // We allocate enough space for a null terminated array and all the strings.
74   void *dev_argv = alloc(argv_size + str_size);
75   if (!dev_argv)
76     return nullptr;
77 
78   // Store the strings linerally in the same memory buffer.
79   void *dev_str = reinterpret_cast<uint8_t *>(dev_argv) + argv_size;
80   for (int i = 0; i < argc; ++i) {
81     size_t size = strlen(argv[i]) + 1;
82     std::memcpy(dev_str, argv[i], size);
83     static_cast<void **>(dev_argv)[i] = dev_str;
84     dev_str = reinterpret_cast<uint8_t *>(dev_str) + size;
85   }
86 
87   // Ensure the vector is null terminated.
88   reinterpret_cast<void **>(dev_argv)[argc] = nullptr;
89   return dev_argv;
90 }
91 
92 /// Copy the system's environment to GPU memory allocated using \p alloc.
93 template <typename Allocator>
94 void *copy_environment(const char **envp, Allocator alloc) {
95   int envc = 0;
96   for (const char **env = envp; *env != 0; ++env)
97     ++envc;
98 
99   return copy_argument_vector(envc, envp, alloc);
100 }
101 
102 inline void handle_error_impl(const char *file, int32_t line, const char *msg) {
103   fprintf(stderr, "%s:%d:0: Error: %s\n", file, line, msg);
104   exit(EXIT_FAILURE);
105 }
106 #define handle_error(X) handle_error_impl(__FILE__, __LINE__, X)
107 
108 template <uint32_t num_lanes, typename Alloc, typename Free>
109 inline uint32_t handle_server(rpc::Server &server, uint32_t index,
110                               Alloc &&alloc, Free &&free) {
111   auto port = server.try_open(num_lanes, index);
112   if (!port)
113     return 0;
114   index = port->get_index() + 1;
115 
116   int status = rpc::RPC_SUCCESS;
117   switch (port->get_opcode()) {
118   case RPC_TEST_INCREMENT: {
119     port->recv_and_send([](rpc::Buffer *buffer, uint32_t) {
120       reinterpret_cast<uint64_t *>(buffer->data)[0] += 1;
121     });
122     break;
123   }
124   case RPC_TEST_INTERFACE: {
125     bool end_with_recv;
126     uint64_t cnt;
127     port->recv([&](rpc::Buffer *buffer, uint32_t) {
128       end_with_recv = buffer->data[0];
129     });
130     port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
131     port->send([&](rpc::Buffer *buffer, uint32_t) {
132       buffer->data[0] = cnt = cnt + 1;
133     });
134     port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
135     port->send([&](rpc::Buffer *buffer, uint32_t) {
136       buffer->data[0] = cnt = cnt + 1;
137     });
138     port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
139     port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
140     port->send([&](rpc::Buffer *buffer, uint32_t) {
141       buffer->data[0] = cnt = cnt + 1;
142     });
143     port->send([&](rpc::Buffer *buffer, uint32_t) {
144       buffer->data[0] = cnt = cnt + 1;
145     });
146     if (end_with_recv)
147       port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; });
148     else
149       port->send([&](rpc::Buffer *buffer, uint32_t) {
150         buffer->data[0] = cnt = cnt + 1;
151       });
152 
153     break;
154   }
155   case RPC_TEST_STREAM: {
156     uint64_t sizes[num_lanes] = {0};
157     void *dst[num_lanes] = {nullptr};
158     port->recv_n(dst, sizes,
159                  [](uint64_t size) -> void * { return new char[size]; });
160     port->send_n(dst, sizes);
161     for (uint64_t i = 0; i < num_lanes; ++i) {
162       if (dst[i])
163         delete[] reinterpret_cast<uint8_t *>(dst[i]);
164     }
165     break;
166   }
167   case RPC_TEST_NOOP: {
168     port->recv([&](rpc::Buffer *, uint32_t) {});
169     break;
170   }
171   case LIBC_MALLOC: {
172     port->recv_and_send([&](rpc::Buffer *buffer, uint32_t) {
173       buffer->data[0] = reinterpret_cast<uintptr_t>(alloc(buffer->data[0]));
174     });
175     break;
176   }
177   case LIBC_FREE: {
178     port->recv([&](rpc::Buffer *buffer, uint32_t) {
179       free(reinterpret_cast<void *>(buffer->data[0]));
180     });
181     break;
182   }
183   default:
184     status = handle_libc_opcodes(*port, num_lanes);
185     break;
186   }
187 
188   // Handle all of the `libc` specific opcodes.
189   if (status != rpc::RPC_SUCCESS)
190     handle_error("Error handling RPC server");
191 
192   port->close();
193 
194   return index;
195 }
196 
197 #endif
198