xref: /llvm-project/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp (revision 4f4e2abb1a5ff1225d32410fd02b732d077aa056)
1 //===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PDLLServer.h"
10 
11 #include "Protocol.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Support/ToolUtilities.h"
14 #include "mlir/Tools/PDLL/AST/Context.h"
15 #include "mlir/Tools/PDLL/AST/Nodes.h"
16 #include "mlir/Tools/PDLL/AST/Types.h"
17 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
18 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
19 #include "mlir/Tools/PDLL/ODS/Constraint.h"
20 #include "mlir/Tools/PDLL/ODS/Context.h"
21 #include "mlir/Tools/PDLL/ODS/Dialect.h"
22 #include "mlir/Tools/PDLL/ODS/Operation.h"
23 #include "mlir/Tools/PDLL/Parser/CodeComplete.h"
24 #include "mlir/Tools/PDLL/Parser/Parser.h"
25 #include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
26 #include "mlir/Tools/lsp-server-support/Logging.h"
27 #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
28 #include "llvm/ADT/IntervalMap.h"
29 #include "llvm/ADT/StringMap.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/FileSystem.h"
33 #include "llvm/Support/Path.h"
34 #include <optional>
35 
36 using namespace mlir;
37 using namespace mlir::pdll;
38 
39 /// Returns a language server uri for the given source location. `mainFileURI`
40 /// corresponds to the uri for the main file of the source manager.
41 static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
42                                      const lsp::URIForFile &mainFileURI) {
43   int bufferId = mgr.FindBufferContainingLoc(loc.Start);
44   if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
45     return mainFileURI;
46   llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
47       mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
48   if (fileForLoc)
49     return *fileForLoc;
50   lsp::Logger::error("Failed to create URI for include file: {0}",
51                      llvm::toString(fileForLoc.takeError()));
52   return mainFileURI;
53 }
54 
55 /// Returns true if the given location is in the main file of the source
56 /// manager.
57 static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
58   return mgr.FindBufferContainingLoc(loc.Start) == mgr.getMainFileID();
59 }
60 
61 /// Returns a language server location from the given source range.
62 static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
63                                         const lsp::URIForFile &uri) {
64   return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range));
65 }
66 
67 /// Convert the given MLIR diagnostic to the LSP form.
68 static std::optional<lsp::Diagnostic>
69 getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
70                         const lsp::URIForFile &uri) {
71   lsp::Diagnostic lspDiag;
72   lspDiag.source = "pdll";
73 
74   // FIXME: Right now all of the diagnostics are treated as parser issues, but
75   // some are parser and some are verifier.
76   lspDiag.category = "Parse Error";
77 
78   // Try to grab a file location for this diagnostic.
79   lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
80   lspDiag.range = loc.range;
81 
82   // Skip diagnostics that weren't emitted within the main file.
83   if (loc.uri != uri)
84     return std::nullopt;
85 
86   // Convert the severity for the diagnostic.
87   switch (diag.getSeverity()) {
88   case ast::Diagnostic::Severity::DK_Note:
89     llvm_unreachable("expected notes to be handled separately");
90   case ast::Diagnostic::Severity::DK_Warning:
91     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
92     break;
93   case ast::Diagnostic::Severity::DK_Error:
94     lspDiag.severity = lsp::DiagnosticSeverity::Error;
95     break;
96   case ast::Diagnostic::Severity::DK_Remark:
97     lspDiag.severity = lsp::DiagnosticSeverity::Information;
98     break;
99   }
100   lspDiag.message = diag.getMessage().str();
101 
102   // Attach any notes to the main diagnostic as related information.
103   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
104   for (const ast::Diagnostic &note : diag.getNotes()) {
105     relatedDiags.emplace_back(
106         getLocationFromLoc(sourceMgr, note.getLocation(), uri),
107         note.getMessage().str());
108   }
109   if (!relatedDiags.empty())
110     lspDiag.relatedInformation = std::move(relatedDiags);
111 
112   return lspDiag;
113 }
114 
115 /// Get or extract the documentation for the given decl.
116 static std::optional<std::string>
117 getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl) {
118   // If the decl already had documentation set, use it.
119   if (std::optional<StringRef> doc = decl->getDocComment())
120     return doc->str();
121 
122   // If the decl doesn't yet have documentation, try to extract it from the
123   // source file.
124   return lsp::extractSourceDocComment(sourceMgr, decl->getLoc().Start);
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // PDLIndex
129 //===----------------------------------------------------------------------===//
130 
131 namespace {
132 struct PDLIndexSymbol {
133   explicit PDLIndexSymbol(const ast::Decl *definition)
134       : definition(definition) {}
135   explicit PDLIndexSymbol(const ods::Operation *definition)
136       : definition(definition) {}
137 
138   /// Return the location of the definition of this symbol.
139   SMRange getDefLoc() const {
140     if (const ast::Decl *decl =
141             llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
142       const ast::Name *declName = decl->getName();
143       return declName ? declName->getLoc() : decl->getLoc();
144     }
145     return cast<const ods::Operation *>(definition)->getLoc();
146   }
147 
148   /// The main definition of the symbol.
149   PointerUnion<const ast::Decl *, const ods::Operation *> definition;
150   /// The set of references to the symbol.
151   std::vector<SMRange> references;
152 };
153 
154 /// This class provides an index for definitions/uses within a PDL document.
155 /// It provides efficient lookup of a definition given an input source range.
156 class PDLIndex {
157 public:
158   PDLIndex() : intervalMap(allocator) {}
159 
160   /// Initialize the index with the given ast::Module.
161   void initialize(const ast::Module &module, const ods::Context &odsContext);
162 
163   /// Lookup a symbol for the given location. Returns nullptr if no symbol could
164   /// be found. If provided, `overlappedRange` is set to the range that the
165   /// provided `loc` overlapped with.
166   const PDLIndexSymbol *lookup(SMLoc loc,
167                                SMRange *overlappedRange = nullptr) const;
168 
169 private:
170   /// The type of interval map used to store source references. SMRange is
171   /// half-open, so we also need to use a half-open interval map.
172   using MapT =
173       llvm::IntervalMap<const char *, const PDLIndexSymbol *,
174                         llvm::IntervalMapImpl::NodeSizer<
175                             const char *, const PDLIndexSymbol *>::LeafSize,
176                         llvm::IntervalMapHalfOpenInfo<const char *>>;
177 
178   /// An allocator for the interval map.
179   MapT::Allocator allocator;
180 
181   /// An interval map containing a corresponding definition mapped to a source
182   /// interval.
183   MapT intervalMap;
184 
185   /// A mapping between definitions and their corresponding symbol.
186   DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol;
187 };
188 } // namespace
189 
190 void PDLIndex::initialize(const ast::Module &module,
191                           const ods::Context &odsContext) {
192   auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
193     auto it = defToSymbol.try_emplace(def, nullptr);
194     if (it.second)
195       it.first->second = std::make_unique<PDLIndexSymbol>(def);
196     return &*it.first->second;
197   };
198   auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
199                            bool isDef = false) {
200     const char *startLoc = refLoc.Start.getPointer();
201     const char *endLoc = refLoc.End.getPointer();
202     if (!intervalMap.overlaps(startLoc, endLoc)) {
203       intervalMap.insert(startLoc, endLoc, sym);
204       if (!isDef)
205         sym->references.push_back(refLoc);
206     }
207   };
208   auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
209     const ods::Operation *odsOp = odsContext.lookupOperation(opName);
210     if (!odsOp)
211       return;
212 
213     PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
214     insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
215     insertDeclRef(symbol, refLoc);
216   };
217 
218   module.walk([&](const ast::Node *node) {
219     // Handle references to PDL decls.
220     if (const auto *decl = dyn_cast<ast::OpNameDecl>(node)) {
221       if (std::optional<StringRef> name = decl->getName())
222         insertODSOpRef(*name, decl->getLoc());
223     } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(node)) {
224       const ast::Name *name = decl->getName();
225       if (!name)
226         return;
227       PDLIndexSymbol *declSym = getOrInsertDef(decl);
228       insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
229 
230       if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) {
231         // Record references to any constraints.
232         for (const auto &it : varDecl->getConstraints())
233           insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
234       }
235     } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(node)) {
236       insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
237     }
238   });
239 }
240 
241 const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
242                                        SMRange *overlappedRange) const {
243   auto it = intervalMap.find(loc.getPointer());
244   if (!it.valid() || loc.getPointer() < it.start())
245     return nullptr;
246 
247   if (overlappedRange) {
248     *overlappedRange = SMRange(SMLoc::getFromPointer(it.start()),
249                                SMLoc::getFromPointer(it.stop()));
250   }
251   return it.value();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // PDLDocument
256 //===----------------------------------------------------------------------===//
257 
258 namespace {
259 /// This class represents all of the information pertaining to a specific PDL
260 /// document.
261 struct PDLDocument {
262   PDLDocument(const lsp::URIForFile &uri, StringRef contents,
263               const std::vector<std::string> &extraDirs,
264               std::vector<lsp::Diagnostic> &diagnostics);
265   PDLDocument(const PDLDocument &) = delete;
266   PDLDocument &operator=(const PDLDocument &) = delete;
267 
268   //===--------------------------------------------------------------------===//
269   // Definitions and References
270   //===--------------------------------------------------------------------===//
271 
272   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
273                       std::vector<lsp::Location> &locations);
274   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
275                         std::vector<lsp::Location> &references);
276 
277   //===--------------------------------------------------------------------===//
278   // Document Links
279   //===--------------------------------------------------------------------===//
280 
281   void getDocumentLinks(const lsp::URIForFile &uri,
282                         std::vector<lsp::DocumentLink> &links);
283 
284   //===--------------------------------------------------------------------===//
285   // Hover
286   //===--------------------------------------------------------------------===//
287 
288   std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
289                                       const lsp::Position &hoverPos);
290   std::optional<lsp::Hover> findHover(const ast::Decl *decl,
291                                       const SMRange &hoverRange);
292   lsp::Hover buildHoverForOpName(const ods::Operation *op,
293                                  const SMRange &hoverRange);
294   lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
295                                    const SMRange &hoverRange);
296   lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
297                                   const SMRange &hoverRange);
298   lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
299                                          const SMRange &hoverRange);
300   template <typename T>
301   lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
302                                                   const T *decl,
303                                                   const SMRange &hoverRange);
304 
305   //===--------------------------------------------------------------------===//
306   // Document Symbols
307   //===--------------------------------------------------------------------===//
308 
309   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
310 
311   //===--------------------------------------------------------------------===//
312   // Code Completion
313   //===--------------------------------------------------------------------===//
314 
315   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
316                                         const lsp::Position &completePos);
317 
318   //===--------------------------------------------------------------------===//
319   // Signature Help
320   //===--------------------------------------------------------------------===//
321 
322   lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
323                                       const lsp::Position &helpPos);
324 
325   //===--------------------------------------------------------------------===//
326   // Inlay Hints
327   //===--------------------------------------------------------------------===//
328 
329   void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range,
330                      std::vector<lsp::InlayHint> &inlayHints);
331   void getInlayHintsFor(const ast::VariableDecl *decl,
332                         const lsp::URIForFile &uri,
333                         std::vector<lsp::InlayHint> &inlayHints);
334   void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri,
335                         std::vector<lsp::InlayHint> &inlayHints);
336   void getInlayHintsFor(const ast::OperationExpr *expr,
337                         const lsp::URIForFile &uri,
338                         std::vector<lsp::InlayHint> &inlayHints);
339 
340   /// Add a parameter hint for the given expression using `label`.
341   void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
342                            const ast::Expr *expr, StringRef label);
343 
344   //===--------------------------------------------------------------------===//
345   // PDLL ViewOutput
346   //===--------------------------------------------------------------------===//
347 
348   void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
349 
350   //===--------------------------------------------------------------------===//
351   // Fields
352   //===--------------------------------------------------------------------===//
353 
354   /// The include directories for this file.
355   std::vector<std::string> includeDirs;
356 
357   /// The source manager containing the contents of the input file.
358   llvm::SourceMgr sourceMgr;
359 
360   /// The ODS and AST contexts.
361   ods::Context odsContext;
362   ast::Context astContext;
363 
364   /// The parsed AST module, or failure if the file wasn't valid.
365   FailureOr<ast::Module *> astModule;
366 
367   /// The index of the parsed module.
368   PDLIndex index;
369 
370   /// The set of includes of the parsed module.
371   SmallVector<lsp::SourceMgrInclude> parsedIncludes;
372 };
373 } // namespace
374 
375 PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
376                          const std::vector<std::string> &extraDirs,
377                          std::vector<lsp::Diagnostic> &diagnostics)
378     : astContext(odsContext) {
379   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
380   if (!memBuffer) {
381     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
382     return;
383   }
384 
385   // Build the set of include directories for this file.
386   llvm::SmallString<32> uriDirectory(uri.file());
387   llvm::sys::path::remove_filename(uriDirectory);
388   includeDirs.push_back(uriDirectory.str().str());
389   includeDirs.insert(includeDirs.end(), extraDirs.begin(), extraDirs.end());
390 
391   sourceMgr.setIncludeDirs(includeDirs);
392   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
393 
394   astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
395     if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
396       diagnostics.push_back(std::move(*lspDiag));
397   });
398   astModule = parsePDLLAST(astContext, sourceMgr, /*enableDocumentation=*/true);
399 
400   // Initialize the set of parsed includes.
401   lsp::gatherIncludeFiles(sourceMgr, parsedIncludes);
402 
403   // If we failed to parse the module, there is nothing left to initialize.
404   if (failed(astModule))
405     return;
406 
407   // Prepare the AST index with the parsed module.
408   index.initialize(**astModule, odsContext);
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // PDLDocument: Definitions and References
413 //===----------------------------------------------------------------------===//
414 
415 void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
416                                  const lsp::Position &defPos,
417                                  std::vector<lsp::Location> &locations) {
418   SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
419   const PDLIndexSymbol *symbol = index.lookup(posLoc);
420   if (!symbol)
421     return;
422 
423   locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
424 }
425 
426 void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
427                                    const lsp::Position &pos,
428                                    std::vector<lsp::Location> &references) {
429   SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
430   const PDLIndexSymbol *symbol = index.lookup(posLoc);
431   if (!symbol)
432     return;
433 
434   references.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
435   for (SMRange refLoc : symbol->references)
436     references.push_back(getLocationFromLoc(sourceMgr, refLoc, uri));
437 }
438 
439 //===--------------------------------------------------------------------===//
440 // PDLDocument: Document Links
441 //===--------------------------------------------------------------------===//
442 
443 void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
444                                    std::vector<lsp::DocumentLink> &links) {
445   for (const lsp::SourceMgrInclude &include : parsedIncludes)
446     links.emplace_back(include.range, include.uri);
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // PDLDocument: Hover
451 //===----------------------------------------------------------------------===//
452 
453 std::optional<lsp::Hover>
454 PDLDocument::findHover(const lsp::URIForFile &uri,
455                        const lsp::Position &hoverPos) {
456   SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
457 
458   // Check for a reference to an include.
459   for (const lsp::SourceMgrInclude &include : parsedIncludes)
460     if (include.range.contains(hoverPos))
461       return include.buildHover();
462 
463   // Find the symbol at the given location.
464   SMRange hoverRange;
465   const PDLIndexSymbol *symbol = index.lookup(posLoc, &hoverRange);
466   if (!symbol)
467     return std::nullopt;
468 
469   // Add hover for operation names.
470   if (const auto *op =
471           llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
472     return buildHoverForOpName(op, hoverRange);
473   const auto *decl = cast<const ast::Decl *>(symbol->definition);
474   return findHover(decl, hoverRange);
475 }
476 
477 std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
478                                                  const SMRange &hoverRange) {
479   // Add hover for variables.
480   if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
481     return buildHoverForVariable(varDecl, hoverRange);
482 
483   // Add hover for patterns.
484   if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl))
485     return buildHoverForPattern(patternDecl, hoverRange);
486 
487   // Add hover for core constraints.
488   if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(decl))
489     return buildHoverForCoreConstraint(cst, hoverRange);
490 
491   // Add hover for user constraints.
492   if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl))
493     return buildHoverForUserConstraintOrRewrite("Constraint", cst, hoverRange);
494 
495   // Add hover for user rewrites.
496   if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(decl))
497     return buildHoverForUserConstraintOrRewrite("Rewrite", rewrite, hoverRange);
498 
499   return std::nullopt;
500 }
501 
502 lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
503                                             const SMRange &hoverRange) {
504   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
505   {
506     llvm::raw_string_ostream hoverOS(hover.contents.value);
507     hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
508             << op->getSummary() << "\n***\n"
509             << op->getDescription();
510   }
511   return hover;
512 }
513 
514 lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
515                                               const SMRange &hoverRange) {
516   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
517   {
518     llvm::raw_string_ostream hoverOS(hover.contents.value);
519     hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
520             << "Type: `" << varDecl->getType() << "`\n";
521   }
522   return hover;
523 }
524 
525 lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
526                                              const SMRange &hoverRange) {
527   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
528   {
529     llvm::raw_string_ostream hoverOS(hover.contents.value);
530     hoverOS << "**Pattern**";
531     if (const ast::Name *name = decl->getName())
532       hoverOS << ": `" << name->getName() << "`";
533     hoverOS << "\n***\n";
534     if (std::optional<uint16_t> benefit = decl->getBenefit())
535       hoverOS << "Benefit: " << *benefit << "\n";
536     if (decl->hasBoundedRewriteRecursion())
537       hoverOS << "HasBoundedRewriteRecursion\n";
538     hoverOS << "RootOp: `"
539             << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
540 
541     // Format the documentation for the decl.
542     if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
543       hoverOS << "\n" << *doc << "\n";
544   }
545   return hover;
546 }
547 
548 lsp::Hover
549 PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
550                                          const SMRange &hoverRange) {
551   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
552   {
553     llvm::raw_string_ostream hoverOS(hover.contents.value);
554     hoverOS << "**Constraint**: `";
555     TypeSwitch<const ast::Decl *>(decl)
556         .Case([&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
557         .Case([&](const ast::OpConstraintDecl *opCst) {
558           hoverOS << "Op";
559           if (std::optional<StringRef> name = opCst->getName())
560             hoverOS << "<" << *name << ">";
561         })
562         .Case([&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
563         .Case([&](const ast::TypeRangeConstraintDecl *) {
564           hoverOS << "TypeRange";
565         })
566         .Case([&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
567         .Case([&](const ast::ValueRangeConstraintDecl *) {
568           hoverOS << "ValueRange";
569         });
570     hoverOS << "`\n";
571   }
572   return hover;
573 }
574 
575 template <typename T>
576 lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
577     StringRef typeName, const T *decl, const SMRange &hoverRange) {
578   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579   {
580     llvm::raw_string_ostream hoverOS(hover.contents.value);
581     hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
582             << "`\n***\n";
583     ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
584     if (!inputs.empty()) {
585       hoverOS << "Parameters:\n";
586       for (const ast::VariableDecl *input : inputs)
587         hoverOS << "* " << input->getName().getName() << ": `"
588                 << input->getType() << "`\n";
589       hoverOS << "***\n";
590     }
591     ast::Type resultType = decl->getResultType();
592     if (auto resultTupleTy = dyn_cast<ast::TupleType>(resultType)) {
593       if (!resultTupleTy.empty()) {
594         hoverOS << "Results:\n";
595         for (auto it : llvm::zip(resultTupleTy.getElementNames(),
596                                  resultTupleTy.getElementTypes())) {
597           StringRef name = std::get<0>(it);
598           hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
599                   << std::get<1>(it) << "`\n";
600         }
601         hoverOS << "***\n";
602       }
603     } else {
604       hoverOS << "Results:\n* `" << resultType << "`\n";
605       hoverOS << "***\n";
606     }
607 
608     // Format the documentation for the decl.
609     if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
610       hoverOS << "\n" << *doc << "\n";
611   }
612   return hover;
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // PDLDocument: Document Symbols
617 //===----------------------------------------------------------------------===//
618 
619 void PDLDocument::findDocumentSymbols(
620     std::vector<lsp::DocumentSymbol> &symbols) {
621   if (failed(astModule))
622     return;
623 
624   for (const ast::Decl *decl : (*astModule)->getChildren()) {
625     if (!isMainFileLoc(sourceMgr, decl->getLoc()))
626       continue;
627 
628     if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(decl)) {
629       const ast::Name *name = patternDecl->getName();
630 
631       SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
632       SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
633 
634       symbols.emplace_back(
635           name ? name->getName() : "<pattern>", lsp::SymbolKind::Class,
636           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
637     } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
638       // TODO: Add source information for the code block body.
639       SMRange nameLoc = cDecl->getName().getLoc();
640       SMRange bodyLoc = nameLoc;
641 
642       symbols.emplace_back(
643           cDecl->getName().getName(), lsp::SymbolKind::Function,
644           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
645     } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
646       // TODO: Add source information for the code block body.
647       SMRange nameLoc = cDecl->getName().getLoc();
648       SMRange bodyLoc = nameLoc;
649 
650       symbols.emplace_back(
651           cDecl->getName().getName(), lsp::SymbolKind::Function,
652           lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
653     }
654   }
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // PDLDocument: Code Completion
659 //===----------------------------------------------------------------------===//
660 
661 namespace {
662 class LSPCodeCompleteContext : public CodeCompleteContext {
663 public:
664   LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
665                          lsp::CompletionList &completionList,
666                          ods::Context &odsContext,
667                          ArrayRef<std::string> includeDirs)
668       : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
669         completionList(completionList), odsContext(odsContext),
670         includeDirs(includeDirs) {}
671 
672   void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
673     ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
674     ArrayRef<StringRef> elementNames = tupleType.getElementNames();
675     for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
676       // Push back a completion item that uses the result index.
677       lsp::CompletionItem item;
678       item.label = llvm::formatv("{0} (field #{0})", i).str();
679       item.insertText = Twine(i).str();
680       item.filterText = item.sortText = item.insertText;
681       item.kind = lsp::CompletionItemKind::Field;
682       item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
683       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
684       completionList.items.emplace_back(item);
685 
686       // If the element has a name, push back a completion item with that name.
687       if (!elementNames[i].empty()) {
688         item.label =
689             llvm::formatv("{1} (field #{0})", i, elementNames[i]).str();
690         item.filterText = item.label;
691         item.insertText = elementNames[i].str();
692         completionList.items.emplace_back(item);
693       }
694     }
695   }
696 
697   void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
698     const ods::Operation *odsOp = opType.getODSOperation();
699     if (!odsOp)
700       return;
701 
702     ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
703     for (const auto &it : llvm::enumerate(results)) {
704       const ods::OperandOrResult &result = it.value();
705       const ods::TypeConstraint &constraint = result.getConstraint();
706 
707       // Push back a completion item that uses the result index.
708       lsp::CompletionItem item;
709       item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
710       item.insertText = Twine(it.index()).str();
711       item.filterText = item.sortText = item.insertText;
712       item.kind = lsp::CompletionItemKind::Field;
713       switch (result.getVariableLengthKind()) {
714       case ods::VariableLengthKind::Single:
715         item.detail = llvm::formatv("{0}: Value", it.index()).str();
716         break;
717       case ods::VariableLengthKind::Optional:
718         item.detail = llvm::formatv("{0}: Value?", it.index()).str();
719         break;
720       case ods::VariableLengthKind::Variadic:
721         item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
722         break;
723       }
724       item.documentation = lsp::MarkupContent{
725           lsp::MarkupKind::Markdown,
726           llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
727                         constraint.getCppClass())
728               .str()};
729       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
730       completionList.items.emplace_back(item);
731 
732       // If the result has a name, push back a completion item with the result
733       // name.
734       if (!result.getName().empty()) {
735         item.label =
736             llvm::formatv("{1} (field #{0})", it.index(), result.getName())
737                 .str();
738         item.filterText = item.label;
739         item.insertText = result.getName().str();
740         completionList.items.emplace_back(item);
741       }
742     }
743   }
744 
745   void codeCompleteOperationAttributeName(StringRef opName) final {
746     const ods::Operation *odsOp = odsContext.lookupOperation(opName);
747     if (!odsOp)
748       return;
749 
750     for (const ods::Attribute &attr : odsOp->getAttributes()) {
751       const ods::AttributeConstraint &constraint = attr.getConstraint();
752 
753       lsp::CompletionItem item;
754       item.label = attr.getName().str();
755       item.kind = lsp::CompletionItemKind::Field;
756       item.detail = attr.isOptional() ? "optional" : "";
757       item.documentation = lsp::MarkupContent{
758           lsp::MarkupKind::Markdown,
759           llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
760                         constraint.getCppClass())
761               .str()};
762       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
763       completionList.items.emplace_back(item);
764     }
765   }
766 
767   void codeCompleteConstraintName(ast::Type currentType,
768                                   bool allowInlineTypeConstraints,
769                                   const ast::DeclScope *scope) final {
770     auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
771                                  StringRef snippetText = "") {
772       lsp::CompletionItem item;
773       item.label = constraint.str();
774       item.kind = lsp::CompletionItemKind::Class;
775       item.detail = (constraint + " constraint").str();
776       item.documentation = lsp::MarkupContent{
777           lsp::MarkupKind::Markdown,
778           ("A single entity core constraint of type `" + mlirType + "`").str()};
779       item.sortText = "0";
780       item.insertText = snippetText.str();
781       item.insertTextFormat = snippetText.empty()
782                                   ? lsp::InsertTextFormat::PlainText
783                                   : lsp::InsertTextFormat::Snippet;
784       completionList.items.emplace_back(item);
785     };
786 
787     // Insert completions for the core constraints. Some core constraints have
788     // additional characteristics, so we may add then even if a type has been
789     // inferred.
790     if (!currentType) {
791       addCoreConstraint("Attr", "mlir::Attribute");
792       addCoreConstraint("Op", "mlir::Operation *");
793       addCoreConstraint("Value", "mlir::Value");
794       addCoreConstraint("ValueRange", "mlir::ValueRange");
795       addCoreConstraint("Type", "mlir::Type");
796       addCoreConstraint("TypeRange", "mlir::TypeRange");
797     }
798     if (allowInlineTypeConstraints) {
799       /// Attr<Type>.
800       if (!currentType || isa<ast::AttributeType>(currentType))
801         addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
802       /// Value<Type>.
803       if (!currentType || isa<ast::ValueType>(currentType))
804         addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
805       /// ValueRange<TypeRange>.
806       if (!currentType || isa<ast::ValueRangeType>(currentType))
807         addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
808                           "ValueRange<$1>");
809     }
810 
811     // If a scope was provided, check it for potential constraints.
812     while (scope) {
813       for (const ast::Decl *decl : scope->getDecls()) {
814         if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
815           lsp::CompletionItem item;
816           item.label = cst->getName().getName().str();
817           item.kind = lsp::CompletionItemKind::Interface;
818           item.sortText = "2_" + item.label;
819 
820           // Skip constraints that are not single-arg. We currently only
821           // complete variable constraints.
822           if (cst->getInputs().size() != 1)
823             continue;
824 
825           // Ensure the input type matched the given type.
826           ast::Type constraintType = cst->getInputs()[0]->getType();
827           if (currentType && !currentType.refineWith(constraintType))
828             continue;
829 
830           // Format the constraint signature.
831           {
832             llvm::raw_string_ostream strOS(item.detail);
833             strOS << "(";
834             llvm::interleaveComma(
835                 cst->getInputs(), strOS, [&](const ast::VariableDecl *var) {
836                   strOS << var->getName().getName() << ": " << var->getType();
837                 });
838             strOS << ") -> " << cst->getResultType();
839           }
840 
841           // Format the documentation for the constraint.
842           if (std::optional<std::string> doc =
843                   getDocumentationFor(sourceMgr, cst)) {
844             item.documentation =
845                 lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)};
846           }
847 
848           completionList.items.emplace_back(item);
849         }
850       }
851 
852       scope = scope->getParentScope();
853     }
854   }
855 
856   void codeCompleteDialectName() final {
857     // Code complete known dialects.
858     for (const ods::Dialect &dialect : odsContext.getDialects()) {
859       lsp::CompletionItem item;
860       item.label = dialect.getName().str();
861       item.kind = lsp::CompletionItemKind::Class;
862       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
863       completionList.items.emplace_back(item);
864     }
865   }
866 
867   void codeCompleteOperationName(StringRef dialectName) final {
868     const ods::Dialect *dialect = odsContext.lookupDialect(dialectName);
869     if (!dialect)
870       return;
871 
872     for (const auto &it : dialect->getOperations()) {
873       const ods::Operation &op = *it.second;
874 
875       lsp::CompletionItem item;
876       item.label = op.getName().drop_front(dialectName.size() + 1).str();
877       item.kind = lsp::CompletionItemKind::Field;
878       item.insertTextFormat = lsp::InsertTextFormat::PlainText;
879       completionList.items.emplace_back(item);
880     }
881   }
882 
883   void codeCompletePatternMetadata() final {
884     auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
885                                    StringRef snippetText = "") {
886       lsp::CompletionItem item;
887       item.label = constraint.str();
888       item.kind = lsp::CompletionItemKind::Class;
889       item.detail = "pattern metadata";
890       item.documentation =
891           lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()};
892       item.insertText = snippetText.str();
893       item.insertTextFormat = snippetText.empty()
894                                   ? lsp::InsertTextFormat::PlainText
895                                   : lsp::InsertTextFormat::Snippet;
896       completionList.items.emplace_back(item);
897     };
898 
899     addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
900                         "benefit($1)");
901     addSimpleConstraint("recursion",
902                         "The pattern properly handles recursive application.");
903   }
904 
905   void codeCompleteIncludeFilename(StringRef curPath) final {
906     // Normalize the path to allow for interacting with the file system
907     // utilities.
908     SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(curPath));
909     llvm::sys::path::native(nativeRelDir);
910 
911     // Set of already included completion paths.
912     StringSet<> seenResults;
913 
914     // Functor used to add a single include completion item.
915     auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
916       lsp::CompletionItem item;
917       item.label = path.str();
918       item.kind = isDirectory ? lsp::CompletionItemKind::Folder
919                               : lsp::CompletionItemKind::File;
920       if (seenResults.insert(item.label).second)
921         completionList.items.emplace_back(item);
922     };
923 
924     // Process the include directories for this file, adding any potential
925     // nested include files or directories.
926     for (StringRef includeDir : includeDirs) {
927       llvm::SmallString<128> dir = includeDir;
928       if (!nativeRelDir.empty())
929         llvm::sys::path::append(dir, nativeRelDir);
930 
931       std::error_code errorCode;
932       for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
933                 e = llvm::sys::fs::directory_iterator();
934            !errorCode && it != e; it.increment(errorCode)) {
935         StringRef filename = llvm::sys::path::filename(it->path());
936 
937         // To know whether a symlink should be treated as file or a directory,
938         // we have to stat it. This should be cheap enough as there shouldn't be
939         // many symlinks.
940         llvm::sys::fs::file_type fileType = it->type();
941         if (fileType == llvm::sys::fs::file_type::symlink_file) {
942           if (auto fileStatus = it->status())
943             fileType = fileStatus->type();
944         }
945 
946         switch (fileType) {
947         case llvm::sys::fs::file_type::directory_file:
948           addIncludeCompletion(filename, /*isDirectory=*/true);
949           break;
950         case llvm::sys::fs::file_type::regular_file: {
951           // Only consider concrete files that can actually be included by PDLL.
952           if (filename.ends_with(".pdll") || filename.ends_with(".td"))
953             addIncludeCompletion(filename, /*isDirectory=*/false);
954           break;
955         }
956         default:
957           break;
958         }
959       }
960     }
961 
962     // Sort the completion results to make sure the output is deterministic in
963     // the face of different iteration schemes for different platforms.
964     llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs,
965                                         const lsp::CompletionItem &rhs) {
966       return lhs.label < rhs.label;
967     });
968   }
969 
970 private:
971   llvm::SourceMgr &sourceMgr;
972   lsp::CompletionList &completionList;
973   ods::Context &odsContext;
974   ArrayRef<std::string> includeDirs;
975 };
976 } // namespace
977 
978 lsp::CompletionList
979 PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
980                                const lsp::Position &completePos) {
981   SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
982   if (!posLoc.isValid())
983     return lsp::CompletionList();
984 
985   // To perform code completion, we run another parse of the module with the
986   // code completion context provided.
987   ods::Context tmpODSContext;
988   lsp::CompletionList completionList;
989   LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
990                                             tmpODSContext,
991                                             sourceMgr.getIncludeDirs());
992 
993   ast::Context tmpContext(tmpODSContext);
994   (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
995                      &lspCompleteContext);
996 
997   return completionList;
998 }
999 
1000 //===----------------------------------------------------------------------===//
1001 // PDLDocument: Signature Help
1002 //===----------------------------------------------------------------------===//
1003 
1004 namespace {
1005 class LSPSignatureHelpContext : public CodeCompleteContext {
1006 public:
1007   LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1008                           lsp::SignatureHelp &signatureHelp,
1009                           ods::Context &odsContext)
1010       : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1011         signatureHelp(signatureHelp), odsContext(odsContext) {}
1012 
1013   void codeCompleteCallSignature(const ast::CallableDecl *callable,
1014                                  unsigned currentNumArgs) final {
1015     signatureHelp.activeParameter = currentNumArgs;
1016 
1017     lsp::SignatureInformation signatureInfo;
1018     {
1019       llvm::raw_string_ostream strOS(signatureInfo.label);
1020       strOS << callable->getName()->getName() << "(";
1021       auto formatParamFn = [&](const ast::VariableDecl *var) {
1022         unsigned paramStart = strOS.str().size();
1023         strOS << var->getName().getName() << ": " << var->getType();
1024         unsigned paramEnd = strOS.str().size();
1025         signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1026             StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1027             std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()});
1028       };
1029       llvm::interleaveComma(callable->getInputs(), strOS, formatParamFn);
1030       strOS << ") -> " << callable->getResultType();
1031     }
1032 
1033     // Format the documentation for the callable.
1034     if (std::optional<std::string> doc =
1035             getDocumentationFor(sourceMgr, callable))
1036       signatureInfo.documentation = std::move(*doc);
1037 
1038     signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1039   }
1040 
1041   void
1042   codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
1043                                          unsigned currentNumOperands) final {
1044     const ods::Operation *odsOp =
1045         opName ? odsContext.lookupOperation(*opName) : nullptr;
1046     codeCompleteOperationOperandOrResultSignature(
1047         opName, odsOp, odsOp ? odsOp->getOperands() : std::nullopt,
1048         currentNumOperands, "operand", "Value");
1049   }
1050 
1051   void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
1052                                              unsigned currentNumResults) final {
1053     const ods::Operation *odsOp =
1054         opName ? odsContext.lookupOperation(*opName) : nullptr;
1055     codeCompleteOperationOperandOrResultSignature(
1056         opName, odsOp, odsOp ? odsOp->getResults() : std::nullopt,
1057         currentNumResults, "result", "Type");
1058   }
1059 
1060   void codeCompleteOperationOperandOrResultSignature(
1061       std::optional<StringRef> opName, const ods::Operation *odsOp,
1062       ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1063       StringRef label, StringRef dataType) {
1064     signatureHelp.activeParameter = currentValue;
1065 
1066     // If we have ODS information for the operation, add in the ODS signature
1067     // for the operation. We also verify that the current number of values is
1068     // not more than what is defined in ODS, as this will result in an error
1069     // anyways.
1070     if (odsOp && currentValue < values.size()) {
1071       lsp::SignatureInformation signatureInfo;
1072 
1073       // Build the signature label.
1074       {
1075         llvm::raw_string_ostream strOS(signatureInfo.label);
1076         strOS << "(";
1077         auto formatFn = [&](const ods::OperandOrResult &value) {
1078           unsigned paramStart = strOS.str().size();
1079 
1080           strOS << value.getName() << ": ";
1081 
1082           StringRef constraintDoc = value.getConstraint().getSummary();
1083           std::string paramDoc;
1084           switch (value.getVariableLengthKind()) {
1085           case ods::VariableLengthKind::Single:
1086             strOS << dataType;
1087             paramDoc = constraintDoc.str();
1088             break;
1089           case ods::VariableLengthKind::Optional:
1090             strOS << dataType << "?";
1091             paramDoc = ("optional: " + constraintDoc).str();
1092             break;
1093           case ods::VariableLengthKind::Variadic:
1094             strOS << dataType << "Range";
1095             paramDoc = ("variadic: " + constraintDoc).str();
1096             break;
1097           }
1098 
1099           unsigned paramEnd = strOS.str().size();
1100           signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1101               StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
1102               std::make_pair(paramStart, paramEnd), paramDoc});
1103         };
1104         llvm::interleaveComma(values, strOS, formatFn);
1105         strOS << ")";
1106       }
1107       signatureInfo.documentation =
1108           llvm::formatv("`op<{0}>` ODS {1} specification", *opName, label)
1109               .str();
1110       signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1111     }
1112 
1113     // If there aren't any arguments yet, we also add the generic signature.
1114     if (currentValue == 0 && (!odsOp || !values.empty())) {
1115       lsp::SignatureInformation signatureInfo;
1116       signatureInfo.label =
1117           llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str();
1118       signatureInfo.documentation =
1119           ("Generic operation " + label + " specification").str();
1120       signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
1121           StringRef(signatureInfo.label).drop_front().drop_back().str(),
1122           std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1123           ("All of the " + label + "s of the operation.").str()});
1124       signatureHelp.signatures.emplace_back(std::move(signatureInfo));
1125     }
1126   }
1127 
1128 private:
1129   llvm::SourceMgr &sourceMgr;
1130   lsp::SignatureHelp &signatureHelp;
1131   ods::Context &odsContext;
1132 };
1133 } // namespace
1134 
1135 lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
1136                                                  const lsp::Position &helpPos) {
1137   SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
1138   if (!posLoc.isValid())
1139     return lsp::SignatureHelp();
1140 
1141   // To perform code completion, we run another parse of the module with the
1142   // code completion context provided.
1143   ods::Context tmpODSContext;
1144   lsp::SignatureHelp signatureHelp;
1145   LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1146                                           tmpODSContext);
1147 
1148   ast::Context tmpContext(tmpODSContext);
1149   (void)parsePDLLAST(tmpContext, sourceMgr, /*enableDocumentation=*/true,
1150                      &completeContext);
1151 
1152   return signatureHelp;
1153 }
1154 
1155 //===----------------------------------------------------------------------===//
1156 // PDLDocument: Inlay Hints
1157 //===----------------------------------------------------------------------===//
1158 
1159 /// Returns true if the given name should be added as a hint for `expr`.
1160 static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1161   if (name.empty())
1162     return false;
1163 
1164   // If the argument is a reference of the same name, don't add it as a hint.
1165   if (auto *ref = dyn_cast<ast::DeclRefExpr>(expr)) {
1166     const ast::Name *declName = ref->getDecl()->getName();
1167     if (declName && declName->getName() == name)
1168       return false;
1169   }
1170 
1171   return true;
1172 }
1173 
1174 void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
1175                                 const lsp::Range &range,
1176                                 std::vector<lsp::InlayHint> &inlayHints) {
1177   if (failed(astModule))
1178     return;
1179   SMRange rangeLoc = range.getAsSMRange(sourceMgr);
1180   if (!rangeLoc.isValid())
1181     return;
1182   (*astModule)->walk([&](const ast::Node *node) {
1183     SMRange loc = node->getLoc();
1184 
1185     // Check that the location of this node is within the input range.
1186     if (!lsp::contains(rangeLoc, loc.Start) &&
1187         !lsp::contains(rangeLoc, loc.End))
1188       return;
1189 
1190     // Handle hints for various types of nodes.
1191     llvm::TypeSwitch<const ast::Node *>(node)
1192         .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1193             [&](const auto *node) {
1194               this->getInlayHintsFor(node, uri, inlayHints);
1195             });
1196   });
1197 }
1198 
1199 void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
1200                                    const lsp::URIForFile &uri,
1201                                    std::vector<lsp::InlayHint> &inlayHints) {
1202   // Check to see if the variable has a constraint list, if it does we don't
1203   // provide initializer hints.
1204   if (!decl->getConstraints().empty())
1205     return;
1206 
1207   // Check to see if the variable has an initializer.
1208   if (const ast::Expr *expr = decl->getInitExpr()) {
1209     // Don't add hints for operation expression initialized variables given that
1210     // the type of the variable is easily inferred by the expression operation
1211     // name.
1212     if (isa<ast::OperationExpr>(expr))
1213       return;
1214   }
1215 
1216   lsp::InlayHint hint(lsp::InlayHintKind::Type,
1217                       lsp::Position(sourceMgr, decl->getLoc().End));
1218   {
1219     llvm::raw_string_ostream labelOS(hint.label);
1220     labelOS << ": " << decl->getType();
1221   }
1222 
1223   inlayHints.emplace_back(std::move(hint));
1224 }
1225 
1226 void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
1227                                    const lsp::URIForFile &uri,
1228                                    std::vector<lsp::InlayHint> &inlayHints) {
1229   // Try to extract the callable of this call.
1230   const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
1231   const auto *callable =
1232       callableRef ? dyn_cast<ast::CallableDecl>(callableRef->getDecl())
1233                   : nullptr;
1234   if (!callable)
1235     return;
1236 
1237   // Add hints for the arguments to the call.
1238   for (const auto &it : llvm::zip(expr->getArguments(), callable->getInputs()))
1239     addParameterHintFor(inlayHints, std::get<0>(it),
1240                         std::get<1>(it)->getName().getName());
1241 }
1242 
1243 void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
1244                                    const lsp::URIForFile &uri,
1245                                    std::vector<lsp::InlayHint> &inlayHints) {
1246   // Check for ODS information.
1247   ast::OperationType opType = dyn_cast<ast::OperationType>(expr->getType());
1248   const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1249 
1250   auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1251     // If the value expression used the same location as the operation, don't
1252     // add a hint. This expression was materialized during parsing.
1253     if (expr->getLoc().Start == valueExpr->getLoc().Start)
1254       return;
1255     addParameterHintFor(inlayHints, valueExpr, label);
1256   };
1257 
1258   // Functor used to process hints for the operands and results of the
1259   // operation. They effectively have the same format, and thus can be processed
1260   // using the same logic.
1261   auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1262                                      ArrayRef<ods::OperandOrResult> odsValues,
1263                                      StringRef allValuesName) {
1264     if (values.empty())
1265       return;
1266 
1267     // The values should either map to a single range, or be equivalent to the
1268     // ODS values.
1269     if (values.size() != odsValues.size()) {
1270       // Handle the case of a single element that covers the full range.
1271       if (values.size() == 1)
1272         return addOpHint(values.front(), allValuesName);
1273       return;
1274     }
1275 
1276     for (const auto &it : llvm::zip(values, odsValues))
1277       addOpHint(std::get<0>(it), std::get<1>(it).getName());
1278   };
1279 
1280   // Add hints for the operands and results of the operation.
1281   addOperandOrResultHints(expr->getOperands(),
1282                           odsOp ? odsOp->getOperands()
1283                                 : ArrayRef<ods::OperandOrResult>(),
1284                           "operands");
1285   addOperandOrResultHints(expr->getResultTypes(),
1286                           odsOp ? odsOp->getResults()
1287                                 : ArrayRef<ods::OperandOrResult>(),
1288                           "results");
1289 }
1290 
1291 void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1292                                       const ast::Expr *expr, StringRef label) {
1293   if (!shouldAddHintFor(expr, label))
1294     return;
1295 
1296   lsp::InlayHint hint(lsp::InlayHintKind::Parameter,
1297                       lsp::Position(sourceMgr, expr->getLoc().Start));
1298   hint.label = (label + ":").str();
1299   hint.paddingRight = true;
1300   inlayHints.emplace_back(std::move(hint));
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // PDLL ViewOutput
1305 //===----------------------------------------------------------------------===//
1306 
1307 void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1308                                     lsp::PDLLViewOutputKind kind) {
1309   if (failed(astModule))
1310     return;
1311   if (kind == lsp::PDLLViewOutputKind::AST) {
1312     (*astModule)->print(os);
1313     return;
1314   }
1315 
1316   // Generate the MLIR for the ast module. We also capture diagnostics here to
1317   // show to the user, which may be useful if PDLL isn't capturing constraints
1318   // expected by PDL.
1319   MLIRContext mlirContext;
1320   SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1321   OwningOpRef<ModuleOp> pdlModule =
1322       codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
1323   if (!pdlModule)
1324     return;
1325   if (kind == lsp::PDLLViewOutputKind::MLIR) {
1326     pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1327     return;
1328   }
1329 
1330   // Otherwise, generate the output for C++.
1331   assert(kind == lsp::PDLLViewOutputKind::CPP &&
1332          "unexpected PDLLViewOutputKind");
1333   codegenPDLLToCPP(**astModule, *pdlModule, os);
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // PDLTextFileChunk
1338 //===----------------------------------------------------------------------===//
1339 
1340 namespace {
1341 /// This class represents a single chunk of an PDL text file.
1342 struct PDLTextFileChunk {
1343   PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
1344                    StringRef contents,
1345                    const std::vector<std::string> &extraDirs,
1346                    std::vector<lsp::Diagnostic> &diagnostics)
1347       : lineOffset(lineOffset),
1348         document(uri, contents, extraDirs, diagnostics) {}
1349 
1350   /// Adjust the line number of the given range to anchor at the beginning of
1351   /// the file, instead of the beginning of this chunk.
1352   void adjustLocForChunkOffset(lsp::Range &range) {
1353     adjustLocForChunkOffset(range.start);
1354     adjustLocForChunkOffset(range.end);
1355   }
1356   /// Adjust the line number of the given position to anchor at the beginning of
1357   /// the file, instead of the beginning of this chunk.
1358   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1359 
1360   /// The line offset of this chunk from the beginning of the file.
1361   uint64_t lineOffset;
1362   /// The document referred to by this chunk.
1363   PDLDocument document;
1364 };
1365 } // namespace
1366 
1367 //===----------------------------------------------------------------------===//
1368 // PDLTextFile
1369 //===----------------------------------------------------------------------===//
1370 
1371 namespace {
1372 /// This class represents a text file containing one or more PDL documents.
1373 class PDLTextFile {
1374 public:
1375   PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1376               int64_t version, const std::vector<std::string> &extraDirs,
1377               std::vector<lsp::Diagnostic> &diagnostics);
1378 
1379   /// Return the current version of this text file.
1380   int64_t getVersion() const { return version; }
1381 
1382   /// Update the file to the new version using the provided set of content
1383   /// changes. Returns failure if the update was unsuccessful.
1384   LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
1385                        ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1386                        std::vector<lsp::Diagnostic> &diagnostics);
1387 
1388   //===--------------------------------------------------------------------===//
1389   // LSP Queries
1390   //===--------------------------------------------------------------------===//
1391 
1392   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1393                       std::vector<lsp::Location> &locations);
1394   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1395                         std::vector<lsp::Location> &references);
1396   void getDocumentLinks(const lsp::URIForFile &uri,
1397                         std::vector<lsp::DocumentLink> &links);
1398   std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1399                                       lsp::Position hoverPos);
1400   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1401   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1402                                         lsp::Position completePos);
1403   lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
1404                                       lsp::Position helpPos);
1405   void getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1406                      std::vector<lsp::InlayHint> &inlayHints);
1407   lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
1408 
1409 private:
1410   using ChunkIterator = llvm::pointee_iterator<
1411       std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1412 
1413   /// Initialize the text file from the given file contents.
1414   void initialize(const lsp::URIForFile &uri, int64_t newVersion,
1415                   std::vector<lsp::Diagnostic> &diagnostics);
1416 
1417   /// Find the PDL document that contains the given position, and update the
1418   /// position to be anchored at the start of the found chunk instead of the
1419   /// beginning of the file.
1420   ChunkIterator getChunkItFor(lsp::Position &pos);
1421   PDLTextFileChunk &getChunkFor(lsp::Position &pos) {
1422     return *getChunkItFor(pos);
1423   }
1424 
1425   /// The full string contents of the file.
1426   std::string contents;
1427 
1428   /// The version of this file.
1429   int64_t version = 0;
1430 
1431   /// The number of lines in the file.
1432   int64_t totalNumLines = 0;
1433 
1434   /// The chunks of this file. The order of these chunks is the order in which
1435   /// they appear in the text file.
1436   std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1437 
1438   /// The extra set of include directories for this file.
1439   std::vector<std::string> extraIncludeDirs;
1440 };
1441 } // namespace
1442 
1443 PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1444                          int64_t version,
1445                          const std::vector<std::string> &extraDirs,
1446                          std::vector<lsp::Diagnostic> &diagnostics)
1447     : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1448   initialize(uri, version, diagnostics);
1449 }
1450 
1451 LogicalResult
1452 PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
1453                     ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1454                     std::vector<lsp::Diagnostic> &diagnostics) {
1455   if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1456     lsp::Logger::error("Failed to update contents of {0}", uri.file());
1457     return failure();
1458   }
1459 
1460   // If the file contents were properly changed, reinitialize the text file.
1461   initialize(uri, newVersion, diagnostics);
1462   return success();
1463 }
1464 
1465 void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
1466                                  lsp::Position defPos,
1467                                  std::vector<lsp::Location> &locations) {
1468   PDLTextFileChunk &chunk = getChunkFor(defPos);
1469   chunk.document.getLocationsOf(uri, defPos, locations);
1470 
1471   // Adjust any locations within this file for the offset of this chunk.
1472   if (chunk.lineOffset == 0)
1473     return;
1474   for (lsp::Location &loc : locations)
1475     if (loc.uri == uri)
1476       chunk.adjustLocForChunkOffset(loc.range);
1477 }
1478 
1479 void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
1480                                    lsp::Position pos,
1481                                    std::vector<lsp::Location> &references) {
1482   PDLTextFileChunk &chunk = getChunkFor(pos);
1483   chunk.document.findReferencesOf(uri, pos, references);
1484 
1485   // Adjust any locations within this file for the offset of this chunk.
1486   if (chunk.lineOffset == 0)
1487     return;
1488   for (lsp::Location &loc : references)
1489     if (loc.uri == uri)
1490       chunk.adjustLocForChunkOffset(loc.range);
1491 }
1492 
1493 void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
1494                                    std::vector<lsp::DocumentLink> &links) {
1495   chunks.front()->document.getDocumentLinks(uri, links);
1496   for (const auto &it : llvm::drop_begin(chunks)) {
1497     size_t currentNumLinks = links.size();
1498     it->document.getDocumentLinks(uri, links);
1499 
1500     // Adjust any links within this file to account for the offset of this
1501     // chunk.
1502     for (auto &link : llvm::drop_begin(links, currentNumLinks))
1503       it->adjustLocForChunkOffset(link.range);
1504   }
1505 }
1506 
1507 std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
1508                                                  lsp::Position hoverPos) {
1509   PDLTextFileChunk &chunk = getChunkFor(hoverPos);
1510   std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1511 
1512   // Adjust any locations within this file for the offset of this chunk.
1513   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1514     chunk.adjustLocForChunkOffset(*hoverInfo->range);
1515   return hoverInfo;
1516 }
1517 
1518 void PDLTextFile::findDocumentSymbols(
1519     std::vector<lsp::DocumentSymbol> &symbols) {
1520   if (chunks.size() == 1)
1521     return chunks.front()->document.findDocumentSymbols(symbols);
1522 
1523   // If there are multiple chunks in this file, we create top-level symbols for
1524   // each chunk.
1525   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1526     PDLTextFileChunk &chunk = *chunks[i];
1527     lsp::Position startPos(chunk.lineOffset);
1528     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1529                                       : chunks[i + 1]->lineOffset);
1530     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1531                                lsp::SymbolKind::Namespace,
1532                                /*range=*/lsp::Range(startPos, endPos),
1533                                /*selectionRange=*/lsp::Range(startPos));
1534     chunk.document.findDocumentSymbols(symbol.children);
1535 
1536     // Fixup the locations of document symbols within this chunk.
1537     if (i != 0) {
1538       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1539       for (lsp::DocumentSymbol &childSymbol : symbol.children)
1540         symbolsToFix.push_back(&childSymbol);
1541 
1542       while (!symbolsToFix.empty()) {
1543         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1544         chunk.adjustLocForChunkOffset(symbol->range);
1545         chunk.adjustLocForChunkOffset(symbol->selectionRange);
1546 
1547         for (lsp::DocumentSymbol &childSymbol : symbol->children)
1548           symbolsToFix.push_back(&childSymbol);
1549       }
1550     }
1551 
1552     // Push the symbol for this chunk.
1553     symbols.emplace_back(std::move(symbol));
1554   }
1555 }
1556 
1557 lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1558                                                    lsp::Position completePos) {
1559   PDLTextFileChunk &chunk = getChunkFor(completePos);
1560   lsp::CompletionList completionList =
1561       chunk.document.getCodeCompletion(uri, completePos);
1562 
1563   // Adjust any completion locations.
1564   for (lsp::CompletionItem &item : completionList.items) {
1565     if (item.textEdit)
1566       chunk.adjustLocForChunkOffset(item.textEdit->range);
1567     for (lsp::TextEdit &edit : item.additionalTextEdits)
1568       chunk.adjustLocForChunkOffset(edit.range);
1569   }
1570   return completionList;
1571 }
1572 
1573 lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
1574                                                  lsp::Position helpPos) {
1575   return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
1576 }
1577 
1578 void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1579                                 std::vector<lsp::InlayHint> &inlayHints) {
1580   auto startIt = getChunkItFor(range.start);
1581   auto endIt = getChunkItFor(range.end);
1582 
1583   // Functor used to get the chunks for a given file, and fixup any locations
1584   auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) {
1585     size_t currentNumHints = inlayHints.size();
1586     chunkIt->document.getInlayHints(uri, range, inlayHints);
1587 
1588     // If this isn't the first chunk, update any positions to account for line
1589     // number differences.
1590     if (&*chunkIt != &*chunks.front()) {
1591       for (auto &hint : llvm::drop_begin(inlayHints, currentNumHints))
1592         chunkIt->adjustLocForChunkOffset(hint.position);
1593     }
1594   };
1595   // Returns the number of lines held by a given chunk.
1596   auto getNumLines = [](ChunkIterator chunkIt) {
1597     return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1598   };
1599 
1600   // Check if the range is fully within a single chunk.
1601   if (startIt == endIt)
1602     return getHintsForChunk(startIt, range);
1603 
1604   // Otherwise, the range is split between multiple chunks. The first chunk
1605   // has the correct range start, but covers the total document.
1606   getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt)));
1607 
1608   // Every chunk in between uses the full document.
1609   for (++startIt; startIt != endIt; ++startIt)
1610     getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt)));
1611 
1612   // The range for the last chunk starts at the beginning of the document, up
1613   // through the end of the input range.
1614   getHintsForChunk(startIt, lsp::Range(0, range.end));
1615 }
1616 
1617 lsp::PDLLViewOutputResult
1618 PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1619   lsp::PDLLViewOutputResult result;
1620   {
1621     llvm::raw_string_ostream outputOS(result.output);
1622     llvm::interleave(
1623         llvm::make_pointee_range(chunks),
1624         [&](PDLTextFileChunk &chunk) {
1625           chunk.document.getPDLLViewOutput(outputOS, kind);
1626         },
1627         [&] { outputOS << "\n"
1628                        << kDefaultSplitMarker << "\n\n"; });
1629   }
1630   return result;
1631 }
1632 
1633 void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
1634                              std::vector<lsp::Diagnostic> &diagnostics) {
1635   version = newVersion;
1636   chunks.clear();
1637 
1638   // Split the file into separate PDL documents.
1639   SmallVector<StringRef, 8> subContents;
1640   StringRef(contents).split(subContents, kDefaultSplitMarker);
1641   chunks.emplace_back(std::make_unique<PDLTextFileChunk>(
1642       /*lineOffset=*/0, uri, subContents.front(), extraIncludeDirs,
1643       diagnostics));
1644 
1645   uint64_t lineOffset = subContents.front().count('\n');
1646   for (StringRef docContents : llvm::drop_begin(subContents)) {
1647     unsigned currentNumDiags = diagnostics.size();
1648     auto chunk = std::make_unique<PDLTextFileChunk>(
1649         lineOffset, uri, docContents, extraIncludeDirs, diagnostics);
1650     lineOffset += docContents.count('\n');
1651 
1652     // Adjust locations used in diagnostics to account for the offset from the
1653     // beginning of the file.
1654     for (lsp::Diagnostic &diag :
1655          llvm::drop_begin(diagnostics, currentNumDiags)) {
1656       chunk->adjustLocForChunkOffset(diag.range);
1657 
1658       if (!diag.relatedInformation)
1659         continue;
1660       for (auto &it : *diag.relatedInformation)
1661         if (it.location.uri == uri)
1662           chunk->adjustLocForChunkOffset(it.location.range);
1663     }
1664     chunks.emplace_back(std::move(chunk));
1665   }
1666   totalNumLines = lineOffset;
1667 }
1668 
1669 PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) {
1670   if (chunks.size() == 1)
1671     return chunks.begin();
1672 
1673   // Search for the first chunk with a greater line offset, the previous chunk
1674   // is the one that contains `pos`.
1675   auto it = llvm::upper_bound(
1676       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1677         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1678       });
1679   ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1680   pos.line -= chunkIt->lineOffset;
1681   return chunkIt;
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // PDLLServer::Impl
1686 //===----------------------------------------------------------------------===//
1687 
1688 struct lsp::PDLLServer::Impl {
1689   explicit Impl(const Options &options)
1690       : options(options), compilationDatabase(options.compilationDatabases) {}
1691 
1692   /// PDLL LSP options.
1693   const Options &options;
1694 
1695   /// The compilation database containing additional information for files
1696   /// passed to the server.
1697   lsp::CompilationDatabase compilationDatabase;
1698 
1699   /// The files held by the server, mapped by their URI file name.
1700   llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1701 };
1702 
1703 //===----------------------------------------------------------------------===//
1704 // PDLLServer
1705 //===----------------------------------------------------------------------===//
1706 
1707 lsp::PDLLServer::PDLLServer(const Options &options)
1708     : impl(std::make_unique<Impl>(options)) {}
1709 lsp::PDLLServer::~PDLLServer() = default;
1710 
1711 void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
1712                                   int64_t version,
1713                                   std::vector<Diagnostic> &diagnostics) {
1714   // Build the set of additional include directories.
1715   std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1716   const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file());
1717   llvm::append_range(additionalIncludeDirs, fileInfo.includeDirs);
1718 
1719   impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1720       uri, contents, version, additionalIncludeDirs, diagnostics);
1721 }
1722 
1723 void lsp::PDLLServer::updateDocument(
1724     const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1725     int64_t version, std::vector<Diagnostic> &diagnostics) {
1726   // Check that we actually have a document for this uri.
1727   auto it = impl->files.find(uri.file());
1728   if (it == impl->files.end())
1729     return;
1730 
1731   // Try to update the document. If we fail, erase the file from the server. A
1732   // failed updated generally means we've fallen out of sync somewhere.
1733   if (failed(it->second->update(uri, version, changes, diagnostics)))
1734     impl->files.erase(it);
1735 }
1736 
1737 std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1738   auto it = impl->files.find(uri.file());
1739   if (it == impl->files.end())
1740     return std::nullopt;
1741 
1742   int64_t version = it->second->getVersion();
1743   impl->files.erase(it);
1744   return version;
1745 }
1746 
1747 void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
1748                                      const Position &defPos,
1749                                      std::vector<Location> &locations) {
1750   auto fileIt = impl->files.find(uri.file());
1751   if (fileIt != impl->files.end())
1752     fileIt->second->getLocationsOf(uri, defPos, locations);
1753 }
1754 
1755 void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
1756                                        const Position &pos,
1757                                        std::vector<Location> &references) {
1758   auto fileIt = impl->files.find(uri.file());
1759   if (fileIt != impl->files.end())
1760     fileIt->second->findReferencesOf(uri, pos, references);
1761 }
1762 
1763 void lsp::PDLLServer::getDocumentLinks(
1764     const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1765   auto fileIt = impl->files.find(uri.file());
1766   if (fileIt != impl->files.end())
1767     return fileIt->second->getDocumentLinks(uri, documentLinks);
1768 }
1769 
1770 std::optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
1771                                                      const Position &hoverPos) {
1772   auto fileIt = impl->files.find(uri.file());
1773   if (fileIt != impl->files.end())
1774     return fileIt->second->findHover(uri, hoverPos);
1775   return std::nullopt;
1776 }
1777 
1778 void lsp::PDLLServer::findDocumentSymbols(
1779     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1780   auto fileIt = impl->files.find(uri.file());
1781   if (fileIt != impl->files.end())
1782     fileIt->second->findDocumentSymbols(symbols);
1783 }
1784 
1785 lsp::CompletionList
1786 lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
1787                                    const Position &completePos) {
1788   auto fileIt = impl->files.find(uri.file());
1789   if (fileIt != impl->files.end())
1790     return fileIt->second->getCodeCompletion(uri, completePos);
1791   return CompletionList();
1792 }
1793 
1794 lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
1795                                                      const Position &helpPos) {
1796   auto fileIt = impl->files.find(uri.file());
1797   if (fileIt != impl->files.end())
1798     return fileIt->second->getSignatureHelp(uri, helpPos);
1799   return SignatureHelp();
1800 }
1801 
1802 void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1803                                     std::vector<InlayHint> &inlayHints) {
1804   auto fileIt = impl->files.find(uri.file());
1805   if (fileIt == impl->files.end())
1806     return;
1807   fileIt->second->getInlayHints(uri, range, inlayHints);
1808 
1809   // Drop any duplicated hints that may have cropped up.
1810   llvm::sort(inlayHints);
1811   inlayHints.erase(llvm::unique(inlayHints), inlayHints.end());
1812 }
1813 
1814 std::optional<lsp::PDLLViewOutputResult>
1815 lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri,
1816                                    PDLLViewOutputKind kind) {
1817   auto fileIt = impl->files.find(uri.file());
1818   if (fileIt != impl->files.end())
1819     return fileIt->second->getPDLLViewOutput(kind);
1820   return std::nullopt;
1821 }
1822