1 //===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "LSPServer.h"
10 #include "MLIRServer.h"
11 #include "Protocol.h"
12 #include "mlir/Tools/lsp-server-support/Logging.h"
13 #include "mlir/Tools/lsp-server-support/Transport.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/ADT/StringMap.h"
16 #include <optional>
17
18 #define DEBUG_TYPE "mlir-lsp-server"
19
20 using namespace mlir;
21 using namespace mlir::lsp;
22
23 //===----------------------------------------------------------------------===//
24 // LSPServer
25 //===----------------------------------------------------------------------===//
26
27 namespace {
28 struct LSPServer {
LSPServer__anona0905ae50111::LSPServer29 LSPServer(MLIRServer &server) : server(server) {}
30
31 //===--------------------------------------------------------------------===//
32 // Initialization
33
34 void onInitialize(const InitializeParams ¶ms,
35 Callback<llvm::json::Value> reply);
36 void onInitialized(const InitializedParams ¶ms);
37 void onShutdown(const NoParams ¶ms, Callback<std::nullptr_t> reply);
38
39 //===--------------------------------------------------------------------===//
40 // Document Change
41
42 void onDocumentDidOpen(const DidOpenTextDocumentParams ¶ms);
43 void onDocumentDidClose(const DidCloseTextDocumentParams ¶ms);
44 void onDocumentDidChange(const DidChangeTextDocumentParams ¶ms);
45
46 //===--------------------------------------------------------------------===//
47 // Definitions and References
48
49 void onGoToDefinition(const TextDocumentPositionParams ¶ms,
50 Callback<std::vector<Location>> reply);
51 void onReference(const ReferenceParams ¶ms,
52 Callback<std::vector<Location>> reply);
53
54 //===--------------------------------------------------------------------===//
55 // Hover
56
57 void onHover(const TextDocumentPositionParams ¶ms,
58 Callback<std::optional<Hover>> reply);
59
60 //===--------------------------------------------------------------------===//
61 // Document Symbols
62
63 void onDocumentSymbol(const DocumentSymbolParams ¶ms,
64 Callback<std::vector<DocumentSymbol>> reply);
65
66 //===--------------------------------------------------------------------===//
67 // Code Completion
68
69 void onCompletion(const CompletionParams ¶ms,
70 Callback<CompletionList> reply);
71
72 //===--------------------------------------------------------------------===//
73 // Code Action
74
75 void onCodeAction(const CodeActionParams ¶ms,
76 Callback<llvm::json::Value> reply);
77
78 //===--------------------------------------------------------------------===//
79 // Bytecode
80
81 void onConvertFromBytecode(const MLIRConvertBytecodeParams ¶ms,
82 Callback<MLIRConvertBytecodeResult> reply);
83 void onConvertToBytecode(const MLIRConvertBytecodeParams ¶ms,
84 Callback<MLIRConvertBytecodeResult> reply);
85
86 //===--------------------------------------------------------------------===//
87 // Fields
88 //===--------------------------------------------------------------------===//
89
90 MLIRServer &server;
91
92 /// An outgoing notification used to send diagnostics to the client when they
93 /// are ready to be processed.
94 OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
95
96 /// Used to indicate that the 'shutdown' request was received from the
97 /// Language Server client.
98 bool shutdownRequestReceived = false;
99 };
100 } // namespace
101
102 //===----------------------------------------------------------------------===//
103 // Initialization
104
onInitialize(const InitializeParams & params,Callback<llvm::json::Value> reply)105 void LSPServer::onInitialize(const InitializeParams ¶ms,
106 Callback<llvm::json::Value> reply) {
107 // Send a response with the capabilities of this server.
108 llvm::json::Object serverCaps{
109 {"textDocumentSync",
110 llvm::json::Object{
111 {"openClose", true},
112 {"change", (int)TextDocumentSyncKind::Full},
113 {"save", true},
114 }},
115 {"completionProvider",
116 llvm::json::Object{
117 {"allCommitCharacters",
118 {
119 "\t",
120 ";",
121 ",",
122 ".",
123 "=",
124 }},
125 {"resolveProvider", false},
126 {"triggerCharacters",
127 {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
128 }},
129 {"definitionProvider", true},
130 {"referencesProvider", true},
131 {"hoverProvider", true},
132
133 // For now we only support documenting symbols when the client supports
134 // hierarchical symbols.
135 {"documentSymbolProvider",
136 params.capabilities.hierarchicalDocumentSymbol},
137 };
138
139 // Per LSP, codeActionProvider can be either boolean or CodeActionOptions.
140 // CodeActionOptions is only valid if the client supports action literal
141 // via textDocument.codeAction.codeActionLiteralSupport.
142 serverCaps["codeActionProvider"] =
143 params.capabilities.codeActionStructure
144 ? llvm::json::Object{{"codeActionKinds",
145 {CodeAction::kQuickFix, CodeAction::kRefactor,
146 CodeAction::kInfo}}}
147 : llvm::json::Value(true);
148
149 llvm::json::Object result{
150 {{"serverInfo",
151 llvm::json::Object{{"name", "mlir-lsp-server"}, {"version", "0.0.0"}}},
152 {"capabilities", std::move(serverCaps)}}};
153 reply(std::move(result));
154 }
onInitialized(const InitializedParams &)155 void LSPServer::onInitialized(const InitializedParams &) {}
onShutdown(const NoParams &,Callback<std::nullptr_t> reply)156 void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) {
157 shutdownRequestReceived = true;
158 reply(nullptr);
159 }
160
161 //===----------------------------------------------------------------------===//
162 // Document Change
163
onDocumentDidOpen(const DidOpenTextDocumentParams & params)164 void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams ¶ms) {
165 PublishDiagnosticsParams diagParams(params.textDocument.uri,
166 params.textDocument.version);
167 server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
168 params.textDocument.version,
169 diagParams.diagnostics);
170
171 // Publish any recorded diagnostics.
172 publishDiagnostics(diagParams);
173 }
onDocumentDidClose(const DidCloseTextDocumentParams & params)174 void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams ¶ms) {
175 std::optional<int64_t> version =
176 server.removeDocument(params.textDocument.uri);
177 if (!version)
178 return;
179
180 // Empty out the diagnostics shown for this document. This will clear out
181 // anything currently displayed by the client for this document (e.g. in the
182 // "Problems" pane of VSCode).
183 publishDiagnostics(
184 PublishDiagnosticsParams(params.textDocument.uri, *version));
185 }
onDocumentDidChange(const DidChangeTextDocumentParams & params)186 void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams ¶ms) {
187 // TODO: We currently only support full document updates, we should refactor
188 // to avoid this.
189 if (params.contentChanges.size() != 1)
190 return;
191 PublishDiagnosticsParams diagParams(params.textDocument.uri,
192 params.textDocument.version);
193 server.addOrUpdateDocument(
194 params.textDocument.uri, params.contentChanges.front().text,
195 params.textDocument.version, diagParams.diagnostics);
196
197 // Publish any recorded diagnostics.
198 publishDiagnostics(diagParams);
199 }
200
201 //===----------------------------------------------------------------------===//
202 // Definitions and References
203
onGoToDefinition(const TextDocumentPositionParams & params,Callback<std::vector<Location>> reply)204 void LSPServer::onGoToDefinition(const TextDocumentPositionParams ¶ms,
205 Callback<std::vector<Location>> reply) {
206 std::vector<Location> locations;
207 server.getLocationsOf(params.textDocument.uri, params.position, locations);
208 reply(std::move(locations));
209 }
210
onReference(const ReferenceParams & params,Callback<std::vector<Location>> reply)211 void LSPServer::onReference(const ReferenceParams ¶ms,
212 Callback<std::vector<Location>> reply) {
213 std::vector<Location> locations;
214 server.findReferencesOf(params.textDocument.uri, params.position, locations);
215 reply(std::move(locations));
216 }
217
218 //===----------------------------------------------------------------------===//
219 // Hover
220
onHover(const TextDocumentPositionParams & params,Callback<std::optional<Hover>> reply)221 void LSPServer::onHover(const TextDocumentPositionParams ¶ms,
222 Callback<std::optional<Hover>> reply) {
223 reply(server.findHover(params.textDocument.uri, params.position));
224 }
225
226 //===----------------------------------------------------------------------===//
227 // Document Symbols
228
onDocumentSymbol(const DocumentSymbolParams & params,Callback<std::vector<DocumentSymbol>> reply)229 void LSPServer::onDocumentSymbol(const DocumentSymbolParams ¶ms,
230 Callback<std::vector<DocumentSymbol>> reply) {
231 std::vector<DocumentSymbol> symbols;
232 server.findDocumentSymbols(params.textDocument.uri, symbols);
233 reply(std::move(symbols));
234 }
235
236 //===----------------------------------------------------------------------===//
237 // Code Completion
238
onCompletion(const CompletionParams & params,Callback<CompletionList> reply)239 void LSPServer::onCompletion(const CompletionParams ¶ms,
240 Callback<CompletionList> reply) {
241 reply(server.getCodeCompletion(params.textDocument.uri, params.position));
242 }
243
244 //===----------------------------------------------------------------------===//
245 // Code Action
246
onCodeAction(const CodeActionParams & params,Callback<llvm::json::Value> reply)247 void LSPServer::onCodeAction(const CodeActionParams ¶ms,
248 Callback<llvm::json::Value> reply) {
249 URIForFile uri = params.textDocument.uri;
250
251 // Check whether a particular CodeActionKind is included in the response.
252 auto isKindAllowed = [only(params.context.only)](StringRef kind) {
253 if (only.empty())
254 return true;
255 return llvm::any_of(only, [&](StringRef base) {
256 return kind.consume_front(base) &&
257 (kind.empty() || kind.starts_with("."));
258 });
259 };
260
261 // We provide a code action for fixes on the specified diagnostics.
262 std::vector<CodeAction> actions;
263 if (isKindAllowed(CodeAction::kQuickFix))
264 server.getCodeActions(uri, params.range.start, params.context, actions);
265 reply(std::move(actions));
266 }
267
268 //===----------------------------------------------------------------------===//
269 // Bytecode
270
onConvertFromBytecode(const MLIRConvertBytecodeParams & params,Callback<MLIRConvertBytecodeResult> reply)271 void LSPServer::onConvertFromBytecode(
272 const MLIRConvertBytecodeParams ¶ms,
273 Callback<MLIRConvertBytecodeResult> reply) {
274 reply(server.convertFromBytecode(params.uri));
275 }
276
onConvertToBytecode(const MLIRConvertBytecodeParams & params,Callback<MLIRConvertBytecodeResult> reply)277 void LSPServer::onConvertToBytecode(const MLIRConvertBytecodeParams ¶ms,
278 Callback<MLIRConvertBytecodeResult> reply) {
279 reply(server.convertToBytecode(params.uri));
280 }
281
282 //===----------------------------------------------------------------------===//
283 // Entry point
284 //===----------------------------------------------------------------------===//
285
runMlirLSPServer(MLIRServer & server,JSONTransport & transport)286 LogicalResult lsp::runMlirLSPServer(MLIRServer &server,
287 JSONTransport &transport) {
288 LSPServer lspServer(server);
289 MessageHandler messageHandler(transport);
290
291 // Initialization
292 messageHandler.method("initialize", &lspServer, &LSPServer::onInitialize);
293 messageHandler.notification("initialized", &lspServer,
294 &LSPServer::onInitialized);
295 messageHandler.method("shutdown", &lspServer, &LSPServer::onShutdown);
296
297 // Document Changes
298 messageHandler.notification("textDocument/didOpen", &lspServer,
299 &LSPServer::onDocumentDidOpen);
300 messageHandler.notification("textDocument/didClose", &lspServer,
301 &LSPServer::onDocumentDidClose);
302 messageHandler.notification("textDocument/didChange", &lspServer,
303 &LSPServer::onDocumentDidChange);
304
305 // Definitions and References
306 messageHandler.method("textDocument/definition", &lspServer,
307 &LSPServer::onGoToDefinition);
308 messageHandler.method("textDocument/references", &lspServer,
309 &LSPServer::onReference);
310
311 // Hover
312 messageHandler.method("textDocument/hover", &lspServer, &LSPServer::onHover);
313
314 // Document Symbols
315 messageHandler.method("textDocument/documentSymbol", &lspServer,
316 &LSPServer::onDocumentSymbol);
317
318 // Code Completion
319 messageHandler.method("textDocument/completion", &lspServer,
320 &LSPServer::onCompletion);
321
322 // Code Action
323 messageHandler.method("textDocument/codeAction", &lspServer,
324 &LSPServer::onCodeAction);
325
326 // Bytecode
327 messageHandler.method("mlir/convertFromBytecode", &lspServer,
328 &LSPServer::onConvertFromBytecode);
329 messageHandler.method("mlir/convertToBytecode", &lspServer,
330 &LSPServer::onConvertToBytecode);
331
332 // Diagnostics
333 lspServer.publishDiagnostics =
334 messageHandler.outgoingNotification<PublishDiagnosticsParams>(
335 "textDocument/publishDiagnostics");
336
337 // Run the main loop of the transport.
338 LogicalResult result = success();
339 if (llvm::Error error = transport.run(messageHandler)) {
340 Logger::error("Transport error: {0}", error);
341 llvm::consumeError(std::move(error));
342 result = failure();
343 } else {
344 result = success(lspServer.shutdownRequestReceived);
345 }
346 return result;
347 }
348