1 //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/lsp-server-support/Transport.h" 10 #include "mlir/Support/ToolUtilities.h" 11 #include "mlir/Tools/lsp-server-support/Logging.h" 12 #include "mlir/Tools/lsp-server-support/Protocol.h" 13 #include "llvm/ADT/SmallString.h" 14 #include "llvm/Support/Errno.h" 15 #include "llvm/Support/Error.h" 16 #include <optional> 17 #include <system_error> 18 #include <utility> 19 20 using namespace mlir; 21 using namespace mlir::lsp; 22 23 //===----------------------------------------------------------------------===// 24 // Reply 25 //===----------------------------------------------------------------------===// 26 27 namespace { 28 /// Function object to reply to an LSP call. 29 /// Each instance must be called exactly once, otherwise: 30 /// - if there was no reply, an error reply is sent 31 /// - if there were multiple replies, only the first is sent 32 class Reply { 33 public: 34 Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport, 35 std::mutex &transportOutputMutex); 36 Reply(Reply &&other); 37 Reply &operator=(Reply &&) = delete; 38 Reply(const Reply &) = delete; 39 Reply &operator=(const Reply &) = delete; 40 41 void operator()(llvm::Expected<llvm::json::Value> reply); 42 43 private: 44 std::string method; 45 std::atomic<bool> replied = {false}; 46 llvm::json::Value id; 47 JSONTransport *transport; 48 std::mutex &transportOutputMutex; 49 }; 50 } // namespace 51 52 Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, 53 JSONTransport &transport, std::mutex &transportOutputMutex) 54 : method(method), id(id), transport(&transport), 55 transportOutputMutex(transportOutputMutex) {} 56 57 Reply::Reply(Reply &&other) 58 : method(other.method), replied(other.replied.load()), 59 id(std::move(other.id)), transport(other.transport), 60 transportOutputMutex(other.transportOutputMutex) { 61 other.transport = nullptr; 62 } 63 64 void Reply::operator()(llvm::Expected<llvm::json::Value> reply) { 65 if (replied.exchange(true)) { 66 Logger::error("Replied twice to message {0}({1})", method, id); 67 assert(false && "must reply to each call only once!"); 68 return; 69 } 70 assert(transport && "expected valid transport to reply to"); 71 72 std::lock_guard<std::mutex> transportLock(transportOutputMutex); 73 if (reply) { 74 Logger::info("--> reply:{0}({1})", method, id); 75 transport->reply(std::move(id), std::move(reply)); 76 } else { 77 llvm::Error error = reply.takeError(); 78 Logger::info("--> reply:{0}({1}): {2}", method, id, error); 79 transport->reply(std::move(id), std::move(error)); 80 } 81 } 82 83 //===----------------------------------------------------------------------===// 84 // MessageHandler 85 //===----------------------------------------------------------------------===// 86 87 bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) { 88 Logger::info("--> {0}", method); 89 90 if (method == "exit") 91 return false; 92 if (method == "$cancel") { 93 // TODO: Add support for cancelling requests. 94 } else { 95 auto it = notificationHandlers.find(method); 96 if (it != notificationHandlers.end()) 97 it->second(std::move(value)); 98 } 99 return true; 100 } 101 102 bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, 103 llvm::json::Value id) { 104 Logger::info("--> {0}({1})", method, id); 105 106 Reply reply(id, method, transport, transportOutputMutex); 107 108 auto it = methodHandlers.find(method); 109 if (it != methodHandlers.end()) { 110 it->second(std::move(params), std::move(reply)); 111 } else { 112 reply(llvm::make_error<LSPError>("method not found: " + method.str(), 113 ErrorCode::MethodNotFound)); 114 } 115 return true; 116 } 117 118 bool MessageHandler::onReply(llvm::json::Value id, 119 llvm::Expected<llvm::json::Value> result) { 120 // Find the response handler in the mapping. If it exists, move it out of the 121 // mapping and erase it. 122 ResponseHandlerTy responseHandler; 123 { 124 std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex); 125 auto it = responseHandlers.find(debugString(id)); 126 if (it != responseHandlers.end()) { 127 responseHandler = std::move(it->second); 128 responseHandlers.erase(it); 129 } 130 } 131 132 // If we found a response handler, invoke it. Otherwise, log an error. 133 if (responseHandler.second) { 134 Logger::info("--> reply:{0}({1})", responseHandler.first, id); 135 responseHandler.second(std::move(id), std::move(result)); 136 } else { 137 Logger::error( 138 "received a reply with ID {0}, but there was no such outgoing request", 139 id); 140 if (!result) 141 llvm::consumeError(result.takeError()); 142 } 143 return true; 144 } 145 146 //===----------------------------------------------------------------------===// 147 // JSONTransport 148 //===----------------------------------------------------------------------===// 149 150 /// Encode the given error as a JSON object. 151 static llvm::json::Object encodeError(llvm::Error error) { 152 std::string message; 153 ErrorCode code = ErrorCode::UnknownErrorCode; 154 auto handlerFn = [&](const LSPError &lspError) -> llvm::Error { 155 message = lspError.message; 156 code = lspError.code; 157 return llvm::Error::success(); 158 }; 159 if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn)) 160 message = llvm::toString(std::move(unhandled)); 161 162 return llvm::json::Object{ 163 {"message", std::move(message)}, 164 {"code", int64_t(code)}, 165 }; 166 } 167 168 /// Decode the given JSON object into an error. 169 llvm::Error decodeError(const llvm::json::Object &o) { 170 StringRef msg = o.getString("message").value_or("Unspecified error"); 171 if (std::optional<int64_t> code = o.getInteger("code")) 172 return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code)); 173 return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(), 174 msg.str()); 175 } 176 177 void JSONTransport::notify(StringRef method, llvm::json::Value params) { 178 sendMessage(llvm::json::Object{ 179 {"jsonrpc", "2.0"}, 180 {"method", method}, 181 {"params", std::move(params)}, 182 }); 183 } 184 void JSONTransport::call(StringRef method, llvm::json::Value params, 185 llvm::json::Value id) { 186 sendMessage(llvm::json::Object{ 187 {"jsonrpc", "2.0"}, 188 {"id", std::move(id)}, 189 {"method", method}, 190 {"params", std::move(params)}, 191 }); 192 } 193 void JSONTransport::reply(llvm::json::Value id, 194 llvm::Expected<llvm::json::Value> result) { 195 if (result) { 196 return sendMessage(llvm::json::Object{ 197 {"jsonrpc", "2.0"}, 198 {"id", std::move(id)}, 199 {"result", std::move(*result)}, 200 }); 201 } 202 203 sendMessage(llvm::json::Object{ 204 {"jsonrpc", "2.0"}, 205 {"id", std::move(id)}, 206 {"error", encodeError(result.takeError())}, 207 }); 208 } 209 210 llvm::Error JSONTransport::run(MessageHandler &handler) { 211 std::string json; 212 while (!feof(in)) { 213 if (ferror(in)) { 214 return llvm::errorCodeToError( 215 std::error_code(errno, std::system_category())); 216 } 217 218 if (succeeded(readMessage(json))) { 219 if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) { 220 if (!handleMessage(std::move(*doc), handler)) 221 return llvm::Error::success(); 222 } else { 223 Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError())); 224 } 225 } 226 } 227 return llvm::errorCodeToError(std::make_error_code(std::errc::io_error)); 228 } 229 230 void JSONTransport::sendMessage(llvm::json::Value msg) { 231 outputBuffer.clear(); 232 llvm::raw_svector_ostream os(outputBuffer); 233 os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg); 234 out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n" 235 << outputBuffer; 236 out.flush(); 237 Logger::debug(">>> {0}\n", outputBuffer); 238 } 239 240 bool JSONTransport::handleMessage(llvm::json::Value msg, 241 MessageHandler &handler) { 242 // Message must be an object with "jsonrpc":"2.0". 243 llvm::json::Object *object = msg.getAsObject(); 244 if (!object || 245 object->getString("jsonrpc") != std::optional<StringRef>("2.0")) 246 return false; 247 248 // `id` may be any JSON value. If absent, this is a notification. 249 std::optional<llvm::json::Value> id; 250 if (llvm::json::Value *i = object->get("id")) 251 id = std::move(*i); 252 std::optional<StringRef> method = object->getString("method"); 253 254 // This is a response. 255 if (!method) { 256 if (!id) 257 return false; 258 if (auto *err = object->getObject("error")) 259 return handler.onReply(std::move(*id), decodeError(*err)); 260 // result should be given, use null if not. 261 llvm::json::Value result = nullptr; 262 if (llvm::json::Value *r = object->get("result")) 263 result = std::move(*r); 264 return handler.onReply(std::move(*id), std::move(result)); 265 } 266 267 // Params should be given, use null if not. 268 llvm::json::Value params = nullptr; 269 if (llvm::json::Value *p = object->get("params")) 270 params = std::move(*p); 271 272 if (id) 273 return handler.onCall(*method, std::move(params), std::move(*id)); 274 return handler.onNotify(*method, std::move(params)); 275 } 276 277 /// Tries to read a line up to and including \n. 278 /// If failing, feof(), ferror(), or shutdownRequested() will be set. 279 LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) { 280 // Big enough to hold any reasonable header line. May not fit content lines 281 // in delimited mode, but performance doesn't matter for that mode. 282 static constexpr int bufSize = 128; 283 size_t size = 0; 284 out.clear(); 285 for (;;) { 286 out.resize_for_overwrite(size + bufSize); 287 if (!std::fgets(&out[size], bufSize, in)) 288 return failure(); 289 290 clearerr(in); 291 292 // If the line contained null bytes, anything after it (including \n) will 293 // be ignored. Fortunately this is not a legal header or JSON. 294 size_t read = std::strlen(&out[size]); 295 if (read > 0 && out[size + read - 1] == '\n') { 296 out.resize(size + read); 297 return success(); 298 } 299 size += read; 300 } 301 } 302 303 // Returns std::nullopt when: 304 // - ferror(), feof(), or shutdownRequested() are set. 305 // - Content-Length is missing or empty (protocol error) 306 LogicalResult JSONTransport::readStandardMessage(std::string &json) { 307 // A Language Server Protocol message starts with a set of HTTP headers, 308 // delimited by \r\n, and terminated by an empty line (\r\n). 309 unsigned long long contentLength = 0; 310 llvm::SmallString<128> line; 311 while (true) { 312 if (feof(in) || ferror(in) || failed(readLine(in, line))) 313 return failure(); 314 315 // Content-Length is a mandatory header, and the only one we handle. 316 StringRef lineRef = line; 317 if (lineRef.consume_front("Content-Length: ")) { 318 llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength); 319 } else if (!lineRef.trim().empty()) { 320 // It's another header, ignore it. 321 continue; 322 } else { 323 // An empty line indicates the end of headers. Go ahead and read the JSON. 324 break; 325 } 326 } 327 328 // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" 329 if (contentLength == 0 || contentLength > 1 << 30) 330 return failure(); 331 332 json.resize(contentLength); 333 for (size_t pos = 0, read; pos < contentLength; pos += read) { 334 read = std::fread(&json[pos], 1, contentLength - pos, in); 335 if (read == 0) 336 return failure(); 337 338 // If we're done, the error was transient. If we're not done, either it was 339 // transient or we'll see it again on retry. 340 clearerr(in); 341 pos += read; 342 } 343 return success(); 344 } 345 346 /// For lit tests we support a simplified syntax: 347 /// - messages are delimited by '// -----' on a line by itself 348 /// - lines starting with // are ignored. 349 /// This is a testing path, so favor simplicity over performance here. 350 /// When returning failure: feof(), ferror(), or shutdownRequested() will be 351 /// set. 352 LogicalResult JSONTransport::readDelimitedMessage(std::string &json) { 353 json.clear(); 354 llvm::SmallString<128> line; 355 while (succeeded(readLine(in, line))) { 356 StringRef lineRef = line.str().trim(); 357 if (lineRef.starts_with("//")) { 358 // Found a delimiter for the message. 359 if (lineRef == kDefaultSplitMarker) 360 break; 361 continue; 362 } 363 364 json += line; 365 } 366 367 return failure(ferror(in)); 368 } 369