xref: /llvm-project/lldb/source/Host/common/TCPSocket.cpp (revision d5ba143a6d8e8726c900dbfc381dab0e7d8b6a65)
1 //===-- TCPSocket.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #if defined(_MSC_VER)
10 #define _WINSOCK_DEPRECATED_NO_WARNINGS
11 #endif
12 
13 #include "lldb/Host/common/TCPSocket.h"
14 
15 #include "lldb/Host/Config.h"
16 #include "lldb/Host/MainLoop.h"
17 #include "lldb/Utility/LLDBLog.h"
18 #include "lldb/Utility/Log.h"
19 
20 #include "llvm/Config/llvm-config.h"
21 #include "llvm/Support/Errno.h"
22 #include "llvm/Support/Error.h"
23 #include "llvm/Support/WindowsError.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 #if LLDB_ENABLE_POSIX
27 #include <arpa/inet.h>
28 #include <netinet/tcp.h>
29 #include <sys/socket.h>
30 #endif
31 
32 #if defined(_WIN32)
33 #include <winsock2.h>
34 #endif
35 
36 using namespace lldb;
37 using namespace lldb_private;
38 
39 static const int kType = SOCK_STREAM;
40 
41 TCPSocket::TCPSocket(bool should_close) : Socket(ProtocolTcp, should_close) {}
42 
43 TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
44     : Socket(ProtocolTcp, listen_socket.m_should_close_fd) {
45   m_socket = socket;
46 }
47 
48 TCPSocket::TCPSocket(NativeSocket socket, bool should_close)
49     : Socket(ProtocolTcp, should_close) {
50   m_socket = socket;
51 }
52 
53 TCPSocket::~TCPSocket() { CloseListenSockets(); }
54 
55 bool TCPSocket::IsValid() const {
56   return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
57 }
58 
59 // Return the port number that is being used by the socket.
60 uint16_t TCPSocket::GetLocalPortNumber() const {
61   if (m_socket != kInvalidSocketValue) {
62     SocketAddress sock_addr;
63     socklen_t sock_addr_len = sock_addr.GetMaxLength();
64     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
65       return sock_addr.GetPort();
66   } else if (!m_listen_sockets.empty()) {
67     SocketAddress sock_addr;
68     socklen_t sock_addr_len = sock_addr.GetMaxLength();
69     if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
70                       &sock_addr_len) == 0)
71       return sock_addr.GetPort();
72   }
73   return 0;
74 }
75 
76 std::string TCPSocket::GetLocalIPAddress() const {
77   // We bound to port zero, so we need to figure out which port we actually
78   // bound to
79   if (m_socket != kInvalidSocketValue) {
80     SocketAddress sock_addr;
81     socklen_t sock_addr_len = sock_addr.GetMaxLength();
82     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
83       return sock_addr.GetIPAddress();
84   }
85   return "";
86 }
87 
88 uint16_t TCPSocket::GetRemotePortNumber() const {
89   if (m_socket != kInvalidSocketValue) {
90     SocketAddress sock_addr;
91     socklen_t sock_addr_len = sock_addr.GetMaxLength();
92     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
93       return sock_addr.GetPort();
94   }
95   return 0;
96 }
97 
98 std::string TCPSocket::GetRemoteIPAddress() const {
99   // We bound to port zero, so we need to figure out which port we actually
100   // bound to
101   if (m_socket != kInvalidSocketValue) {
102     SocketAddress sock_addr;
103     socklen_t sock_addr_len = sock_addr.GetMaxLength();
104     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
105       return sock_addr.GetIPAddress();
106   }
107   return "";
108 }
109 
110 std::string TCPSocket::GetRemoteConnectionURI() const {
111   if (m_socket != kInvalidSocketValue) {
112     return std::string(llvm::formatv(
113         "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
114   }
115   return "";
116 }
117 
118 std::vector<std::string> TCPSocket::GetListeningConnectionURI() const {
119   std::vector<std::string> URIs;
120   for (const auto &[fd, addr] : m_listen_sockets)
121     URIs.emplace_back(llvm::formatv("connection://[{0}]:{1}",
122                                     addr.GetIPAddress(), addr.GetPort()));
123   return URIs;
124 }
125 
126 Status TCPSocket::CreateSocket(int domain) {
127   Status error;
128   if (IsValid())
129     error = Close();
130   if (error.Fail())
131     return error;
132   m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP, error);
133   return error;
134 }
135 
136 Status TCPSocket::Connect(llvm::StringRef name) {
137 
138   Log *log = GetLog(LLDBLog::Communication);
139   LLDB_LOG(log, "Connect to host/port {0}", name);
140 
141   Status error;
142   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
143   if (!host_port)
144     return Status::FromError(host_port.takeError());
145 
146   std::vector<SocketAddress> addresses =
147       SocketAddress::GetAddressInfo(host_port->hostname.c_str(), nullptr,
148                                     AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
149   for (SocketAddress &address : addresses) {
150     error = CreateSocket(address.GetFamily());
151     if (error.Fail())
152       continue;
153 
154     address.SetPort(host_port->port);
155 
156     if (llvm::sys::RetryAfterSignal(-1, ::connect, GetNativeSocket(),
157                                     &address.sockaddr(),
158                                     address.GetLength()) == -1) {
159       Close();
160       continue;
161     }
162 
163     if (SetOptionNoDelay() == -1) {
164       Close();
165       continue;
166     }
167 
168     error.Clear();
169     return error;
170   }
171 
172   error = Status::FromErrorString("Failed to connect port");
173   return error;
174 }
175 
176 Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
177   Log *log = GetLog(LLDBLog::Connection);
178   LLDB_LOG(log, "Listen to {0}", name);
179 
180   Status error;
181   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
182   if (!host_port)
183     return Status::FromError(host_port.takeError());
184 
185   if (host_port->hostname == "*")
186     host_port->hostname = "0.0.0.0";
187   std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
188       host_port->hostname.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
189   for (SocketAddress &address : addresses) {
190     int fd =
191         Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP, error);
192     if (error.Fail() || fd < 0)
193       continue;
194 
195     // enable local address reuse
196     if (SetOption(fd, SOL_SOCKET, SO_REUSEADDR, 1) == -1) {
197       CloseSocket(fd);
198       continue;
199     }
200 
201     SocketAddress listen_address = address;
202     if(!listen_address.IsLocalhost())
203       listen_address.SetToAnyAddress(address.GetFamily(), host_port->port);
204     else
205       listen_address.SetPort(host_port->port);
206 
207     int err =
208         ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
209     if (err != -1)
210       err = ::listen(fd, backlog);
211 
212     if (err == -1) {
213       error = GetLastError();
214       CloseSocket(fd);
215       continue;
216     }
217 
218     if (host_port->port == 0) {
219       socklen_t sa_len = listen_address.GetLength();
220       if (getsockname(fd, &listen_address.sockaddr(), &sa_len) == 0)
221         host_port->port = listen_address.GetPort();
222     }
223     m_listen_sockets[fd] = listen_address;
224   }
225 
226   if (m_listen_sockets.empty()) {
227     assert(error.Fail());
228     return error;
229   }
230   return Status();
231 }
232 
233 void TCPSocket::CloseListenSockets() {
234   for (auto socket : m_listen_sockets)
235     CloseSocket(socket.first);
236   m_listen_sockets.clear();
237 }
238 
239 llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>>
240 TCPSocket::Accept(MainLoopBase &loop,
241                   std::function<void(std::unique_ptr<Socket> socket)> sock_cb) {
242   if (m_listen_sockets.size() == 0)
243     return llvm::createStringError("No open listening sockets!");
244 
245   std::vector<MainLoopBase::ReadHandleUP> handles;
246   for (auto socket : m_listen_sockets) {
247     auto fd = socket.first;
248     auto io_sp = std::make_shared<TCPSocket>(fd, false);
249     auto cb = [this, fd, sock_cb](MainLoopBase &loop) {
250       lldb_private::SocketAddress AcceptAddr;
251       socklen_t sa_len = AcceptAddr.GetMaxLength();
252       Status error;
253       NativeSocket sock =
254           AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, error);
255       Log *log = GetLog(LLDBLog::Host);
256       if (error.Fail()) {
257         LLDB_LOG(log, "AcceptSocket({0}): {1}", fd, error);
258         return;
259       }
260 
261       const lldb_private::SocketAddress &AddrIn = m_listen_sockets[fd];
262       if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
263         CloseSocket(sock);
264         LLDB_LOG(log, "rejecting incoming connection from {0} (expecting {1})",
265                  AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
266         return;
267       }
268       std::unique_ptr<TCPSocket> sock_up(new TCPSocket(sock, *this));
269 
270       // Keep our TCP packets coming without any delays.
271       sock_up->SetOptionNoDelay();
272 
273       sock_cb(std::move(sock_up));
274     };
275     Status error;
276     handles.emplace_back(loop.RegisterReadObject(io_sp, cb, error));
277     if (error.Fail())
278       return error.ToError();
279   }
280 
281   return handles;
282 }
283 
284 int TCPSocket::SetOptionNoDelay() {
285   return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
286 }
287 
288 int TCPSocket::SetOptionReuseAddress() {
289   return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
290 }
291