xref: /llvm-project/lldb/source/Host/common/Socket.cpp (revision 641694729df1564710c91d6778ca4f9c841b561a)
1 //===-- Socket.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "lldb/Host/Socket.h"
10 
11 #include "lldb/Host/Config.h"
12 #include "lldb/Host/Host.h"
13 #include "lldb/Host/MainLoop.h"
14 #include "lldb/Host/SocketAddress.h"
15 #include "lldb/Host/common/TCPSocket.h"
16 #include "lldb/Host/common/UDPSocket.h"
17 #include "lldb/Utility/LLDBLog.h"
18 #include "lldb/Utility/Log.h"
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Errno.h"
23 #include "llvm/Support/Error.h"
24 #include "llvm/Support/Regex.h"
25 #include "llvm/Support/WindowsError.h"
26 
27 #if LLDB_ENABLE_POSIX
28 #include "lldb/Host/posix/DomainSocket.h"
29 
30 #include <arpa/inet.h>
31 #include <netdb.h>
32 #include <netinet/in.h>
33 #include <netinet/tcp.h>
34 #include <sys/socket.h>
35 #include <sys/un.h>
36 #include <unistd.h>
37 #endif
38 
39 #ifdef __linux__
40 #include "lldb/Host/linux/AbstractSocket.h"
41 #endif
42 
43 using namespace lldb;
44 using namespace lldb_private;
45 
46 #if defined(_WIN32)
47 typedef const char *set_socket_option_arg_type;
48 typedef char *get_socket_option_arg_type;
49 const NativeSocket Socket::kInvalidSocketValue = INVALID_SOCKET;
50 const shared_fd_t SharedSocket::kInvalidFD = LLDB_INVALID_PIPE;
51 #else  // #if defined(_WIN32)
52 typedef const void *set_socket_option_arg_type;
53 typedef void *get_socket_option_arg_type;
54 const NativeSocket Socket::kInvalidSocketValue = -1;
55 const shared_fd_t SharedSocket::kInvalidFD = Socket::kInvalidSocketValue;
56 #endif // #if defined(_WIN32)
57 
58 static bool IsInterrupted() {
59 #if defined(_WIN32)
60   return ::WSAGetLastError() == WSAEINTR;
61 #else
62   return errno == EINTR;
63 #endif
64 }
65 
66 SharedSocket::SharedSocket(const Socket *socket, Status &error) {
67 #ifdef _WIN32
68   m_socket = socket->GetNativeSocket();
69   m_fd = kInvalidFD;
70 
71   // Create a pipe to transfer WSAPROTOCOL_INFO to the child process.
72   error = m_socket_pipe.CreateNew(true);
73   if (error.Fail())
74     return;
75 
76   m_fd = m_socket_pipe.GetReadPipe();
77 #else
78   m_fd = socket->GetNativeSocket();
79   error = Status();
80 #endif
81 }
82 
83 Status SharedSocket::CompleteSending(lldb::pid_t child_pid) {
84 #ifdef _WIN32
85   // Transfer WSAPROTOCOL_INFO to the child process.
86   m_socket_pipe.CloseReadFileDescriptor();
87 
88   WSAPROTOCOL_INFO protocol_info;
89   if (::WSADuplicateSocket(m_socket, child_pid, &protocol_info) ==
90       SOCKET_ERROR) {
91     int last_error = ::WSAGetLastError();
92     return Status::FromErrorStringWithFormat(
93         "WSADuplicateSocket() failed, error: %d", last_error);
94   }
95 
96   size_t num_bytes;
97   Status error =
98       m_socket_pipe.WriteWithTimeout(&protocol_info, sizeof(protocol_info),
99                                      std::chrono::seconds(10), num_bytes);
100   if (error.Fail())
101     return error;
102   if (num_bytes != sizeof(protocol_info))
103     return Status::FromErrorStringWithFormatv(
104         "WriteWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes", num_bytes);
105 #endif
106   return Status();
107 }
108 
109 Status SharedSocket::GetNativeSocket(shared_fd_t fd, NativeSocket &socket) {
110 #ifdef _WIN32
111   socket = Socket::kInvalidSocketValue;
112   // Read WSAPROTOCOL_INFO from the parent process and create NativeSocket.
113   WSAPROTOCOL_INFO protocol_info;
114   {
115     Pipe socket_pipe(fd, LLDB_INVALID_PIPE);
116     size_t num_bytes;
117     Status error =
118         socket_pipe.ReadWithTimeout(&protocol_info, sizeof(protocol_info),
119                                     std::chrono::seconds(10), num_bytes);
120     if (error.Fail())
121       return error;
122     if (num_bytes != sizeof(protocol_info)) {
123       return Status::FromErrorStringWithFormatv(
124           "socket_pipe.ReadWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes",
125           num_bytes);
126     }
127   }
128   socket = ::WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO,
129                        FROM_PROTOCOL_INFO, &protocol_info, 0, 0);
130   if (socket == INVALID_SOCKET) {
131     return Status::FromErrorStringWithFormatv(
132         "WSASocket(FROM_PROTOCOL_INFO) failed: error {0}", ::WSAGetLastError());
133   }
134   return Status();
135 #else
136   socket = fd;
137   return Status();
138 #endif
139 }
140 
141 struct SocketScheme {
142   const char *m_scheme;
143   const Socket::SocketProtocol m_protocol;
144 };
145 
146 static SocketScheme socket_schemes[] = {
147     {"tcp", Socket::ProtocolTcp},
148     {"udp", Socket::ProtocolUdp},
149     {"unix", Socket::ProtocolUnixDomain},
150     {"unix-abstract", Socket::ProtocolUnixAbstract},
151 };
152 
153 const char *
154 Socket::FindSchemeByProtocol(const Socket::SocketProtocol protocol) {
155   for (auto s : socket_schemes) {
156     if (s.m_protocol == protocol)
157       return s.m_scheme;
158   }
159   return nullptr;
160 }
161 
162 bool Socket::FindProtocolByScheme(const char *scheme,
163                                   Socket::SocketProtocol &protocol) {
164   for (auto s : socket_schemes) {
165     if (!strcmp(s.m_scheme, scheme)) {
166       protocol = s.m_protocol;
167       return true;
168     }
169   }
170   return false;
171 }
172 
173 Socket::Socket(SocketProtocol protocol, bool should_close)
174     : IOObject(eFDTypeSocket), m_protocol(protocol),
175       m_socket(kInvalidSocketValue), m_should_close_fd(should_close) {}
176 
177 Socket::~Socket() { Close(); }
178 
179 llvm::Error Socket::Initialize() {
180 #if defined(_WIN32)
181   auto wVersion = WINSOCK_VERSION;
182   WSADATA wsaData;
183   int err = ::WSAStartup(wVersion, &wsaData);
184   if (err == 0) {
185     if (wsaData.wVersion < wVersion) {
186       WSACleanup();
187       return llvm::createStringError("WSASock version is not expected.");
188     }
189   } else {
190     return llvm::errorCodeToError(llvm::mapWindowsError(::WSAGetLastError()));
191   }
192 #endif
193 
194   return llvm::Error::success();
195 }
196 
197 void Socket::Terminate() {
198 #if defined(_WIN32)
199   ::WSACleanup();
200 #endif
201 }
202 
203 std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol,
204                                        Status &error) {
205   error.Clear();
206 
207   const bool should_close = true;
208   std::unique_ptr<Socket> socket_up;
209   switch (protocol) {
210   case ProtocolTcp:
211     socket_up = std::make_unique<TCPSocket>(should_close);
212     break;
213   case ProtocolUdp:
214     socket_up = std::make_unique<UDPSocket>(should_close);
215     break;
216   case ProtocolUnixDomain:
217 #if LLDB_ENABLE_POSIX
218     socket_up = std::make_unique<DomainSocket>(should_close);
219 #else
220     error = Status::FromErrorString(
221         "Unix domain sockets are not supported on this platform.");
222 #endif
223     break;
224   case ProtocolUnixAbstract:
225 #ifdef __linux__
226     socket_up = std::make_unique<AbstractSocket>();
227 #else
228     error = Status::FromErrorString(
229         "Abstract domain sockets are not supported on this platform.");
230 #endif
231     break;
232   }
233 
234   if (error.Fail())
235     socket_up.reset();
236 
237   return socket_up;
238 }
239 
240 llvm::Expected<std::unique_ptr<Socket>>
241 Socket::TcpConnect(llvm::StringRef host_and_port) {
242   Log *log = GetLog(LLDBLog::Connection);
243   LLDB_LOG(log, "host_and_port = {0}", host_and_port);
244 
245   Status error;
246   std::unique_ptr<Socket> connect_socket = Create(ProtocolTcp, error);
247   if (error.Fail())
248     return error.ToError();
249 
250   error = connect_socket->Connect(host_and_port);
251   if (error.Success())
252     return std::move(connect_socket);
253 
254   return error.ToError();
255 }
256 
257 llvm::Expected<std::unique_ptr<TCPSocket>>
258 Socket::TcpListen(llvm::StringRef host_and_port, int backlog) {
259   Log *log = GetLog(LLDBLog::Connection);
260   LLDB_LOG(log, "host_and_port = {0}", host_and_port);
261 
262   std::unique_ptr<TCPSocket> listen_socket(
263       new TCPSocket(/*should_close=*/true));
264 
265   Status error = listen_socket->Listen(host_and_port, backlog);
266   if (error.Fail())
267     return error.ToError();
268 
269   return std::move(listen_socket);
270 }
271 
272 llvm::Expected<std::unique_ptr<UDPSocket>>
273 Socket::UdpConnect(llvm::StringRef host_and_port) {
274   return UDPSocket::CreateConnected(host_and_port);
275 }
276 
277 llvm::Expected<Socket::HostAndPort> Socket::DecodeHostAndPort(llvm::StringRef host_and_port) {
278   static llvm::Regex g_regex("([^:]+|\\[[0-9a-fA-F:]+.*\\]):([0-9]+)");
279   HostAndPort ret;
280   llvm::SmallVector<llvm::StringRef, 3> matches;
281   if (g_regex.match(host_and_port, &matches)) {
282     ret.hostname = matches[1].str();
283     // IPv6 addresses are wrapped in [] when specified with ports
284     if (ret.hostname.front() == '[' && ret.hostname.back() == ']')
285       ret.hostname = ret.hostname.substr(1, ret.hostname.size() - 2);
286     if (to_integer(matches[2], ret.port, 10))
287       return ret;
288   } else {
289     // If this was unsuccessful, then check if it's simply an unsigned 16-bit
290     // integer, representing a port with an empty host.
291     if (to_integer(host_and_port, ret.port, 10))
292       return ret;
293   }
294 
295   return llvm::createStringError(llvm::inconvertibleErrorCode(),
296                                  "invalid host:port specification: '%s'",
297                                  host_and_port.str().c_str());
298 }
299 
300 IOObject::WaitableHandle Socket::GetWaitableHandle() {
301   // TODO: On Windows, use WSAEventSelect
302   return m_socket;
303 }
304 
305 Status Socket::Read(void *buf, size_t &num_bytes) {
306   Status error;
307   int bytes_received = 0;
308   do {
309     bytes_received = ::recv(m_socket, static_cast<char *>(buf), num_bytes, 0);
310   } while (bytes_received < 0 && IsInterrupted());
311 
312   if (bytes_received < 0) {
313     SetLastError(error);
314     num_bytes = 0;
315   } else
316     num_bytes = bytes_received;
317 
318   Log *log = GetLog(LLDBLog::Communication);
319   if (log) {
320     LLDB_LOGF(log,
321               "%p Socket::Read() (socket = %" PRIu64
322               ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64
323               " (error = %s)",
324               static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf,
325               static_cast<uint64_t>(num_bytes),
326               static_cast<int64_t>(bytes_received), error.AsCString());
327   }
328 
329   return error;
330 }
331 
332 Status Socket::Write(const void *buf, size_t &num_bytes) {
333   const size_t src_len = num_bytes;
334   Status error;
335   int bytes_sent = 0;
336   do {
337     bytes_sent = Send(buf, num_bytes);
338   } while (bytes_sent < 0 && IsInterrupted());
339 
340   if (bytes_sent < 0) {
341     SetLastError(error);
342     num_bytes = 0;
343   } else
344     num_bytes = bytes_sent;
345 
346   Log *log = GetLog(LLDBLog::Communication);
347   if (log) {
348     LLDB_LOGF(log,
349               "%p Socket::Write() (socket = %" PRIu64
350               ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64
351               " (error = %s)",
352               static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf,
353               static_cast<uint64_t>(src_len),
354               static_cast<int64_t>(bytes_sent), error.AsCString());
355   }
356 
357   return error;
358 }
359 
360 Status Socket::Close() {
361   Status error;
362   if (!IsValid() || !m_should_close_fd)
363     return error;
364 
365   Log *log = GetLog(LLDBLog::Connection);
366   LLDB_LOGF(log, "%p Socket::Close (fd = %" PRIu64 ")",
367             static_cast<void *>(this), static_cast<uint64_t>(m_socket));
368 
369   bool success = CloseSocket(m_socket) == 0;
370   // A reference to a FD was passed in, set it to an invalid value
371   m_socket = kInvalidSocketValue;
372   if (!success) {
373     SetLastError(error);
374   }
375 
376   return error;
377 }
378 
379 int Socket::GetOption(NativeSocket sockfd, int level, int option_name,
380                       int &option_value) {
381   get_socket_option_arg_type option_value_p =
382       reinterpret_cast<get_socket_option_arg_type>(&option_value);
383   socklen_t option_value_size = sizeof(int);
384   return ::getsockopt(sockfd, level, option_name, option_value_p,
385                       &option_value_size);
386 }
387 
388 int Socket::SetOption(NativeSocket sockfd, int level, int option_name,
389                       int option_value) {
390   set_socket_option_arg_type option_value_p =
391       reinterpret_cast<set_socket_option_arg_type>(&option_value);
392   return ::setsockopt(sockfd, level, option_name, option_value_p,
393                       sizeof(option_value));
394 }
395 
396 size_t Socket::Send(const void *buf, const size_t num_bytes) {
397   return ::send(m_socket, static_cast<const char *>(buf), num_bytes, 0);
398 }
399 
400 void Socket::SetLastError(Status &error) {
401 #if defined(_WIN32)
402   error = Status(::WSAGetLastError(), lldb::eErrorTypeWin32);
403 #else
404   error = Status::FromErrno();
405 #endif
406 }
407 
408 Status Socket::GetLastError() {
409   std::error_code EC;
410 #ifdef _WIN32
411   EC = llvm::mapWindowsError(WSAGetLastError());
412 #else
413   EC = std::error_code(errno, std::generic_category());
414 #endif
415   return EC;
416 }
417 
418 int Socket::CloseSocket(NativeSocket sockfd) {
419 #ifdef _WIN32
420   return ::closesocket(sockfd);
421 #else
422   return ::close(sockfd);
423 #endif
424 }
425 
426 NativeSocket Socket::CreateSocket(const int domain, const int type,
427                                   const int protocol, Status &error) {
428   error.Clear();
429   auto socket_type = type;
430 #ifdef SOCK_CLOEXEC
431   socket_type |= SOCK_CLOEXEC;
432 #endif
433   auto sock = ::socket(domain, socket_type, protocol);
434   if (sock == kInvalidSocketValue)
435     SetLastError(error);
436 
437   return sock;
438 }
439 
440 Status Socket::Accept(const Timeout<std::micro> &timeout, Socket *&socket) {
441   socket = nullptr;
442   MainLoop accept_loop;
443   llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> expected_handles =
444       Accept(accept_loop,
445              [&accept_loop, &socket](std::unique_ptr<Socket> sock) {
446                socket = sock.release();
447                accept_loop.RequestTermination();
448              });
449   if (!expected_handles)
450     return Status::FromError(expected_handles.takeError());
451   if (timeout) {
452     accept_loop.AddCallback(
453         [](MainLoopBase &loop) { loop.RequestTermination(); }, *timeout);
454   }
455   if (Status status = accept_loop.Run(); status.Fail())
456     return status;
457   if (socket)
458     return Status();
459   return Status(std::make_error_code(std::errc::timed_out));
460 }
461 
462 NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr,
463                                   socklen_t *addrlen, Status &error) {
464   error.Clear();
465 #if defined(SOCK_CLOEXEC) && defined(HAVE_ACCEPT4)
466   int flags = SOCK_CLOEXEC;
467   NativeSocket fd = llvm::sys::RetryAfterSignal(
468       static_cast<NativeSocket>(-1), ::accept4, sockfd, addr, addrlen, flags);
469 #else
470   NativeSocket fd = llvm::sys::RetryAfterSignal(
471       static_cast<NativeSocket>(-1), ::accept, sockfd, addr, addrlen);
472 #endif
473   if (fd == kInvalidSocketValue)
474     SetLastError(error);
475   return fd;
476 }
477 
478 llvm::raw_ostream &lldb_private::operator<<(llvm::raw_ostream &OS,
479                                             const Socket::HostAndPort &HP) {
480   return OS << '[' << HP.hostname << ']' << ':' << HP.port;
481 }
482