1 //===-- Generic device loader interface -----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_UTILS_GPU_LOADER_LOADER_H 10 #define LLVM_LIBC_UTILS_GPU_LOADER_LOADER_H 11 12 #include "include/llvm-libc-types/test_rpc_opcodes_t.h" 13 14 #include "shared/rpc.h" 15 #include "shared/rpc_opcodes.h" 16 17 #include <cstddef> 18 #include <cstdint> 19 #include <cstdio> 20 #include <cstdlib> 21 #include <cstring> 22 23 /// Generic launch parameters for configuration the number of blocks / threads. 24 struct LaunchParameters { 25 uint32_t num_threads_x; 26 uint32_t num_threads_y; 27 uint32_t num_threads_z; 28 uint32_t num_blocks_x; 29 uint32_t num_blocks_y; 30 uint32_t num_blocks_z; 31 }; 32 33 /// The arguments to the '_begin' kernel. 34 struct begin_args_t { 35 int argc; 36 void *argv; 37 void *envp; 38 }; 39 40 /// The arguments to the '_start' kernel. 41 struct start_args_t { 42 int argc; 43 void *argv; 44 void *envp; 45 void *ret; 46 }; 47 48 /// The arguments to the '_end' kernel. 49 struct end_args_t { 50 int argc; 51 }; 52 53 /// Generic interface to load the \p image and launch execution of the _start 54 /// kernel on the target device. Copies \p argc and \p argv to the device. 55 /// Returns the final value of the `main` function on the device. 56 int load(int argc, const char **argv, const char **evnp, void *image, 57 size_t size, const LaunchParameters ¶ms, 58 bool print_resource_usage); 59 60 /// Return \p V aligned "upwards" according to \p Align. 61 template <typename V, typename A> inline V align_up(V val, A align) { 62 return ((val + V(align) - 1) / V(align)) * V(align); 63 } 64 65 /// Copy the system's argument vector to GPU memory allocated using \p alloc. 66 template <typename Allocator> 67 void *copy_argument_vector(int argc, const char **argv, Allocator alloc) { 68 size_t argv_size = sizeof(char *) * (argc + 1); 69 size_t str_size = 0; 70 for (int i = 0; i < argc; ++i) 71 str_size += strlen(argv[i]) + 1; 72 73 // We allocate enough space for a null terminated array and all the strings. 74 void *dev_argv = alloc(argv_size + str_size); 75 if (!dev_argv) 76 return nullptr; 77 78 // Store the strings linerally in the same memory buffer. 79 void *dev_str = reinterpret_cast<uint8_t *>(dev_argv) + argv_size; 80 for (int i = 0; i < argc; ++i) { 81 size_t size = strlen(argv[i]) + 1; 82 std::memcpy(dev_str, argv[i], size); 83 static_cast<void **>(dev_argv)[i] = dev_str; 84 dev_str = reinterpret_cast<uint8_t *>(dev_str) + size; 85 } 86 87 // Ensure the vector is null terminated. 88 reinterpret_cast<void **>(dev_argv)[argc] = nullptr; 89 return dev_argv; 90 } 91 92 /// Copy the system's environment to GPU memory allocated using \p alloc. 93 template <typename Allocator> 94 void *copy_environment(const char **envp, Allocator alloc) { 95 int envc = 0; 96 for (const char **env = envp; *env != 0; ++env) 97 ++envc; 98 99 return copy_argument_vector(envc, envp, alloc); 100 } 101 102 inline void handle_error_impl(const char *file, int32_t line, const char *msg) { 103 fprintf(stderr, "%s:%d:0: Error: %s\n", file, line, msg); 104 exit(EXIT_FAILURE); 105 } 106 #define handle_error(X) handle_error_impl(__FILE__, __LINE__, X) 107 108 template <uint32_t num_lanes, typename Alloc, typename Free> 109 inline uint32_t handle_server(rpc::Server &server, uint32_t index, 110 Alloc &&alloc, Free &&free) { 111 auto port = server.try_open(num_lanes, index); 112 if (!port) 113 return 0; 114 index = port->get_index() + 1; 115 116 int status = rpc::RPC_SUCCESS; 117 switch (port->get_opcode()) { 118 case RPC_TEST_INCREMENT: { 119 port->recv_and_send([](rpc::Buffer *buffer, uint32_t) { 120 reinterpret_cast<uint64_t *>(buffer->data)[0] += 1; 121 }); 122 break; 123 } 124 case RPC_TEST_INTERFACE: { 125 bool end_with_recv; 126 uint64_t cnt; 127 port->recv([&](rpc::Buffer *buffer, uint32_t) { 128 end_with_recv = buffer->data[0]; 129 }); 130 port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; }); 131 port->send([&](rpc::Buffer *buffer, uint32_t) { 132 buffer->data[0] = cnt = cnt + 1; 133 }); 134 port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; }); 135 port->send([&](rpc::Buffer *buffer, uint32_t) { 136 buffer->data[0] = cnt = cnt + 1; 137 }); 138 port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; }); 139 port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; }); 140 port->send([&](rpc::Buffer *buffer, uint32_t) { 141 buffer->data[0] = cnt = cnt + 1; 142 }); 143 port->send([&](rpc::Buffer *buffer, uint32_t) { 144 buffer->data[0] = cnt = cnt + 1; 145 }); 146 if (end_with_recv) 147 port->recv([&](rpc::Buffer *buffer, uint32_t) { cnt = buffer->data[0]; }); 148 else 149 port->send([&](rpc::Buffer *buffer, uint32_t) { 150 buffer->data[0] = cnt = cnt + 1; 151 }); 152 153 break; 154 } 155 case RPC_TEST_STREAM: { 156 uint64_t sizes[num_lanes] = {0}; 157 void *dst[num_lanes] = {nullptr}; 158 port->recv_n(dst, sizes, 159 [](uint64_t size) -> void * { return new char[size]; }); 160 port->send_n(dst, sizes); 161 for (uint64_t i = 0; i < num_lanes; ++i) { 162 if (dst[i]) 163 delete[] reinterpret_cast<uint8_t *>(dst[i]); 164 } 165 break; 166 } 167 case RPC_TEST_NOOP: { 168 port->recv([&](rpc::Buffer *, uint32_t) {}); 169 break; 170 } 171 case LIBC_MALLOC: { 172 port->recv_and_send([&](rpc::Buffer *buffer, uint32_t) { 173 buffer->data[0] = reinterpret_cast<uintptr_t>(alloc(buffer->data[0])); 174 }); 175 break; 176 } 177 case LIBC_FREE: { 178 port->recv([&](rpc::Buffer *buffer, uint32_t) { 179 free(reinterpret_cast<void *>(buffer->data[0])); 180 }); 181 break; 182 } 183 default: 184 status = handle_libc_opcodes(*port, num_lanes); 185 break; 186 } 187 188 // Handle all of the `libc` specific opcodes. 189 if (status != rpc::RPC_SUCCESS) 190 handle_error("Error handling RPC server"); 191 192 port->close(); 193 194 return index; 195 } 196 197 #endif 198