xref: /llvm-project/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp (revision a8c88de0dff865d1969b654bb5c0df181aa2df6a)
1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LSPServer.h"
10 #include "MLIRServer.h"
11 #include "Protocol.h"
12 #include "mlir/Tools/lsp-server-support/Logging.h"
13 #include "mlir/Tools/lsp-server-support/Transport.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16 #include <optional>
17 
18 #define DEBUG_TYPE "mlir-lsp-server"
19 
20 using namespace mlir;
21 using namespace mlir::lsp;
22 
23 //===----------------------------------------------------------------------===//
24 // LSPServer
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 struct LSPServer {
LSPServer__anona0905ae50111::LSPServer29   LSPServer(MLIRServer &server) : server(server) {}
30 
31   //===--------------------------------------------------------------------===//
32   // Initialization
33 
34   void onInitialize(const InitializeParams &params,
35                     Callback<llvm::json::Value> reply);
36   void onInitialized(const InitializedParams &params);
37   void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);
38 
39   //===--------------------------------------------------------------------===//
40   // Document Change
41 
42   void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
43   void onDocumentDidClose(const DidCloseTextDocumentParams &params);
44   void onDocumentDidChange(const DidChangeTextDocumentParams &params);
45 
46   //===--------------------------------------------------------------------===//
47   // Definitions and References
48 
49   void onGoToDefinition(const TextDocumentPositionParams &params,
50                         Callback<std::vector<Location>> reply);
51   void onReference(const ReferenceParams &params,
52                    Callback<std::vector<Location>> reply);
53 
54   //===--------------------------------------------------------------------===//
55   // Hover
56 
57   void onHover(const TextDocumentPositionParams &params,
58                Callback<std::optional<Hover>> reply);
59 
60   //===--------------------------------------------------------------------===//
61   // Document Symbols
62 
63   void onDocumentSymbol(const DocumentSymbolParams &params,
64                         Callback<std::vector<DocumentSymbol>> reply);
65 
66   //===--------------------------------------------------------------------===//
67   // Code Completion
68 
69   void onCompletion(const CompletionParams &params,
70                     Callback<CompletionList> reply);
71 
72   //===--------------------------------------------------------------------===//
73   // Code Action
74 
75   void onCodeAction(const CodeActionParams &params,
76                     Callback<llvm::json::Value> reply);
77 
78   //===--------------------------------------------------------------------===//
79   // Bytecode
80 
81   void onConvertFromBytecode(const MLIRConvertBytecodeParams &params,
82                              Callback<MLIRConvertBytecodeResult> reply);
83   void onConvertToBytecode(const MLIRConvertBytecodeParams &params,
84                            Callback<MLIRConvertBytecodeResult> reply);
85 
86   //===--------------------------------------------------------------------===//
87   // Fields
88   //===--------------------------------------------------------------------===//
89 
90   MLIRServer &server;
91 
92   /// An outgoing notification used to send diagnostics to the client when they
93   /// are ready to be processed.
94   OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
95 
96   /// Used to indicate that the 'shutdown' request was received from the
97   /// Language Server client.
98   bool shutdownRequestReceived = false;
99 };
100 } // namespace
101 
102 //===----------------------------------------------------------------------===//
103 // Initialization
104 
onInitialize(const InitializeParams & params,Callback<llvm::json::Value> reply)105 void LSPServer::onInitialize(const InitializeParams &params,
106                              Callback<llvm::json::Value> reply) {
107   // Send a response with the capabilities of this server.
108   llvm::json::Object serverCaps{
109       {"textDocumentSync",
110        llvm::json::Object{
111            {"openClose", true},
112            {"change", (int)TextDocumentSyncKind::Full},
113            {"save", true},
114        }},
115       {"completionProvider",
116        llvm::json::Object{
117            {"allCommitCharacters",
118             {
119                 "\t",
120                 ";",
121                 ",",
122                 ".",
123                 "=",
124             }},
125            {"resolveProvider", false},
126            {"triggerCharacters",
127             {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
128        }},
129       {"definitionProvider", true},
130       {"referencesProvider", true},
131       {"hoverProvider", true},
132 
133       // For now we only support documenting symbols when the client supports
134       // hierarchical symbols.
135       {"documentSymbolProvider",
136        params.capabilities.hierarchicalDocumentSymbol},
137   };
138 
139   // Per LSP, codeActionProvider can be either boolean or CodeActionOptions.
140   // CodeActionOptions is only valid if the client supports action literal
141   // via textDocument.codeAction.codeActionLiteralSupport.
142   serverCaps["codeActionProvider"] =
143       params.capabilities.codeActionStructure
144           ? llvm::json::Object{{"codeActionKinds",
145                                 {CodeAction::kQuickFix, CodeAction::kRefactor,
146                                  CodeAction::kInfo}}}
147           : llvm::json::Value(true);
148 
149   llvm::json::Object result{
150       {{"serverInfo",
151         llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
152        {"capabilities", std::move(serverCaps)}}};
153   reply(std::move(result));
154 }
onInitialized(const InitializedParams &)155 void LSPServer::onInitialized(const InitializedParams &) {}
onShutdown(const NoParams &,Callback<std::nullptr_t> reply)156 void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) {
157   shutdownRequestReceived = true;
158   reply(nullptr);
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // Document Change
163 
onDocumentDidOpen(const DidOpenTextDocumentParams & params)164 void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams &params) {
165   PublishDiagnosticsParams diagParams(params.textDocument.uri,
166                                       params.textDocument.version);
167   server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
168                              params.textDocument.version,
169                              diagParams.diagnostics);
170 
171   // Publish any recorded diagnostics.
172   publishDiagnostics(diagParams);
173 }
onDocumentDidClose(const DidCloseTextDocumentParams & params)174 void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams &params) {
175   std::optional<int64_t> version =
176       server.removeDocument(params.textDocument.uri);
177   if (!version)
178     return;
179 
180   // Empty out the diagnostics shown for this document. This will clear out
181   // anything currently displayed by the client for this document (e.g. in the
182   // "Problems" pane of VSCode).
183   publishDiagnostics(
184       PublishDiagnosticsParams(params.textDocument.uri, *version));
185 }
onDocumentDidChange(const DidChangeTextDocumentParams & params)186 void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams &params) {
187   // TODO: We currently only support full document updates, we should refactor
188   // to avoid this.
189   if (params.contentChanges.size() != 1)
190     return;
191   PublishDiagnosticsParams diagParams(params.textDocument.uri,
192                                       params.textDocument.version);
193   server.addOrUpdateDocument(
194       params.textDocument.uri, params.contentChanges.front().text,
195       params.textDocument.version, diagParams.diagnostics);
196 
197   // Publish any recorded diagnostics.
198   publishDiagnostics(diagParams);
199 }
200 
201 //===----------------------------------------------------------------------===//
202 // Definitions and References
203 
onGoToDefinition(const TextDocumentPositionParams & params,Callback<std::vector<Location>> reply)204 void LSPServer::onGoToDefinition(const TextDocumentPositionParams &params,
205                                  Callback<std::vector<Location>> reply) {
206   std::vector<Location> locations;
207   server.getLocationsOf(params.textDocument.uri, params.position, locations);
208   reply(std::move(locations));
209 }
210 
onReference(const ReferenceParams & params,Callback<std::vector<Location>> reply)211 void LSPServer::onReference(const ReferenceParams &params,
212                             Callback<std::vector<Location>> reply) {
213   std::vector<Location> locations;
214   server.findReferencesOf(params.textDocument.uri, params.position, locations);
215   reply(std::move(locations));
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // Hover
220 
onHover(const TextDocumentPositionParams & params,Callback<std::optional<Hover>> reply)221 void LSPServer::onHover(const TextDocumentPositionParams &params,
222                         Callback<std::optional<Hover>> reply) {
223   reply(server.findHover(params.textDocument.uri, params.position));
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // Document Symbols
228 
onDocumentSymbol(const DocumentSymbolParams & params,Callback<std::vector<DocumentSymbol>> reply)229 void LSPServer::onDocumentSymbol(const DocumentSymbolParams &params,
230                                  Callback<std::vector<DocumentSymbol>> reply) {
231   std::vector<DocumentSymbol> symbols;
232   server.findDocumentSymbols(params.textDocument.uri, symbols);
233   reply(std::move(symbols));
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // Code Completion
238 
onCompletion(const CompletionParams & params,Callback<CompletionList> reply)239 void LSPServer::onCompletion(const CompletionParams &params,
240                              Callback<CompletionList> reply) {
241   reply(server.getCodeCompletion(params.textDocument.uri, params.position));
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Code Action
246 
onCodeAction(const CodeActionParams & params,Callback<llvm::json::Value> reply)247 void LSPServer::onCodeAction(const CodeActionParams &params,
248                              Callback<llvm::json::Value> reply) {
249   URIForFile uri = params.textDocument.uri;
250 
251   // Check whether a particular CodeActionKind is included in the response.
252   auto isKindAllowed = [only(params.context.only)](StringRef kind) {
253     if (only.empty())
254       return true;
255     return llvm::any_of(only, [&](StringRef base) {
256       return kind.consume_front(base) &&
257              (kind.empty() || kind.starts_with("."));
258     });
259   };
260 
261   // We provide a code action for fixes on the specified diagnostics.
262   std::vector<CodeAction> actions;
263   if (isKindAllowed(CodeAction::kQuickFix))
264     server.getCodeActions(uri, params.range.start, params.context, actions);
265   reply(std::move(actions));
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // Bytecode
270 
onConvertFromBytecode(const MLIRConvertBytecodeParams & params,Callback<MLIRConvertBytecodeResult> reply)271 void LSPServer::onConvertFromBytecode(
272     const MLIRConvertBytecodeParams &params,
273     Callback<MLIRConvertBytecodeResult> reply) {
274   reply(server.convertFromBytecode(params.uri));
275 }
276 
onConvertToBytecode(const MLIRConvertBytecodeParams & params,Callback<MLIRConvertBytecodeResult> reply)277 void LSPServer::onConvertToBytecode(const MLIRConvertBytecodeParams &params,
278                                     Callback<MLIRConvertBytecodeResult> reply) {
279   reply(server.convertToBytecode(params.uri));
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // Entry point
284 //===----------------------------------------------------------------------===//
285 
runMlirLSPServer(MLIRServer & server,JSONTransport & transport)286 LogicalResult lsp::runMlirLSPServer(MLIRServer &server,
287                                     JSONTransport &transport) {
288   LSPServer lspServer(server);
289   MessageHandler messageHandler(transport);
290 
291   // Initialization
292   messageHandler.method("initialize", &lspServer, &LSPServer::onInitialize);
293   messageHandler.notification("initialized", &lspServer,
294                               &LSPServer::onInitialized);
295   messageHandler.method("shutdown", &lspServer, &LSPServer::onShutdown);
296 
297   // Document Changes
298   messageHandler.notification("textDocument/didOpen", &lspServer,
299                               &LSPServer::onDocumentDidOpen);
300   messageHandler.notification("textDocument/didClose", &lspServer,
301                               &LSPServer::onDocumentDidClose);
302   messageHandler.notification("textDocument/didChange", &lspServer,
303                               &LSPServer::onDocumentDidChange);
304 
305   // Definitions and References
306   messageHandler.method("textDocument/definition", &lspServer,
307                         &LSPServer::onGoToDefinition);
308   messageHandler.method("textDocument/references", &lspServer,
309                         &LSPServer::onReference);
310 
311   // Hover
312   messageHandler.method("textDocument/hover", &lspServer, &LSPServer::onHover);
313 
314   // Document Symbols
315   messageHandler.method("textDocument/documentSymbol", &lspServer,
316                         &LSPServer::onDocumentSymbol);
317 
318   // Code Completion
319   messageHandler.method("textDocument/completion", &lspServer,
320                         &LSPServer::onCompletion);
321 
322   // Code Action
323   messageHandler.method("textDocument/codeAction", &lspServer,
324                         &LSPServer::onCodeAction);
325 
326   // Bytecode
327   messageHandler.method("mlir/convertFromBytecode", &lspServer,
328                         &LSPServer::onConvertFromBytecode);
329   messageHandler.method("mlir/convertToBytecode", &lspServer,
330                         &LSPServer::onConvertToBytecode);
331 
332   // Diagnostics
333   lspServer.publishDiagnostics =
334       messageHandler.outgoingNotification<PublishDiagnosticsParams>(
335           "textDocument/publishDiagnostics");
336 
337   // Run the main loop of the transport.
338   LogicalResult result = success();
339   if (llvm::Error error = transport.run(messageHandler)) {
340     Logger::error("Transport error: {0}", error);
341     llvm::consumeError(std::move(error));
342     result = failure();
343   } else {
344     result = success(lspServer.shutdownRequestReceived);
345   }
346   return result;
347 }
348