xref: /llvm-project/mlir/lib/Tools/lsp-server-support/Transport.cpp (revision e40c5b42fe8f489ea4bac4694ef58f09bd95263b)
1 //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/lsp-server-support/Transport.h"
10 #include "mlir/Support/ToolUtilities.h"
11 #include "mlir/Tools/lsp-server-support/Logging.h"
12 #include "mlir/Tools/lsp-server-support/Protocol.h"
13 #include "llvm/ADT/SmallString.h"
14 #include "llvm/Support/Errno.h"
15 #include "llvm/Support/Error.h"
16 #include <optional>
17 #include <system_error>
18 #include <utility>
19 
20 using namespace mlir;
21 using namespace mlir::lsp;
22 
23 //===----------------------------------------------------------------------===//
24 // Reply
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 /// Function object to reply to an LSP call.
29 /// Each instance must be called exactly once, otherwise:
30 ///  - if there was no reply, an error reply is sent
31 ///  - if there were multiple replies, only the first is sent
32 class Reply {
33 public:
34   Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport,
35         std::mutex &transportOutputMutex);
36   Reply(Reply &&other);
37   Reply &operator=(Reply &&) = delete;
38   Reply(const Reply &) = delete;
39   Reply &operator=(const Reply &) = delete;
40 
41   void operator()(llvm::Expected<llvm::json::Value> reply);
42 
43 private:
44   std::string method;
45   std::atomic<bool> replied = {false};
46   llvm::json::Value id;
47   JSONTransport *transport;
48   std::mutex &transportOutputMutex;
49 };
50 } // namespace
51 
52 Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
53              JSONTransport &transport, std::mutex &transportOutputMutex)
54     : method(method), id(id), transport(&transport),
55       transportOutputMutex(transportOutputMutex) {}
56 
57 Reply::Reply(Reply &&other)
58     : method(other.method), replied(other.replied.load()),
59       id(std::move(other.id)), transport(other.transport),
60       transportOutputMutex(other.transportOutputMutex) {
61   other.transport = nullptr;
62 }
63 
64 void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
65   if (replied.exchange(true)) {
66     Logger::error("Replied twice to message {0}({1})", method, id);
67     assert(false && "must reply to each call only once!");
68     return;
69   }
70   assert(transport && "expected valid transport to reply to");
71 
72   std::lock_guard<std::mutex> transportLock(transportOutputMutex);
73   if (reply) {
74     Logger::info("--> reply:{0}({1})", method, id);
75     transport->reply(std::move(id), std::move(reply));
76   } else {
77     llvm::Error error = reply.takeError();
78     Logger::info("--> reply:{0}({1}): {2}", method, id, error);
79     transport->reply(std::move(id), std::move(error));
80   }
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // MessageHandler
85 //===----------------------------------------------------------------------===//
86 
87 bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
88   Logger::info("--> {0}", method);
89 
90   if (method == "exit")
91     return false;
92   if (method == "$cancel") {
93     // TODO: Add support for cancelling requests.
94   } else {
95     auto it = notificationHandlers.find(method);
96     if (it != notificationHandlers.end())
97       it->second(std::move(value));
98   }
99   return true;
100 }
101 
102 bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
103                             llvm::json::Value id) {
104   Logger::info("--> {0}({1})", method, id);
105 
106   Reply reply(id, method, transport, transportOutputMutex);
107 
108   auto it = methodHandlers.find(method);
109   if (it != methodHandlers.end()) {
110     it->second(std::move(params), std::move(reply));
111   } else {
112     reply(llvm::make_error<LSPError>("method not found: " + method.str(),
113                                      ErrorCode::MethodNotFound));
114   }
115   return true;
116 }
117 
118 bool MessageHandler::onReply(llvm::json::Value id,
119                              llvm::Expected<llvm::json::Value> result) {
120   // Find the response handler in the mapping. If it exists, move it out of the
121   // mapping and erase it.
122   ResponseHandlerTy responseHandler;
123   {
124     std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex);
125     auto it = responseHandlers.find(debugString(id));
126     if (it != responseHandlers.end()) {
127       responseHandler = std::move(it->second);
128       responseHandlers.erase(it);
129     }
130   }
131 
132   // If we found a response handler, invoke it. Otherwise, log an error.
133   if (responseHandler.second) {
134     Logger::info("--> reply:{0}({1})", responseHandler.first, id);
135     responseHandler.second(std::move(id), std::move(result));
136   } else {
137     Logger::error(
138         "received a reply with ID {0}, but there was no such outgoing request",
139         id);
140     if (!result)
141       llvm::consumeError(result.takeError());
142   }
143   return true;
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // JSONTransport
148 //===----------------------------------------------------------------------===//
149 
150 /// Encode the given error as a JSON object.
151 static llvm::json::Object encodeError(llvm::Error error) {
152   std::string message;
153   ErrorCode code = ErrorCode::UnknownErrorCode;
154   auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
155     message = lspError.message;
156     code = lspError.code;
157     return llvm::Error::success();
158   };
159   if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
160     message = llvm::toString(std::move(unhandled));
161 
162   return llvm::json::Object{
163       {"message", std::move(message)},
164       {"code", int64_t(code)},
165   };
166 }
167 
168 /// Decode the given JSON object into an error.
169 llvm::Error decodeError(const llvm::json::Object &o) {
170   StringRef msg = o.getString("message").value_or("Unspecified error");
171   if (std::optional<int64_t> code = o.getInteger("code"))
172     return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
173   return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
174                                              msg.str());
175 }
176 
177 void JSONTransport::notify(StringRef method, llvm::json::Value params) {
178   sendMessage(llvm::json::Object{
179       {"jsonrpc", "2.0"},
180       {"method", method},
181       {"params", std::move(params)},
182   });
183 }
184 void JSONTransport::call(StringRef method, llvm::json::Value params,
185                          llvm::json::Value id) {
186   sendMessage(llvm::json::Object{
187       {"jsonrpc", "2.0"},
188       {"id", std::move(id)},
189       {"method", method},
190       {"params", std::move(params)},
191   });
192 }
193 void JSONTransport::reply(llvm::json::Value id,
194                           llvm::Expected<llvm::json::Value> result) {
195   if (result) {
196     return sendMessage(llvm::json::Object{
197         {"jsonrpc", "2.0"},
198         {"id", std::move(id)},
199         {"result", std::move(*result)},
200     });
201   }
202 
203   sendMessage(llvm::json::Object{
204       {"jsonrpc", "2.0"},
205       {"id", std::move(id)},
206       {"error", encodeError(result.takeError())},
207   });
208 }
209 
210 llvm::Error JSONTransport::run(MessageHandler &handler) {
211   std::string json;
212   while (!feof(in)) {
213     if (ferror(in)) {
214       return llvm::errorCodeToError(
215           std::error_code(errno, std::system_category()));
216     }
217 
218     if (succeeded(readMessage(json))) {
219       if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
220         if (!handleMessage(std::move(*doc), handler))
221           return llvm::Error::success();
222       } else {
223         Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
224       }
225     }
226   }
227   return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
228 }
229 
230 void JSONTransport::sendMessage(llvm::json::Value msg) {
231   outputBuffer.clear();
232   llvm::raw_svector_ostream os(outputBuffer);
233   os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
234   out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
235       << outputBuffer;
236   out.flush();
237   Logger::debug(">>> {0}\n", outputBuffer);
238 }
239 
240 bool JSONTransport::handleMessage(llvm::json::Value msg,
241                                   MessageHandler &handler) {
242   // Message must be an object with "jsonrpc":"2.0".
243   llvm::json::Object *object = msg.getAsObject();
244   if (!object ||
245       object->getString("jsonrpc") != std::optional<StringRef>("2.0"))
246     return false;
247 
248   // `id` may be any JSON value. If absent, this is a notification.
249   std::optional<llvm::json::Value> id;
250   if (llvm::json::Value *i = object->get("id"))
251     id = std::move(*i);
252   std::optional<StringRef> method = object->getString("method");
253 
254   // This is a response.
255   if (!method) {
256     if (!id)
257       return false;
258     if (auto *err = object->getObject("error"))
259       return handler.onReply(std::move(*id), decodeError(*err));
260     // result should be given, use null if not.
261     llvm::json::Value result = nullptr;
262     if (llvm::json::Value *r = object->get("result"))
263       result = std::move(*r);
264     return handler.onReply(std::move(*id), std::move(result));
265   }
266 
267   // Params should be given, use null if not.
268   llvm::json::Value params = nullptr;
269   if (llvm::json::Value *p = object->get("params"))
270     params = std::move(*p);
271 
272   if (id)
273     return handler.onCall(*method, std::move(params), std::move(*id));
274   return handler.onNotify(*method, std::move(params));
275 }
276 
277 /// Tries to read a line up to and including \n.
278 /// If failing, feof(), ferror(), or shutdownRequested() will be set.
279 LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) {
280   // Big enough to hold any reasonable header line. May not fit content lines
281   // in delimited mode, but performance doesn't matter for that mode.
282   static constexpr int bufSize = 128;
283   size_t size = 0;
284   out.clear();
285   for (;;) {
286     out.resize_for_overwrite(size + bufSize);
287     if (!std::fgets(&out[size], bufSize, in))
288       return failure();
289 
290     clearerr(in);
291 
292     // If the line contained null bytes, anything after it (including \n) will
293     // be ignored. Fortunately this is not a legal header or JSON.
294     size_t read = std::strlen(&out[size]);
295     if (read > 0 && out[size + read - 1] == '\n') {
296       out.resize(size + read);
297       return success();
298     }
299     size += read;
300   }
301 }
302 
303 // Returns std::nullopt when:
304 //  - ferror(), feof(), or shutdownRequested() are set.
305 //  - Content-Length is missing or empty (protocol error)
306 LogicalResult JSONTransport::readStandardMessage(std::string &json) {
307   // A Language Server Protocol message starts with a set of HTTP headers,
308   // delimited  by \r\n, and terminated by an empty line (\r\n).
309   unsigned long long contentLength = 0;
310   llvm::SmallString<128> line;
311   while (true) {
312     if (feof(in) || ferror(in) || failed(readLine(in, line)))
313       return failure();
314 
315     // Content-Length is a mandatory header, and the only one we handle.
316     StringRef lineRef = line;
317     if (lineRef.consume_front("Content-Length: ")) {
318       llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
319     } else if (!lineRef.trim().empty()) {
320       // It's another header, ignore it.
321       continue;
322     } else {
323       // An empty line indicates the end of headers. Go ahead and read the JSON.
324       break;
325     }
326   }
327 
328   // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
329   if (contentLength == 0 || contentLength > 1 << 30)
330     return failure();
331 
332   json.resize(contentLength);
333   for (size_t pos = 0, read; pos < contentLength; pos += read) {
334     read = std::fread(&json[pos], 1, contentLength - pos, in);
335     if (read == 0)
336       return failure();
337 
338     // If we're done, the error was transient. If we're not done, either it was
339     // transient or we'll see it again on retry.
340     clearerr(in);
341     pos += read;
342   }
343   return success();
344 }
345 
346 /// For lit tests we support a simplified syntax:
347 /// - messages are delimited by '// -----' on a line by itself
348 /// - lines starting with // are ignored.
349 /// This is a testing path, so favor simplicity over performance here.
350 /// When returning failure: feof(), ferror(), or shutdownRequested() will be
351 /// set.
352 LogicalResult JSONTransport::readDelimitedMessage(std::string &json) {
353   json.clear();
354   llvm::SmallString<128> line;
355   while (succeeded(readLine(in, line))) {
356     StringRef lineRef = line.str().trim();
357     if (lineRef.starts_with("//")) {
358       // Found a delimiter for the message.
359       if (lineRef == kDefaultSplitMarker)
360         break;
361       continue;
362     }
363 
364     json += line;
365   }
366 
367   return failure(ferror(in));
368 }
369