xref: /llvm-project/mlir/include/mlir/Tools/lsp-server-support/Transport.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The language server protocol is usually implemented by writing messages as
10 // JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON
11 // transport interface that handles this communication.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
16 #define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
17 
18 #include "mlir/Support/DebugStringHelper.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Tools/lsp-server-support/Logging.h"
21 #include "mlir/Tools/lsp-server-support/Protocol.h"
22 #include "llvm/ADT/FunctionExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/JSON.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <atomic>
29 
30 namespace mlir {
31 namespace lsp {
32 class MessageHandler;
33 
34 //===----------------------------------------------------------------------===//
35 // JSONTransport
36 //===----------------------------------------------------------------------===//
37 
38 /// The encoding style of the JSON-RPC messages (both input and output).
39 enum JSONStreamStyle {
40   /// Encoding per the LSP specification, with mandatory Content-Length header.
41   Standard,
42   /// Messages are delimited by a '// -----' line. Comment lines start with //.
43   Delimited
44 };
45 
46 /// A transport class that performs the JSON-RPC communication with the LSP
47 /// client.
48 class JSONTransport {
49 public:
50   JSONTransport(std::FILE *in, raw_ostream &out,
51                 JSONStreamStyle style = JSONStreamStyle::Standard,
52                 bool prettyOutput = false)
in(in)53       : in(in), out(out), style(style), prettyOutput(prettyOutput) {}
54 
55   /// The following methods are used to send a message to the LSP client.
56   void notify(StringRef method, llvm::json::Value params);
57   void call(StringRef method, llvm::json::Value params, llvm::json::Value id);
58   void reply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result);
59 
60   /// Start executing the JSON-RPC transport.
61   llvm::Error run(MessageHandler &handler);
62 
63 private:
64   /// Dispatches the given incoming json message to the message handler.
65   bool handleMessage(llvm::json::Value msg, MessageHandler &handler);
66   /// Writes the given message to the output stream.
67   void sendMessage(llvm::json::Value msg);
68 
69   /// Read in a message from the input stream.
readMessage(std::string & json)70   LogicalResult readMessage(std::string &json) {
71     return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json)
72                                                : readStandardMessage(json);
73   }
74   LogicalResult readDelimitedMessage(std::string &json);
75   LogicalResult readStandardMessage(std::string &json);
76 
77   /// An output buffer used when building output messages.
78   SmallVector<char, 0> outputBuffer;
79   /// The input file stream.
80   std::FILE *in;
81   /// The output file stream.
82   raw_ostream &out;
83   /// The JSON stream style to use.
84   JSONStreamStyle style;
85   /// If the output JSON should be formatted for easier readability.
86   bool prettyOutput;
87 };
88 
89 //===----------------------------------------------------------------------===//
90 // MessageHandler
91 //===----------------------------------------------------------------------===//
92 
93 /// A Callback<T> is a void function that accepts Expected<T>. This is
94 /// accepted by functions that logically return T.
95 template <typename T>
96 using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
97 
98 /// An OutgoingNotification<T> is a function used for outgoing notifications
99 /// send to the client.
100 template <typename T>
101 using OutgoingNotification = llvm::unique_function<void(const T &)>;
102 
103 /// An OutgoingRequest<T> is a function used for outgoing requests to send to
104 /// the client.
105 template <typename T>
106 using OutgoingRequest =
107     llvm::unique_function<void(const T &, llvm::json::Value id)>;
108 
109 /// An `OutgoingRequestCallback` is invoked when an outgoing request to the
110 /// client receives a response in turn. It is passed the original request's ID,
111 /// as well as the response result.
112 template <typename T>
113 using OutgoingRequestCallback =
114     std::function<void(llvm::json::Value, llvm::Expected<T>)>;
115 
116 /// A handler used to process the incoming transport messages.
117 class MessageHandler {
118 public:
MessageHandler(JSONTransport & transport)119   MessageHandler(JSONTransport &transport) : transport(transport) {}
120 
121   bool onNotify(StringRef method, llvm::json::Value value);
122   bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id);
123   bool onReply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result);
124 
125   template <typename T>
parse(const llvm::json::Value & raw,StringRef payloadName,StringRef payloadKind)126   static llvm::Expected<T> parse(const llvm::json::Value &raw,
127                                  StringRef payloadName, StringRef payloadKind) {
128     T result;
129     llvm::json::Path::Root root;
130     if (fromJSON(raw, result, root))
131       return std::move(result);
132 
133     // Dump the relevant parts of the broken message.
134     std::string context;
135     llvm::raw_string_ostream os(context);
136     root.printErrorContext(raw, os);
137 
138     // Report the error (e.g. to the client).
139     return llvm::make_error<LSPError>(
140         llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind,
141                       fmt_consume(root.getError())),
142         ErrorCode::InvalidParams);
143   }
144 
145   template <typename Param, typename Result, typename ThisT>
method(llvm::StringLiteral method,ThisT * thisPtr,void (ThisT::* handler)(const Param &,Callback<Result>))146   void method(llvm::StringLiteral method, ThisT *thisPtr,
147               void (ThisT::*handler)(const Param &, Callback<Result>)) {
148     methodHandlers[method] = [method, handler,
149                               thisPtr](llvm::json::Value rawParams,
150                                        Callback<llvm::json::Value> reply) {
151       llvm::Expected<Param> param = parse<Param>(rawParams, method, "request");
152       if (!param)
153         return reply(param.takeError());
154       (thisPtr->*handler)(*param, std::move(reply));
155     };
156   }
157 
158   template <typename Param, typename ThisT>
notification(llvm::StringLiteral method,ThisT * thisPtr,void (ThisT::* handler)(const Param &))159   void notification(llvm::StringLiteral method, ThisT *thisPtr,
160                     void (ThisT::*handler)(const Param &)) {
161     notificationHandlers[method] = [method, handler,
162                                     thisPtr](llvm::json::Value rawParams) {
163       llvm::Expected<Param> param =
164           parse<Param>(rawParams, method, "notification");
165       if (!param) {
166         return llvm::consumeError(
167             llvm::handleErrors(param.takeError(), [](const LSPError &lspError) {
168               Logger::error("JSON parsing error: {0}",
169                             lspError.message.c_str());
170             }));
171       }
172       (thisPtr->*handler)(*param);
173     };
174   }
175 
176   /// Create an OutgoingNotification object used for the given method.
177   template <typename T>
outgoingNotification(llvm::StringLiteral method)178   OutgoingNotification<T> outgoingNotification(llvm::StringLiteral method) {
179     return [&, method](const T &params) {
180       std::lock_guard<std::mutex> transportLock(transportOutputMutex);
181       Logger::info("--> {0}", method);
182       transport.notify(method, llvm::json::Value(params));
183     };
184   }
185 
186   /// Create an OutgoingRequest function that, when called, sends a request with
187   /// the given method via the transport. Should the outgoing request be
188   /// met with a response, the result JSON is parsed and the response callback
189   /// is invoked.
190   template <typename Param, typename Result>
191   OutgoingRequest<Param>
outgoingRequest(llvm::StringLiteral method,OutgoingRequestCallback<Result> callback)192   outgoingRequest(llvm::StringLiteral method,
193                   OutgoingRequestCallback<Result> callback) {
194     return [&, method, callback](const Param &param, llvm::json::Value id) {
195       auto callbackWrapper = [method, callback = std::move(callback)](
196                                  llvm::json::Value id,
197                                  llvm::Expected<llvm::json::Value> value) {
198         if (!value)
199           return callback(std::move(id), value.takeError());
200 
201         std::string responseName = llvm::formatv("reply:{0}({1})", method, id);
202         llvm::Expected<Result> result =
203             parse<Result>(*value, responseName, "response");
204         if (!result)
205           return callback(std::move(id), result.takeError());
206 
207         return callback(std::move(id), *result);
208       };
209 
210       {
211         std::lock_guard<std::mutex> lock(responseHandlersMutex);
212         responseHandlers.insert(
213             {debugString(id), std::make_pair(method.str(), callbackWrapper)});
214       }
215 
216       std::lock_guard<std::mutex> transportLock(transportOutputMutex);
217       Logger::info("--> {0}({1})", method, id);
218       transport.call(method, llvm::json::Value(param), id);
219     };
220   }
221 
222 private:
223   template <typename HandlerT>
224   using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
225 
226   HandlerMap<void(llvm::json::Value)> notificationHandlers;
227   HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
228       methodHandlers;
229 
230   /// A pair of (1) the original request's method name, and (2) the callback
231   /// function to be invoked for responses.
232   using ResponseHandlerTy =
233       std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>;
234   /// A mapping from request/response ID to response handler.
235   llvm::StringMap<ResponseHandlerTy> responseHandlers;
236   /// Mutex to guard insertion into the response handler map.
237   std::mutex responseHandlersMutex;
238 
239   JSONTransport &transport;
240 
241   /// Mutex to guard sending output messages to the transport.
242   std::mutex transportOutputMutex;
243 };
244 
245 } // namespace lsp
246 } // namespace mlir
247 
248 #endif
249