xref: /llvm-project/libc/utils/gpu/server/rpc_server.cpp (revision e85a9f5540f5399b20a32c8d87474e6fc906ad33)
1 //===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // Workaround for missing __has_builtin in < GCC 10.
10 #ifndef __has_builtin
11 #define __has_builtin(x) 0
12 #endif
13 
14 // Make sure these are included first so they don't conflict with the system.
15 #include <limits.h>
16 
17 #include "shared/rpc.h"
18 #include "shared/rpc_opcodes.h"
19 
20 #include "src/__support/arg_list.h"
21 #include "src/stdio/printf_core/converter.h"
22 #include "src/stdio/printf_core/parser.h"
23 #include "src/stdio/printf_core/writer.h"
24 
25 #include <algorithm>
26 #include <atomic>
27 #include <cstdio>
28 #include <cstring>
29 #include <memory>
30 #include <mutex>
31 #include <unordered_map>
32 #include <variant>
33 #include <vector>
34 
35 using namespace LIBC_NAMESPACE;
36 using namespace LIBC_NAMESPACE::printf_core;
37 
38 namespace {
39 struct TempStorage {
40   char *alloc(size_t size) {
41     storage.emplace_back(std::make_unique<char[]>(size));
42     return storage.back().get();
43   }
44 
45   std::vector<std::unique_ptr<char[]>> storage;
46 };
47 } // namespace
48 
49 enum Stream {
50   File = 0,
51   Stdin = 1,
52   Stdout = 2,
53   Stderr = 3,
54 };
55 
56 // Get the associated stream out of an encoded number.
57 LIBC_INLINE ::FILE *to_stream(uintptr_t f) {
58   ::FILE *stream = reinterpret_cast<FILE *>(f & ~0x3ull);
59   Stream type = static_cast<Stream>(f & 0x3ull);
60   if (type == Stdin)
61     return stdin;
62   if (type == Stdout)
63     return stdout;
64   if (type == Stderr)
65     return stderr;
66   return stream;
67 }
68 
69 template <bool packed, uint32_t num_lanes>
70 static void handle_printf(rpc::Server::Port &port, TempStorage &temp_storage) {
71   FILE *files[num_lanes] = {nullptr};
72   // Get the appropriate output stream to use.
73   if (port.get_opcode() == LIBC_PRINTF_TO_STREAM ||
74       port.get_opcode() == LIBC_PRINTF_TO_STREAM_PACKED)
75     port.recv([&](rpc::Buffer *buffer, uint32_t id) {
76       files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
77     });
78   else if (port.get_opcode() == LIBC_PRINTF_TO_STDOUT ||
79            port.get_opcode() == LIBC_PRINTF_TO_STDOUT_PACKED)
80     std::fill(files, files + num_lanes, stdout);
81   else
82     std::fill(files, files + num_lanes, stderr);
83 
84   uint64_t format_sizes[num_lanes] = {0};
85   void *format[num_lanes] = {nullptr};
86 
87   uint64_t args_sizes[num_lanes] = {0};
88   void *args[num_lanes] = {nullptr};
89 
90   // Recieve the format string and arguments from the client.
91   port.recv_n(format, format_sizes,
92               [&](uint64_t size) { return temp_storage.alloc(size); });
93 
94   // Parse the format string to get the expected size of the buffer.
95   for (uint32_t lane = 0; lane < num_lanes; ++lane) {
96     if (!format[lane])
97       continue;
98 
99     WriteBuffer wb(nullptr, 0);
100     Writer writer(&wb);
101 
102     internal::DummyArgList<packed> printf_args;
103     Parser<internal::DummyArgList<packed> &> parser(
104         reinterpret_cast<const char *>(format[lane]), printf_args);
105 
106     for (FormatSection cur_section = parser.get_next_section();
107          !cur_section.raw_string.empty();
108          cur_section = parser.get_next_section())
109       ;
110     args_sizes[lane] = printf_args.read_count();
111   }
112   port.send([&](rpc::Buffer *buffer, uint32_t id) {
113     buffer->data[0] = args_sizes[id];
114   });
115   port.recv_n(args, args_sizes,
116               [&](uint64_t size) { return temp_storage.alloc(size); });
117 
118   // Identify any arguments that are actually pointers to strings on the client.
119   // Additionally we want to determine how much buffer space we need to print.
120   std::vector<void *> strs_to_copy[num_lanes];
121   int buffer_size[num_lanes] = {0};
122   for (uint32_t lane = 0; lane < num_lanes; ++lane) {
123     if (!format[lane])
124       continue;
125 
126     WriteBuffer wb(nullptr, 0);
127     Writer writer(&wb);
128 
129     internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
130     Parser<internal::StructArgList<packed>> parser(
131         reinterpret_cast<const char *>(format[lane]), printf_args);
132 
133     for (FormatSection cur_section = parser.get_next_section();
134          !cur_section.raw_string.empty();
135          cur_section = parser.get_next_section()) {
136       if (cur_section.has_conv && cur_section.conv_name == 's' &&
137           cur_section.conv_val_ptr) {
138         strs_to_copy[lane].emplace_back(cur_section.conv_val_ptr);
139         // Get the minimum size of the string in the case of padding.
140         char c = '\0';
141         cur_section.conv_val_ptr = &c;
142         convert(&writer, cur_section);
143       } else if (cur_section.has_conv) {
144         // Ignore conversion errors for the first pass.
145         convert(&writer, cur_section);
146       } else {
147         writer.write(cur_section.raw_string);
148       }
149     }
150     buffer_size[lane] = writer.get_chars_written();
151   }
152 
153   // Recieve any strings from the client and push them into a buffer.
154   std::vector<void *> copied_strs[num_lanes];
155   while (std::any_of(std::begin(strs_to_copy), std::end(strs_to_copy),
156                      [](const auto &v) { return !v.empty() && v.back(); })) {
157     port.send([&](rpc::Buffer *buffer, uint32_t id) {
158       void *ptr = !strs_to_copy[id].empty() ? strs_to_copy[id].back() : nullptr;
159       buffer->data[1] = reinterpret_cast<uintptr_t>(ptr);
160       if (!strs_to_copy[id].empty())
161         strs_to_copy[id].pop_back();
162     });
163     uint64_t str_sizes[num_lanes] = {0};
164     void *strs[num_lanes] = {nullptr};
165     port.recv_n(strs, str_sizes,
166                 [&](uint64_t size) { return temp_storage.alloc(size); });
167     for (uint32_t lane = 0; lane < num_lanes; ++lane) {
168       if (!strs[lane])
169         continue;
170 
171       copied_strs[lane].emplace_back(strs[lane]);
172       buffer_size[lane] += str_sizes[lane];
173     }
174   }
175 
176   // Perform the final formatting and printing using the LLVM C library printf.
177   int results[num_lanes] = {0};
178   for (uint32_t lane = 0; lane < num_lanes; ++lane) {
179     if (!format[lane])
180       continue;
181 
182     char *buffer = temp_storage.alloc(buffer_size[lane]);
183     WriteBuffer wb(buffer, buffer_size[lane]);
184     Writer writer(&wb);
185 
186     internal::StructArgList<packed> printf_args(args[lane], args_sizes[lane]);
187     Parser<internal::StructArgList<packed>> parser(
188         reinterpret_cast<const char *>(format[lane]), printf_args);
189 
190     // Parse and print the format string using the arguments we copied from
191     // the client.
192     int ret = 0;
193     for (FormatSection cur_section = parser.get_next_section();
194          !cur_section.raw_string.empty();
195          cur_section = parser.get_next_section()) {
196       // If this argument was a string we use the memory buffer we copied from
197       // the client by replacing the raw pointer with the copied one.
198       if (cur_section.has_conv && cur_section.conv_name == 's') {
199         if (!copied_strs[lane].empty()) {
200           cur_section.conv_val_ptr = copied_strs[lane].back();
201           copied_strs[lane].pop_back();
202         } else {
203           cur_section.conv_val_ptr = nullptr;
204         }
205       }
206       if (cur_section.has_conv) {
207         ret = convert(&writer, cur_section);
208         if (ret == -1)
209           break;
210       } else {
211         writer.write(cur_section.raw_string);
212       }
213     }
214 
215     results[lane] = fwrite(buffer, 1, writer.get_chars_written(), files[lane]);
216     if (results[lane] != writer.get_chars_written() || ret == -1)
217       results[lane] = -1;
218   }
219 
220   // Send the final return value and signal completion by setting the string
221   // argument to null.
222   port.send([&](rpc::Buffer *buffer, uint32_t id) {
223     buffer->data[0] = static_cast<uint64_t>(results[id]);
224     buffer->data[1] = reinterpret_cast<uintptr_t>(nullptr);
225   });
226 }
227 
228 template <uint32_t num_lanes>
229 rpc::Status handle_port_impl(rpc::Server::Port &port) {
230   TempStorage temp_storage;
231 
232   switch (port.get_opcode()) {
233   case LIBC_WRITE_TO_STREAM:
234   case LIBC_WRITE_TO_STDERR:
235   case LIBC_WRITE_TO_STDOUT:
236   case LIBC_WRITE_TO_STDOUT_NEWLINE: {
237     uint64_t sizes[num_lanes] = {0};
238     void *strs[num_lanes] = {nullptr};
239     FILE *files[num_lanes] = {nullptr};
240     if (port.get_opcode() == LIBC_WRITE_TO_STREAM) {
241       port.recv([&](rpc::Buffer *buffer, uint32_t id) {
242         files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
243       });
244     } else if (port.get_opcode() == LIBC_WRITE_TO_STDERR) {
245       std::fill(files, files + num_lanes, stderr);
246     } else {
247       std::fill(files, files + num_lanes, stdout);
248     }
249 
250     port.recv_n(strs, sizes,
251                 [&](uint64_t size) { return temp_storage.alloc(size); });
252     port.send([&](rpc::Buffer *buffer, uint32_t id) {
253       flockfile(files[id]);
254       buffer->data[0] = fwrite_unlocked(strs[id], 1, sizes[id], files[id]);
255       if (port.get_opcode() == LIBC_WRITE_TO_STDOUT_NEWLINE &&
256           buffer->data[0] == sizes[id])
257         buffer->data[0] += fwrite_unlocked("\n", 1, 1, files[id]);
258       funlockfile(files[id]);
259     });
260     break;
261   }
262   case LIBC_READ_FROM_STREAM: {
263     uint64_t sizes[num_lanes] = {0};
264     void *data[num_lanes] = {nullptr};
265     port.recv([&](rpc::Buffer *buffer, uint32_t id) {
266       data[id] = temp_storage.alloc(buffer->data[0]);
267       sizes[id] =
268           fread(data[id], 1, buffer->data[0], to_stream(buffer->data[1]));
269     });
270     port.send_n(data, sizes);
271     port.send([&](rpc::Buffer *buffer, uint32_t id) {
272       std::memcpy(buffer->data, &sizes[id], sizeof(uint64_t));
273     });
274     break;
275   }
276   case LIBC_READ_FGETS: {
277     uint64_t sizes[num_lanes] = {0};
278     void *data[num_lanes] = {nullptr};
279     port.recv([&](rpc::Buffer *buffer, uint32_t id) {
280       data[id] = temp_storage.alloc(buffer->data[0]);
281       const char *str = fgets(reinterpret_cast<char *>(data[id]),
282                               buffer->data[0], to_stream(buffer->data[1]));
283       sizes[id] = !str ? 0 : std::strlen(str) + 1;
284     });
285     port.send_n(data, sizes);
286     break;
287   }
288   case LIBC_OPEN_FILE: {
289     uint64_t sizes[num_lanes] = {0};
290     void *paths[num_lanes] = {nullptr};
291     port.recv_n(paths, sizes,
292                 [&](uint64_t size) { return temp_storage.alloc(size); });
293     port.recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
294       FILE *file = fopen(reinterpret_cast<char *>(paths[id]),
295                          reinterpret_cast<char *>(buffer->data));
296       buffer->data[0] = reinterpret_cast<uintptr_t>(file);
297     });
298     break;
299   }
300   case LIBC_CLOSE_FILE: {
301     port.recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
302       FILE *file = reinterpret_cast<FILE *>(buffer->data[0]);
303       buffer->data[0] = fclose(file);
304     });
305     break;
306   }
307   case LIBC_EXIT: {
308     // Send a response to the client to signal that we are ready to exit.
309     port.recv_and_send([](rpc::Buffer *, uint32_t) {});
310     port.recv([](rpc::Buffer *buffer, uint32_t) {
311       int status = 0;
312       std::memcpy(&status, buffer->data, sizeof(int));
313       exit(status);
314     });
315     break;
316   }
317   case LIBC_ABORT: {
318     // Send a response to the client to signal that we are ready to abort.
319     port.recv_and_send([](rpc::Buffer *, uint32_t) {});
320     port.recv([](rpc::Buffer *, uint32_t) {});
321     abort();
322     break;
323   }
324   case LIBC_HOST_CALL: {
325     uint64_t sizes[num_lanes] = {0};
326     unsigned long long results[num_lanes] = {0};
327     void *args[num_lanes] = {nullptr};
328     port.recv_n(args, sizes,
329                 [&](uint64_t size) { return temp_storage.alloc(size); });
330     port.recv([&](rpc::Buffer *buffer, uint32_t id) {
331       using func_ptr_t = unsigned long long (*)(void *);
332       auto func = reinterpret_cast<func_ptr_t>(buffer->data[0]);
333       results[id] = func(args[id]);
334     });
335     port.send([&](rpc::Buffer *buffer, uint32_t id) {
336       buffer->data[0] = static_cast<uint64_t>(results[id]);
337     });
338     break;
339   }
340   case LIBC_FEOF: {
341     port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
342       buffer->data[0] = feof(to_stream(buffer->data[0]));
343     });
344     break;
345   }
346   case LIBC_FERROR: {
347     port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
348       buffer->data[0] = ferror(to_stream(buffer->data[0]));
349     });
350     break;
351   }
352   case LIBC_CLEARERR: {
353     port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
354       clearerr(to_stream(buffer->data[0]));
355     });
356     break;
357   }
358   case LIBC_FSEEK: {
359     port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
360       buffer->data[0] =
361           fseek(to_stream(buffer->data[0]), static_cast<long>(buffer->data[1]),
362                 static_cast<int>(buffer->data[2]));
363     });
364     break;
365   }
366   case LIBC_FTELL: {
367     port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
368       buffer->data[0] = ftell(to_stream(buffer->data[0]));
369     });
370     break;
371   }
372   case LIBC_FFLUSH: {
373     port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
374       buffer->data[0] = fflush(to_stream(buffer->data[0]));
375     });
376     break;
377   }
378   case LIBC_UNGETC: {
379     port.recv_and_send([](rpc::Buffer *buffer, uint32_t) {
380       buffer->data[0] =
381           ungetc(static_cast<int>(buffer->data[0]), to_stream(buffer->data[1]));
382     });
383     break;
384   }
385   case LIBC_PRINTF_TO_STREAM_PACKED:
386   case LIBC_PRINTF_TO_STDOUT_PACKED:
387   case LIBC_PRINTF_TO_STDERR_PACKED: {
388     handle_printf<true, num_lanes>(port, temp_storage);
389     break;
390   }
391   case LIBC_PRINTF_TO_STREAM:
392   case LIBC_PRINTF_TO_STDOUT:
393   case LIBC_PRINTF_TO_STDERR: {
394     handle_printf<false, num_lanes>(port, temp_storage);
395     break;
396   }
397   case LIBC_REMOVE: {
398     uint64_t sizes[num_lanes] = {0};
399     void *args[num_lanes] = {nullptr};
400     port.recv_n(args, sizes,
401                 [&](uint64_t size) { return temp_storage.alloc(size); });
402     port.send([&](rpc::Buffer *buffer, uint32_t id) {
403       buffer->data[0] = static_cast<uint64_t>(
404           remove(reinterpret_cast<const char *>(args[id])));
405     });
406     break;
407   }
408   case LIBC_RENAME: {
409     uint64_t oldsizes[num_lanes] = {0};
410     uint64_t newsizes[num_lanes] = {0};
411     void *oldpath[num_lanes] = {nullptr};
412     void *newpath[num_lanes] = {nullptr};
413     port.recv_n(oldpath, oldsizes,
414                 [&](uint64_t size) { return temp_storage.alloc(size); });
415     port.recv_n(newpath, newsizes,
416                 [&](uint64_t size) { return temp_storage.alloc(size); });
417     port.send([&](rpc::Buffer *buffer, uint32_t id) {
418       buffer->data[0] = static_cast<uint64_t>(
419           rename(reinterpret_cast<const char *>(oldpath[id]),
420                  reinterpret_cast<const char *>(newpath[id])));
421     });
422     break;
423   }
424   case LIBC_SYSTEM: {
425     uint64_t sizes[num_lanes] = {0};
426     void *args[num_lanes] = {nullptr};
427     port.recv_n(args, sizes,
428                 [&](uint64_t size) { return temp_storage.alloc(size); });
429     port.send([&](rpc::Buffer *buffer, uint32_t id) {
430       buffer->data[0] = static_cast<uint64_t>(
431           system(reinterpret_cast<const char *>(args[id])));
432     });
433     break;
434   }
435   case LIBC_NOOP: {
436     port.recv([](rpc::Buffer *, uint32_t) {});
437     break;
438   }
439   default:
440     return rpc::RPC_UNHANDLED_OPCODE;
441   }
442 
443   return rpc::RPC_SUCCESS;
444 }
445 
446 namespace rpc {
447 // The implementation of this function currently lives in the utility directory
448 // at 'utils/gpu/server/rpc_server.cpp'.
449 rpc::Status handle_libc_opcodes(rpc::Server::Port &port, uint32_t num_lanes) {
450   switch (num_lanes) {
451   case 1:
452     return handle_port_impl<1>(port);
453   case 32:
454     return handle_port_impl<32>(port);
455   case 64:
456     return handle_port_impl<64>(port);
457   default:
458     return rpc::RPC_ERROR;
459   }
460 }
461 } // namespace rpc
462