xref: /llvm-project/llvm/lib/Support/raw_socket_stream.cpp (revision d44ea7186befe38eb2b3804b15cd1ee1777458ed)
1c0d5d36dScriis //===-- llvm/Support/raw_socket_stream.cpp - Socket streams --*- C++ -*-===//
2c0d5d36dScriis //
3c0d5d36dScriis // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c0d5d36dScriis // See https://llvm.org/LICENSE.txt for license information.
5c0d5d36dScriis // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c0d5d36dScriis //
7c0d5d36dScriis //===----------------------------------------------------------------------===//
8c0d5d36dScriis //
9c0d5d36dScriis // This file contains raw_ostream implementations for streams to communicate
10c0d5d36dScriis // via UNIX sockets
11c0d5d36dScriis //
12c0d5d36dScriis //===----------------------------------------------------------------------===//
13c0d5d36dScriis 
14c0d5d36dScriis #include "llvm/Support/raw_socket_stream.h"
15c0d5d36dScriis #include "llvm/Config/config.h"
16c0d5d36dScriis #include "llvm/Support/Error.h"
1787e6f87fSConnor Sughrue #include "llvm/Support/FileSystem.h"
1887e6f87fSConnor Sughrue 
1987e6f87fSConnor Sughrue #include <atomic>
2087e6f87fSConnor Sughrue #include <fcntl.h>
21*76321b9fSConnor Sughrue #include <functional>
22c0d5d36dScriis 
23c0d5d36dScriis #ifndef _WIN32
2487e6f87fSConnor Sughrue #include <poll.h>
25c0d5d36dScriis #include <sys/socket.h>
26c0d5d36dScriis #include <sys/un.h>
27c0d5d36dScriis #else
28c0d5d36dScriis #include "llvm/Support/Windows/WindowsSupport.h"
29c0d5d36dScriis // winsock2.h must be included before afunix.h. Briefly turn off clang-format to
30c0d5d36dScriis // avoid error.
31c0d5d36dScriis // clang-format off
32c0d5d36dScriis #include <winsock2.h>
33c0d5d36dScriis #include <afunix.h>
34c0d5d36dScriis // clang-format on
35c0d5d36dScriis #include <io.h>
36c0d5d36dScriis #endif // _WIN32
37c0d5d36dScriis 
38c0d5d36dScriis #if defined(HAVE_UNISTD_H)
39c0d5d36dScriis #include <unistd.h>
40c0d5d36dScriis #endif
41c0d5d36dScriis 
42c0d5d36dScriis using namespace llvm;
43c0d5d36dScriis 
44c0d5d36dScriis #ifdef _WIN32
45c0d5d36dScriis WSABalancer::WSABalancer() {
46c0d5d36dScriis   WSADATA WsaData;
47c0d5d36dScriis   ::memset(&WsaData, 0, sizeof(WsaData));
48c0d5d36dScriis   if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) {
49c0d5d36dScriis     llvm::report_fatal_error("WSAStartup failed");
50c0d5d36dScriis   }
51c0d5d36dScriis }
52c0d5d36dScriis 
53c0d5d36dScriis WSABalancer::~WSABalancer() { WSACleanup(); }
54c0d5d36dScriis #endif // _WIN32
55c0d5d36dScriis 
56c0d5d36dScriis static std::error_code getLastSocketErrorCode() {
57c0d5d36dScriis #ifdef _WIN32
58c0d5d36dScriis   return std::error_code(::WSAGetLastError(), std::system_category());
59c0d5d36dScriis #else
60ba13fa2aSMichael Spencer   return errnoAsErrorCode();
61c0d5d36dScriis #endif
62c0d5d36dScriis }
63c0d5d36dScriis 
6487e6f87fSConnor Sughrue static sockaddr_un setSocketAddr(StringRef SocketPath) {
65c0d5d36dScriis   struct sockaddr_un Addr;
66c0d5d36dScriis   memset(&Addr, 0, sizeof(Addr));
67c0d5d36dScriis   Addr.sun_family = AF_UNIX;
68c0d5d36dScriis   strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
6987e6f87fSConnor Sughrue   return Addr;
7087e6f87fSConnor Sughrue }
71c0d5d36dScriis 
7287e6f87fSConnor Sughrue static Expected<int> getSocketFD(StringRef SocketPath) {
73c0d5d36dScriis #ifdef _WIN32
7487e6f87fSConnor Sughrue   SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0);
7587e6f87fSConnor Sughrue   if (Socket == INVALID_SOCKET) {
76c0d5d36dScriis #else
7787e6f87fSConnor Sughrue   int Socket = socket(AF_UNIX, SOCK_STREAM, 0);
7887e6f87fSConnor Sughrue   if (Socket == -1) {
79c0d5d36dScriis #endif // _WIN32
80c0d5d36dScriis     return llvm::make_error<StringError>(getLastSocketErrorCode(),
81c0d5d36dScriis                                          "Create socket failed");
82c0d5d36dScriis   }
83c0d5d36dScriis 
8487e6f87fSConnor Sughrue   struct sockaddr_un Addr = setSocketAddr(SocketPath);
8587e6f87fSConnor Sughrue   if (::connect(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1)
86c0d5d36dScriis     return llvm::make_error<StringError>(getLastSocketErrorCode(),
87c0d5d36dScriis                                          "Connect socket failed");
8887e6f87fSConnor Sughrue 
89c0d5d36dScriis #ifdef _WIN32
9087e6f87fSConnor Sughrue   return _open_osfhandle(Socket, 0);
91c0d5d36dScriis #else
9287e6f87fSConnor Sughrue   return Socket;
93c0d5d36dScriis #endif // _WIN32
94c0d5d36dScriis }
95c0d5d36dScriis 
9687e6f87fSConnor Sughrue ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath,
9787e6f87fSConnor Sughrue                                  int PipeFD[2])
9887e6f87fSConnor Sughrue     : FD(SocketFD), SocketPath(SocketPath), PipeFD{PipeFD[0], PipeFD[1]} {}
9987e6f87fSConnor Sughrue 
10087e6f87fSConnor Sughrue ListeningSocket::ListeningSocket(ListeningSocket &&LS)
10187e6f87fSConnor Sughrue     : FD(LS.FD.load()), SocketPath(LS.SocketPath),
10287e6f87fSConnor Sughrue       PipeFD{LS.PipeFD[0], LS.PipeFD[1]} {
10387e6f87fSConnor Sughrue 
10487e6f87fSConnor Sughrue   LS.FD = -1;
10587e6f87fSConnor Sughrue   LS.SocketPath.clear();
10687e6f87fSConnor Sughrue   LS.PipeFD[0] = -1;
10787e6f87fSConnor Sughrue   LS.PipeFD[1] = -1;
10887e6f87fSConnor Sughrue }
10987e6f87fSConnor Sughrue 
11087e6f87fSConnor Sughrue Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
11187e6f87fSConnor Sughrue                                                       int MaxBacklog) {
11287e6f87fSConnor Sughrue 
11387e6f87fSConnor Sughrue   // Handle instances where the target socket address already exists and
11487e6f87fSConnor Sughrue   // differentiate between a preexisting file with and without a bound socket
11587e6f87fSConnor Sughrue   //
11687e6f87fSConnor Sughrue   // ::bind will return std::errc:address_in_use if a file at the socket address
11787e6f87fSConnor Sughrue   // already exists (e.g., the file was not properly unlinked due to a crash)
11887e6f87fSConnor Sughrue   // even if another socket has not yet binded to that address
11987e6f87fSConnor Sughrue   if (llvm::sys::fs::exists(SocketPath)) {
12087e6f87fSConnor Sughrue     Expected<int> MaybeFD = getSocketFD(SocketPath);
12187e6f87fSConnor Sughrue     if (!MaybeFD) {
12287e6f87fSConnor Sughrue 
12387e6f87fSConnor Sughrue       // Regardless of the error, notify the caller that a file already exists
12487e6f87fSConnor Sughrue       // at the desired socket address and that there is no bound socket at that
12587e6f87fSConnor Sughrue       // address. The file must be removed before ::bind can use the address
12687e6f87fSConnor Sughrue       consumeError(MaybeFD.takeError());
12787e6f87fSConnor Sughrue       return llvm::make_error<StringError>(
12887e6f87fSConnor Sughrue           std::make_error_code(std::errc::file_exists),
12987e6f87fSConnor Sughrue           "Socket address unavailable");
13087e6f87fSConnor Sughrue     }
13187e6f87fSConnor Sughrue     ::close(std::move(*MaybeFD));
13287e6f87fSConnor Sughrue 
13387e6f87fSConnor Sughrue     // Notify caller that the provided socket address already has a bound socket
13487e6f87fSConnor Sughrue     return llvm::make_error<StringError>(
13587e6f87fSConnor Sughrue         std::make_error_code(std::errc::address_in_use),
13687e6f87fSConnor Sughrue         "Socket address unavailable");
13787e6f87fSConnor Sughrue   }
13887e6f87fSConnor Sughrue 
13987e6f87fSConnor Sughrue #ifdef _WIN32
14087e6f87fSConnor Sughrue   WSABalancer _;
14187e6f87fSConnor Sughrue   SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0);
14287e6f87fSConnor Sughrue   if (Socket == INVALID_SOCKET)
14387e6f87fSConnor Sughrue #else
14487e6f87fSConnor Sughrue   int Socket = socket(AF_UNIX, SOCK_STREAM, 0);
14587e6f87fSConnor Sughrue   if (Socket == -1)
14687e6f87fSConnor Sughrue #endif
14787e6f87fSConnor Sughrue     return llvm::make_error<StringError>(getLastSocketErrorCode(),
14887e6f87fSConnor Sughrue                                          "socket create failed");
14987e6f87fSConnor Sughrue 
15087e6f87fSConnor Sughrue   struct sockaddr_un Addr = setSocketAddr(SocketPath);
15187e6f87fSConnor Sughrue   if (::bind(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
15287e6f87fSConnor Sughrue     // Grab error code from call to ::bind before calling ::close
15387e6f87fSConnor Sughrue     std::error_code EC = getLastSocketErrorCode();
15487e6f87fSConnor Sughrue     ::close(Socket);
15587e6f87fSConnor Sughrue     return llvm::make_error<StringError>(EC, "Bind error");
15687e6f87fSConnor Sughrue   }
15787e6f87fSConnor Sughrue 
15887e6f87fSConnor Sughrue   // Mark socket as passive so incoming connections can be accepted
15987e6f87fSConnor Sughrue   if (::listen(Socket, MaxBacklog) == -1)
16087e6f87fSConnor Sughrue     return llvm::make_error<StringError>(getLastSocketErrorCode(),
16187e6f87fSConnor Sughrue                                          "Listen error");
16287e6f87fSConnor Sughrue 
16387e6f87fSConnor Sughrue   int PipeFD[2];
16487e6f87fSConnor Sughrue #ifdef _WIN32
16587e6f87fSConnor Sughrue   // Reserve 1 byte for the pipe and use default textmode
16687e6f87fSConnor Sughrue   if (::_pipe(PipeFD, 1, 0) == -1)
16787e6f87fSConnor Sughrue #else
16887e6f87fSConnor Sughrue   if (::pipe(PipeFD) == -1)
16987e6f87fSConnor Sughrue #endif // _WIN32
17087e6f87fSConnor Sughrue     return llvm::make_error<StringError>(getLastSocketErrorCode(),
17187e6f87fSConnor Sughrue                                          "pipe failed");
17287e6f87fSConnor Sughrue 
17387e6f87fSConnor Sughrue #ifdef _WIN32
17487e6f87fSConnor Sughrue   return ListeningSocket{_open_osfhandle(Socket, 0), SocketPath, PipeFD};
17587e6f87fSConnor Sughrue #else
17687e6f87fSConnor Sughrue   return ListeningSocket{Socket, SocketPath, PipeFD};
17787e6f87fSConnor Sughrue #endif // _WIN32
17887e6f87fSConnor Sughrue }
17987e6f87fSConnor Sughrue 
180*76321b9fSConnor Sughrue // If a file descriptor being monitored by ::poll is closed by another thread,
181*76321b9fSConnor Sughrue // the result is unspecified. In the case ::poll does not unblock and return,
182*76321b9fSConnor Sughrue // when ActiveFD is closed, you can provide another file descriptor via CancelFD
183*76321b9fSConnor Sughrue // that when written to will cause poll to return. Typically CancelFD is the
184*76321b9fSConnor Sughrue // read end of a unidirectional pipe.
185*76321b9fSConnor Sughrue //
186*76321b9fSConnor Sughrue // Timeout should be -1 to block indefinitly
187*76321b9fSConnor Sughrue //
188*76321b9fSConnor Sughrue // getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
189*76321b9fSConnor Sughrue static std::error_code
190*76321b9fSConnor Sughrue manageTimeout(const std::chrono::milliseconds &Timeout,
191*76321b9fSConnor Sughrue               const std::function<int()> &getActiveFD,
192*76321b9fSConnor Sughrue               const std::optional<int> &CancelFD = std::nullopt) {
193*76321b9fSConnor Sughrue   struct pollfd FD[2];
194*76321b9fSConnor Sughrue   FD[0].events = POLLIN;
19587e6f87fSConnor Sughrue #ifdef _WIN32
196*76321b9fSConnor Sughrue   SOCKET WinServerSock = _get_osfhandle(getActiveFD());
197*76321b9fSConnor Sughrue   FD[0].fd = WinServerSock;
19887e6f87fSConnor Sughrue #else
199*76321b9fSConnor Sughrue   FD[0].fd = getActiveFD();
20087e6f87fSConnor Sughrue #endif
201*76321b9fSConnor Sughrue   uint8_t FDCount = 1;
202*76321b9fSConnor Sughrue   if (CancelFD.has_value()) {
203*76321b9fSConnor Sughrue     FD[1].events = POLLIN;
204*76321b9fSConnor Sughrue     FD[1].fd = CancelFD.value();
205*76321b9fSConnor Sughrue     FDCount++;
206*76321b9fSConnor Sughrue   }
20787e6f87fSConnor Sughrue 
208*76321b9fSConnor Sughrue   // Keep track of how much time has passed in case ::poll or WSAPoll are
209*76321b9fSConnor Sughrue   // interupted by a signal and need to be recalled
21087e6f87fSConnor Sughrue   auto Start = std::chrono::steady_clock::now();
211*76321b9fSConnor Sughrue   auto RemainingTimeout = Timeout;
212*76321b9fSConnor Sughrue   int PollStatus = 0;
213*76321b9fSConnor Sughrue   do {
214*76321b9fSConnor Sughrue     // If Timeout is -1 then poll should block and RemainingTimeout does not
215*76321b9fSConnor Sughrue     // need to be recalculated
216*76321b9fSConnor Sughrue     if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) {
217*76321b9fSConnor Sughrue       auto TotalElapsedTime =
218*76321b9fSConnor Sughrue           std::chrono::duration_cast<std::chrono::milliseconds>(
219*76321b9fSConnor Sughrue               std::chrono::steady_clock::now() - Start);
220*76321b9fSConnor Sughrue 
221*76321b9fSConnor Sughrue       if (TotalElapsedTime >= Timeout)
222*76321b9fSConnor Sughrue         return std::make_error_code(std::errc::operation_would_block);
223*76321b9fSConnor Sughrue 
224*76321b9fSConnor Sughrue       RemainingTimeout = Timeout - TotalElapsedTime;
225*76321b9fSConnor Sughrue     }
22687e6f87fSConnor Sughrue #ifdef _WIN32
227*76321b9fSConnor Sughrue     PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count());
228*76321b9fSConnor Sughrue   } while (PollStatus == SOCKET_ERROR &&
229*76321b9fSConnor Sughrue            getLastSocketErrorCode() == std::errc::interrupted);
23087e6f87fSConnor Sughrue #else
231*76321b9fSConnor Sughrue     PollStatus = ::poll(FD, FDCount, RemainingTimeout.count());
232*76321b9fSConnor Sughrue   } while (PollStatus == -1 &&
233*76321b9fSConnor Sughrue            getLastSocketErrorCode() == std::errc::interrupted);
234203232ffSConnor Sughrue #endif
235203232ffSConnor Sughrue 
236*76321b9fSConnor Sughrue   // If ActiveFD equals -1 or CancelFD has data to be read then the operation
237*76321b9fSConnor Sughrue   // has been canceled by another thread
238*76321b9fSConnor Sughrue   if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN))
239*76321b9fSConnor Sughrue     return std::make_error_code(std::errc::operation_canceled);
240203232ffSConnor Sughrue #if _WIN32
241*76321b9fSConnor Sughrue   if (PollStatus == SOCKET_ERROR)
242203232ffSConnor Sughrue #else
243*76321b9fSConnor Sughrue   if (PollStatus == -1)
24487e6f87fSConnor Sughrue #endif
245*76321b9fSConnor Sughrue     return getLastSocketErrorCode();
24687e6f87fSConnor Sughrue   if (PollStatus == 0)
247*76321b9fSConnor Sughrue     return std::make_error_code(std::errc::timed_out);
248*76321b9fSConnor Sughrue   if (FD[0].revents & POLLNVAL)
249*76321b9fSConnor Sughrue     return std::make_error_code(std::errc::bad_file_descriptor);
250*76321b9fSConnor Sughrue   return std::error_code();
25187e6f87fSConnor Sughrue }
25287e6f87fSConnor Sughrue 
253*76321b9fSConnor Sughrue Expected<std::unique_ptr<raw_socket_stream>>
254*76321b9fSConnor Sughrue ListeningSocket::accept(const std::chrono::milliseconds &Timeout) {
255*76321b9fSConnor Sughrue   auto getActiveFD = [this]() -> int { return FD; };
256*76321b9fSConnor Sughrue   std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
257*76321b9fSConnor Sughrue   if (TimeoutErr)
258*76321b9fSConnor Sughrue     return llvm::make_error<StringError>(TimeoutErr, "Timeout error");
259*76321b9fSConnor Sughrue 
26087e6f87fSConnor Sughrue   int AcceptFD;
26187e6f87fSConnor Sughrue #ifdef _WIN32
262*76321b9fSConnor Sughrue   SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
26387e6f87fSConnor Sughrue   AcceptFD = _open_osfhandle(WinAcceptSock, 0);
26487e6f87fSConnor Sughrue #else
26587e6f87fSConnor Sughrue   AcceptFD = ::accept(FD, NULL, NULL);
26687e6f87fSConnor Sughrue #endif
26787e6f87fSConnor Sughrue 
26887e6f87fSConnor Sughrue   if (AcceptFD == -1)
26987e6f87fSConnor Sughrue     return llvm::make_error<StringError>(getLastSocketErrorCode(),
27087e6f87fSConnor Sughrue                                          "Socket accept failed");
27187e6f87fSConnor Sughrue   return std::make_unique<raw_socket_stream>(AcceptFD);
27287e6f87fSConnor Sughrue }
27387e6f87fSConnor Sughrue 
27487e6f87fSConnor Sughrue void ListeningSocket::shutdown() {
27587e6f87fSConnor Sughrue   int ObservedFD = FD.load();
27687e6f87fSConnor Sughrue 
27787e6f87fSConnor Sughrue   if (ObservedFD == -1)
27887e6f87fSConnor Sughrue     return;
27987e6f87fSConnor Sughrue 
28087e6f87fSConnor Sughrue   // If FD equals ObservedFD set FD to -1; If FD doesn't equal ObservedFD then
28187e6f87fSConnor Sughrue   // another thread is responsible for shutdown so return
28287e6f87fSConnor Sughrue   if (!FD.compare_exchange_strong(ObservedFD, -1))
28387e6f87fSConnor Sughrue     return;
28487e6f87fSConnor Sughrue 
28587e6f87fSConnor Sughrue   ::close(ObservedFD);
28687e6f87fSConnor Sughrue   ::unlink(SocketPath.c_str());
28787e6f87fSConnor Sughrue 
288d4a01549SJay Foad   // Ensure ::poll returns if shutdown is called by a separate thread
28987e6f87fSConnor Sughrue   char Byte = 'A';
2906b46166eSJordan Rupprecht   ssize_t written = ::write(PipeFD[1], &Byte, 1);
2916b46166eSJordan Rupprecht 
2926b46166eSJordan Rupprecht   // Ignore any write() error
2936b46166eSJordan Rupprecht   (void)written;
29487e6f87fSConnor Sughrue }
29587e6f87fSConnor Sughrue 
29687e6f87fSConnor Sughrue ListeningSocket::~ListeningSocket() {
29787e6f87fSConnor Sughrue   shutdown();
29887e6f87fSConnor Sughrue 
29987e6f87fSConnor Sughrue   // Close the pipe's FDs in the destructor instead of within
30087e6f87fSConnor Sughrue   // ListeningSocket::shutdown to avoid unnecessary synchronization issues that
30187e6f87fSConnor Sughrue   // would occur as PipeFD's values would have to be changed to -1
30287e6f87fSConnor Sughrue   //
30387e6f87fSConnor Sughrue   // The move constructor sets PipeFD to -1
30487e6f87fSConnor Sughrue   if (PipeFD[0] != -1)
30587e6f87fSConnor Sughrue     ::close(PipeFD[0]);
30687e6f87fSConnor Sughrue   if (PipeFD[1] != -1)
30787e6f87fSConnor Sughrue     ::close(PipeFD[1]);
30887e6f87fSConnor Sughrue }
30987e6f87fSConnor Sughrue 
31087e6f87fSConnor Sughrue //===----------------------------------------------------------------------===//
31187e6f87fSConnor Sughrue //  raw_socket_stream
31287e6f87fSConnor Sughrue //===----------------------------------------------------------------------===//
31387e6f87fSConnor Sughrue 
314c0d5d36dScriis raw_socket_stream::raw_socket_stream(int SocketFD)
315c0d5d36dScriis     : raw_fd_stream(SocketFD, true) {}
316c0d5d36dScriis 
317*76321b9fSConnor Sughrue raw_socket_stream::~raw_socket_stream() {}
318*76321b9fSConnor Sughrue 
319c0d5d36dScriis Expected<std::unique_ptr<raw_socket_stream>>
320c0d5d36dScriis raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
321c0d5d36dScriis #ifdef _WIN32
322c0d5d36dScriis   WSABalancer _;
323c0d5d36dScriis #endif // _WIN32
32487e6f87fSConnor Sughrue   Expected<int> FD = getSocketFD(SocketPath);
325c0d5d36dScriis   if (!FD)
326c0d5d36dScriis     return FD.takeError();
327c0d5d36dScriis   return std::make_unique<raw_socket_stream>(*FD);
328c0d5d36dScriis }
329c0d5d36dScriis 
330*76321b9fSConnor Sughrue ssize_t raw_socket_stream::read(char *Ptr, size_t Size,
331*76321b9fSConnor Sughrue                                 const std::chrono::milliseconds &Timeout) {
332*76321b9fSConnor Sughrue   auto getActiveFD = [this]() -> int { return this->get_fd(); };
333*76321b9fSConnor Sughrue   std::error_code Err = manageTimeout(Timeout, getActiveFD);
334*76321b9fSConnor Sughrue   // Mimic raw_fd_stream::read error handling behavior
335*76321b9fSConnor Sughrue   if (Err) {
336*76321b9fSConnor Sughrue     raw_fd_stream::error_detected(Err);
337*76321b9fSConnor Sughrue     return -1;
338*76321b9fSConnor Sughrue   }
339*76321b9fSConnor Sughrue   return raw_fd_stream::read(Ptr, Size);
340*76321b9fSConnor Sughrue }
341