1 //===-- Socket.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "lldb/Host/Socket.h" 10 11 #include "lldb/Host/Config.h" 12 #include "lldb/Host/Host.h" 13 #include "lldb/Host/MainLoop.h" 14 #include "lldb/Host/SocketAddress.h" 15 #include "lldb/Host/common/TCPSocket.h" 16 #include "lldb/Host/common/UDPSocket.h" 17 #include "lldb/Utility/LLDBLog.h" 18 #include "lldb/Utility/Log.h" 19 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/Support/Errno.h" 23 #include "llvm/Support/Error.h" 24 #include "llvm/Support/Regex.h" 25 #include "llvm/Support/WindowsError.h" 26 27 #if LLDB_ENABLE_POSIX 28 #include "lldb/Host/posix/DomainSocket.h" 29 30 #include <arpa/inet.h> 31 #include <netdb.h> 32 #include <netinet/in.h> 33 #include <netinet/tcp.h> 34 #include <sys/socket.h> 35 #include <sys/un.h> 36 #include <unistd.h> 37 #endif 38 39 #ifdef __linux__ 40 #include "lldb/Host/linux/AbstractSocket.h" 41 #endif 42 43 using namespace lldb; 44 using namespace lldb_private; 45 46 #if defined(_WIN32) 47 typedef const char *set_socket_option_arg_type; 48 typedef char *get_socket_option_arg_type; 49 const NativeSocket Socket::kInvalidSocketValue = INVALID_SOCKET; 50 const shared_fd_t SharedSocket::kInvalidFD = LLDB_INVALID_PIPE; 51 #else // #if defined(_WIN32) 52 typedef const void *set_socket_option_arg_type; 53 typedef void *get_socket_option_arg_type; 54 const NativeSocket Socket::kInvalidSocketValue = -1; 55 const shared_fd_t SharedSocket::kInvalidFD = Socket::kInvalidSocketValue; 56 #endif // #if defined(_WIN32) 57 58 static bool IsInterrupted() { 59 #if defined(_WIN32) 60 return ::WSAGetLastError() == WSAEINTR; 61 #else 62 return errno == EINTR; 63 #endif 64 } 65 66 SharedSocket::SharedSocket(const Socket *socket, Status &error) { 67 #ifdef _WIN32 68 m_socket = socket->GetNativeSocket(); 69 m_fd = kInvalidFD; 70 71 // Create a pipe to transfer WSAPROTOCOL_INFO to the child process. 72 error = m_socket_pipe.CreateNew(true); 73 if (error.Fail()) 74 return; 75 76 m_fd = m_socket_pipe.GetReadPipe(); 77 #else 78 m_fd = socket->GetNativeSocket(); 79 error = Status(); 80 #endif 81 } 82 83 Status SharedSocket::CompleteSending(lldb::pid_t child_pid) { 84 #ifdef _WIN32 85 // Transfer WSAPROTOCOL_INFO to the child process. 86 m_socket_pipe.CloseReadFileDescriptor(); 87 88 WSAPROTOCOL_INFO protocol_info; 89 if (::WSADuplicateSocket(m_socket, child_pid, &protocol_info) == 90 SOCKET_ERROR) { 91 int last_error = ::WSAGetLastError(); 92 return Status::FromErrorStringWithFormat( 93 "WSADuplicateSocket() failed, error: %d", last_error); 94 } 95 96 size_t num_bytes; 97 Status error = 98 m_socket_pipe.WriteWithTimeout(&protocol_info, sizeof(protocol_info), 99 std::chrono::seconds(10), num_bytes); 100 if (error.Fail()) 101 return error; 102 if (num_bytes != sizeof(protocol_info)) 103 return Status::FromErrorStringWithFormatv( 104 "WriteWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes", num_bytes); 105 #endif 106 return Status(); 107 } 108 109 Status SharedSocket::GetNativeSocket(shared_fd_t fd, NativeSocket &socket) { 110 #ifdef _WIN32 111 socket = Socket::kInvalidSocketValue; 112 // Read WSAPROTOCOL_INFO from the parent process and create NativeSocket. 113 WSAPROTOCOL_INFO protocol_info; 114 { 115 Pipe socket_pipe(fd, LLDB_INVALID_PIPE); 116 size_t num_bytes; 117 Status error = 118 socket_pipe.ReadWithTimeout(&protocol_info, sizeof(protocol_info), 119 std::chrono::seconds(10), num_bytes); 120 if (error.Fail()) 121 return error; 122 if (num_bytes != sizeof(protocol_info)) { 123 return Status::FromErrorStringWithFormatv( 124 "socket_pipe.ReadWithTimeout(WSAPROTOCOL_INFO) failed: {0} bytes", 125 num_bytes); 126 } 127 } 128 socket = ::WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, 129 FROM_PROTOCOL_INFO, &protocol_info, 0, 0); 130 if (socket == INVALID_SOCKET) { 131 return Status::FromErrorStringWithFormatv( 132 "WSASocket(FROM_PROTOCOL_INFO) failed: error {0}", ::WSAGetLastError()); 133 } 134 return Status(); 135 #else 136 socket = fd; 137 return Status(); 138 #endif 139 } 140 141 struct SocketScheme { 142 const char *m_scheme; 143 const Socket::SocketProtocol m_protocol; 144 }; 145 146 static SocketScheme socket_schemes[] = { 147 {"tcp", Socket::ProtocolTcp}, 148 {"udp", Socket::ProtocolUdp}, 149 {"unix", Socket::ProtocolUnixDomain}, 150 {"unix-abstract", Socket::ProtocolUnixAbstract}, 151 }; 152 153 const char * 154 Socket::FindSchemeByProtocol(const Socket::SocketProtocol protocol) { 155 for (auto s : socket_schemes) { 156 if (s.m_protocol == protocol) 157 return s.m_scheme; 158 } 159 return nullptr; 160 } 161 162 bool Socket::FindProtocolByScheme(const char *scheme, 163 Socket::SocketProtocol &protocol) { 164 for (auto s : socket_schemes) { 165 if (!strcmp(s.m_scheme, scheme)) { 166 protocol = s.m_protocol; 167 return true; 168 } 169 } 170 return false; 171 } 172 173 Socket::Socket(SocketProtocol protocol, bool should_close) 174 : IOObject(eFDTypeSocket), m_protocol(protocol), 175 m_socket(kInvalidSocketValue), m_should_close_fd(should_close) {} 176 177 Socket::~Socket() { Close(); } 178 179 llvm::Error Socket::Initialize() { 180 #if defined(_WIN32) 181 auto wVersion = WINSOCK_VERSION; 182 WSADATA wsaData; 183 int err = ::WSAStartup(wVersion, &wsaData); 184 if (err == 0) { 185 if (wsaData.wVersion < wVersion) { 186 WSACleanup(); 187 return llvm::createStringError("WSASock version is not expected."); 188 } 189 } else { 190 return llvm::errorCodeToError(llvm::mapWindowsError(::WSAGetLastError())); 191 } 192 #endif 193 194 return llvm::Error::success(); 195 } 196 197 void Socket::Terminate() { 198 #if defined(_WIN32) 199 ::WSACleanup(); 200 #endif 201 } 202 203 std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol, 204 Status &error) { 205 error.Clear(); 206 207 const bool should_close = true; 208 std::unique_ptr<Socket> socket_up; 209 switch (protocol) { 210 case ProtocolTcp: 211 socket_up = std::make_unique<TCPSocket>(should_close); 212 break; 213 case ProtocolUdp: 214 socket_up = std::make_unique<UDPSocket>(should_close); 215 break; 216 case ProtocolUnixDomain: 217 #if LLDB_ENABLE_POSIX 218 socket_up = std::make_unique<DomainSocket>(should_close); 219 #else 220 error = Status::FromErrorString( 221 "Unix domain sockets are not supported on this platform."); 222 #endif 223 break; 224 case ProtocolUnixAbstract: 225 #ifdef __linux__ 226 socket_up = std::make_unique<AbstractSocket>(); 227 #else 228 error = Status::FromErrorString( 229 "Abstract domain sockets are not supported on this platform."); 230 #endif 231 break; 232 } 233 234 if (error.Fail()) 235 socket_up.reset(); 236 237 return socket_up; 238 } 239 240 llvm::Expected<std::unique_ptr<Socket>> 241 Socket::TcpConnect(llvm::StringRef host_and_port) { 242 Log *log = GetLog(LLDBLog::Connection); 243 LLDB_LOG(log, "host_and_port = {0}", host_and_port); 244 245 Status error; 246 std::unique_ptr<Socket> connect_socket = Create(ProtocolTcp, error); 247 if (error.Fail()) 248 return error.ToError(); 249 250 error = connect_socket->Connect(host_and_port); 251 if (error.Success()) 252 return std::move(connect_socket); 253 254 return error.ToError(); 255 } 256 257 llvm::Expected<std::unique_ptr<TCPSocket>> 258 Socket::TcpListen(llvm::StringRef host_and_port, int backlog) { 259 Log *log = GetLog(LLDBLog::Connection); 260 LLDB_LOG(log, "host_and_port = {0}", host_and_port); 261 262 std::unique_ptr<TCPSocket> listen_socket( 263 new TCPSocket(/*should_close=*/true)); 264 265 Status error = listen_socket->Listen(host_and_port, backlog); 266 if (error.Fail()) 267 return error.ToError(); 268 269 return std::move(listen_socket); 270 } 271 272 llvm::Expected<std::unique_ptr<UDPSocket>> 273 Socket::UdpConnect(llvm::StringRef host_and_port) { 274 return UDPSocket::CreateConnected(host_and_port); 275 } 276 277 llvm::Expected<Socket::HostAndPort> Socket::DecodeHostAndPort(llvm::StringRef host_and_port) { 278 static llvm::Regex g_regex("([^:]+|\\[[0-9a-fA-F:]+.*\\]):([0-9]+)"); 279 HostAndPort ret; 280 llvm::SmallVector<llvm::StringRef, 3> matches; 281 if (g_regex.match(host_and_port, &matches)) { 282 ret.hostname = matches[1].str(); 283 // IPv6 addresses are wrapped in [] when specified with ports 284 if (ret.hostname.front() == '[' && ret.hostname.back() == ']') 285 ret.hostname = ret.hostname.substr(1, ret.hostname.size() - 2); 286 if (to_integer(matches[2], ret.port, 10)) 287 return ret; 288 } else { 289 // If this was unsuccessful, then check if it's simply an unsigned 16-bit 290 // integer, representing a port with an empty host. 291 if (to_integer(host_and_port, ret.port, 10)) 292 return ret; 293 } 294 295 return llvm::createStringError(llvm::inconvertibleErrorCode(), 296 "invalid host:port specification: '%s'", 297 host_and_port.str().c_str()); 298 } 299 300 IOObject::WaitableHandle Socket::GetWaitableHandle() { 301 // TODO: On Windows, use WSAEventSelect 302 return m_socket; 303 } 304 305 Status Socket::Read(void *buf, size_t &num_bytes) { 306 Status error; 307 int bytes_received = 0; 308 do { 309 bytes_received = ::recv(m_socket, static_cast<char *>(buf), num_bytes, 0); 310 } while (bytes_received < 0 && IsInterrupted()); 311 312 if (bytes_received < 0) { 313 SetLastError(error); 314 num_bytes = 0; 315 } else 316 num_bytes = bytes_received; 317 318 Log *log = GetLog(LLDBLog::Communication); 319 if (log) { 320 LLDB_LOGF(log, 321 "%p Socket::Read() (socket = %" PRIu64 322 ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64 323 " (error = %s)", 324 static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf, 325 static_cast<uint64_t>(num_bytes), 326 static_cast<int64_t>(bytes_received), error.AsCString()); 327 } 328 329 return error; 330 } 331 332 Status Socket::Write(const void *buf, size_t &num_bytes) { 333 const size_t src_len = num_bytes; 334 Status error; 335 int bytes_sent = 0; 336 do { 337 bytes_sent = Send(buf, num_bytes); 338 } while (bytes_sent < 0 && IsInterrupted()); 339 340 if (bytes_sent < 0) { 341 SetLastError(error); 342 num_bytes = 0; 343 } else 344 num_bytes = bytes_sent; 345 346 Log *log = GetLog(LLDBLog::Communication); 347 if (log) { 348 LLDB_LOGF(log, 349 "%p Socket::Write() (socket = %" PRIu64 350 ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64 351 " (error = %s)", 352 static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf, 353 static_cast<uint64_t>(src_len), 354 static_cast<int64_t>(bytes_sent), error.AsCString()); 355 } 356 357 return error; 358 } 359 360 Status Socket::Close() { 361 Status error; 362 if (!IsValid() || !m_should_close_fd) 363 return error; 364 365 Log *log = GetLog(LLDBLog::Connection); 366 LLDB_LOGF(log, "%p Socket::Close (fd = %" PRIu64 ")", 367 static_cast<void *>(this), static_cast<uint64_t>(m_socket)); 368 369 bool success = CloseSocket(m_socket) == 0; 370 // A reference to a FD was passed in, set it to an invalid value 371 m_socket = kInvalidSocketValue; 372 if (!success) { 373 SetLastError(error); 374 } 375 376 return error; 377 } 378 379 int Socket::GetOption(NativeSocket sockfd, int level, int option_name, 380 int &option_value) { 381 get_socket_option_arg_type option_value_p = 382 reinterpret_cast<get_socket_option_arg_type>(&option_value); 383 socklen_t option_value_size = sizeof(int); 384 return ::getsockopt(sockfd, level, option_name, option_value_p, 385 &option_value_size); 386 } 387 388 int Socket::SetOption(NativeSocket sockfd, int level, int option_name, 389 int option_value) { 390 set_socket_option_arg_type option_value_p = 391 reinterpret_cast<set_socket_option_arg_type>(&option_value); 392 return ::setsockopt(sockfd, level, option_name, option_value_p, 393 sizeof(option_value)); 394 } 395 396 size_t Socket::Send(const void *buf, const size_t num_bytes) { 397 return ::send(m_socket, static_cast<const char *>(buf), num_bytes, 0); 398 } 399 400 void Socket::SetLastError(Status &error) { 401 #if defined(_WIN32) 402 error = Status(::WSAGetLastError(), lldb::eErrorTypeWin32); 403 #else 404 error = Status::FromErrno(); 405 #endif 406 } 407 408 Status Socket::GetLastError() { 409 std::error_code EC; 410 #ifdef _WIN32 411 EC = llvm::mapWindowsError(WSAGetLastError()); 412 #else 413 EC = std::error_code(errno, std::generic_category()); 414 #endif 415 return EC; 416 } 417 418 int Socket::CloseSocket(NativeSocket sockfd) { 419 #ifdef _WIN32 420 return ::closesocket(sockfd); 421 #else 422 return ::close(sockfd); 423 #endif 424 } 425 426 NativeSocket Socket::CreateSocket(const int domain, const int type, 427 const int protocol, Status &error) { 428 error.Clear(); 429 auto socket_type = type; 430 #ifdef SOCK_CLOEXEC 431 socket_type |= SOCK_CLOEXEC; 432 #endif 433 auto sock = ::socket(domain, socket_type, protocol); 434 if (sock == kInvalidSocketValue) 435 SetLastError(error); 436 437 return sock; 438 } 439 440 Status Socket::Accept(const Timeout<std::micro> &timeout, Socket *&socket) { 441 socket = nullptr; 442 MainLoop accept_loop; 443 llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> expected_handles = 444 Accept(accept_loop, 445 [&accept_loop, &socket](std::unique_ptr<Socket> sock) { 446 socket = sock.release(); 447 accept_loop.RequestTermination(); 448 }); 449 if (!expected_handles) 450 return Status::FromError(expected_handles.takeError()); 451 if (timeout) { 452 accept_loop.AddCallback( 453 [](MainLoopBase &loop) { loop.RequestTermination(); }, *timeout); 454 } 455 if (Status status = accept_loop.Run(); status.Fail()) 456 return status; 457 if (socket) 458 return Status(); 459 return Status(std::make_error_code(std::errc::timed_out)); 460 } 461 462 NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr, 463 socklen_t *addrlen, Status &error) { 464 error.Clear(); 465 #if defined(SOCK_CLOEXEC) && defined(HAVE_ACCEPT4) 466 int flags = SOCK_CLOEXEC; 467 NativeSocket fd = llvm::sys::RetryAfterSignal( 468 static_cast<NativeSocket>(-1), ::accept4, sockfd, addr, addrlen, flags); 469 #else 470 NativeSocket fd = llvm::sys::RetryAfterSignal( 471 static_cast<NativeSocket>(-1), ::accept, sockfd, addr, addrlen); 472 #endif 473 if (fd == kInvalidSocketValue) 474 SetLastError(error); 475 return fd; 476 } 477 478 llvm::raw_ostream &lldb_private::operator<<(llvm::raw_ostream &OS, 479 const Socket::HostAndPort &HP) { 480 return OS << '[' << HP.hostname << ']' << ':' << HP.port; 481 } 482