xref: /llvm-project/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp (revision e411c88b7223d87520af819fdc012c9dbb46e575)
1 //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "MLIRServer.h"
10 #include "Protocol.h"
11 #include "mlir/AsmParser/AsmParser.h"
12 #include "mlir/AsmParser/AsmParserState.h"
13 #include "mlir/AsmParser/CodeComplete.h"
14 #include "mlir/Bytecode/BytecodeWriter.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Interfaces/FunctionInterfaces.h"
17 #include "mlir/Parser/Parser.h"
18 #include "mlir/Support/ToolUtilities.h"
19 #include "mlir/Tools/lsp-server-support/Logging.h"
20 #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Base64.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include <optional>
25 
26 using namespace mlir;
27 
28 /// Returns the range of a lexical token given a SMLoc corresponding to the
29 /// start of an token location. The range is computed heuristically, and
30 /// supports identifier-like tokens, strings, etc.
convertTokenLocToRange(SMLoc loc)31 static SMRange convertTokenLocToRange(SMLoc loc) {
32   return lsp::convertTokenLocToRange(loc, "$-.");
33 }
34 
35 /// Returns a language server location from the given MLIR file location.
36 /// `uriScheme` is the scheme to use when building new uris.
getLocationFromLoc(StringRef uriScheme,FileLineColLoc loc)37 static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38                                                        FileLineColLoc loc) {
39   llvm::Expected<lsp::URIForFile> sourceURI =
40       lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
41   if (!sourceURI) {
42     lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43                        loc.getFilename(),
44                        llvm::toString(sourceURI.takeError()));
45     return std::nullopt;
46   }
47 
48   lsp::Position position;
49   position.line = loc.getLine() - 1;
50   position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51   return lsp::Location{*sourceURI, lsp::Range(position)};
52 }
53 
54 /// Returns a language server location from the given MLIR location, or
55 /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56 /// when building new uris. `uri` is an optional additional filter that, when
57 /// present, is used to filter sub locations that do not share the same uri.
58 static std::optional<lsp::Location>
getLocationFromLoc(llvm::SourceMgr & sourceMgr,Location loc,StringRef uriScheme,const lsp::URIForFile * uri=nullptr)59 getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60                    StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
61   std::optional<lsp::Location> location;
62   loc->walk([&](Location nestedLoc) {
63     FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64     if (!fileLoc)
65       return WalkResult::advance();
66 
67     std::optional<lsp::Location> sourceLoc =
68         getLocationFromLoc(uriScheme, fileLoc);
69     if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70       location = *sourceLoc;
71       SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72           sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
73 
74       // Use range of potential identifier starting at location, else length 1
75       // range.
76       location->range.end.character += 1;
77       if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78         auto lineCol = sourceMgr.getLineAndColumn(range->End);
79         location->range.end.character =
80             std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
81       }
82       return WalkResult::interrupt();
83     }
84     return WalkResult::advance();
85   });
86   return location;
87 }
88 
89 /// Collect all of the locations from the given MLIR location that are not
90 /// contained within the given URI.
collectLocationsFromLoc(Location loc,std::vector<lsp::Location> & locations,const lsp::URIForFile & uri)91 static void collectLocationsFromLoc(Location loc,
92                                     std::vector<lsp::Location> &locations,
93                                     const lsp::URIForFile &uri) {
94   SetVector<Location> visitedLocs;
95   loc->walk([&](Location nestedLoc) {
96     FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97     if (!fileLoc || !visitedLocs.insert(nestedLoc))
98       return WalkResult::advance();
99 
100     std::optional<lsp::Location> sourceLoc =
101         getLocationFromLoc(uri.scheme(), fileLoc);
102     if (sourceLoc && sourceLoc->uri != uri)
103       locations.push_back(*sourceLoc);
104     return WalkResult::advance();
105   });
106 }
107 
108 /// Returns true if the given range contains the given source location. Note
109 /// that this has slightly different behavior than SMRange because it is
110 /// inclusive of the end location.
contains(SMRange range,SMLoc loc)111 static bool contains(SMRange range, SMLoc loc) {
112   return range.Start.getPointer() <= loc.getPointer() &&
113          loc.getPointer() <= range.End.getPointer();
114 }
115 
116 /// Returns true if the given location is contained by the definition or one of
117 /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118 /// the range within `def` that the provided `loc` overlapped with.
isDefOrUse(const AsmParserState::SMDefinition & def,SMLoc loc,SMRange * overlappedRange=nullptr)119 static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
120                        SMRange *overlappedRange = nullptr) {
121   // Check the main definition.
122   if (contains(def.loc, loc)) {
123     if (overlappedRange)
124       *overlappedRange = def.loc;
125     return true;
126   }
127 
128   // Check the uses.
129   const auto *useIt = llvm::find_if(
130       def.uses, [&](const SMRange &range) { return contains(range, loc); });
131   if (useIt != def.uses.end()) {
132     if (overlappedRange)
133       *overlappedRange = *useIt;
134     return true;
135   }
136   return false;
137 }
138 
139 /// Given a location pointing to a result, return the result number it refers
140 /// to or std::nullopt if it refers to all of the results.
getResultNumberFromLoc(SMLoc loc)141 static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
142   // Skip all of the identifier characters.
143   auto isIdentifierChar = [](char c) {
144     return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
145            c == '-';
146   };
147   const char *curPtr = loc.getPointer();
148   while (isIdentifierChar(*curPtr))
149     ++curPtr;
150 
151   // Check to see if this location indexes into the result group, via `#`. If it
152   // doesn't, we can't extract a sub result number.
153   if (*curPtr != '#')
154     return std::nullopt;
155 
156   // Compute the sub result number from the remaining portion of the string.
157   const char *numberStart = ++curPtr;
158   while (llvm::isDigit(*curPtr))
159     ++curPtr;
160   StringRef numberStr(numberStart, curPtr - numberStart);
161   unsigned resultNumber = 0;
162   return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
163                                                     : resultNumber;
164 }
165 
166 /// Given a source location range, return the text covered by the given range.
167 /// If the range is invalid, returns std::nullopt.
getTextFromRange(SMRange range)168 static std::optional<StringRef> getTextFromRange(SMRange range) {
169   if (!range.isValid())
170     return std::nullopt;
171   const char *startPtr = range.Start.getPointer();
172   return StringRef(startPtr, range.End.getPointer() - startPtr);
173 }
174 
175 /// Given a block, return its position in its parent region.
getBlockNumber(Block * block)176 static unsigned getBlockNumber(Block *block) {
177   return std::distance(block->getParent()->begin(), block->getIterator());
178 }
179 
180 /// Given a block and source location, print the source name of the block to the
181 /// given output stream.
printDefBlockName(raw_ostream & os,Block * block,SMRange loc={})182 static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
183   // Try to extract a name from the source location.
184   std::optional<StringRef> text = getTextFromRange(loc);
185   if (text && text->starts_with("^")) {
186     os << *text;
187     return;
188   }
189 
190   // Otherwise, we don't have a name so print the block number.
191   os << "<Block #" << getBlockNumber(block) << ">";
192 }
printDefBlockName(raw_ostream & os,const AsmParserState::BlockDefinition & def)193 static void printDefBlockName(raw_ostream &os,
194                               const AsmParserState::BlockDefinition &def) {
195   printDefBlockName(os, def.block, def.definition.loc);
196 }
197 
198 /// Convert the given MLIR diagnostic to the LSP form.
getLspDiagnoticFromDiag(llvm::SourceMgr & sourceMgr,Diagnostic & diag,const lsp::URIForFile & uri)199 static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
200                                                Diagnostic &diag,
201                                                const lsp::URIForFile &uri) {
202   lsp::Diagnostic lspDiag;
203   lspDiag.source = "mlir";
204 
205   // Note: Right now all of the diagnostics are treated as parser issues, but
206   // some are parser and some are verifier.
207   lspDiag.category = "Parse Error";
208 
209   // Try to grab a file location for this diagnostic.
210   // TODO: For simplicity, we just grab the first one. It may be likely that we
211   // will need a more interesting heuristic here.'
212   StringRef uriScheme = uri.scheme();
213   std::optional<lsp::Location> lspLocation =
214       getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
215   if (lspLocation)
216     lspDiag.range = lspLocation->range;
217 
218   // Convert the severity for the diagnostic.
219   switch (diag.getSeverity()) {
220   case DiagnosticSeverity::Note:
221     llvm_unreachable("expected notes to be handled separately");
222   case DiagnosticSeverity::Warning:
223     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
224     break;
225   case DiagnosticSeverity::Error:
226     lspDiag.severity = lsp::DiagnosticSeverity::Error;
227     break;
228   case DiagnosticSeverity::Remark:
229     lspDiag.severity = lsp::DiagnosticSeverity::Information;
230     break;
231   }
232   lspDiag.message = diag.str();
233 
234   // Attach any notes to the main diagnostic as related information.
235   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
236   for (Diagnostic &note : diag.getNotes()) {
237     lsp::Location noteLoc;
238     if (std::optional<lsp::Location> loc =
239             getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
240       noteLoc = *loc;
241     else
242       noteLoc.uri = uri;
243     relatedDiags.emplace_back(noteLoc, note.str());
244   }
245   if (!relatedDiags.empty())
246     lspDiag.relatedInformation = std::move(relatedDiags);
247 
248   return lspDiag;
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // MLIRDocument
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 /// This class represents all of the information pertaining to a specific MLIR
257 /// document.
258 struct MLIRDocument {
259   MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
260                StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261   MLIRDocument(const MLIRDocument &) = delete;
262   MLIRDocument &operator=(const MLIRDocument &) = delete;
263 
264   //===--------------------------------------------------------------------===//
265   // Definitions and References
266   //===--------------------------------------------------------------------===//
267 
268   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
269                       std::vector<lsp::Location> &locations);
270   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
271                         std::vector<lsp::Location> &references);
272 
273   //===--------------------------------------------------------------------===//
274   // Hover
275   //===--------------------------------------------------------------------===//
276 
277   std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
278                                       const lsp::Position &hoverPos);
279   std::optional<lsp::Hover>
280   buildHoverForOperation(SMRange hoverRange,
281                          const AsmParserState::OperationDefinition &op);
282   lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
283                                           unsigned resultStart,
284                                           unsigned resultEnd, SMLoc posLoc);
285   lsp::Hover buildHoverForBlock(SMRange hoverRange,
286                                 const AsmParserState::BlockDefinition &block);
287   lsp::Hover
288   buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
289                              const AsmParserState::BlockDefinition &block);
290 
291   lsp::Hover buildHoverForAttributeAlias(
292       SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
293   lsp::Hover
294   buildHoverForTypeAlias(SMRange hoverRange,
295                          const AsmParserState::TypeAliasDefinition &type);
296 
297   //===--------------------------------------------------------------------===//
298   // Document Symbols
299   //===--------------------------------------------------------------------===//
300 
301   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302   void findDocumentSymbols(Operation *op,
303                            std::vector<lsp::DocumentSymbol> &symbols);
304 
305   //===--------------------------------------------------------------------===//
306   // Code Completion
307   //===--------------------------------------------------------------------===//
308 
309   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
310                                         const lsp::Position &completePos,
311                                         const DialectRegistry &registry);
312 
313   //===--------------------------------------------------------------------===//
314   // Code Action
315   //===--------------------------------------------------------------------===//
316 
317   void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
318                                   lsp::Position &pos, StringRef severity,
319                                   StringRef message,
320                                   std::vector<lsp::TextEdit> &edits);
321 
322   //===--------------------------------------------------------------------===//
323   // Bytecode
324   //===--------------------------------------------------------------------===//
325 
326   llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
327 
328   //===--------------------------------------------------------------------===//
329   // Fields
330   //===--------------------------------------------------------------------===//
331 
332   /// The high level parser state used to find definitions and references within
333   /// the source file.
334   AsmParserState asmState;
335 
336   /// The container for the IR parsed from the input file.
337   Block parsedIR;
338 
339   /// A collection of external resources, which we want to propagate up to the
340   /// user.
341   FallbackAsmResourceMap fallbackResourceMap;
342 
343   /// The source manager containing the contents of the input file.
344   llvm::SourceMgr sourceMgr;
345 };
346 } // namespace
347 
MLIRDocument(MLIRContext & context,const lsp::URIForFile & uri,StringRef contents,std::vector<lsp::Diagnostic> & diagnostics)348 MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
349                            StringRef contents,
350                            std::vector<lsp::Diagnostic> &diagnostics) {
351   ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
352     diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
353   });
354 
355   // Try to parsed the given IR string.
356   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
357   if (!memBuffer) {
358     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
359     return;
360   }
361 
362   ParserConfig config(&context, /*verifyAfterParse=*/true,
363                       &fallbackResourceMap);
364   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
365   if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
366     // If parsing failed, clear out any of the current state.
367     parsedIR.clear();
368     asmState = AsmParserState();
369     fallbackResourceMap = FallbackAsmResourceMap();
370     return;
371   }
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // MLIRDocument: Definitions and References
376 //===----------------------------------------------------------------------===//
377 
getLocationsOf(const lsp::URIForFile & uri,const lsp::Position & defPos,std::vector<lsp::Location> & locations)378 void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
379                                   const lsp::Position &defPos,
380                                   std::vector<lsp::Location> &locations) {
381   SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
382 
383   // Functor used to check if an SM definition contains the position.
384   auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
385     if (!isDefOrUse(def, posLoc))
386       return false;
387     locations.emplace_back(uri, sourceMgr, def.loc);
388     return true;
389   };
390 
391   // Check all definitions related to operations.
392   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
393     if (contains(op.loc, posLoc))
394       return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
395     for (const auto &result : op.resultGroups)
396       if (containsPosition(result.definition))
397         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
398     for (const auto &symUse : op.symbolUses) {
399       if (contains(symUse, posLoc)) {
400         locations.emplace_back(uri, sourceMgr, op.loc);
401         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
402       }
403     }
404   }
405 
406   // Check all definitions related to blocks.
407   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
408     if (containsPosition(block.definition))
409       return;
410     for (const AsmParserState::SMDefinition &arg : block.arguments)
411       if (containsPosition(arg))
412         return;
413   }
414 
415   // Check all alias definitions.
416   for (const AsmParserState::AttributeAliasDefinition &attr :
417        asmState.getAttributeAliasDefs()) {
418     if (containsPosition(attr.definition))
419       return;
420   }
421   for (const AsmParserState::TypeAliasDefinition &type :
422        asmState.getTypeAliasDefs()) {
423     if (containsPosition(type.definition))
424       return;
425   }
426 }
427 
findReferencesOf(const lsp::URIForFile & uri,const lsp::Position & pos,std::vector<lsp::Location> & references)428 void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
429                                     const lsp::Position &pos,
430                                     std::vector<lsp::Location> &references) {
431   // Functor used to append all of the definitions/uses of the given SM
432   // definition to the reference list.
433   auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
434     references.emplace_back(uri, sourceMgr, def.loc);
435     for (const SMRange &use : def.uses)
436       references.emplace_back(uri, sourceMgr, use);
437   };
438 
439   SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
440 
441   // Check all definitions related to operations.
442   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
443     if (contains(op.loc, posLoc)) {
444       for (const auto &result : op.resultGroups)
445         appendSMDef(result.definition);
446       for (const auto &symUse : op.symbolUses)
447         if (contains(symUse, posLoc))
448           references.emplace_back(uri, sourceMgr, symUse);
449       return;
450     }
451     for (const auto &result : op.resultGroups)
452       if (isDefOrUse(result.definition, posLoc))
453         return appendSMDef(result.definition);
454     for (const auto &symUse : op.symbolUses) {
455       if (!contains(symUse, posLoc))
456         continue;
457       for (const auto &symUse : op.symbolUses)
458         references.emplace_back(uri, sourceMgr, symUse);
459       return;
460     }
461   }
462 
463   // Check all definitions related to blocks.
464   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
465     if (isDefOrUse(block.definition, posLoc))
466       return appendSMDef(block.definition);
467 
468     for (const AsmParserState::SMDefinition &arg : block.arguments)
469       if (isDefOrUse(arg, posLoc))
470         return appendSMDef(arg);
471   }
472 
473   // Check all alias definitions.
474   for (const AsmParserState::AttributeAliasDefinition &attr :
475        asmState.getAttributeAliasDefs()) {
476     if (isDefOrUse(attr.definition, posLoc))
477       return appendSMDef(attr.definition);
478   }
479   for (const AsmParserState::TypeAliasDefinition &type :
480        asmState.getTypeAliasDefs()) {
481     if (isDefOrUse(type.definition, posLoc))
482       return appendSMDef(type.definition);
483   }
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // MLIRDocument: Hover
488 //===----------------------------------------------------------------------===//
489 
490 std::optional<lsp::Hover>
findHover(const lsp::URIForFile & uri,const lsp::Position & hoverPos)491 MLIRDocument::findHover(const lsp::URIForFile &uri,
492                         const lsp::Position &hoverPos) {
493   SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
494   SMRange hoverRange;
495 
496   // Check for Hovers on operations and results.
497   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
498     // Check if the position points at this operation.
499     if (contains(op.loc, posLoc))
500       return buildHoverForOperation(op.loc, op);
501 
502     // Check if the position points at the symbol name.
503     for (auto &use : op.symbolUses)
504       if (contains(use, posLoc))
505         return buildHoverForOperation(use, op);
506 
507     // Check if the position points at a result group.
508     for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
509       const auto &result = op.resultGroups[i];
510       if (!isDefOrUse(result.definition, posLoc, &hoverRange))
511         continue;
512 
513       // Get the range of results covered by the over position.
514       unsigned resultStart = result.startIndex;
515       unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
516                                         : op.resultGroups[i + 1].startIndex;
517       return buildHoverForOperationResult(hoverRange, op.op, resultStart,
518                                           resultEnd, posLoc);
519     }
520   }
521 
522   // Check to see if the hover is over a block argument.
523   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
524     if (isDefOrUse(block.definition, posLoc, &hoverRange))
525       return buildHoverForBlock(hoverRange, block);
526 
527     for (const auto &arg : llvm::enumerate(block.arguments)) {
528       if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
529         continue;
530 
531       return buildHoverForBlockArgument(
532           hoverRange, block.block->getArgument(arg.index()), block);
533     }
534   }
535 
536   // Check to see if the hover is over an alias.
537   for (const AsmParserState::AttributeAliasDefinition &attr :
538        asmState.getAttributeAliasDefs()) {
539     if (isDefOrUse(attr.definition, posLoc, &hoverRange))
540       return buildHoverForAttributeAlias(hoverRange, attr);
541   }
542   for (const AsmParserState::TypeAliasDefinition &type :
543        asmState.getTypeAliasDefs()) {
544     if (isDefOrUse(type.definition, posLoc, &hoverRange))
545       return buildHoverForTypeAlias(hoverRange, type);
546   }
547 
548   return std::nullopt;
549 }
550 
buildHoverForOperation(SMRange hoverRange,const AsmParserState::OperationDefinition & op)551 std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
552     SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
553   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
554   llvm::raw_string_ostream os(hover.contents.value);
555 
556   // Add the operation name to the hover.
557   os << "\"" << op.op->getName() << "\"";
558   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
559     os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
560   os << "\n\n";
561 
562   os << "Generic Form:\n\n```mlir\n";
563 
564   op.op->print(os, OpPrintingFlags()
565                        .printGenericOpForm()
566                        .elideLargeElementsAttrs()
567                        .skipRegions());
568   os << "\n```\n";
569 
570   return hover;
571 }
572 
buildHoverForOperationResult(SMRange hoverRange,Operation * op,unsigned resultStart,unsigned resultEnd,SMLoc posLoc)573 lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
574                                                       Operation *op,
575                                                       unsigned resultStart,
576                                                       unsigned resultEnd,
577                                                       SMLoc posLoc) {
578   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579   llvm::raw_string_ostream os(hover.contents.value);
580 
581   // Add the parent operation name to the hover.
582   os << "Operation: \"" << op->getName() << "\"\n\n";
583 
584   // Check to see if the location points to a specific result within the
585   // group.
586   if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
587     if ((resultStart + *resultNumber) < resultEnd) {
588       resultStart += *resultNumber;
589       resultEnd = resultStart + 1;
590     }
591   }
592 
593   // Add the range of results and their types to the hover info.
594   if ((resultStart + 1) == resultEnd) {
595     os << "Result #" << resultStart << "\n\n"
596        << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
597   } else {
598     os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
599        << "Types: ";
600     llvm::interleaveComma(
601         op->getResults().slice(resultStart, resultEnd), os,
602         [&](Value result) { os << "`" << result.getType() << "`"; });
603   }
604 
605   return hover;
606 }
607 
608 lsp::Hover
buildHoverForBlock(SMRange hoverRange,const AsmParserState::BlockDefinition & block)609 MLIRDocument::buildHoverForBlock(SMRange hoverRange,
610                                  const AsmParserState::BlockDefinition &block) {
611   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
612   llvm::raw_string_ostream os(hover.contents.value);
613 
614   // Print the given block to the hover output stream.
615   auto printBlockToHover = [&](Block *newBlock) {
616     if (const auto *def = asmState.getBlockDef(newBlock))
617       printDefBlockName(os, *def);
618     else
619       printDefBlockName(os, newBlock);
620   };
621 
622   // Display the parent operation, block number, predecessors, and successors.
623   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
624      << "Block #" << getBlockNumber(block.block) << "\n\n";
625   if (!block.block->hasNoPredecessors()) {
626     os << "Predecessors: ";
627     llvm::interleaveComma(block.block->getPredecessors(), os,
628                           printBlockToHover);
629     os << "\n\n";
630   }
631   if (!block.block->hasNoSuccessors()) {
632     os << "Successors: ";
633     llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
634     os << "\n\n";
635   }
636 
637   return hover;
638 }
639 
buildHoverForBlockArgument(SMRange hoverRange,BlockArgument arg,const AsmParserState::BlockDefinition & block)640 lsp::Hover MLIRDocument::buildHoverForBlockArgument(
641     SMRange hoverRange, BlockArgument arg,
642     const AsmParserState::BlockDefinition &block) {
643   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
644   llvm::raw_string_ostream os(hover.contents.value);
645 
646   // Display the parent operation, block, the argument number, and the type.
647   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
648      << "Block: ";
649   printDefBlockName(os, block);
650   os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
651      << "Type: `" << arg.getType() << "`\n\n";
652 
653   return hover;
654 }
655 
buildHoverForAttributeAlias(SMRange hoverRange,const AsmParserState::AttributeAliasDefinition & attr)656 lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
657     SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
658   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
659   llvm::raw_string_ostream os(hover.contents.value);
660 
661   os << "Attribute Alias: \"" << attr.name << "\n\n";
662   os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
663 
664   return hover;
665 }
666 
buildHoverForTypeAlias(SMRange hoverRange,const AsmParserState::TypeAliasDefinition & type)667 lsp::Hover MLIRDocument::buildHoverForTypeAlias(
668     SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
669   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
670   llvm::raw_string_ostream os(hover.contents.value);
671 
672   os << "Type Alias: \"" << type.name << "\n\n";
673   os << "Value: ```mlir\n" << type.value << "\n```\n\n";
674 
675   return hover;
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // MLIRDocument: Document Symbols
680 //===----------------------------------------------------------------------===//
681 
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)682 void MLIRDocument::findDocumentSymbols(
683     std::vector<lsp::DocumentSymbol> &symbols) {
684   for (Operation &op : parsedIR)
685     findDocumentSymbols(&op, symbols);
686 }
687 
findDocumentSymbols(Operation * op,std::vector<lsp::DocumentSymbol> & symbols)688 void MLIRDocument::findDocumentSymbols(
689     Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
690   std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
691 
692   // Check for the source information of this operation.
693   if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
694     // If this operation defines a symbol, record it.
695     if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
696       symbols.emplace_back(symbol.getName(),
697                            isa<FunctionOpInterface>(op)
698                                ? lsp::SymbolKind::Function
699                                : lsp::SymbolKind::Class,
700                            lsp::Range(sourceMgr, def->scopeLoc),
701                            lsp::Range(sourceMgr, def->loc));
702       childSymbols = &symbols.back().children;
703 
704     } else if (op->hasTrait<OpTrait::SymbolTable>()) {
705       // Otherwise, if this is a symbol table push an anonymous document symbol.
706       symbols.emplace_back("<" + op->getName().getStringRef() + ">",
707                            lsp::SymbolKind::Namespace,
708                            lsp::Range(sourceMgr, def->scopeLoc),
709                            lsp::Range(sourceMgr, def->loc));
710       childSymbols = &symbols.back().children;
711     }
712   }
713 
714   // Recurse into the regions of this operation.
715   if (!op->getNumRegions())
716     return;
717   for (Region &region : op->getRegions())
718     for (Operation &childOp : region.getOps())
719       findDocumentSymbols(&childOp, *childSymbols);
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // MLIRDocument: Code Completion
724 //===----------------------------------------------------------------------===//
725 
726 namespace {
727 class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
728 public:
LSPCodeCompleteContext(SMLoc completeLoc,lsp::CompletionList & completionList,MLIRContext * ctx)729   LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
730                          MLIRContext *ctx)
731       : AsmParserCodeCompleteContext(completeLoc),
732         completionList(completionList), ctx(ctx) {}
733 
734   /// Signal code completion for a dialect name, with an optional prefix.
completeDialectName(StringRef prefix)735   void completeDialectName(StringRef prefix) final {
736     for (StringRef dialect : ctx->getAvailableDialects()) {
737       lsp::CompletionItem item(prefix + dialect,
738                                lsp::CompletionItemKind::Module,
739                                /*sortText=*/"3");
740       item.detail = "dialect";
741       completionList.items.emplace_back(item);
742     }
743   }
744   using AsmParserCodeCompleteContext::completeDialectName;
745 
746   /// Signal code completion for an operation name within the given dialect.
completeOperationName(StringRef dialectName)747   void completeOperationName(StringRef dialectName) final {
748     Dialect *dialect = ctx->getOrLoadDialect(dialectName);
749     if (!dialect)
750       return;
751 
752     for (const auto &op : ctx->getRegisteredOperations()) {
753       if (&op.getDialect() != dialect)
754         continue;
755 
756       lsp::CompletionItem item(
757           op.getStringRef().drop_front(dialectName.size() + 1),
758           lsp::CompletionItemKind::Field,
759           /*sortText=*/"1");
760       item.detail = "operation";
761       completionList.items.emplace_back(item);
762     }
763   }
764 
765   /// Append the given SSA value as a code completion result for SSA value
766   /// completions.
appendSSAValueCompletion(StringRef name,std::string typeData)767   void appendSSAValueCompletion(StringRef name, std::string typeData) final {
768     // Check if we need to insert the `%` or not.
769     bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
770 
771     lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
772     if (stripPrefix)
773       item.insertText = name.drop_front(1).str();
774     item.detail = std::move(typeData);
775     completionList.items.emplace_back(item);
776   }
777 
778   /// Append the given block as a code completion result for block name
779   /// completions.
appendBlockCompletion(StringRef name)780   void appendBlockCompletion(StringRef name) final {
781     // Check if we need to insert the `^` or not.
782     bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
783 
784     lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
785     if (stripPrefix)
786       item.insertText = name.drop_front(1).str();
787     completionList.items.emplace_back(item);
788   }
789 
790   /// Signal a completion for the given expected token.
completeExpectedTokens(ArrayRef<StringRef> tokens,bool optional)791   void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
792     for (StringRef token : tokens) {
793       lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
794                                /*sortText=*/"0");
795       item.detail = optional ? "optional" : "";
796       completionList.items.emplace_back(item);
797     }
798   }
799 
800   /// Signal a completion for an attribute.
completeAttribute(const llvm::StringMap<Attribute> & aliases)801   void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
802     appendSimpleCompletions({"affine_set", "affine_map", "dense",
803                              "dense_resource", "false", "loc", "sparse", "true",
804                              "unit"},
805                             lsp::CompletionItemKind::Field,
806                             /*sortText=*/"1");
807 
808     completeDialectName("#");
809     completeAliases(aliases, "#");
810   }
completeDialectAttributeOrAlias(const llvm::StringMap<Attribute> & aliases)811   void completeDialectAttributeOrAlias(
812       const llvm::StringMap<Attribute> &aliases) override {
813     completeDialectName();
814     completeAliases(aliases);
815   }
816 
817   /// Signal a completion for a type.
completeType(const llvm::StringMap<Type> & aliases)818   void completeType(const llvm::StringMap<Type> &aliases) override {
819     // Handle the various builtin types.
820     appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
821                              "bf16", "f16", "f32", "f64", "f80", "f128",
822                              "index", "none"},
823                             lsp::CompletionItemKind::Field,
824                             /*sortText=*/"1");
825 
826     // Handle the builtin integer types.
827     for (StringRef type : {"i", "si", "ui"}) {
828       lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field,
829                                /*sortText=*/"1");
830       item.insertText = type.str();
831       completionList.items.emplace_back(item);
832     }
833 
834     // Insert completions for dialect types and aliases.
835     completeDialectName("!");
836     completeAliases(aliases, "!");
837   }
838   void
completeDialectTypeOrAlias(const llvm::StringMap<Type> & aliases)839   completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
840     completeDialectName();
841     completeAliases(aliases);
842   }
843 
844   /// Add completion results for the given set of aliases.
845   template <typename T>
completeAliases(const llvm::StringMap<T> & aliases,StringRef prefix="")846   void completeAliases(const llvm::StringMap<T> &aliases,
847                        StringRef prefix = "") {
848     for (const auto &alias : aliases) {
849       lsp::CompletionItem item(prefix + alias.getKey(),
850                                lsp::CompletionItemKind::Field,
851                                /*sortText=*/"2");
852       llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
853       completionList.items.emplace_back(item);
854     }
855   }
856 
857   /// Add a set of simple completions that all have the same kind.
appendSimpleCompletions(ArrayRef<StringRef> completions,lsp::CompletionItemKind kind,StringRef sortText="")858   void appendSimpleCompletions(ArrayRef<StringRef> completions,
859                                lsp::CompletionItemKind kind,
860                                StringRef sortText = "") {
861     for (StringRef completion : completions)
862       completionList.items.emplace_back(completion, kind, sortText);
863   }
864 
865 private:
866   lsp::CompletionList &completionList;
867   MLIRContext *ctx;
868 };
869 } // namespace
870 
871 lsp::CompletionList
getCodeCompletion(const lsp::URIForFile & uri,const lsp::Position & completePos,const DialectRegistry & registry)872 MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
873                                 const lsp::Position &completePos,
874                                 const DialectRegistry &registry) {
875   SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
876   if (!posLoc.isValid())
877     return lsp::CompletionList();
878 
879   // To perform code completion, we run another parse of the module with the
880   // code completion context provided.
881   MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
882   tmpContext.allowUnregisteredDialects();
883   lsp::CompletionList completionList;
884   LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
885                                             &tmpContext);
886 
887   Block tmpIR;
888   AsmParserState tmpState;
889   (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
890                            &lspCompleteContext);
891   return completionList;
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // MLIRDocument: Code Action
896 //===----------------------------------------------------------------------===//
897 
getCodeActionForDiagnostic(const lsp::URIForFile & uri,lsp::Position & pos,StringRef severity,StringRef message,std::vector<lsp::TextEdit> & edits)898 void MLIRDocument::getCodeActionForDiagnostic(
899     const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
900     StringRef message, std::vector<lsp::TextEdit> &edits) {
901   // Ignore diagnostics that print the current operation. These are always
902   // enabled for the language server, but not generally during normal
903   // parsing/verification.
904   if (message.starts_with("see current operation: "))
905     return;
906 
907   // Get the start of the line containing the diagnostic.
908   const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
909   const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
910   if (!lineStart)
911     return;
912   StringRef line(lineStart, pos.character);
913 
914   // Add a text edit for adding an expected-* diagnostic check for this
915   // diagnostic.
916   lsp::TextEdit edit;
917   edit.range = lsp::Range(lsp::Position(pos.line, 0));
918 
919   // Use the indent of the current line for the expected-* diagnostic.
920   size_t indent = line.find_first_not_of(' ');
921   if (indent == StringRef::npos)
922     indent = line.size();
923 
924   edit.newText.append(indent, ' ');
925   llvm::raw_string_ostream(edit.newText)
926       << "// expected-" << severity << " @below {{" << message << "}}\n";
927   edits.emplace_back(std::move(edit));
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // MLIRDocument: Bytecode
932 //===----------------------------------------------------------------------===//
933 
934 llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertToBytecode()935 MLIRDocument::convertToBytecode() {
936   // TODO: We currently require a single top-level operation, but this could
937   // conceptually be relaxed.
938   if (!llvm::hasSingleElement(parsedIR)) {
939     if (parsedIR.empty()) {
940       return llvm::make_error<lsp::LSPError>(
941           "expected a single and valid top-level operation, please ensure "
942           "there are no errors",
943           lsp::ErrorCode::RequestFailed);
944     }
945     return llvm::make_error<lsp::LSPError>(
946         "expected a single top-level operation", lsp::ErrorCode::RequestFailed);
947   }
948 
949   lsp::MLIRConvertBytecodeResult result;
950   {
951     BytecodeWriterConfig writerConfig(fallbackResourceMap);
952 
953     std::string rawBytecodeBuffer;
954     llvm::raw_string_ostream os(rawBytecodeBuffer);
955     // No desired bytecode version set, so no need to check for error.
956     (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
957     result.output = llvm::encodeBase64(rawBytecodeBuffer);
958   }
959   return result;
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // MLIRTextFileChunk
964 //===----------------------------------------------------------------------===//
965 
966 namespace {
967 /// This class represents a single chunk of an MLIR text file.
968 struct MLIRTextFileChunk {
MLIRTextFileChunk__anon05235aca0c11::MLIRTextFileChunk969   MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
970                     const lsp::URIForFile &uri, StringRef contents,
971                     std::vector<lsp::Diagnostic> &diagnostics)
972       : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
973 
974   /// Adjust the line number of the given range to anchor at the beginning of
975   /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon05235aca0c11::MLIRTextFileChunk976   void adjustLocForChunkOffset(lsp::Range &range) {
977     adjustLocForChunkOffset(range.start);
978     adjustLocForChunkOffset(range.end);
979   }
980   /// Adjust the line number of the given position to anchor at the beginning of
981   /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon05235aca0c11::MLIRTextFileChunk982   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
983 
984   /// The line offset of this chunk from the beginning of the file.
985   uint64_t lineOffset;
986   /// The document referred to by this chunk.
987   MLIRDocument document;
988 };
989 } // namespace
990 
991 //===----------------------------------------------------------------------===//
992 // MLIRTextFile
993 //===----------------------------------------------------------------------===//
994 
995 namespace {
996 /// This class represents a text file containing one or more MLIR documents.
997 class MLIRTextFile {
998 public:
999   MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1000                int64_t version, DialectRegistry &registry,
1001                std::vector<lsp::Diagnostic> &diagnostics);
1002 
1003   /// Return the current version of this text file.
getVersion() const1004   int64_t getVersion() const { return version; }
1005 
1006   //===--------------------------------------------------------------------===//
1007   // LSP Queries
1008   //===--------------------------------------------------------------------===//
1009 
1010   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1011                       std::vector<lsp::Location> &locations);
1012   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1013                         std::vector<lsp::Location> &references);
1014   std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1015                                       lsp::Position hoverPos);
1016   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1017   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1018                                         lsp::Position completePos);
1019   void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1020                       const lsp::CodeActionContext &context,
1021                       std::vector<lsp::CodeAction> &actions);
1022   llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1023 
1024 private:
1025   /// Find the MLIR document that contains the given position, and update the
1026   /// position to be anchored at the start of the found chunk instead of the
1027   /// beginning of the file.
1028   MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1029 
1030   /// The context used to hold the state contained by the parsed document.
1031   MLIRContext context;
1032 
1033   /// The full string contents of the file.
1034   std::string contents;
1035 
1036   /// The version of this file.
1037   int64_t version;
1038 
1039   /// The number of lines in the file.
1040   int64_t totalNumLines = 0;
1041 
1042   /// The chunks of this file. The order of these chunks is the order in which
1043   /// they appear in the text file.
1044   std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1045 };
1046 } // namespace
1047 
MLIRTextFile(const lsp::URIForFile & uri,StringRef fileContents,int64_t version,DialectRegistry & registry,std::vector<lsp::Diagnostic> & diagnostics)1048 MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1049                            int64_t version, DialectRegistry &registry,
1050                            std::vector<lsp::Diagnostic> &diagnostics)
1051     : context(registry, MLIRContext::Threading::DISABLED),
1052       contents(fileContents.str()), version(version) {
1053   context.allowUnregisteredDialects();
1054 
1055   // Split the file into separate MLIR documents.
1056   SmallVector<StringRef, 8> subContents;
1057   StringRef(contents).split(subContents, kDefaultSplitMarker);
1058   chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
1059       context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
1060 
1061   uint64_t lineOffset = subContents.front().count('\n');
1062   for (StringRef docContents : llvm::drop_begin(subContents)) {
1063     unsigned currentNumDiags = diagnostics.size();
1064     auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
1065                                                      docContents, diagnostics);
1066     lineOffset += docContents.count('\n');
1067 
1068     // Adjust locations used in diagnostics to account for the offset from the
1069     // beginning of the file.
1070     for (lsp::Diagnostic &diag :
1071          llvm::drop_begin(diagnostics, currentNumDiags)) {
1072       chunk->adjustLocForChunkOffset(diag.range);
1073 
1074       if (!diag.relatedInformation)
1075         continue;
1076       for (auto &it : *diag.relatedInformation)
1077         if (it.location.uri == uri)
1078           chunk->adjustLocForChunkOffset(it.location.range);
1079     }
1080     chunks.emplace_back(std::move(chunk));
1081   }
1082   totalNumLines = lineOffset;
1083 }
1084 
getLocationsOf(const lsp::URIForFile & uri,lsp::Position defPos,std::vector<lsp::Location> & locations)1085 void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1086                                   lsp::Position defPos,
1087                                   std::vector<lsp::Location> &locations) {
1088   MLIRTextFileChunk &chunk = getChunkFor(defPos);
1089   chunk.document.getLocationsOf(uri, defPos, locations);
1090 
1091   // Adjust any locations within this file for the offset of this chunk.
1092   if (chunk.lineOffset == 0)
1093     return;
1094   for (lsp::Location &loc : locations)
1095     if (loc.uri == uri)
1096       chunk.adjustLocForChunkOffset(loc.range);
1097 }
1098 
findReferencesOf(const lsp::URIForFile & uri,lsp::Position pos,std::vector<lsp::Location> & references)1099 void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1100                                     lsp::Position pos,
1101                                     std::vector<lsp::Location> &references) {
1102   MLIRTextFileChunk &chunk = getChunkFor(pos);
1103   chunk.document.findReferencesOf(uri, pos, references);
1104 
1105   // Adjust any locations within this file for the offset of this chunk.
1106   if (chunk.lineOffset == 0)
1107     return;
1108   for (lsp::Location &loc : references)
1109     if (loc.uri == uri)
1110       chunk.adjustLocForChunkOffset(loc.range);
1111 }
1112 
findHover(const lsp::URIForFile & uri,lsp::Position hoverPos)1113 std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1114                                                   lsp::Position hoverPos) {
1115   MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
1116   std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1117 
1118   // Adjust any locations within this file for the offset of this chunk.
1119   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1120     chunk.adjustLocForChunkOffset(*hoverInfo->range);
1121   return hoverInfo;
1122 }
1123 
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)1124 void MLIRTextFile::findDocumentSymbols(
1125     std::vector<lsp::DocumentSymbol> &symbols) {
1126   if (chunks.size() == 1)
1127     return chunks.front()->document.findDocumentSymbols(symbols);
1128 
1129   // If there are multiple chunks in this file, we create top-level symbols for
1130   // each chunk.
1131   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1132     MLIRTextFileChunk &chunk = *chunks[i];
1133     lsp::Position startPos(chunk.lineOffset);
1134     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1135                                       : chunks[i + 1]->lineOffset);
1136     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1137                                lsp::SymbolKind::Namespace,
1138                                /*range=*/lsp::Range(startPos, endPos),
1139                                /*selectionRange=*/lsp::Range(startPos));
1140     chunk.document.findDocumentSymbols(symbol.children);
1141 
1142     // Fixup the locations of document symbols within this chunk.
1143     if (i != 0) {
1144       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1145       for (lsp::DocumentSymbol &childSymbol : symbol.children)
1146         symbolsToFix.push_back(&childSymbol);
1147 
1148       while (!symbolsToFix.empty()) {
1149         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1150         chunk.adjustLocForChunkOffset(symbol->range);
1151         chunk.adjustLocForChunkOffset(symbol->selectionRange);
1152 
1153         for (lsp::DocumentSymbol &childSymbol : symbol->children)
1154           symbolsToFix.push_back(&childSymbol);
1155       }
1156     }
1157 
1158     // Push the symbol for this chunk.
1159     symbols.emplace_back(std::move(symbol));
1160   }
1161 }
1162 
getCodeCompletion(const lsp::URIForFile & uri,lsp::Position completePos)1163 lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1164                                                     lsp::Position completePos) {
1165   MLIRTextFileChunk &chunk = getChunkFor(completePos);
1166   lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1167       uri, completePos, context.getDialectRegistry());
1168 
1169   // Adjust any completion locations.
1170   for (lsp::CompletionItem &item : completionList.items) {
1171     if (item.textEdit)
1172       chunk.adjustLocForChunkOffset(item.textEdit->range);
1173     for (lsp::TextEdit &edit : item.additionalTextEdits)
1174       chunk.adjustLocForChunkOffset(edit.range);
1175   }
1176   return completionList;
1177 }
1178 
getCodeActions(const lsp::URIForFile & uri,const lsp::Range & pos,const lsp::CodeActionContext & context,std::vector<lsp::CodeAction> & actions)1179 void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1180                                   const lsp::Range &pos,
1181                                   const lsp::CodeActionContext &context,
1182                                   std::vector<lsp::CodeAction> &actions) {
1183   // Create actions for any diagnostics in this file.
1184   for (auto &diag : context.diagnostics) {
1185     if (diag.source != "mlir")
1186       continue;
1187     lsp::Position diagPos = diag.range.start;
1188     MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1189 
1190     // Add a new code action that inserts a "expected" diagnostic check.
1191     lsp::CodeAction action;
1192     action.title = "Add expected-* diagnostic checks";
1193     action.kind = lsp::CodeAction::kQuickFix.str();
1194 
1195     StringRef severity;
1196     switch (diag.severity) {
1197     case lsp::DiagnosticSeverity::Error:
1198       severity = "error";
1199       break;
1200     case lsp::DiagnosticSeverity::Warning:
1201       severity = "warning";
1202       break;
1203     default:
1204       continue;
1205     }
1206 
1207     // Get edits for the diagnostic.
1208     std::vector<lsp::TextEdit> edits;
1209     chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1210                                               diag.message, edits);
1211 
1212     // Walk the related diagnostics, this is how we encode notes.
1213     if (diag.relatedInformation) {
1214       for (auto &noteDiag : *diag.relatedInformation) {
1215         if (noteDiag.location.uri != uri)
1216           continue;
1217         diagPos = noteDiag.location.range.start;
1218         diagPos.line -= chunk.lineOffset;
1219         chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1220                                                   noteDiag.message, edits);
1221       }
1222     }
1223     // Fixup the locations for any edits.
1224     for (lsp::TextEdit &edit : edits)
1225       chunk.adjustLocForChunkOffset(edit.range);
1226 
1227     action.edit.emplace();
1228     action.edit->changes[uri.uri().str()] = std::move(edits);
1229     action.diagnostics = {diag};
1230 
1231     actions.emplace_back(std::move(action));
1232   }
1233 }
1234 
1235 llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertToBytecode()1236 MLIRTextFile::convertToBytecode() {
1237   // Bail out if there is more than one chunk, bytecode wants a single module.
1238   if (chunks.size() != 1) {
1239     return llvm::make_error<lsp::LSPError>(
1240         "unexpected split file, please remove all `// -----`",
1241         lsp::ErrorCode::RequestFailed);
1242   }
1243   return chunks.front()->document.convertToBytecode();
1244 }
1245 
getChunkFor(lsp::Position & pos)1246 MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1247   if (chunks.size() == 1)
1248     return *chunks.front();
1249 
1250   // Search for the first chunk with a greater line offset, the previous chunk
1251   // is the one that contains `pos`.
1252   auto it = llvm::upper_bound(
1253       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
1254         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1255       });
1256   MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1257   pos.line -= chunk.lineOffset;
1258   return chunk;
1259 }
1260 
1261 //===----------------------------------------------------------------------===//
1262 // MLIRServer::Impl
1263 //===----------------------------------------------------------------------===//
1264 
1265 struct lsp::MLIRServer::Impl {
Impllsp::MLIRServer::Impl1266   Impl(DialectRegistry &registry) : registry(registry) {}
1267 
1268   /// The registry containing dialects that can be recognized in parsed .mlir
1269   /// files.
1270   DialectRegistry &registry;
1271 
1272   /// The files held by the server, mapped by their URI file name.
1273   llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1274 };
1275 
1276 //===----------------------------------------------------------------------===//
1277 // MLIRServer
1278 //===----------------------------------------------------------------------===//
1279 
MLIRServer(DialectRegistry & registry)1280 lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1281     : impl(std::make_unique<Impl>(registry)) {}
1282 lsp::MLIRServer::~MLIRServer() = default;
1283 
addOrUpdateDocument(const URIForFile & uri,StringRef contents,int64_t version,std::vector<Diagnostic> & diagnostics)1284 void lsp::MLIRServer::addOrUpdateDocument(
1285     const URIForFile &uri, StringRef contents, int64_t version,
1286     std::vector<Diagnostic> &diagnostics) {
1287   impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288       uri, contents, version, impl->registry, diagnostics);
1289 }
1290 
removeDocument(const URIForFile & uri)1291 std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1292   auto it = impl->files.find(uri.file());
1293   if (it == impl->files.end())
1294     return std::nullopt;
1295 
1296   int64_t version = it->second->getVersion();
1297   impl->files.erase(it);
1298   return version;
1299 }
1300 
getLocationsOf(const URIForFile & uri,const Position & defPos,std::vector<Location> & locations)1301 void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
1302                                      const Position &defPos,
1303                                      std::vector<Location> &locations) {
1304   auto fileIt = impl->files.find(uri.file());
1305   if (fileIt != impl->files.end())
1306     fileIt->second->getLocationsOf(uri, defPos, locations);
1307 }
1308 
findReferencesOf(const URIForFile & uri,const Position & pos,std::vector<Location> & references)1309 void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
1310                                        const Position &pos,
1311                                        std::vector<Location> &references) {
1312   auto fileIt = impl->files.find(uri.file());
1313   if (fileIt != impl->files.end())
1314     fileIt->second->findReferencesOf(uri, pos, references);
1315 }
1316 
findHover(const URIForFile & uri,const Position & hoverPos)1317 std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1318                                                      const Position &hoverPos) {
1319   auto fileIt = impl->files.find(uri.file());
1320   if (fileIt != impl->files.end())
1321     return fileIt->second->findHover(uri, hoverPos);
1322   return std::nullopt;
1323 }
1324 
findDocumentSymbols(const URIForFile & uri,std::vector<DocumentSymbol> & symbols)1325 void lsp::MLIRServer::findDocumentSymbols(
1326     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1327   auto fileIt = impl->files.find(uri.file());
1328   if (fileIt != impl->files.end())
1329     fileIt->second->findDocumentSymbols(symbols);
1330 }
1331 
1332 lsp::CompletionList
getCodeCompletion(const URIForFile & uri,const Position & completePos)1333 lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
1334                                    const Position &completePos) {
1335   auto fileIt = impl->files.find(uri.file());
1336   if (fileIt != impl->files.end())
1337     return fileIt->second->getCodeCompletion(uri, completePos);
1338   return CompletionList();
1339 }
1340 
getCodeActions(const URIForFile & uri,const Range & pos,const CodeActionContext & context,std::vector<CodeAction> & actions)1341 void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1342                                      const CodeActionContext &context,
1343                                      std::vector<CodeAction> &actions) {
1344   auto fileIt = impl->files.find(uri.file());
1345   if (fileIt != impl->files.end())
1346     fileIt->second->getCodeActions(uri, pos, context, actions);
1347 }
1348 
1349 llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertFromBytecode(const URIForFile & uri)1350 lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
1351   MLIRContext tempContext(impl->registry);
1352   tempContext.allowUnregisteredDialects();
1353 
1354   // Collect any errors during parsing.
1355   std::string errorMsg;
1356   ScopedDiagnosticHandler diagHandler(
1357       &tempContext,
1358       [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1359 
1360   // Handling for external resources, which we want to propagate up to the user.
1361   FallbackAsmResourceMap fallbackResourceMap;
1362 
1363   // Setup the parser config.
1364   ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1365                             &fallbackResourceMap);
1366 
1367   // Try to parse the given source file.
1368   Block parsedBlock;
1369   if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1370     return llvm::make_error<lsp::LSPError>(
1371         "failed to parse bytecode source file: " + errorMsg,
1372         lsp::ErrorCode::RequestFailed);
1373   }
1374 
1375   // TODO: We currently expect a single top-level operation, but this could
1376   // conceptually be relaxed.
1377   if (!llvm::hasSingleElement(parsedBlock)) {
1378     return llvm::make_error<lsp::LSPError>(
1379         "expected bytecode to contain a single top-level operation",
1380         lsp::ErrorCode::RequestFailed);
1381   }
1382 
1383   // Print the module to a buffer.
1384   lsp::MLIRConvertBytecodeResult result;
1385   {
1386     // Extract the top-level op so that aliases get printed.
1387     // FIXME: We should be able to enable aliases without having to do this!
1388     OwningOpRef<Operation *> topOp = &parsedBlock.front();
1389     topOp->remove();
1390 
1391     AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1392                    /*locationMap=*/nullptr, &fallbackResourceMap);
1393 
1394     llvm::raw_string_ostream os(result.output);
1395     topOp->print(os, state);
1396   }
1397   return std::move(result);
1398 }
1399 
1400 llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertToBytecode(const URIForFile & uri)1401 lsp::MLIRServer::convertToBytecode(const URIForFile &uri) {
1402   auto fileIt = impl->files.find(uri.file());
1403   if (fileIt == impl->files.end()) {
1404     return llvm::make_error<lsp::LSPError>(
1405         "language server does not contain an entry for this source file",
1406         lsp::ErrorCode::RequestFailed);
1407   }
1408   return fileIt->second->convertToBytecode();
1409 }
1410