1 //===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The language server protocol is usually implemented by writing messages as 10 // JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON 11 // transport interface that handles this communication. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H 16 #define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H 17 18 #include "mlir/Support/DebugStringHelper.h" 19 #include "mlir/Support/LLVM.h" 20 #include "mlir/Tools/lsp-server-support/Logging.h" 21 #include "mlir/Tools/lsp-server-support/Protocol.h" 22 #include "llvm/ADT/FunctionExtras.h" 23 #include "llvm/ADT/StringMap.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/FormatAdapters.h" 26 #include "llvm/Support/JSON.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <atomic> 29 30 namespace mlir { 31 namespace lsp { 32 class MessageHandler; 33 34 //===----------------------------------------------------------------------===// 35 // JSONTransport 36 //===----------------------------------------------------------------------===// 37 38 /// The encoding style of the JSON-RPC messages (both input and output). 39 enum JSONStreamStyle { 40 /// Encoding per the LSP specification, with mandatory Content-Length header. 41 Standard, 42 /// Messages are delimited by a '// -----' line. Comment lines start with //. 43 Delimited 44 }; 45 46 /// A transport class that performs the JSON-RPC communication with the LSP 47 /// client. 48 class JSONTransport { 49 public: 50 JSONTransport(std::FILE *in, raw_ostream &out, 51 JSONStreamStyle style = JSONStreamStyle::Standard, 52 bool prettyOutput = false) in(in)53 : in(in), out(out), style(style), prettyOutput(prettyOutput) {} 54 55 /// The following methods are used to send a message to the LSP client. 56 void notify(StringRef method, llvm::json::Value params); 57 void call(StringRef method, llvm::json::Value params, llvm::json::Value id); 58 void reply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result); 59 60 /// Start executing the JSON-RPC transport. 61 llvm::Error run(MessageHandler &handler); 62 63 private: 64 /// Dispatches the given incoming json message to the message handler. 65 bool handleMessage(llvm::json::Value msg, MessageHandler &handler); 66 /// Writes the given message to the output stream. 67 void sendMessage(llvm::json::Value msg); 68 69 /// Read in a message from the input stream. readMessage(std::string & json)70 LogicalResult readMessage(std::string &json) { 71 return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) 72 : readStandardMessage(json); 73 } 74 LogicalResult readDelimitedMessage(std::string &json); 75 LogicalResult readStandardMessage(std::string &json); 76 77 /// An output buffer used when building output messages. 78 SmallVector<char, 0> outputBuffer; 79 /// The input file stream. 80 std::FILE *in; 81 /// The output file stream. 82 raw_ostream &out; 83 /// The JSON stream style to use. 84 JSONStreamStyle style; 85 /// If the output JSON should be formatted for easier readability. 86 bool prettyOutput; 87 }; 88 89 //===----------------------------------------------------------------------===// 90 // MessageHandler 91 //===----------------------------------------------------------------------===// 92 93 /// A Callback<T> is a void function that accepts Expected<T>. This is 94 /// accepted by functions that logically return T. 95 template <typename T> 96 using Callback = llvm::unique_function<void(llvm::Expected<T>)>; 97 98 /// An OutgoingNotification<T> is a function used for outgoing notifications 99 /// send to the client. 100 template <typename T> 101 using OutgoingNotification = llvm::unique_function<void(const T &)>; 102 103 /// An OutgoingRequest<T> is a function used for outgoing requests to send to 104 /// the client. 105 template <typename T> 106 using OutgoingRequest = 107 llvm::unique_function<void(const T &, llvm::json::Value id)>; 108 109 /// An `OutgoingRequestCallback` is invoked when an outgoing request to the 110 /// client receives a response in turn. It is passed the original request's ID, 111 /// as well as the response result. 112 template <typename T> 113 using OutgoingRequestCallback = 114 std::function<void(llvm::json::Value, llvm::Expected<T>)>; 115 116 /// A handler used to process the incoming transport messages. 117 class MessageHandler { 118 public: MessageHandler(JSONTransport & transport)119 MessageHandler(JSONTransport &transport) : transport(transport) {} 120 121 bool onNotify(StringRef method, llvm::json::Value value); 122 bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id); 123 bool onReply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result); 124 125 template <typename T> parse(const llvm::json::Value & raw,StringRef payloadName,StringRef payloadKind)126 static llvm::Expected<T> parse(const llvm::json::Value &raw, 127 StringRef payloadName, StringRef payloadKind) { 128 T result; 129 llvm::json::Path::Root root; 130 if (fromJSON(raw, result, root)) 131 return std::move(result); 132 133 // Dump the relevant parts of the broken message. 134 std::string context; 135 llvm::raw_string_ostream os(context); 136 root.printErrorContext(raw, os); 137 138 // Report the error (e.g. to the client). 139 return llvm::make_error<LSPError>( 140 llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind, 141 fmt_consume(root.getError())), 142 ErrorCode::InvalidParams); 143 } 144 145 template <typename Param, typename Result, typename ThisT> method(llvm::StringLiteral method,ThisT * thisPtr,void (ThisT::* handler)(const Param &,Callback<Result>))146 void method(llvm::StringLiteral method, ThisT *thisPtr, 147 void (ThisT::*handler)(const Param &, Callback<Result>)) { 148 methodHandlers[method] = [method, handler, 149 thisPtr](llvm::json::Value rawParams, 150 Callback<llvm::json::Value> reply) { 151 llvm::Expected<Param> param = parse<Param>(rawParams, method, "request"); 152 if (!param) 153 return reply(param.takeError()); 154 (thisPtr->*handler)(*param, std::move(reply)); 155 }; 156 } 157 158 template <typename Param, typename ThisT> notification(llvm::StringLiteral method,ThisT * thisPtr,void (ThisT::* handler)(const Param &))159 void notification(llvm::StringLiteral method, ThisT *thisPtr, 160 void (ThisT::*handler)(const Param &)) { 161 notificationHandlers[method] = [method, handler, 162 thisPtr](llvm::json::Value rawParams) { 163 llvm::Expected<Param> param = 164 parse<Param>(rawParams, method, "notification"); 165 if (!param) { 166 return llvm::consumeError( 167 llvm::handleErrors(param.takeError(), [](const LSPError &lspError) { 168 Logger::error("JSON parsing error: {0}", 169 lspError.message.c_str()); 170 })); 171 } 172 (thisPtr->*handler)(*param); 173 }; 174 } 175 176 /// Create an OutgoingNotification object used for the given method. 177 template <typename T> outgoingNotification(llvm::StringLiteral method)178 OutgoingNotification<T> outgoingNotification(llvm::StringLiteral method) { 179 return [&, method](const T ¶ms) { 180 std::lock_guard<std::mutex> transportLock(transportOutputMutex); 181 Logger::info("--> {0}", method); 182 transport.notify(method, llvm::json::Value(params)); 183 }; 184 } 185 186 /// Create an OutgoingRequest function that, when called, sends a request with 187 /// the given method via the transport. Should the outgoing request be 188 /// met with a response, the result JSON is parsed and the response callback 189 /// is invoked. 190 template <typename Param, typename Result> 191 OutgoingRequest<Param> outgoingRequest(llvm::StringLiteral method,OutgoingRequestCallback<Result> callback)192 outgoingRequest(llvm::StringLiteral method, 193 OutgoingRequestCallback<Result> callback) { 194 return [&, method, callback](const Param ¶m, llvm::json::Value id) { 195 auto callbackWrapper = [method, callback = std::move(callback)]( 196 llvm::json::Value id, 197 llvm::Expected<llvm::json::Value> value) { 198 if (!value) 199 return callback(std::move(id), value.takeError()); 200 201 std::string responseName = llvm::formatv("reply:{0}({1})", method, id); 202 llvm::Expected<Result> result = 203 parse<Result>(*value, responseName, "response"); 204 if (!result) 205 return callback(std::move(id), result.takeError()); 206 207 return callback(std::move(id), *result); 208 }; 209 210 { 211 std::lock_guard<std::mutex> lock(responseHandlersMutex); 212 responseHandlers.insert( 213 {debugString(id), std::make_pair(method.str(), callbackWrapper)}); 214 } 215 216 std::lock_guard<std::mutex> transportLock(transportOutputMutex); 217 Logger::info("--> {0}({1})", method, id); 218 transport.call(method, llvm::json::Value(param), id); 219 }; 220 } 221 222 private: 223 template <typename HandlerT> 224 using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>; 225 226 HandlerMap<void(llvm::json::Value)> notificationHandlers; 227 HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)> 228 methodHandlers; 229 230 /// A pair of (1) the original request's method name, and (2) the callback 231 /// function to be invoked for responses. 232 using ResponseHandlerTy = 233 std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>; 234 /// A mapping from request/response ID to response handler. 235 llvm::StringMap<ResponseHandlerTy> responseHandlers; 236 /// Mutex to guard insertion into the response handler map. 237 std::mutex responseHandlersMutex; 238 239 JSONTransport &transport; 240 241 /// Mutex to guard sending output messages to the transport. 242 std::mutex transportOutputMutex; 243 }; 244 245 } // namespace lsp 246 } // namespace mlir 247 248 #endif 249