xref: /llvm-project/clang-tools-extra/clangd/unittests/LSPClient.cpp (revision a45df47375e50914900dcc07abd2fa67bfa0dd3b)
1 //===-- LSPClient.cpp - Helper for ClangdLSPServer tests ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "LSPClient.h"
10 #include "Protocol.h"
11 #include "TestFS.h"
12 #include "Transport.h"
13 #include "support/Logger.h"
14 #include "support/Threading.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/StringMap.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Support/Error.h"
19 #include "llvm/Support/JSON.h"
20 #include "llvm/Support/Path.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "gtest/gtest.h"
23 #include <condition_variable>
24 #include <cstddef>
25 #include <cstdint>
26 #include <deque>
27 #include <functional>
28 #include <memory>
29 #include <mutex>
30 #include <optional>
31 #include <queue>
32 #include <string>
33 #include <utility>
34 #include <vector>
35 
36 namespace clang {
37 namespace clangd {
38 
take()39 llvm::Expected<llvm::json::Value> clang::clangd::LSPClient::CallResult::take() {
40   std::unique_lock<std::mutex> Lock(Mu);
41   static constexpr size_t TimeoutSecs = 60;
42   if (!clangd::wait(Lock, CV, timeoutSeconds(TimeoutSecs),
43                     [this] { return Value.has_value(); })) {
44     ADD_FAILURE() << "No result from call after " << TimeoutSecs << " seconds!";
45     return llvm::json::Value(nullptr);
46   }
47   auto Res = std::move(*Value);
48   Value.reset();
49   return Res;
50 }
51 
takeValue()52 llvm::json::Value LSPClient::CallResult::takeValue() {
53   auto ExpValue = take();
54   if (!ExpValue) {
55     ADD_FAILURE() << "takeValue(): " << llvm::toString(ExpValue.takeError());
56     return llvm::json::Value(nullptr);
57   }
58   return std::move(*ExpValue);
59 }
60 
set(llvm::Expected<llvm::json::Value> V)61 void LSPClient::CallResult::set(llvm::Expected<llvm::json::Value> V) {
62   std::lock_guard<std::mutex> Lock(Mu);
63   if (Value) {
64     ADD_FAILURE() << "Multiple replies";
65     llvm::consumeError(V.takeError());
66     return;
67   }
68   Value = std::move(V);
69   CV.notify_all();
70 }
71 
~CallResult()72 LSPClient::CallResult::~CallResult() {
73   if (Value && !*Value) {
74     ADD_FAILURE() << llvm::toString(Value->takeError());
75   }
76 }
77 
logBody(llvm::StringRef Method,llvm::json::Value V,bool Send)78 static void logBody(llvm::StringRef Method, llvm::json::Value V, bool Send) {
79   // We invert <<< and >>> as the combined log is from the server's viewpoint.
80   vlog("{0} {1}: {2:2}", Send ? "<<<" : ">>>", Method, V);
81 }
82 
83 class LSPClient::TransportImpl : public Transport {
84 public:
addCallSlot()85   std::pair<llvm::json::Value, CallResult *> addCallSlot() {
86     std::lock_guard<std::mutex> Lock(Mu);
87     unsigned ID = CallResults.size();
88     CallResults.emplace_back();
89     return {ID, &CallResults.back()};
90   }
91 
92   // A null action causes the transport to shut down.
enqueue(std::function<void (MessageHandler &)> Action)93   void enqueue(std::function<void(MessageHandler &)> Action) {
94     std::lock_guard<std::mutex> Lock(Mu);
95     Actions.push(std::move(Action));
96     CV.notify_all();
97   }
98 
takeNotifications(llvm::StringRef Method)99   std::vector<llvm::json::Value> takeNotifications(llvm::StringRef Method) {
100     std::vector<llvm::json::Value> Result;
101     {
102       std::lock_guard<std::mutex> Lock(Mu);
103       std::swap(Result, Notifications[Method]);
104     }
105     return Result;
106   }
107 
expectCall(llvm::StringRef Method)108   void expectCall(llvm::StringRef Method) {
109     std::lock_guard<std::mutex> Lock(Mu);
110     Calls[Method] = {};
111   }
112 
takeCallParams(llvm::StringRef Method)113   std::vector<llvm::json::Value> takeCallParams(llvm::StringRef Method) {
114     std::vector<llvm::json::Value> Result;
115     {
116       std::lock_guard<std::mutex> Lock(Mu);
117       std::swap(Result, Calls[Method]);
118     }
119     return Result;
120   }
121 
122 private:
reply(llvm::json::Value ID,llvm::Expected<llvm::json::Value> V)123   void reply(llvm::json::Value ID,
124              llvm::Expected<llvm::json::Value> V) override {
125     if (V) // Nothing additional to log for error.
126       logBody("reply", *V, /*Send=*/false);
127     std::lock_guard<std::mutex> Lock(Mu);
128     if (auto I = ID.getAsInteger()) {
129       if (*I >= 0 && *I < static_cast<int64_t>(CallResults.size())) {
130         CallResults[*I].set(std::move(V));
131         return;
132       }
133     }
134     ADD_FAILURE() << "Invalid reply to ID " << ID;
135     llvm::consumeError(std::move(V).takeError());
136   }
137 
notify(llvm::StringRef Method,llvm::json::Value V)138   void notify(llvm::StringRef Method, llvm::json::Value V) override {
139     logBody(Method, V, /*Send=*/false);
140     std::lock_guard<std::mutex> Lock(Mu);
141     Notifications[Method].push_back(std::move(V));
142   }
143 
call(llvm::StringRef Method,llvm::json::Value Params,llvm::json::Value ID)144   void call(llvm::StringRef Method, llvm::json::Value Params,
145             llvm::json::Value ID) override {
146     logBody(Method, Params, /*Send=*/false);
147     std::lock_guard<std::mutex> Lock(Mu);
148     if (Calls.contains(Method)) {
149       Calls[Method].push_back(std::move(Params));
150     } else {
151       ADD_FAILURE() << "Unexpected server->client call " << Method;
152     }
153   }
154 
loop(MessageHandler & H)155   llvm::Error loop(MessageHandler &H) override {
156     std::unique_lock<std::mutex> Lock(Mu);
157     while (true) {
158       CV.wait(Lock, [&] { return !Actions.empty(); });
159       if (!Actions.front()) // Stop!
160         return llvm::Error::success();
161       auto Action = std::move(Actions.front());
162       Actions.pop();
163       Lock.unlock();
164       Action(H);
165       Lock.lock();
166     }
167   }
168 
169   std::mutex Mu;
170   std::deque<CallResult> CallResults;
171   std::queue<std::function<void(Transport::MessageHandler &)>> Actions;
172   std::condition_variable CV;
173   llvm::StringMap<std::vector<llvm::json::Value>> Notifications;
174   llvm::StringMap<std::vector<llvm::json::Value>> Calls;
175 };
176 
LSPClient()177 LSPClient::LSPClient() : T(std::make_unique<TransportImpl>()) {}
178 LSPClient::~LSPClient() = default;
179 
call(llvm::StringRef Method,llvm::json::Value Params)180 LSPClient::CallResult &LSPClient::call(llvm::StringRef Method,
181                                        llvm::json::Value Params) {
182   auto Slot = T->addCallSlot();
183   T->enqueue([ID(Slot.first), Method(Method.str()),
184               Params(std::move(Params))](Transport::MessageHandler &H) {
185     logBody(Method, Params, /*Send=*/true);
186     H.onCall(Method, std::move(Params), ID);
187   });
188   return *Slot.second;
189 }
190 
expectServerCall(llvm::StringRef Method)191 void LSPClient::expectServerCall(llvm::StringRef Method) {
192   T->expectCall(Method);
193 }
194 
notify(llvm::StringRef Method,llvm::json::Value Params)195 void LSPClient::notify(llvm::StringRef Method, llvm::json::Value Params) {
196   T->enqueue([Method(Method.str()),
197               Params(std::move(Params))](Transport::MessageHandler &H) {
198     logBody(Method, Params, /*Send=*/true);
199     H.onNotify(Method, std::move(Params));
200   });
201 }
202 
203 std::vector<llvm::json::Value>
takeNotifications(llvm::StringRef Method)204 LSPClient::takeNotifications(llvm::StringRef Method) {
205   return T->takeNotifications(Method);
206 }
207 
208 std::vector<llvm::json::Value>
takeCallParams(llvm::StringRef Method)209 LSPClient::takeCallParams(llvm::StringRef Method) {
210   return T->takeCallParams(Method);
211 }
212 
stop()213 void LSPClient::stop() { T->enqueue(nullptr); }
214 
transport()215 Transport &LSPClient::transport() { return *T; }
216 
217 using Obj = llvm::json::Object;
218 
uri(llvm::StringRef Path)219 llvm::json::Value LSPClient::uri(llvm::StringRef Path) {
220   std::string Storage;
221   if (!llvm::sys::path::is_absolute(Path))
222     Path = Storage = testPath(Path);
223   return toJSON(URIForFile::canonicalize(Path, Path));
224 }
documentID(llvm::StringRef Path)225 llvm::json::Value LSPClient::documentID(llvm::StringRef Path) {
226   return Obj{{"uri", uri(Path)}};
227 }
228 
didOpen(llvm::StringRef Path,llvm::StringRef Content)229 void LSPClient::didOpen(llvm::StringRef Path, llvm::StringRef Content) {
230   notify(
231       "textDocument/didOpen",
232       Obj{{"textDocument",
233            Obj{{"uri", uri(Path)}, {"text", Content}, {"languageId", "cpp"}}}});
234 }
didChange(llvm::StringRef Path,llvm::StringRef Content)235 void LSPClient::didChange(llvm::StringRef Path, llvm::StringRef Content) {
236   notify("textDocument/didChange",
237          Obj{{"textDocument", documentID(Path)},
238              {"contentChanges", llvm::json::Array{Obj{{"text", Content}}}}});
239 }
didClose(llvm::StringRef Path)240 void LSPClient::didClose(llvm::StringRef Path) {
241   notify("textDocument/didClose", Obj{{"textDocument", documentID(Path)}});
242 }
243 
sync()244 void LSPClient::sync() { call("sync", nullptr).takeValue(); }
245 
246 std::optional<std::vector<llvm::json::Value>>
diagnostics(llvm::StringRef Path)247 LSPClient::diagnostics(llvm::StringRef Path) {
248   sync();
249   auto Notifications = takeNotifications("textDocument/publishDiagnostics");
250   for (const auto &Notification : llvm::reverse(Notifications)) {
251     if (const auto *PubDiagsParams = Notification.getAsObject()) {
252       auto U = PubDiagsParams->getString("uri");
253       auto *D = PubDiagsParams->getArray("diagnostics");
254       if (!U || !D) {
255         ADD_FAILURE() << "Bad PublishDiagnosticsParams: " << PubDiagsParams;
256         continue;
257       }
258       if (*U == uri(Path))
259         return std::vector<llvm::json::Value>(D->begin(), D->end());
260     }
261   }
262   return {};
263 }
264 
265 } // namespace clangd
266 } // namespace clang
267