xref: /llvm-project/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp (revision e411c88b7223d87520af819fdc012c9dbb46e575)
1751c14fcSRiver Riddle //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2751c14fcSRiver Riddle //
3751c14fcSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4751c14fcSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5751c14fcSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6751c14fcSRiver Riddle //
7751c14fcSRiver Riddle //===----------------------------------------------------------------------===//
8751c14fcSRiver Riddle 
9751c14fcSRiver Riddle #include "MLIRServer.h"
10f7b8a70eSRiver Riddle #include "Protocol.h"
11c60b897dSRiver Riddle #include "mlir/AsmParser/AsmParser.h"
12c60b897dSRiver Riddle #include "mlir/AsmParser/AsmParserState.h"
13c60b897dSRiver Riddle #include "mlir/AsmParser/CodeComplete.h"
14f7b8a70eSRiver Riddle #include "mlir/Bytecode/BytecodeWriter.h"
15751c14fcSRiver Riddle #include "mlir/IR/Operation.h"
1634a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h"
17f7b8a70eSRiver Riddle #include "mlir/Parser/Parser.h"
18516ccce7SIngo Müller #include "mlir/Support/ToolUtilities.h"
19305d7185SRiver Riddle #include "mlir/Tools/lsp-server-support/Logging.h"
20305d7185SRiver Riddle #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
21b0abd489SElliot Goodrich #include "llvm/ADT/StringExtras.h"
22f7b8a70eSRiver Riddle #include "llvm/Support/Base64.h"
23751c14fcSRiver Riddle #include "llvm/Support/SourceMgr.h"
24a1fe1f5fSKazu Hirata #include <optional>
25751c14fcSRiver Riddle 
26751c14fcSRiver Riddle using namespace mlir;
27751c14fcSRiver Riddle 
28305d7185SRiver Riddle /// Returns the range of a lexical token given a SMLoc corresponding to the
29305d7185SRiver Riddle /// start of an token location. The range is computed heuristically, and
30305d7185SRiver Riddle /// supports identifier-like tokens, strings, etc.
convertTokenLocToRange(SMLoc loc)31305d7185SRiver Riddle static SMRange convertTokenLocToRange(SMLoc loc) {
32305d7185SRiver Riddle   return lsp::convertTokenLocToRange(loc, "$-.");
33305d7185SRiver Riddle }
34305d7185SRiver Riddle 
35751c14fcSRiver Riddle /// Returns a language server location from the given MLIR file location.
36f7b8a70eSRiver Riddle /// `uriScheme` is the scheme to use when building new uris.
getLocationFromLoc(StringRef uriScheme,FileLineColLoc loc)370a81ace0SKazu Hirata static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38f7b8a70eSRiver Riddle                                                        FileLineColLoc loc) {
39751c14fcSRiver Riddle   llvm::Expected<lsp::URIForFile> sourceURI =
40f7b8a70eSRiver Riddle       lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
41751c14fcSRiver Riddle   if (!sourceURI) {
42751c14fcSRiver Riddle     lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43751c14fcSRiver Riddle                        loc.getFilename(),
44751c14fcSRiver Riddle                        llvm::toString(sourceURI.takeError()));
451a36588eSKazu Hirata     return std::nullopt;
46751c14fcSRiver Riddle   }
47751c14fcSRiver Riddle 
48751c14fcSRiver Riddle   lsp::Position position;
49751c14fcSRiver Riddle   position.line = loc.getLine() - 1;
5032bf1f1dSRiver Riddle   position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51b3911cdfSRiver Riddle   return lsp::Location{*sourceURI, lsp::Range(position)};
52b3911cdfSRiver Riddle }
53b3911cdfSRiver Riddle 
54192d9dd7SKazu Hirata /// Returns a language server location from the given MLIR location, or
55192d9dd7SKazu Hirata /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56192d9dd7SKazu Hirata /// when building new uris. `uri` is an optional additional filter that, when
57192d9dd7SKazu Hirata /// present, is used to filter sub locations that do not share the same uri.
580a81ace0SKazu Hirata static std::optional<lsp::Location>
getLocationFromLoc(llvm::SourceMgr & sourceMgr,Location loc,StringRef uriScheme,const lsp::URIForFile * uri=nullptr)59644f722bSJacques Pienaar getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60f7b8a70eSRiver Riddle                    StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
610a81ace0SKazu Hirata   std::optional<lsp::Location> location;
62b3911cdfSRiver Riddle   loc->walk([&](Location nestedLoc) {
635550c821STres Popp     FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64b3911cdfSRiver Riddle     if (!fileLoc)
65b3911cdfSRiver Riddle       return WalkResult::advance();
66b3911cdfSRiver Riddle 
670a81ace0SKazu Hirata     std::optional<lsp::Location> sourceLoc =
680a81ace0SKazu Hirata         getLocationFromLoc(uriScheme, fileLoc);
69b3911cdfSRiver Riddle     if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70b3911cdfSRiver Riddle       location = *sourceLoc;
716842ec42SRiver Riddle       SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72644f722bSJacques Pienaar           sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn());
73644f722bSJacques Pienaar 
74644f722bSJacques Pienaar       // Use range of potential identifier starting at location, else length 1
75644f722bSJacques Pienaar       // range.
76644f722bSJacques Pienaar       location->range.end.character += 1;
77305d7185SRiver Riddle       if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78644f722bSJacques Pienaar         auto lineCol = sourceMgr.getLineAndColumn(range->End);
793f70b4e0SJacques Pienaar         location->range.end.character =
803f70b4e0SJacques Pienaar             std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
81644f722bSJacques Pienaar       }
82b3911cdfSRiver Riddle       return WalkResult::interrupt();
83b3911cdfSRiver Riddle     }
84b3911cdfSRiver Riddle     return WalkResult::advance();
85b3911cdfSRiver Riddle   });
86b3911cdfSRiver Riddle   return location;
87751c14fcSRiver Riddle }
88751c14fcSRiver Riddle 
89751c14fcSRiver Riddle /// Collect all of the locations from the given MLIR location that are not
90751c14fcSRiver Riddle /// contained within the given URI.
collectLocationsFromLoc(Location loc,std::vector<lsp::Location> & locations,const lsp::URIForFile & uri)91751c14fcSRiver Riddle static void collectLocationsFromLoc(Location loc,
92751c14fcSRiver Riddle                                     std::vector<lsp::Location> &locations,
93751c14fcSRiver Riddle                                     const lsp::URIForFile &uri) {
94751c14fcSRiver Riddle   SetVector<Location> visitedLocs;
95751c14fcSRiver Riddle   loc->walk([&](Location nestedLoc) {
965550c821STres Popp     FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97751c14fcSRiver Riddle     if (!fileLoc || !visitedLocs.insert(nestedLoc))
98751c14fcSRiver Riddle       return WalkResult::advance();
99751c14fcSRiver Riddle 
1000a81ace0SKazu Hirata     std::optional<lsp::Location> sourceLoc =
101f7b8a70eSRiver Riddle         getLocationFromLoc(uri.scheme(), fileLoc);
102751c14fcSRiver Riddle     if (sourceLoc && sourceLoc->uri != uri)
103751c14fcSRiver Riddle       locations.push_back(*sourceLoc);
104751c14fcSRiver Riddle     return WalkResult::advance();
105751c14fcSRiver Riddle   });
106751c14fcSRiver Riddle }
107751c14fcSRiver Riddle 
108751c14fcSRiver Riddle /// Returns true if the given range contains the given source location. Note
109751c14fcSRiver Riddle /// that this has slightly different behavior than SMRange because it is
110751c14fcSRiver Riddle /// inclusive of the end location.
contains(SMRange range,SMLoc loc)1116842ec42SRiver Riddle static bool contains(SMRange range, SMLoc loc) {
112751c14fcSRiver Riddle   return range.Start.getPointer() <= loc.getPointer() &&
113751c14fcSRiver Riddle          loc.getPointer() <= range.End.getPointer();
114751c14fcSRiver Riddle }
115751c14fcSRiver Riddle 
116751c14fcSRiver Riddle /// Returns true if the given location is contained by the definition or one of
1175c84195bSRiver Riddle /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
1185c84195bSRiver Riddle /// the range within `def` that the provided `loc` overlapped with.
isDefOrUse(const AsmParserState::SMDefinition & def,SMLoc loc,SMRange * overlappedRange=nullptr)1196842ec42SRiver Riddle static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
1206842ec42SRiver Riddle                        SMRange *overlappedRange = nullptr) {
1215c84195bSRiver Riddle   // Check the main definition.
1225c84195bSRiver Riddle   if (contains(def.loc, loc)) {
1235c84195bSRiver Riddle     if (overlappedRange)
1245c84195bSRiver Riddle       *overlappedRange = def.loc;
1255c84195bSRiver Riddle     return true;
1265c84195bSRiver Riddle   }
1275c84195bSRiver Riddle 
1285c84195bSRiver Riddle   // Check the uses.
1297bc52733SRiver Riddle   const auto *useIt = llvm::find_if(
1307bc52733SRiver Riddle       def.uses, [&](const SMRange &range) { return contains(range, loc); });
1315c84195bSRiver Riddle   if (useIt != def.uses.end()) {
1325c84195bSRiver Riddle     if (overlappedRange)
1335c84195bSRiver Riddle       *overlappedRange = *useIt;
1345c84195bSRiver Riddle     return true;
1355c84195bSRiver Riddle   }
1365c84195bSRiver Riddle   return false;
1375c84195bSRiver Riddle }
1385c84195bSRiver Riddle 
1395c84195bSRiver Riddle /// Given a location pointing to a result, return the result number it refers
140192d9dd7SKazu Hirata /// to or std::nullopt if it refers to all of the results.
getResultNumberFromLoc(SMLoc loc)1410a81ace0SKazu Hirata static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
1425c84195bSRiver Riddle   // Skip all of the identifier characters.
1435c84195bSRiver Riddle   auto isIdentifierChar = [](char c) {
1445c84195bSRiver Riddle     return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
1455c84195bSRiver Riddle            c == '-';
146751c14fcSRiver Riddle   };
1475c84195bSRiver Riddle   const char *curPtr = loc.getPointer();
1485c84195bSRiver Riddle   while (isIdentifierChar(*curPtr))
1495c84195bSRiver Riddle     ++curPtr;
1505c84195bSRiver Riddle 
1515c84195bSRiver Riddle   // Check to see if this location indexes into the result group, via `#`. If it
1525c84195bSRiver Riddle   // doesn't, we can't extract a sub result number.
1535c84195bSRiver Riddle   if (*curPtr != '#')
1541a36588eSKazu Hirata     return std::nullopt;
1555c84195bSRiver Riddle 
1565c84195bSRiver Riddle   // Compute the sub result number from the remaining portion of the string.
1575c84195bSRiver Riddle   const char *numberStart = ++curPtr;
1585c84195bSRiver Riddle   while (llvm::isDigit(*curPtr))
1595c84195bSRiver Riddle     ++curPtr;
1605c84195bSRiver Riddle   StringRef numberStr(numberStart, curPtr - numberStart);
1615c84195bSRiver Riddle   unsigned resultNumber = 0;
1620a81ace0SKazu Hirata   return numberStr.consumeInteger(10, resultNumber) ? std::optional<unsigned>()
1635c84195bSRiver Riddle                                                     : resultNumber;
1645c84195bSRiver Riddle }
1655c84195bSRiver Riddle 
1665c84195bSRiver Riddle /// Given a source location range, return the text covered by the given range.
16770c73d1bSKazu Hirata /// If the range is invalid, returns std::nullopt.
getTextFromRange(SMRange range)1680a81ace0SKazu Hirata static std::optional<StringRef> getTextFromRange(SMRange range) {
1695c84195bSRiver Riddle   if (!range.isValid())
1701a36588eSKazu Hirata     return std::nullopt;
1715c84195bSRiver Riddle   const char *startPtr = range.Start.getPointer();
1725c84195bSRiver Riddle   return StringRef(startPtr, range.End.getPointer() - startPtr);
1735c84195bSRiver Riddle }
1745c84195bSRiver Riddle 
1755c84195bSRiver Riddle /// Given a block, return its position in its parent region.
getBlockNumber(Block * block)1765c84195bSRiver Riddle static unsigned getBlockNumber(Block *block) {
1775c84195bSRiver Riddle   return std::distance(block->getParent()->begin(), block->getIterator());
1785c84195bSRiver Riddle }
1795c84195bSRiver Riddle 
1805c84195bSRiver Riddle /// Given a block and source location, print the source name of the block to the
1815c84195bSRiver Riddle /// given output stream.
printDefBlockName(raw_ostream & os,Block * block,SMRange loc={})1827bc52733SRiver Riddle static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
1835c84195bSRiver Riddle   // Try to extract a name from the source location.
1840a81ace0SKazu Hirata   std::optional<StringRef> text = getTextFromRange(loc);
18588d319a2SKazu Hirata   if (text && text->starts_with("^")) {
1865c84195bSRiver Riddle     os << *text;
1875c84195bSRiver Riddle     return;
1885c84195bSRiver Riddle   }
1895c84195bSRiver Riddle 
1905c84195bSRiver Riddle   // Otherwise, we don't have a name so print the block number.
1915c84195bSRiver Riddle   os << "<Block #" << getBlockNumber(block) << ">";
1925c84195bSRiver Riddle }
printDefBlockName(raw_ostream & os,const AsmParserState::BlockDefinition & def)1935c84195bSRiver Riddle static void printDefBlockName(raw_ostream &os,
1945c84195bSRiver Riddle                               const AsmParserState::BlockDefinition &def) {
1955c84195bSRiver Riddle   printDefBlockName(os, def.block, def.definition.loc);
196751c14fcSRiver Riddle }
197751c14fcSRiver Riddle 
198b3911cdfSRiver Riddle /// Convert the given MLIR diagnostic to the LSP form.
getLspDiagnoticFromDiag(llvm::SourceMgr & sourceMgr,Diagnostic & diag,const lsp::URIForFile & uri)199644f722bSJacques Pienaar static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
200644f722bSJacques Pienaar                                                Diagnostic &diag,
201b3911cdfSRiver Riddle                                                const lsp::URIForFile &uri) {
202b3911cdfSRiver Riddle   lsp::Diagnostic lspDiag;
203b3911cdfSRiver Riddle   lspDiag.source = "mlir";
204b3911cdfSRiver Riddle 
205b3911cdfSRiver Riddle   // Note: Right now all of the diagnostics are treated as parser issues, but
206b3911cdfSRiver Riddle   // some are parser and some are verifier.
207b3911cdfSRiver Riddle   lspDiag.category = "Parse Error";
208b3911cdfSRiver Riddle 
209b3911cdfSRiver Riddle   // Try to grab a file location for this diagnostic.
210b3911cdfSRiver Riddle   // TODO: For simplicity, we just grab the first one. It may be likely that we
211b3911cdfSRiver Riddle   // will need a more interesting heuristic here.'
212f7b8a70eSRiver Riddle   StringRef uriScheme = uri.scheme();
2130a81ace0SKazu Hirata   std::optional<lsp::Location> lspLocation =
214f7b8a70eSRiver Riddle       getLocationFromLoc(sourceMgr, diag.getLocation(), uriScheme, &uri);
215b3911cdfSRiver Riddle   if (lspLocation)
216b3911cdfSRiver Riddle     lspDiag.range = lspLocation->range;
217b3911cdfSRiver Riddle 
218b3911cdfSRiver Riddle   // Convert the severity for the diagnostic.
219b3911cdfSRiver Riddle   switch (diag.getSeverity()) {
220b3911cdfSRiver Riddle   case DiagnosticSeverity::Note:
221b3911cdfSRiver Riddle     llvm_unreachable("expected notes to be handled separately");
222b3911cdfSRiver Riddle   case DiagnosticSeverity::Warning:
223b3911cdfSRiver Riddle     lspDiag.severity = lsp::DiagnosticSeverity::Warning;
224b3911cdfSRiver Riddle     break;
225b3911cdfSRiver Riddle   case DiagnosticSeverity::Error:
226b3911cdfSRiver Riddle     lspDiag.severity = lsp::DiagnosticSeverity::Error;
227b3911cdfSRiver Riddle     break;
228b3911cdfSRiver Riddle   case DiagnosticSeverity::Remark:
229b3911cdfSRiver Riddle     lspDiag.severity = lsp::DiagnosticSeverity::Information;
230b3911cdfSRiver Riddle     break;
231b3911cdfSRiver Riddle   }
232b3911cdfSRiver Riddle   lspDiag.message = diag.str();
233b3911cdfSRiver Riddle 
234b3911cdfSRiver Riddle   // Attach any notes to the main diagnostic as related information.
235b3911cdfSRiver Riddle   std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
236b3911cdfSRiver Riddle   for (Diagnostic &note : diag.getNotes()) {
237b3911cdfSRiver Riddle     lsp::Location noteLoc;
2380a81ace0SKazu Hirata     if (std::optional<lsp::Location> loc =
239f7b8a70eSRiver Riddle             getLocationFromLoc(sourceMgr, note.getLocation(), uriScheme))
240b3911cdfSRiver Riddle       noteLoc = *loc;
241b3911cdfSRiver Riddle     else
242b3911cdfSRiver Riddle       noteLoc.uri = uri;
243b3911cdfSRiver Riddle     relatedDiags.emplace_back(noteLoc, note.str());
244b3911cdfSRiver Riddle   }
245b3911cdfSRiver Riddle   if (!relatedDiags.empty())
246b3911cdfSRiver Riddle     lspDiag.relatedInformation = std::move(relatedDiags);
247b3911cdfSRiver Riddle 
248b3911cdfSRiver Riddle   return lspDiag;
249b3911cdfSRiver Riddle }
250b3911cdfSRiver Riddle 
251751c14fcSRiver Riddle //===----------------------------------------------------------------------===//
252751c14fcSRiver Riddle // MLIRDocument
253751c14fcSRiver Riddle //===----------------------------------------------------------------------===//
254751c14fcSRiver Riddle 
255751c14fcSRiver Riddle namespace {
256751c14fcSRiver Riddle /// This class represents all of the information pertaining to a specific MLIR
257751c14fcSRiver Riddle /// document.
258751c14fcSRiver Riddle struct MLIRDocument {
2590bd297fcSRiver Riddle   MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
2600bd297fcSRiver Riddle                StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
2618cbbc5d0SRiver Riddle   MLIRDocument(const MLIRDocument &) = delete;
2628cbbc5d0SRiver Riddle   MLIRDocument &operator=(const MLIRDocument &) = delete;
263751c14fcSRiver Riddle 
2645c84195bSRiver Riddle   //===--------------------------------------------------------------------===//
2655c84195bSRiver Riddle   // Definitions and References
2665c84195bSRiver Riddle   //===--------------------------------------------------------------------===//
2675c84195bSRiver Riddle 
268751c14fcSRiver Riddle   void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
269751c14fcSRiver Riddle                       std::vector<lsp::Location> &locations);
270751c14fcSRiver Riddle   void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
271751c14fcSRiver Riddle                         std::vector<lsp::Location> &references);
272751c14fcSRiver Riddle 
2735c84195bSRiver Riddle   //===--------------------------------------------------------------------===//
2745c84195bSRiver Riddle   // Hover
2755c84195bSRiver Riddle   //===--------------------------------------------------------------------===//
2765c84195bSRiver Riddle 
2771da3a795SFangrui Song   std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
2785c84195bSRiver Riddle                                       const lsp::Position &hoverPos);
2791da3a795SFangrui Song   std::optional<lsp::Hover>
2806842ec42SRiver Riddle   buildHoverForOperation(SMRange hoverRange,
2814c3adea7SRiver Riddle                          const AsmParserState::OperationDefinition &op);
2827bc52733SRiver Riddle   lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
2837bc52733SRiver Riddle                                           unsigned resultStart,
2847bc52733SRiver Riddle                                           unsigned resultEnd, SMLoc posLoc);
2856842ec42SRiver Riddle   lsp::Hover buildHoverForBlock(SMRange hoverRange,
2865c84195bSRiver Riddle                                 const AsmParserState::BlockDefinition &block);
2875c84195bSRiver Riddle   lsp::Hover
2886842ec42SRiver Riddle   buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
2895c84195bSRiver Riddle                              const AsmParserState::BlockDefinition &block);
2905c84195bSRiver Riddle 
2913bcc63e3SMogball   lsp::Hover buildHoverForAttributeAlias(
2923bcc63e3SMogball       SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
2933bcc63e3SMogball   lsp::Hover
2943bcc63e3SMogball   buildHoverForTypeAlias(SMRange hoverRange,
2953bcc63e3SMogball                          const AsmParserState::TypeAliasDefinition &type);
2963bcc63e3SMogball 
297ff81a2c9SRiver Riddle   //===--------------------------------------------------------------------===//
298ff81a2c9SRiver Riddle   // Document Symbols
299ff81a2c9SRiver Riddle   //===--------------------------------------------------------------------===//
300ff81a2c9SRiver Riddle 
301ff81a2c9SRiver Riddle   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302ff81a2c9SRiver Riddle   void findDocumentSymbols(Operation *op,
303ff81a2c9SRiver Riddle                            std::vector<lsp::DocumentSymbol> &symbols);
304ff81a2c9SRiver Riddle 
305ff81a2c9SRiver Riddle   //===--------------------------------------------------------------------===//
306ed2fb173SRiver Riddle   // Code Completion
307ed2fb173SRiver Riddle   //===--------------------------------------------------------------------===//
308ed2fb173SRiver Riddle 
309ed2fb173SRiver Riddle   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
310ed2fb173SRiver Riddle                                         const lsp::Position &completePos,
311ed2fb173SRiver Riddle                                         const DialectRegistry &registry);
312ed2fb173SRiver Riddle 
313ed2fb173SRiver Riddle   //===--------------------------------------------------------------------===//
314ed344c88SRiver Riddle   // Code Action
315ed344c88SRiver Riddle   //===--------------------------------------------------------------------===//
316ed344c88SRiver Riddle 
317ed344c88SRiver Riddle   void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
318ed344c88SRiver Riddle                                   lsp::Position &pos, StringRef severity,
319ed344c88SRiver Riddle                                   StringRef message,
320ed344c88SRiver Riddle                                   std::vector<lsp::TextEdit> &edits);
321ed344c88SRiver Riddle 
322ed344c88SRiver Riddle   //===--------------------------------------------------------------------===//
323f7b8a70eSRiver Riddle   // Bytecode
324f7b8a70eSRiver Riddle   //===--------------------------------------------------------------------===//
325f7b8a70eSRiver Riddle 
326f7b8a70eSRiver Riddle   llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
327f7b8a70eSRiver Riddle 
328f7b8a70eSRiver Riddle   //===--------------------------------------------------------------------===//
329ff81a2c9SRiver Riddle   // Fields
330ff81a2c9SRiver Riddle   //===--------------------------------------------------------------------===//
331ff81a2c9SRiver Riddle 
332751c14fcSRiver Riddle   /// The high level parser state used to find definitions and references within
333751c14fcSRiver Riddle   /// the source file.
334751c14fcSRiver Riddle   AsmParserState asmState;
335751c14fcSRiver Riddle 
336751c14fcSRiver Riddle   /// The container for the IR parsed from the input file.
337751c14fcSRiver Riddle   Block parsedIR;
338751c14fcSRiver Riddle 
33934300ee3SRiver Riddle   /// A collection of external resources, which we want to propagate up to the
34034300ee3SRiver Riddle   /// user.
34134300ee3SRiver Riddle   FallbackAsmResourceMap fallbackResourceMap;
34234300ee3SRiver Riddle 
343751c14fcSRiver Riddle   /// The source manager containing the contents of the input file.
344751c14fcSRiver Riddle   llvm::SourceMgr sourceMgr;
345751c14fcSRiver Riddle };
346751c14fcSRiver Riddle } // namespace
347751c14fcSRiver Riddle 
MLIRDocument(MLIRContext & context,const lsp::URIForFile & uri,StringRef contents,std::vector<lsp::Diagnostic> & diagnostics)3480bd297fcSRiver Riddle MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
3490bd297fcSRiver Riddle                            StringRef contents,
3500bd297fcSRiver Riddle                            std::vector<lsp::Diagnostic> &diagnostics) {
351751c14fcSRiver Riddle   ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
352644f722bSJacques Pienaar     diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri));
353751c14fcSRiver Riddle   });
354751c14fcSRiver Riddle 
355751c14fcSRiver Riddle   // Try to parsed the given IR string.
356751c14fcSRiver Riddle   auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
357751c14fcSRiver Riddle   if (!memBuffer) {
358751c14fcSRiver Riddle     lsp::Logger::error("Failed to create memory buffer for file", uri.file());
359751c14fcSRiver Riddle     return;
360751c14fcSRiver Riddle   }
361751c14fcSRiver Riddle 
3621ae60e04SRiver Riddle   ParserConfig config(&context, /*verifyAfterParse=*/true,
3631ae60e04SRiver Riddle                       &fallbackResourceMap);
3646842ec42SRiver Riddle   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
36534300ee3SRiver Riddle   if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
366b3911cdfSRiver Riddle     // If parsing failed, clear out any of the current state.
367b3911cdfSRiver Riddle     parsedIR.clear();
368b3911cdfSRiver Riddle     asmState = AsmParserState();
36934300ee3SRiver Riddle     fallbackResourceMap = FallbackAsmResourceMap();
370751c14fcSRiver Riddle     return;
371751c14fcSRiver Riddle   }
372b3911cdfSRiver Riddle }
373751c14fcSRiver Riddle 
3745c84195bSRiver Riddle //===----------------------------------------------------------------------===//
3755c84195bSRiver Riddle // MLIRDocument: Definitions and References
3765c84195bSRiver Riddle //===----------------------------------------------------------------------===//
3775c84195bSRiver Riddle 
getLocationsOf(const lsp::URIForFile & uri,const lsp::Position & defPos,std::vector<lsp::Location> & locations)378751c14fcSRiver Riddle void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
379751c14fcSRiver Riddle                                   const lsp::Position &defPos,
380751c14fcSRiver Riddle                                   std::vector<lsp::Location> &locations) {
38152b34df9SRiver Riddle   SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
382751c14fcSRiver Riddle 
383751c14fcSRiver Riddle   // Functor used to check if an SM definition contains the position.
384751c14fcSRiver Riddle   auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
385751c14fcSRiver Riddle     if (!isDefOrUse(def, posLoc))
386751c14fcSRiver Riddle       return false;
38752b34df9SRiver Riddle     locations.emplace_back(uri, sourceMgr, def.loc);
388751c14fcSRiver Riddle     return true;
389751c14fcSRiver Riddle   };
390751c14fcSRiver Riddle 
391751c14fcSRiver Riddle   // Check all definitions related to operations.
392751c14fcSRiver Riddle   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
393751c14fcSRiver Riddle     if (contains(op.loc, posLoc))
394751c14fcSRiver Riddle       return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
395751c14fcSRiver Riddle     for (const auto &result : op.resultGroups)
39647f76bb0SMogball       if (containsPosition(result.definition))
397751c14fcSRiver Riddle         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
398d6af89beSRiver Riddle     for (const auto &symUse : op.symbolUses) {
399d6af89beSRiver Riddle       if (contains(symUse, posLoc)) {
40052b34df9SRiver Riddle         locations.emplace_back(uri, sourceMgr, op.loc);
401d6af89beSRiver Riddle         return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
402d6af89beSRiver Riddle       }
403d6af89beSRiver Riddle     }
404751c14fcSRiver Riddle   }
405751c14fcSRiver Riddle 
406751c14fcSRiver Riddle   // Check all definitions related to blocks.
407751c14fcSRiver Riddle   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
408751c14fcSRiver Riddle     if (containsPosition(block.definition))
409751c14fcSRiver Riddle       return;
410751c14fcSRiver Riddle     for (const AsmParserState::SMDefinition &arg : block.arguments)
411751c14fcSRiver Riddle       if (containsPosition(arg))
412751c14fcSRiver Riddle         return;
413751c14fcSRiver Riddle   }
4143bcc63e3SMogball 
4153bcc63e3SMogball   // Check all alias definitions.
4163bcc63e3SMogball   for (const AsmParserState::AttributeAliasDefinition &attr :
4173bcc63e3SMogball        asmState.getAttributeAliasDefs()) {
4183bcc63e3SMogball     if (containsPosition(attr.definition))
4193bcc63e3SMogball       return;
4203bcc63e3SMogball   }
4213bcc63e3SMogball   for (const AsmParserState::TypeAliasDefinition &type :
4223bcc63e3SMogball        asmState.getTypeAliasDefs()) {
4233bcc63e3SMogball     if (containsPosition(type.definition))
4243bcc63e3SMogball       return;
4253bcc63e3SMogball   }
426751c14fcSRiver Riddle }
427751c14fcSRiver Riddle 
findReferencesOf(const lsp::URIForFile & uri,const lsp::Position & pos,std::vector<lsp::Location> & references)428751c14fcSRiver Riddle void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
429751c14fcSRiver Riddle                                     const lsp::Position &pos,
430751c14fcSRiver Riddle                                     std::vector<lsp::Location> &references) {
431751c14fcSRiver Riddle   // Functor used to append all of the definitions/uses of the given SM
432751c14fcSRiver Riddle   // definition to the reference list.
433751c14fcSRiver Riddle   auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
43452b34df9SRiver Riddle     references.emplace_back(uri, sourceMgr, def.loc);
4356842ec42SRiver Riddle     for (const SMRange &use : def.uses)
43652b34df9SRiver Riddle       references.emplace_back(uri, sourceMgr, use);
437751c14fcSRiver Riddle   };
438751c14fcSRiver Riddle 
43952b34df9SRiver Riddle   SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
440751c14fcSRiver Riddle 
441751c14fcSRiver Riddle   // Check all definitions related to operations.
442751c14fcSRiver Riddle   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
443751c14fcSRiver Riddle     if (contains(op.loc, posLoc)) {
444751c14fcSRiver Riddle       for (const auto &result : op.resultGroups)
44547f76bb0SMogball         appendSMDef(result.definition);
446d6af89beSRiver Riddle       for (const auto &symUse : op.symbolUses)
447d6af89beSRiver Riddle         if (contains(symUse, posLoc))
44852b34df9SRiver Riddle           references.emplace_back(uri, sourceMgr, symUse);
449751c14fcSRiver Riddle       return;
450751c14fcSRiver Riddle     }
451751c14fcSRiver Riddle     for (const auto &result : op.resultGroups)
45247f76bb0SMogball       if (isDefOrUse(result.definition, posLoc))
45347f76bb0SMogball         return appendSMDef(result.definition);
454d6af89beSRiver Riddle     for (const auto &symUse : op.symbolUses) {
455d6af89beSRiver Riddle       if (!contains(symUse, posLoc))
456d6af89beSRiver Riddle         continue;
457d6af89beSRiver Riddle       for (const auto &symUse : op.symbolUses)
45852b34df9SRiver Riddle         references.emplace_back(uri, sourceMgr, symUse);
459d6af89beSRiver Riddle       return;
460d6af89beSRiver Riddle     }
461751c14fcSRiver Riddle   }
462751c14fcSRiver Riddle 
463751c14fcSRiver Riddle   // Check all definitions related to blocks.
464751c14fcSRiver Riddle   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
465751c14fcSRiver Riddle     if (isDefOrUse(block.definition, posLoc))
466751c14fcSRiver Riddle       return appendSMDef(block.definition);
467751c14fcSRiver Riddle 
468751c14fcSRiver Riddle     for (const AsmParserState::SMDefinition &arg : block.arguments)
469751c14fcSRiver Riddle       if (isDefOrUse(arg, posLoc))
470751c14fcSRiver Riddle         return appendSMDef(arg);
471751c14fcSRiver Riddle   }
4723bcc63e3SMogball 
4733bcc63e3SMogball   // Check all alias definitions.
4743bcc63e3SMogball   for (const AsmParserState::AttributeAliasDefinition &attr :
4753bcc63e3SMogball        asmState.getAttributeAliasDefs()) {
4763bcc63e3SMogball     if (isDefOrUse(attr.definition, posLoc))
4773bcc63e3SMogball       return appendSMDef(attr.definition);
4783bcc63e3SMogball   }
4793bcc63e3SMogball   for (const AsmParserState::TypeAliasDefinition &type :
4803bcc63e3SMogball        asmState.getTypeAliasDefs()) {
4813bcc63e3SMogball     if (isDefOrUse(type.definition, posLoc))
4823bcc63e3SMogball       return appendSMDef(type.definition);
4833bcc63e3SMogball   }
484751c14fcSRiver Riddle }
485751c14fcSRiver Riddle 
486751c14fcSRiver Riddle //===----------------------------------------------------------------------===//
4875c84195bSRiver Riddle // MLIRDocument: Hover
4885c84195bSRiver Riddle //===----------------------------------------------------------------------===//
4895c84195bSRiver Riddle 
4901da3a795SFangrui Song std::optional<lsp::Hover>
findHover(const lsp::URIForFile & uri,const lsp::Position & hoverPos)4911da3a795SFangrui Song MLIRDocument::findHover(const lsp::URIForFile &uri,
4925c84195bSRiver Riddle                         const lsp::Position &hoverPos) {
49352b34df9SRiver Riddle   SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
4946842ec42SRiver Riddle   SMRange hoverRange;
4955c84195bSRiver Riddle 
4965c84195bSRiver Riddle   // Check for Hovers on operations and results.
4975c84195bSRiver Riddle   for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
4985c84195bSRiver Riddle     // Check if the position points at this operation.
4995c84195bSRiver Riddle     if (contains(op.loc, posLoc))
5004c3adea7SRiver Riddle       return buildHoverForOperation(op.loc, op);
5014c3adea7SRiver Riddle 
5024c3adea7SRiver Riddle     // Check if the position points at the symbol name.
5034c3adea7SRiver Riddle     for (auto &use : op.symbolUses)
5044c3adea7SRiver Riddle       if (contains(use, posLoc))
5054c3adea7SRiver Riddle         return buildHoverForOperation(use, op);
5065c84195bSRiver Riddle 
5075c84195bSRiver Riddle     // Check if the position points at a result group.
5085c84195bSRiver Riddle     for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
5095c84195bSRiver Riddle       const auto &result = op.resultGroups[i];
51047f76bb0SMogball       if (!isDefOrUse(result.definition, posLoc, &hoverRange))
5115c84195bSRiver Riddle         continue;
5125c84195bSRiver Riddle 
5135c84195bSRiver Riddle       // Get the range of results covered by the over position.
51447f76bb0SMogball       unsigned resultStart = result.startIndex;
51547f76bb0SMogball       unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
51647f76bb0SMogball                                         : op.resultGroups[i + 1].startIndex;
5175c84195bSRiver Riddle       return buildHoverForOperationResult(hoverRange, op.op, resultStart,
5185c84195bSRiver Riddle                                           resultEnd, posLoc);
5195c84195bSRiver Riddle     }
5205c84195bSRiver Riddle   }
5215c84195bSRiver Riddle 
5225c84195bSRiver Riddle   // Check to see if the hover is over a block argument.
5235c84195bSRiver Riddle   for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
5245c84195bSRiver Riddle     if (isDefOrUse(block.definition, posLoc, &hoverRange))
5255c84195bSRiver Riddle       return buildHoverForBlock(hoverRange, block);
5265c84195bSRiver Riddle 
5275c84195bSRiver Riddle     for (const auto &arg : llvm::enumerate(block.arguments)) {
5285c84195bSRiver Riddle       if (!isDefOrUse(arg.value(), posLoc, &hoverRange))
5295c84195bSRiver Riddle         continue;
5305c84195bSRiver Riddle 
5315c84195bSRiver Riddle       return buildHoverForBlockArgument(
5325c84195bSRiver Riddle           hoverRange, block.block->getArgument(arg.index()), block);
5335c84195bSRiver Riddle     }
5345c84195bSRiver Riddle   }
5353bcc63e3SMogball 
5363bcc63e3SMogball   // Check to see if the hover is over an alias.
5373bcc63e3SMogball   for (const AsmParserState::AttributeAliasDefinition &attr :
5383bcc63e3SMogball        asmState.getAttributeAliasDefs()) {
5393bcc63e3SMogball     if (isDefOrUse(attr.definition, posLoc, &hoverRange))
5403bcc63e3SMogball       return buildHoverForAttributeAlias(hoverRange, attr);
5413bcc63e3SMogball   }
5423bcc63e3SMogball   for (const AsmParserState::TypeAliasDefinition &type :
5433bcc63e3SMogball        asmState.getTypeAliasDefs()) {
5443bcc63e3SMogball     if (isDefOrUse(type.definition, posLoc, &hoverRange))
5453bcc63e3SMogball       return buildHoverForTypeAlias(hoverRange, type);
5463bcc63e3SMogball   }
5473bcc63e3SMogball 
5481a36588eSKazu Hirata   return std::nullopt;
5495c84195bSRiver Riddle }
5505c84195bSRiver Riddle 
buildHoverForOperation(SMRange hoverRange,const AsmParserState::OperationDefinition & op)5511da3a795SFangrui Song std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
5526842ec42SRiver Riddle     SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
55352b34df9SRiver Riddle   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
5545c84195bSRiver Riddle   llvm::raw_string_ostream os(hover.contents.value);
5555c84195bSRiver Riddle 
556f492c359SRiver Riddle   // Add the operation name to the hover.
557f492c359SRiver Riddle   os << "\"" << op.op->getName() << "\"";
558f492c359SRiver Riddle   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
559f492c359SRiver Riddle     os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
560f492c359SRiver Riddle   os << "\n\n";
561f492c359SRiver Riddle 
562f492c359SRiver Riddle   os << "Generic Form:\n\n```mlir\n";
563f492c359SRiver Riddle 
5645736a8a2SMehdi Amini   op.op->print(os, OpPrintingFlags()
5655736a8a2SMehdi Amini                        .printGenericOpForm()
5665736a8a2SMehdi Amini                        .elideLargeElementsAttrs()
5675736a8a2SMehdi Amini                        .skipRegions());
5685c84195bSRiver Riddle   os << "\n```\n";
5695c84195bSRiver Riddle 
5705c84195bSRiver Riddle   return hover;
5715c84195bSRiver Riddle }
5725c84195bSRiver Riddle 
buildHoverForOperationResult(SMRange hoverRange,Operation * op,unsigned resultStart,unsigned resultEnd,SMLoc posLoc)5736842ec42SRiver Riddle lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
5745c84195bSRiver Riddle                                                       Operation *op,
5755c84195bSRiver Riddle                                                       unsigned resultStart,
5765c84195bSRiver Riddle                                                       unsigned resultEnd,
5776842ec42SRiver Riddle                                                       SMLoc posLoc) {
57852b34df9SRiver Riddle   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
5795c84195bSRiver Riddle   llvm::raw_string_ostream os(hover.contents.value);
5805c84195bSRiver Riddle 
5815c84195bSRiver Riddle   // Add the parent operation name to the hover.
5825c84195bSRiver Riddle   os << "Operation: \"" << op->getName() << "\"\n\n";
5835c84195bSRiver Riddle 
5845c84195bSRiver Riddle   // Check to see if the location points to a specific result within the
5855c84195bSRiver Riddle   // group.
5860a81ace0SKazu Hirata   if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(posLoc)) {
5875c84195bSRiver Riddle     if ((resultStart + *resultNumber) < resultEnd) {
5885c84195bSRiver Riddle       resultStart += *resultNumber;
5895c84195bSRiver Riddle       resultEnd = resultStart + 1;
5905c84195bSRiver Riddle     }
5915c84195bSRiver Riddle   }
5925c84195bSRiver Riddle 
5935c84195bSRiver Riddle   // Add the range of results and their types to the hover info.
5945c84195bSRiver Riddle   if ((resultStart + 1) == resultEnd) {
5955c84195bSRiver Riddle     os << "Result #" << resultStart << "\n\n"
5965c84195bSRiver Riddle        << "Type: `" << op->getResult(resultStart).getType() << "`\n\n";
5975c84195bSRiver Riddle   } else {
5985c84195bSRiver Riddle     os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
5995c84195bSRiver Riddle        << "Types: ";
6005c84195bSRiver Riddle     llvm::interleaveComma(
6015c84195bSRiver Riddle         op->getResults().slice(resultStart, resultEnd), os,
6025c84195bSRiver Riddle         [&](Value result) { os << "`" << result.getType() << "`"; });
6035c84195bSRiver Riddle   }
6045c84195bSRiver Riddle 
6055c84195bSRiver Riddle   return hover;
6065c84195bSRiver Riddle }
6075c84195bSRiver Riddle 
6085c84195bSRiver Riddle lsp::Hover
buildHoverForBlock(SMRange hoverRange,const AsmParserState::BlockDefinition & block)6096842ec42SRiver Riddle MLIRDocument::buildHoverForBlock(SMRange hoverRange,
6105c84195bSRiver Riddle                                  const AsmParserState::BlockDefinition &block) {
61152b34df9SRiver Riddle   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
6125c84195bSRiver Riddle   llvm::raw_string_ostream os(hover.contents.value);
6135c84195bSRiver Riddle 
6145c84195bSRiver Riddle   // Print the given block to the hover output stream.
6155c84195bSRiver Riddle   auto printBlockToHover = [&](Block *newBlock) {
6165c84195bSRiver Riddle     if (const auto *def = asmState.getBlockDef(newBlock))
6175c84195bSRiver Riddle       printDefBlockName(os, *def);
6185c84195bSRiver Riddle     else
6195c84195bSRiver Riddle       printDefBlockName(os, newBlock);
6205c84195bSRiver Riddle   };
6215c84195bSRiver Riddle 
6225c84195bSRiver Riddle   // Display the parent operation, block number, predecessors, and successors.
6235c84195bSRiver Riddle   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
6245c84195bSRiver Riddle      << "Block #" << getBlockNumber(block.block) << "\n\n";
6255c84195bSRiver Riddle   if (!block.block->hasNoPredecessors()) {
6265c84195bSRiver Riddle     os << "Predecessors: ";
6275c84195bSRiver Riddle     llvm::interleaveComma(block.block->getPredecessors(), os,
6285c84195bSRiver Riddle                           printBlockToHover);
6295c84195bSRiver Riddle     os << "\n\n";
6305c84195bSRiver Riddle   }
6315c84195bSRiver Riddle   if (!block.block->hasNoSuccessors()) {
6325c84195bSRiver Riddle     os << "Successors: ";
6335c84195bSRiver Riddle     llvm::interleaveComma(block.block->getSuccessors(), os, printBlockToHover);
6345c84195bSRiver Riddle     os << "\n\n";
6355c84195bSRiver Riddle   }
6365c84195bSRiver Riddle 
6375c84195bSRiver Riddle   return hover;
6385c84195bSRiver Riddle }
6395c84195bSRiver Riddle 
buildHoverForBlockArgument(SMRange hoverRange,BlockArgument arg,const AsmParserState::BlockDefinition & block)6405c84195bSRiver Riddle lsp::Hover MLIRDocument::buildHoverForBlockArgument(
6416842ec42SRiver Riddle     SMRange hoverRange, BlockArgument arg,
6425c84195bSRiver Riddle     const AsmParserState::BlockDefinition &block) {
64352b34df9SRiver Riddle   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
6445c84195bSRiver Riddle   llvm::raw_string_ostream os(hover.contents.value);
6455c84195bSRiver Riddle 
6465c84195bSRiver Riddle   // Display the parent operation, block, the argument number, and the type.
6475c84195bSRiver Riddle   os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
6485c84195bSRiver Riddle      << "Block: ";
6495c84195bSRiver Riddle   printDefBlockName(os, block);
6505c84195bSRiver Riddle   os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
6515c84195bSRiver Riddle      << "Type: `" << arg.getType() << "`\n\n";
6525c84195bSRiver Riddle 
6535c84195bSRiver Riddle   return hover;
6545c84195bSRiver Riddle }
6555c84195bSRiver Riddle 
buildHoverForAttributeAlias(SMRange hoverRange,const AsmParserState::AttributeAliasDefinition & attr)6563bcc63e3SMogball lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
6573bcc63e3SMogball     SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
6583bcc63e3SMogball   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
6593bcc63e3SMogball   llvm::raw_string_ostream os(hover.contents.value);
6603bcc63e3SMogball 
6613bcc63e3SMogball   os << "Attribute Alias: \"" << attr.name << "\n\n";
6623bcc63e3SMogball   os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
6633bcc63e3SMogball 
6643bcc63e3SMogball   return hover;
6653bcc63e3SMogball }
6663bcc63e3SMogball 
buildHoverForTypeAlias(SMRange hoverRange,const AsmParserState::TypeAliasDefinition & type)6673bcc63e3SMogball lsp::Hover MLIRDocument::buildHoverForTypeAlias(
6683bcc63e3SMogball     SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
6693bcc63e3SMogball   lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
6703bcc63e3SMogball   llvm::raw_string_ostream os(hover.contents.value);
6713bcc63e3SMogball 
6723bcc63e3SMogball   os << "Type Alias: \"" << type.name << "\n\n";
6733bcc63e3SMogball   os << "Value: ```mlir\n" << type.value << "\n```\n\n";
6743bcc63e3SMogball 
6753bcc63e3SMogball   return hover;
6763bcc63e3SMogball }
6773bcc63e3SMogball 
6785c84195bSRiver Riddle //===----------------------------------------------------------------------===//
679ff81a2c9SRiver Riddle // MLIRDocument: Document Symbols
680ff81a2c9SRiver Riddle //===----------------------------------------------------------------------===//
681ff81a2c9SRiver Riddle 
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)682ff81a2c9SRiver Riddle void MLIRDocument::findDocumentSymbols(
683ff81a2c9SRiver Riddle     std::vector<lsp::DocumentSymbol> &symbols) {
684ff81a2c9SRiver Riddle   for (Operation &op : parsedIR)
685ff81a2c9SRiver Riddle     findDocumentSymbols(&op, symbols);
686ff81a2c9SRiver Riddle }
687ff81a2c9SRiver Riddle 
findDocumentSymbols(Operation * op,std::vector<lsp::DocumentSymbol> & symbols)688ff81a2c9SRiver Riddle void MLIRDocument::findDocumentSymbols(
689ff81a2c9SRiver Riddle     Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
690ff81a2c9SRiver Riddle   std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
691ff81a2c9SRiver Riddle 
692ff81a2c9SRiver Riddle   // Check for the source information of this operation.
693ff81a2c9SRiver Riddle   if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
694ff81a2c9SRiver Riddle     // If this operation defines a symbol, record it.
695ff81a2c9SRiver Riddle     if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
696ff81a2c9SRiver Riddle       symbols.emplace_back(symbol.getName(),
6977ceffae1SRiver Riddle                            isa<FunctionOpInterface>(op)
698ff81a2c9SRiver Riddle                                ? lsp::SymbolKind::Function
699ff81a2c9SRiver Riddle                                : lsp::SymbolKind::Class,
70052b34df9SRiver Riddle                            lsp::Range(sourceMgr, def->scopeLoc),
70152b34df9SRiver Riddle                            lsp::Range(sourceMgr, def->loc));
702ff81a2c9SRiver Riddle       childSymbols = &symbols.back().children;
703ff81a2c9SRiver Riddle 
704ff81a2c9SRiver Riddle     } else if (op->hasTrait<OpTrait::SymbolTable>()) {
705ff81a2c9SRiver Riddle       // Otherwise, if this is a symbol table push an anonymous document symbol.
706ff81a2c9SRiver Riddle       symbols.emplace_back("<" + op->getName().getStringRef() + ">",
707ff81a2c9SRiver Riddle                            lsp::SymbolKind::Namespace,
70852b34df9SRiver Riddle                            lsp::Range(sourceMgr, def->scopeLoc),
70952b34df9SRiver Riddle                            lsp::Range(sourceMgr, def->loc));
710ff81a2c9SRiver Riddle       childSymbols = &symbols.back().children;
711ff81a2c9SRiver Riddle     }
712ff81a2c9SRiver Riddle   }
713ff81a2c9SRiver Riddle 
714ff81a2c9SRiver Riddle   // Recurse into the regions of this operation.
715ff81a2c9SRiver Riddle   if (!op->getNumRegions())
716ff81a2c9SRiver Riddle     return;
717ff81a2c9SRiver Riddle   for (Region &region : op->getRegions())
718ff81a2c9SRiver Riddle     for (Operation &childOp : region.getOps())
719ff81a2c9SRiver Riddle       findDocumentSymbols(&childOp, *childSymbols);
720ff81a2c9SRiver Riddle }
721ff81a2c9SRiver Riddle 
722ff81a2c9SRiver Riddle //===----------------------------------------------------------------------===//
723ed2fb173SRiver Riddle // MLIRDocument: Code Completion
724ed2fb173SRiver Riddle //===----------------------------------------------------------------------===//
725ed2fb173SRiver Riddle 
726ed2fb173SRiver Riddle namespace {
727ed2fb173SRiver Riddle class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
728ed2fb173SRiver Riddle public:
LSPCodeCompleteContext(SMLoc completeLoc,lsp::CompletionList & completionList,MLIRContext * ctx)729ed2fb173SRiver Riddle   LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
730ed2fb173SRiver Riddle                          MLIRContext *ctx)
731ed2fb173SRiver Riddle       : AsmParserCodeCompleteContext(completeLoc),
732ed2fb173SRiver Riddle         completionList(completionList), ctx(ctx) {}
733ed2fb173SRiver Riddle 
734fe4f512bSRiver Riddle   /// Signal code completion for a dialect name, with an optional prefix.
completeDialectName(StringRef prefix)735fe4f512bSRiver Riddle   void completeDialectName(StringRef prefix) final {
736ed2fb173SRiver Riddle     for (StringRef dialect : ctx->getAvailableDialects()) {
737fe4f512bSRiver Riddle       lsp::CompletionItem item(prefix + dialect,
738fe4f512bSRiver Riddle                                lsp::CompletionItemKind::Module,
739fe4f512bSRiver Riddle                                /*sortText=*/"3");
740ed2fb173SRiver Riddle       item.detail = "dialect";
741ed2fb173SRiver Riddle       completionList.items.emplace_back(item);
742ed2fb173SRiver Riddle     }
743ed2fb173SRiver Riddle   }
744fe4f512bSRiver Riddle   using AsmParserCodeCompleteContext::completeDialectName;
745ed2fb173SRiver Riddle 
746ed2fb173SRiver Riddle   /// Signal code completion for an operation name within the given dialect.
completeOperationName(StringRef dialectName)747ed2fb173SRiver Riddle   void completeOperationName(StringRef dialectName) final {
748ed2fb173SRiver Riddle     Dialect *dialect = ctx->getOrLoadDialect(dialectName);
749ed2fb173SRiver Riddle     if (!dialect)
750ed2fb173SRiver Riddle       return;
751ed2fb173SRiver Riddle 
752ed2fb173SRiver Riddle     for (const auto &op : ctx->getRegisteredOperations()) {
753ed2fb173SRiver Riddle       if (&op.getDialect() != dialect)
754ed2fb173SRiver Riddle         continue;
755ed2fb173SRiver Riddle 
756ed2fb173SRiver Riddle       lsp::CompletionItem item(
757ed2fb173SRiver Riddle           op.getStringRef().drop_front(dialectName.size() + 1),
758fe4f512bSRiver Riddle           lsp::CompletionItemKind::Field,
759fe4f512bSRiver Riddle           /*sortText=*/"1");
760ed2fb173SRiver Riddle       item.detail = "operation";
761ed2fb173SRiver Riddle       completionList.items.emplace_back(item);
762ed2fb173SRiver Riddle     }
763ed2fb173SRiver Riddle   }
764ed2fb173SRiver Riddle 
765ed2fb173SRiver Riddle   /// Append the given SSA value as a code completion result for SSA value
766ed2fb173SRiver Riddle   /// completions.
appendSSAValueCompletion(StringRef name,std::string typeData)767ed2fb173SRiver Riddle   void appendSSAValueCompletion(StringRef name, std::string typeData) final {
768ed2fb173SRiver Riddle     // Check if we need to insert the `%` or not.
769ed2fb173SRiver Riddle     bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
770ed2fb173SRiver Riddle 
771ed2fb173SRiver Riddle     lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
772ed2fb173SRiver Riddle     if (stripPrefix)
773ed2fb173SRiver Riddle       item.insertText = name.drop_front(1).str();
774ed2fb173SRiver Riddle     item.detail = std::move(typeData);
775ed2fb173SRiver Riddle     completionList.items.emplace_back(item);
776ed2fb173SRiver Riddle   }
777ed2fb173SRiver Riddle 
778ed2fb173SRiver Riddle   /// Append the given block as a code completion result for block name
779ed2fb173SRiver Riddle   /// completions.
appendBlockCompletion(StringRef name)780ed2fb173SRiver Riddle   void appendBlockCompletion(StringRef name) final {
781ed2fb173SRiver Riddle     // Check if we need to insert the `^` or not.
782ed2fb173SRiver Riddle     bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
783ed2fb173SRiver Riddle 
784ed2fb173SRiver Riddle     lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
785ed2fb173SRiver Riddle     if (stripPrefix)
786ed2fb173SRiver Riddle       item.insertText = name.drop_front(1).str();
787ed2fb173SRiver Riddle     completionList.items.emplace_back(item);
788ed2fb173SRiver Riddle   }
789ed2fb173SRiver Riddle 
7902e41ea32SRiver Riddle   /// Signal a completion for the given expected token.
completeExpectedTokens(ArrayRef<StringRef> tokens,bool optional)7912e41ea32SRiver Riddle   void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
7922e41ea32SRiver Riddle     for (StringRef token : tokens) {
793fe4f512bSRiver Riddle       lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
794fe4f512bSRiver Riddle                                /*sortText=*/"0");
7952e41ea32SRiver Riddle       item.detail = optional ? "optional" : "";
7962e41ea32SRiver Riddle       completionList.items.emplace_back(item);
7972e41ea32SRiver Riddle     }
7982e41ea32SRiver Riddle   }
7992e41ea32SRiver Riddle 
800fe4f512bSRiver Riddle   /// Signal a completion for an attribute.
completeAttribute(const llvm::StringMap<Attribute> & aliases)801fe4f512bSRiver Riddle   void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
802995ab929SRiver Riddle     appendSimpleCompletions({"affine_set", "affine_map", "dense",
80340abd7eaSRiver Riddle                              "dense_resource", "false", "loc", "sparse", "true",
80440abd7eaSRiver Riddle                              "unit"},
805fe4f512bSRiver Riddle                             lsp::CompletionItemKind::Field,
806fe4f512bSRiver Riddle                             /*sortText=*/"1");
807fe4f512bSRiver Riddle 
808fe4f512bSRiver Riddle     completeDialectName("#");
809fe4f512bSRiver Riddle     completeAliases(aliases, "#");
810fe4f512bSRiver Riddle   }
completeDialectAttributeOrAlias(const llvm::StringMap<Attribute> & aliases)811fe4f512bSRiver Riddle   void completeDialectAttributeOrAlias(
812fe4f512bSRiver Riddle       const llvm::StringMap<Attribute> &aliases) override {
813fe4f512bSRiver Riddle     completeDialectName();
814fe4f512bSRiver Riddle     completeAliases(aliases);
815fe4f512bSRiver Riddle   }
816fe4f512bSRiver Riddle 
817fe4f512bSRiver Riddle   /// Signal a completion for a type.
completeType(const llvm::StringMap<Type> & aliases)818fe4f512bSRiver Riddle   void completeType(const llvm::StringMap<Type> &aliases) override {
81934b3f066SRiver Riddle     // Handle the various builtin types.
820fe4f512bSRiver Riddle     appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
821fe4f512bSRiver Riddle                              "bf16", "f16", "f32", "f64", "f80", "f128",
822fe4f512bSRiver Riddle                              "index", "none"},
823fe4f512bSRiver Riddle                             lsp::CompletionItemKind::Field,
824fe4f512bSRiver Riddle                             /*sortText=*/"1");
825fe4f512bSRiver Riddle 
82634b3f066SRiver Riddle     // Handle the builtin integer types.
82734b3f066SRiver Riddle     for (StringRef type : {"i", "si", "ui"}) {
82834b3f066SRiver Riddle       lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field,
82934b3f066SRiver Riddle                                /*sortText=*/"1");
83034b3f066SRiver Riddle       item.insertText = type.str();
83134b3f066SRiver Riddle       completionList.items.emplace_back(item);
83234b3f066SRiver Riddle     }
83334b3f066SRiver Riddle 
83434b3f066SRiver Riddle     // Insert completions for dialect types and aliases.
835fe4f512bSRiver Riddle     completeDialectName("!");
836fe4f512bSRiver Riddle     completeAliases(aliases, "!");
837fe4f512bSRiver Riddle   }
838fe4f512bSRiver Riddle   void
completeDialectTypeOrAlias(const llvm::StringMap<Type> & aliases)839fe4f512bSRiver Riddle   completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
840fe4f512bSRiver Riddle     completeDialectName();
841fe4f512bSRiver Riddle     completeAliases(aliases);
842fe4f512bSRiver Riddle   }
843fe4f512bSRiver Riddle 
844fe4f512bSRiver Riddle   /// Add completion results for the given set of aliases.
845fe4f512bSRiver Riddle   template <typename T>
completeAliases(const llvm::StringMap<T> & aliases,StringRef prefix="")846fe4f512bSRiver Riddle   void completeAliases(const llvm::StringMap<T> &aliases,
847fe4f512bSRiver Riddle                        StringRef prefix = "") {
848fe4f512bSRiver Riddle     for (const auto &alias : aliases) {
849fe4f512bSRiver Riddle       lsp::CompletionItem item(prefix + alias.getKey(),
850fe4f512bSRiver Riddle                                lsp::CompletionItemKind::Field,
851fe4f512bSRiver Riddle                                /*sortText=*/"2");
852fe4f512bSRiver Riddle       llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
853fe4f512bSRiver Riddle       completionList.items.emplace_back(item);
854fe4f512bSRiver Riddle     }
855fe4f512bSRiver Riddle   }
856fe4f512bSRiver Riddle 
857fe4f512bSRiver Riddle   /// Add a set of simple completions that all have the same kind.
appendSimpleCompletions(ArrayRef<StringRef> completions,lsp::CompletionItemKind kind,StringRef sortText="")858fe4f512bSRiver Riddle   void appendSimpleCompletions(ArrayRef<StringRef> completions,
859fe4f512bSRiver Riddle                                lsp::CompletionItemKind kind,
860fe4f512bSRiver Riddle                                StringRef sortText = "") {
861fe4f512bSRiver Riddle     for (StringRef completion : completions)
862fe4f512bSRiver Riddle       completionList.items.emplace_back(completion, kind, sortText);
863fe4f512bSRiver Riddle   }
864fe4f512bSRiver Riddle 
865ed2fb173SRiver Riddle private:
866ed2fb173SRiver Riddle   lsp::CompletionList &completionList;
867ed2fb173SRiver Riddle   MLIRContext *ctx;
868ed2fb173SRiver Riddle };
869ed2fb173SRiver Riddle } // namespace
870ed2fb173SRiver Riddle 
871ed2fb173SRiver Riddle lsp::CompletionList
getCodeCompletion(const lsp::URIForFile & uri,const lsp::Position & completePos,const DialectRegistry & registry)872ed2fb173SRiver Riddle MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
873ed2fb173SRiver Riddle                                 const lsp::Position &completePos,
874ed2fb173SRiver Riddle                                 const DialectRegistry &registry) {
875ed2fb173SRiver Riddle   SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
876ed2fb173SRiver Riddle   if (!posLoc.isValid())
877ed2fb173SRiver Riddle     return lsp::CompletionList();
878ed2fb173SRiver Riddle 
879ed2fb173SRiver Riddle   // To perform code completion, we run another parse of the module with the
880ed2fb173SRiver Riddle   // code completion context provided.
881ed2fb173SRiver Riddle   MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
882ed2fb173SRiver Riddle   tmpContext.allowUnregisteredDialects();
883ed2fb173SRiver Riddle   lsp::CompletionList completionList;
884ed2fb173SRiver Riddle   LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
885ed2fb173SRiver Riddle                                             &tmpContext);
886ed2fb173SRiver Riddle 
887ed2fb173SRiver Riddle   Block tmpIR;
888ed2fb173SRiver Riddle   AsmParserState tmpState;
889c60b897dSRiver Riddle   (void)parseAsmSourceFile(sourceMgr, &tmpIR, &tmpContext, &tmpState,
890ed2fb173SRiver Riddle                            &lspCompleteContext);
891ed2fb173SRiver Riddle   return completionList;
892ed2fb173SRiver Riddle }
893ed2fb173SRiver Riddle 
894ed2fb173SRiver Riddle //===----------------------------------------------------------------------===//
895ed344c88SRiver Riddle // MLIRDocument: Code Action
896ed344c88SRiver Riddle //===----------------------------------------------------------------------===//
897ed344c88SRiver Riddle 
getCodeActionForDiagnostic(const lsp::URIForFile & uri,lsp::Position & pos,StringRef severity,StringRef message,std::vector<lsp::TextEdit> & edits)898ed344c88SRiver Riddle void MLIRDocument::getCodeActionForDiagnostic(
899ed344c88SRiver Riddle     const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
900ed344c88SRiver Riddle     StringRef message, std::vector<lsp::TextEdit> &edits) {
901ed344c88SRiver Riddle   // Ignore diagnostics that print the current operation. These are always
902ed344c88SRiver Riddle   // enabled for the language server, but not generally during normal
903ed344c88SRiver Riddle   // parsing/verification.
90488d319a2SKazu Hirata   if (message.starts_with("see current operation: "))
905ed344c88SRiver Riddle     return;
906ed344c88SRiver Riddle 
907ed344c88SRiver Riddle   // Get the start of the line containing the diagnostic.
908ed344c88SRiver Riddle   const auto &buffer = sourceMgr.getBufferInfo(sourceMgr.getMainFileID());
909ed344c88SRiver Riddle   const char *lineStart = buffer.getPointerForLineNumber(pos.line + 1);
910ed344c88SRiver Riddle   if (!lineStart)
911ed344c88SRiver Riddle     return;
912ed344c88SRiver Riddle   StringRef line(lineStart, pos.character);
913ed344c88SRiver Riddle 
914ed344c88SRiver Riddle   // Add a text edit for adding an expected-* diagnostic check for this
915ed344c88SRiver Riddle   // diagnostic.
916ed344c88SRiver Riddle   lsp::TextEdit edit;
917ed344c88SRiver Riddle   edit.range = lsp::Range(lsp::Position(pos.line, 0));
918ed344c88SRiver Riddle 
919ed344c88SRiver Riddle   // Use the indent of the current line for the expected-* diagnostic.
920*e411c88bSKazu Hirata   size_t indent = line.find_first_not_of(' ');
921ed344c88SRiver Riddle   if (indent == StringRef::npos)
922ed344c88SRiver Riddle     indent = line.size();
923ed344c88SRiver Riddle 
924ed344c88SRiver Riddle   edit.newText.append(indent, ' ');
925ed344c88SRiver Riddle   llvm::raw_string_ostream(edit.newText)
926ed344c88SRiver Riddle       << "// expected-" << severity << " @below {{" << message << "}}\n";
927ed344c88SRiver Riddle   edits.emplace_back(std::move(edit));
928ed344c88SRiver Riddle }
929ed344c88SRiver Riddle 
930ed344c88SRiver Riddle //===----------------------------------------------------------------------===//
931f7b8a70eSRiver Riddle // MLIRDocument: Bytecode
932f7b8a70eSRiver Riddle //===----------------------------------------------------------------------===//
933f7b8a70eSRiver Riddle 
934f7b8a70eSRiver Riddle llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertToBytecode()935f7b8a70eSRiver Riddle MLIRDocument::convertToBytecode() {
936f7b8a70eSRiver Riddle   // TODO: We currently require a single top-level operation, but this could
937f7b8a70eSRiver Riddle   // conceptually be relaxed.
938f7b8a70eSRiver Riddle   if (!llvm::hasSingleElement(parsedIR)) {
939f7b8a70eSRiver Riddle     if (parsedIR.empty()) {
940f7b8a70eSRiver Riddle       return llvm::make_error<lsp::LSPError>(
941f7b8a70eSRiver Riddle           "expected a single and valid top-level operation, please ensure "
942f7b8a70eSRiver Riddle           "there are no errors",
943f7b8a70eSRiver Riddle           lsp::ErrorCode::RequestFailed);
944f7b8a70eSRiver Riddle     }
945f7b8a70eSRiver Riddle     return llvm::make_error<lsp::LSPError>(
946f7b8a70eSRiver Riddle         "expected a single top-level operation", lsp::ErrorCode::RequestFailed);
947f7b8a70eSRiver Riddle   }
948f7b8a70eSRiver Riddle 
949f7b8a70eSRiver Riddle   lsp::MLIRConvertBytecodeResult result;
950f7b8a70eSRiver Riddle   {
95134300ee3SRiver Riddle     BytecodeWriterConfig writerConfig(fallbackResourceMap);
95234300ee3SRiver Riddle 
953f7b8a70eSRiver Riddle     std::string rawBytecodeBuffer;
954f7b8a70eSRiver Riddle     llvm::raw_string_ostream os(rawBytecodeBuffer);
9555c90e1ffSJacques Pienaar     // No desired bytecode version set, so no need to check for error.
9565c90e1ffSJacques Pienaar     (void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
957f7b8a70eSRiver Riddle     result.output = llvm::encodeBase64(rawBytecodeBuffer);
958f7b8a70eSRiver Riddle   }
959f7b8a70eSRiver Riddle   return result;
960f7b8a70eSRiver Riddle }
961f7b8a70eSRiver Riddle 
962f7b8a70eSRiver Riddle //===----------------------------------------------------------------------===//
9638cbbc5d0SRiver Riddle // MLIRTextFileChunk
9648cbbc5d0SRiver Riddle //===----------------------------------------------------------------------===//
9658cbbc5d0SRiver Riddle 
9668cbbc5d0SRiver Riddle namespace {
9678cbbc5d0SRiver Riddle /// This class represents a single chunk of an MLIR text file.
9688cbbc5d0SRiver Riddle struct MLIRTextFileChunk {
MLIRTextFileChunk__anon05235aca0c11::MLIRTextFileChunk9690bd297fcSRiver Riddle   MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
9700bd297fcSRiver Riddle                     const lsp::URIForFile &uri, StringRef contents,
9718cbbc5d0SRiver Riddle                     std::vector<lsp::Diagnostic> &diagnostics)
9720bd297fcSRiver Riddle       : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
9738cbbc5d0SRiver Riddle 
9748cbbc5d0SRiver Riddle   /// Adjust the line number of the given range to anchor at the beginning of
9758cbbc5d0SRiver Riddle   /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon05235aca0c11::MLIRTextFileChunk9768cbbc5d0SRiver Riddle   void adjustLocForChunkOffset(lsp::Range &range) {
9778cbbc5d0SRiver Riddle     adjustLocForChunkOffset(range.start);
9788cbbc5d0SRiver Riddle     adjustLocForChunkOffset(range.end);
9798cbbc5d0SRiver Riddle   }
9808cbbc5d0SRiver Riddle   /// Adjust the line number of the given position to anchor at the beginning of
9818cbbc5d0SRiver Riddle   /// the file, instead of the beginning of this chunk.
adjustLocForChunkOffset__anon05235aca0c11::MLIRTextFileChunk9828cbbc5d0SRiver Riddle   void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
9838cbbc5d0SRiver Riddle 
9848cbbc5d0SRiver Riddle   /// The line offset of this chunk from the beginning of the file.
9858cbbc5d0SRiver Riddle   uint64_t lineOffset;
9868cbbc5d0SRiver Riddle   /// The document referred to by this chunk.
9878cbbc5d0SRiver Riddle   MLIRDocument document;
9888cbbc5d0SRiver Riddle };
9898cbbc5d0SRiver Riddle } // namespace
9908cbbc5d0SRiver Riddle 
9918cbbc5d0SRiver Riddle //===----------------------------------------------------------------------===//
9928cbbc5d0SRiver Riddle // MLIRTextFile
9938cbbc5d0SRiver Riddle //===----------------------------------------------------------------------===//
9948cbbc5d0SRiver Riddle 
9958cbbc5d0SRiver Riddle namespace {
9968cbbc5d0SRiver Riddle /// This class represents a text file containing one or more MLIR documents.
9978cbbc5d0SRiver Riddle class MLIRTextFile {
9988cbbc5d0SRiver Riddle public:
9998cbbc5d0SRiver Riddle   MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
10008cbbc5d0SRiver Riddle                int64_t version, DialectRegistry &registry,
10018cbbc5d0SRiver Riddle                std::vector<lsp::Diagnostic> &diagnostics);
10028cbbc5d0SRiver Riddle 
10038cbbc5d0SRiver Riddle   /// Return the current version of this text file.
getVersion() const10048cbbc5d0SRiver Riddle   int64_t getVersion() const { return version; }
10058cbbc5d0SRiver Riddle 
10068cbbc5d0SRiver Riddle   //===--------------------------------------------------------------------===//
10078cbbc5d0SRiver Riddle   // LSP Queries
10088cbbc5d0SRiver Riddle   //===--------------------------------------------------------------------===//
10098cbbc5d0SRiver Riddle 
10108cbbc5d0SRiver Riddle   void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
10118cbbc5d0SRiver Riddle                       std::vector<lsp::Location> &locations);
10128cbbc5d0SRiver Riddle   void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
10138cbbc5d0SRiver Riddle                         std::vector<lsp::Location> &references);
10141da3a795SFangrui Song   std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
10158cbbc5d0SRiver Riddle                                       lsp::Position hoverPos);
1016ff81a2c9SRiver Riddle   void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1017ed2fb173SRiver Riddle   lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1018ed2fb173SRiver Riddle                                         lsp::Position completePos);
1019ed344c88SRiver Riddle   void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1020ed344c88SRiver Riddle                       const lsp::CodeActionContext &context,
1021ed344c88SRiver Riddle                       std::vector<lsp::CodeAction> &actions);
1022f7b8a70eSRiver Riddle   llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
10238cbbc5d0SRiver Riddle 
10248cbbc5d0SRiver Riddle private:
10258cbbc5d0SRiver Riddle   /// Find the MLIR document that contains the given position, and update the
10268cbbc5d0SRiver Riddle   /// position to be anchored at the start of the found chunk instead of the
10278cbbc5d0SRiver Riddle   /// beginning of the file.
10288cbbc5d0SRiver Riddle   MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
10298cbbc5d0SRiver Riddle 
10300bd297fcSRiver Riddle   /// The context used to hold the state contained by the parsed document.
10310bd297fcSRiver Riddle   MLIRContext context;
10320bd297fcSRiver Riddle 
10338cbbc5d0SRiver Riddle   /// The full string contents of the file.
10348cbbc5d0SRiver Riddle   std::string contents;
10358cbbc5d0SRiver Riddle 
10368cbbc5d0SRiver Riddle   /// The version of this file.
10378cbbc5d0SRiver Riddle   int64_t version;
10388cbbc5d0SRiver Riddle 
1039ff81a2c9SRiver Riddle   /// The number of lines in the file.
1040671e30a1SMehdi Amini   int64_t totalNumLines = 0;
1041ff81a2c9SRiver Riddle 
10428cbbc5d0SRiver Riddle   /// The chunks of this file. The order of these chunks is the order in which
10438cbbc5d0SRiver Riddle   /// they appear in the text file.
10448cbbc5d0SRiver Riddle   std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
10458cbbc5d0SRiver Riddle };
10468cbbc5d0SRiver Riddle } // namespace
10478cbbc5d0SRiver Riddle 
MLIRTextFile(const lsp::URIForFile & uri,StringRef fileContents,int64_t version,DialectRegistry & registry,std::vector<lsp::Diagnostic> & diagnostics)10488cbbc5d0SRiver Riddle MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
10498cbbc5d0SRiver Riddle                            int64_t version, DialectRegistry &registry,
10508cbbc5d0SRiver Riddle                            std::vector<lsp::Diagnostic> &diagnostics)
10510bd297fcSRiver Riddle     : context(registry, MLIRContext::Threading::DISABLED),
1052671e30a1SMehdi Amini       contents(fileContents.str()), version(version) {
10530bd297fcSRiver Riddle   context.allowUnregisteredDialects();
10540bd297fcSRiver Riddle 
10558cbbc5d0SRiver Riddle   // Split the file into separate MLIR documents.
10568cbbc5d0SRiver Riddle   SmallVector<StringRef, 8> subContents;
1057516ccce7SIngo Müller   StringRef(contents).split(subContents, kDefaultSplitMarker);
10588cbbc5d0SRiver Riddle   chunks.emplace_back(std::make_unique<MLIRTextFileChunk>(
10590bd297fcSRiver Riddle       context, /*lineOffset=*/0, uri, subContents.front(), diagnostics));
10608cbbc5d0SRiver Riddle 
10618cbbc5d0SRiver Riddle   uint64_t lineOffset = subContents.front().count('\n');
10628cbbc5d0SRiver Riddle   for (StringRef docContents : llvm::drop_begin(subContents)) {
10638cbbc5d0SRiver Riddle     unsigned currentNumDiags = diagnostics.size();
10640bd297fcSRiver Riddle     auto chunk = std::make_unique<MLIRTextFileChunk>(context, lineOffset, uri,
10650bd297fcSRiver Riddle                                                      docContents, diagnostics);
10668cbbc5d0SRiver Riddle     lineOffset += docContents.count('\n');
10678cbbc5d0SRiver Riddle 
10688cbbc5d0SRiver Riddle     // Adjust locations used in diagnostics to account for the offset from the
10698cbbc5d0SRiver Riddle     // beginning of the file.
10708cbbc5d0SRiver Riddle     for (lsp::Diagnostic &diag :
10718cbbc5d0SRiver Riddle          llvm::drop_begin(diagnostics, currentNumDiags)) {
10728cbbc5d0SRiver Riddle       chunk->adjustLocForChunkOffset(diag.range);
10738cbbc5d0SRiver Riddle 
10748cbbc5d0SRiver Riddle       if (!diag.relatedInformation)
10758cbbc5d0SRiver Riddle         continue;
10768cbbc5d0SRiver Riddle       for (auto &it : *diag.relatedInformation)
10778cbbc5d0SRiver Riddle         if (it.location.uri == uri)
10788cbbc5d0SRiver Riddle           chunk->adjustLocForChunkOffset(it.location.range);
10798cbbc5d0SRiver Riddle     }
10808cbbc5d0SRiver Riddle     chunks.emplace_back(std::move(chunk));
10818cbbc5d0SRiver Riddle   }
1082ff81a2c9SRiver Riddle   totalNumLines = lineOffset;
10838cbbc5d0SRiver Riddle }
10848cbbc5d0SRiver Riddle 
getLocationsOf(const lsp::URIForFile & uri,lsp::Position defPos,std::vector<lsp::Location> & locations)10858cbbc5d0SRiver Riddle void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
10868cbbc5d0SRiver Riddle                                   lsp::Position defPos,
10878cbbc5d0SRiver Riddle                                   std::vector<lsp::Location> &locations) {
10888cbbc5d0SRiver Riddle   MLIRTextFileChunk &chunk = getChunkFor(defPos);
10898cbbc5d0SRiver Riddle   chunk.document.getLocationsOf(uri, defPos, locations);
10908cbbc5d0SRiver Riddle 
10918cbbc5d0SRiver Riddle   // Adjust any locations within this file for the offset of this chunk.
10928cbbc5d0SRiver Riddle   if (chunk.lineOffset == 0)
10938cbbc5d0SRiver Riddle     return;
10948cbbc5d0SRiver Riddle   for (lsp::Location &loc : locations)
10958cbbc5d0SRiver Riddle     if (loc.uri == uri)
10968cbbc5d0SRiver Riddle       chunk.adjustLocForChunkOffset(loc.range);
10978cbbc5d0SRiver Riddle }
10988cbbc5d0SRiver Riddle 
findReferencesOf(const lsp::URIForFile & uri,lsp::Position pos,std::vector<lsp::Location> & references)10998cbbc5d0SRiver Riddle void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
11008cbbc5d0SRiver Riddle                                     lsp::Position pos,
11018cbbc5d0SRiver Riddle                                     std::vector<lsp::Location> &references) {
11028cbbc5d0SRiver Riddle   MLIRTextFileChunk &chunk = getChunkFor(pos);
11038cbbc5d0SRiver Riddle   chunk.document.findReferencesOf(uri, pos, references);
11048cbbc5d0SRiver Riddle 
11058cbbc5d0SRiver Riddle   // Adjust any locations within this file for the offset of this chunk.
11068cbbc5d0SRiver Riddle   if (chunk.lineOffset == 0)
11078cbbc5d0SRiver Riddle     return;
11088cbbc5d0SRiver Riddle   for (lsp::Location &loc : references)
11098cbbc5d0SRiver Riddle     if (loc.uri == uri)
11108cbbc5d0SRiver Riddle       chunk.adjustLocForChunkOffset(loc.range);
11118cbbc5d0SRiver Riddle }
11128cbbc5d0SRiver Riddle 
findHover(const lsp::URIForFile & uri,lsp::Position hoverPos)11131da3a795SFangrui Song std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
11148cbbc5d0SRiver Riddle                                                   lsp::Position hoverPos) {
11158cbbc5d0SRiver Riddle   MLIRTextFileChunk &chunk = getChunkFor(hoverPos);
11161da3a795SFangrui Song   std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
11178cbbc5d0SRiver Riddle 
11188cbbc5d0SRiver Riddle   // Adjust any locations within this file for the offset of this chunk.
11198cbbc5d0SRiver Riddle   if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
11208cbbc5d0SRiver Riddle     chunk.adjustLocForChunkOffset(*hoverInfo->range);
11218cbbc5d0SRiver Riddle   return hoverInfo;
11228cbbc5d0SRiver Riddle }
11238cbbc5d0SRiver Riddle 
findDocumentSymbols(std::vector<lsp::DocumentSymbol> & symbols)1124ff81a2c9SRiver Riddle void MLIRTextFile::findDocumentSymbols(
1125ff81a2c9SRiver Riddle     std::vector<lsp::DocumentSymbol> &symbols) {
1126ff81a2c9SRiver Riddle   if (chunks.size() == 1)
1127ff81a2c9SRiver Riddle     return chunks.front()->document.findDocumentSymbols(symbols);
1128ff81a2c9SRiver Riddle 
1129ff81a2c9SRiver Riddle   // If there are multiple chunks in this file, we create top-level symbols for
1130ff81a2c9SRiver Riddle   // each chunk.
1131ff81a2c9SRiver Riddle   for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1132ff81a2c9SRiver Riddle     MLIRTextFileChunk &chunk = *chunks[i];
1133ff81a2c9SRiver Riddle     lsp::Position startPos(chunk.lineOffset);
1134ff81a2c9SRiver Riddle     lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1135ff81a2c9SRiver Riddle                                       : chunks[i + 1]->lineOffset);
1136ff81a2c9SRiver Riddle     lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1137ff81a2c9SRiver Riddle                                lsp::SymbolKind::Namespace,
1138ff81a2c9SRiver Riddle                                /*range=*/lsp::Range(startPos, endPos),
1139ff81a2c9SRiver Riddle                                /*selectionRange=*/lsp::Range(startPos));
1140ff81a2c9SRiver Riddle     chunk.document.findDocumentSymbols(symbol.children);
1141ff81a2c9SRiver Riddle 
1142ff81a2c9SRiver Riddle     // Fixup the locations of document symbols within this chunk.
1143ff81a2c9SRiver Riddle     if (i != 0) {
1144ff81a2c9SRiver Riddle       SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1145ff81a2c9SRiver Riddle       for (lsp::DocumentSymbol &childSymbol : symbol.children)
1146ff81a2c9SRiver Riddle         symbolsToFix.push_back(&childSymbol);
1147ff81a2c9SRiver Riddle 
1148ff81a2c9SRiver Riddle       while (!symbolsToFix.empty()) {
1149ff81a2c9SRiver Riddle         lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1150ff81a2c9SRiver Riddle         chunk.adjustLocForChunkOffset(symbol->range);
1151ff81a2c9SRiver Riddle         chunk.adjustLocForChunkOffset(symbol->selectionRange);
1152ff81a2c9SRiver Riddle 
1153ff81a2c9SRiver Riddle         for (lsp::DocumentSymbol &childSymbol : symbol->children)
1154ff81a2c9SRiver Riddle           symbolsToFix.push_back(&childSymbol);
1155ff81a2c9SRiver Riddle       }
1156ff81a2c9SRiver Riddle     }
1157ff81a2c9SRiver Riddle 
1158ff81a2c9SRiver Riddle     // Push the symbol for this chunk.
1159ff81a2c9SRiver Riddle     symbols.emplace_back(std::move(symbol));
1160ff81a2c9SRiver Riddle   }
1161ff81a2c9SRiver Riddle }
1162ff81a2c9SRiver Riddle 
getCodeCompletion(const lsp::URIForFile & uri,lsp::Position completePos)1163ed2fb173SRiver Riddle lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1164ed2fb173SRiver Riddle                                                     lsp::Position completePos) {
1165ed2fb173SRiver Riddle   MLIRTextFileChunk &chunk = getChunkFor(completePos);
1166ed2fb173SRiver Riddle   lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1167ed2fb173SRiver Riddle       uri, completePos, context.getDialectRegistry());
1168ed2fb173SRiver Riddle 
1169ed2fb173SRiver Riddle   // Adjust any completion locations.
1170ed2fb173SRiver Riddle   for (lsp::CompletionItem &item : completionList.items) {
1171ed2fb173SRiver Riddle     if (item.textEdit)
1172ed2fb173SRiver Riddle       chunk.adjustLocForChunkOffset(item.textEdit->range);
1173ed2fb173SRiver Riddle     for (lsp::TextEdit &edit : item.additionalTextEdits)
1174ed2fb173SRiver Riddle       chunk.adjustLocForChunkOffset(edit.range);
1175ed2fb173SRiver Riddle   }
1176ed2fb173SRiver Riddle   return completionList;
1177ed2fb173SRiver Riddle }
1178ed2fb173SRiver Riddle 
getCodeActions(const lsp::URIForFile & uri,const lsp::Range & pos,const lsp::CodeActionContext & context,std::vector<lsp::CodeAction> & actions)1179ed344c88SRiver Riddle void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1180ed344c88SRiver Riddle                                   const lsp::Range &pos,
1181ed344c88SRiver Riddle                                   const lsp::CodeActionContext &context,
1182ed344c88SRiver Riddle                                   std::vector<lsp::CodeAction> &actions) {
1183ed344c88SRiver Riddle   // Create actions for any diagnostics in this file.
1184ed344c88SRiver Riddle   for (auto &diag : context.diagnostics) {
1185ed344c88SRiver Riddle     if (diag.source != "mlir")
1186ed344c88SRiver Riddle       continue;
1187ed344c88SRiver Riddle     lsp::Position diagPos = diag.range.start;
1188ed344c88SRiver Riddle     MLIRTextFileChunk &chunk = getChunkFor(diagPos);
1189ed344c88SRiver Riddle 
1190ed344c88SRiver Riddle     // Add a new code action that inserts a "expected" diagnostic check.
1191ed344c88SRiver Riddle     lsp::CodeAction action;
1192ed344c88SRiver Riddle     action.title = "Add expected-* diagnostic checks";
1193ed344c88SRiver Riddle     action.kind = lsp::CodeAction::kQuickFix.str();
1194ed344c88SRiver Riddle 
1195ed344c88SRiver Riddle     StringRef severity;
1196ed344c88SRiver Riddle     switch (diag.severity) {
1197ed344c88SRiver Riddle     case lsp::DiagnosticSeverity::Error:
1198ed344c88SRiver Riddle       severity = "error";
1199ed344c88SRiver Riddle       break;
1200ed344c88SRiver Riddle     case lsp::DiagnosticSeverity::Warning:
1201ed344c88SRiver Riddle       severity = "warning";
1202ed344c88SRiver Riddle       break;
1203ed344c88SRiver Riddle     default:
1204ed344c88SRiver Riddle       continue;
1205ed344c88SRiver Riddle     }
1206ed344c88SRiver Riddle 
1207ed344c88SRiver Riddle     // Get edits for the diagnostic.
1208ed344c88SRiver Riddle     std::vector<lsp::TextEdit> edits;
1209ed344c88SRiver Riddle     chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
1210ed344c88SRiver Riddle                                               diag.message, edits);
1211ed344c88SRiver Riddle 
1212ed344c88SRiver Riddle     // Walk the related diagnostics, this is how we encode notes.
1213ed344c88SRiver Riddle     if (diag.relatedInformation) {
1214ed344c88SRiver Riddle       for (auto &noteDiag : *diag.relatedInformation) {
1215ed344c88SRiver Riddle         if (noteDiag.location.uri != uri)
1216ed344c88SRiver Riddle           continue;
1217ed344c88SRiver Riddle         diagPos = noteDiag.location.range.start;
1218ed344c88SRiver Riddle         diagPos.line -= chunk.lineOffset;
1219ed344c88SRiver Riddle         chunk.document.getCodeActionForDiagnostic(uri, diagPos, "note",
1220ed344c88SRiver Riddle                                                   noteDiag.message, edits);
1221ed344c88SRiver Riddle       }
1222ed344c88SRiver Riddle     }
1223ed344c88SRiver Riddle     // Fixup the locations for any edits.
1224ed344c88SRiver Riddle     for (lsp::TextEdit &edit : edits)
1225ed344c88SRiver Riddle       chunk.adjustLocForChunkOffset(edit.range);
1226ed344c88SRiver Riddle 
1227ed344c88SRiver Riddle     action.edit.emplace();
1228ed344c88SRiver Riddle     action.edit->changes[uri.uri().str()] = std::move(edits);
1229ed344c88SRiver Riddle     action.diagnostics = {diag};
1230ed344c88SRiver Riddle 
1231ed344c88SRiver Riddle     actions.emplace_back(std::move(action));
1232ed344c88SRiver Riddle   }
1233ed344c88SRiver Riddle }
1234ed344c88SRiver Riddle 
1235f7b8a70eSRiver Riddle llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertToBytecode()1236f7b8a70eSRiver Riddle MLIRTextFile::convertToBytecode() {
1237f7b8a70eSRiver Riddle   // Bail out if there is more than one chunk, bytecode wants a single module.
1238f7b8a70eSRiver Riddle   if (chunks.size() != 1) {
1239f7b8a70eSRiver Riddle     return llvm::make_error<lsp::LSPError>(
1240f7b8a70eSRiver Riddle         "unexpected split file, please remove all `// -----`",
1241f7b8a70eSRiver Riddle         lsp::ErrorCode::RequestFailed);
1242f7b8a70eSRiver Riddle   }
1243f7b8a70eSRiver Riddle   return chunks.front()->document.convertToBytecode();
1244f7b8a70eSRiver Riddle }
1245f7b8a70eSRiver Riddle 
getChunkFor(lsp::Position & pos)12468cbbc5d0SRiver Riddle MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
12478cbbc5d0SRiver Riddle   if (chunks.size() == 1)
12488cbbc5d0SRiver Riddle     return *chunks.front();
12498cbbc5d0SRiver Riddle 
12508cbbc5d0SRiver Riddle   // Search for the first chunk with a greater line offset, the previous chunk
12518cbbc5d0SRiver Riddle   // is the one that contains `pos`.
12528cbbc5d0SRiver Riddle   auto it = llvm::upper_bound(
12538cbbc5d0SRiver Riddle       chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
12548cbbc5d0SRiver Riddle         return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
12558cbbc5d0SRiver Riddle       });
12568cbbc5d0SRiver Riddle   MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
12578cbbc5d0SRiver Riddle   pos.line -= chunk.lineOffset;
12588cbbc5d0SRiver Riddle   return chunk;
12598cbbc5d0SRiver Riddle }
12608cbbc5d0SRiver Riddle 
12618cbbc5d0SRiver Riddle //===----------------------------------------------------------------------===//
1262751c14fcSRiver Riddle // MLIRServer::Impl
1263751c14fcSRiver Riddle //===----------------------------------------------------------------------===//
1264751c14fcSRiver Riddle 
1265751c14fcSRiver Riddle struct lsp::MLIRServer::Impl {
Impllsp::MLIRServer::Impl1266751c14fcSRiver Riddle   Impl(DialectRegistry &registry) : registry(registry) {}
1267751c14fcSRiver Riddle 
1268751c14fcSRiver Riddle   /// The registry containing dialects that can be recognized in parsed .mlir
1269751c14fcSRiver Riddle   /// files.
1270751c14fcSRiver Riddle   DialectRegistry &registry;
1271751c14fcSRiver Riddle 
12728cbbc5d0SRiver Riddle   /// The files held by the server, mapped by their URI file name.
12738cbbc5d0SRiver Riddle   llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1274751c14fcSRiver Riddle };
1275751c14fcSRiver Riddle 
1276751c14fcSRiver Riddle //===----------------------------------------------------------------------===//
1277751c14fcSRiver Riddle // MLIRServer
1278751c14fcSRiver Riddle //===----------------------------------------------------------------------===//
1279751c14fcSRiver Riddle 
MLIRServer(DialectRegistry & registry)1280751c14fcSRiver Riddle lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1281751c14fcSRiver Riddle     : impl(std::make_unique<Impl>(registry)) {}
1282e5639b3fSMehdi Amini lsp::MLIRServer::~MLIRServer() = default;
1283751c14fcSRiver Riddle 
addOrUpdateDocument(const URIForFile & uri,StringRef contents,int64_t version,std::vector<Diagnostic> & diagnostics)1284b3911cdfSRiver Riddle void lsp::MLIRServer::addOrUpdateDocument(
1285f9ea3ebeSRiver Riddle     const URIForFile &uri, StringRef contents, int64_t version,
1286b3911cdfSRiver Riddle     std::vector<Diagnostic> &diagnostics) {
12878cbbc5d0SRiver Riddle   impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288f9ea3ebeSRiver Riddle       uri, contents, version, impl->registry, diagnostics);
1289751c14fcSRiver Riddle }
1290751c14fcSRiver Riddle 
removeDocument(const URIForFile & uri)12910a81ace0SKazu Hirata std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
12928cbbc5d0SRiver Riddle   auto it = impl->files.find(uri.file());
12938cbbc5d0SRiver Riddle   if (it == impl->files.end())
12941a36588eSKazu Hirata     return std::nullopt;
1295f9ea3ebeSRiver Riddle 
12968cbbc5d0SRiver Riddle   int64_t version = it->second->getVersion();
12978cbbc5d0SRiver Riddle   impl->files.erase(it);
1298f9ea3ebeSRiver Riddle   return version;
1299751c14fcSRiver Riddle }
1300751c14fcSRiver Riddle 
getLocationsOf(const URIForFile & uri,const Position & defPos,std::vector<Location> & locations)1301751c14fcSRiver Riddle void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
1302751c14fcSRiver Riddle                                      const Position &defPos,
1303751c14fcSRiver Riddle                                      std::vector<Location> &locations) {
13048cbbc5d0SRiver Riddle   auto fileIt = impl->files.find(uri.file());
13058cbbc5d0SRiver Riddle   if (fileIt != impl->files.end())
1306751c14fcSRiver Riddle     fileIt->second->getLocationsOf(uri, defPos, locations);
1307751c14fcSRiver Riddle }
1308751c14fcSRiver Riddle 
findReferencesOf(const URIForFile & uri,const Position & pos,std::vector<Location> & references)1309751c14fcSRiver Riddle void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
1310751c14fcSRiver Riddle                                        const Position &pos,
1311751c14fcSRiver Riddle                                        std::vector<Location> &references) {
13128cbbc5d0SRiver Riddle   auto fileIt = impl->files.find(uri.file());
13138cbbc5d0SRiver Riddle   if (fileIt != impl->files.end())
1314751c14fcSRiver Riddle     fileIt->second->findReferencesOf(uri, pos, references);
1315751c14fcSRiver Riddle }
13165c84195bSRiver Riddle 
findHover(const URIForFile & uri,const Position & hoverPos)13171da3a795SFangrui Song std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
13185c84195bSRiver Riddle                                                      const Position &hoverPos) {
13198cbbc5d0SRiver Riddle   auto fileIt = impl->files.find(uri.file());
13208cbbc5d0SRiver Riddle   if (fileIt != impl->files.end())
13215c84195bSRiver Riddle     return fileIt->second->findHover(uri, hoverPos);
13221a36588eSKazu Hirata   return std::nullopt;
13235c84195bSRiver Riddle }
1324ff81a2c9SRiver Riddle 
findDocumentSymbols(const URIForFile & uri,std::vector<DocumentSymbol> & symbols)1325ff81a2c9SRiver Riddle void lsp::MLIRServer::findDocumentSymbols(
1326ff81a2c9SRiver Riddle     const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1327ff81a2c9SRiver Riddle   auto fileIt = impl->files.find(uri.file());
1328ff81a2c9SRiver Riddle   if (fileIt != impl->files.end())
1329ff81a2c9SRiver Riddle     fileIt->second->findDocumentSymbols(symbols);
1330ff81a2c9SRiver Riddle }
1331ed2fb173SRiver Riddle 
1332ed2fb173SRiver Riddle lsp::CompletionList
getCodeCompletion(const URIForFile & uri,const Position & completePos)1333ed2fb173SRiver Riddle lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
1334ed2fb173SRiver Riddle                                    const Position &completePos) {
1335ed2fb173SRiver Riddle   auto fileIt = impl->files.find(uri.file());
1336ed2fb173SRiver Riddle   if (fileIt != impl->files.end())
1337ed2fb173SRiver Riddle     return fileIt->second->getCodeCompletion(uri, completePos);
1338ed2fb173SRiver Riddle   return CompletionList();
1339ed2fb173SRiver Riddle }
1340ed344c88SRiver Riddle 
getCodeActions(const URIForFile & uri,const Range & pos,const CodeActionContext & context,std::vector<CodeAction> & actions)1341ed344c88SRiver Riddle void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1342ed344c88SRiver Riddle                                      const CodeActionContext &context,
1343ed344c88SRiver Riddle                                      std::vector<CodeAction> &actions) {
1344ed344c88SRiver Riddle   auto fileIt = impl->files.find(uri.file());
1345ed344c88SRiver Riddle   if (fileIt != impl->files.end())
1346ed344c88SRiver Riddle     fileIt->second->getCodeActions(uri, pos, context, actions);
1347ed344c88SRiver Riddle }
1348f7b8a70eSRiver Riddle 
1349f7b8a70eSRiver Riddle llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertFromBytecode(const URIForFile & uri)1350f7b8a70eSRiver Riddle lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
1351f7b8a70eSRiver Riddle   MLIRContext tempContext(impl->registry);
1352f7b8a70eSRiver Riddle   tempContext.allowUnregisteredDialects();
1353f7b8a70eSRiver Riddle 
1354f7b8a70eSRiver Riddle   // Collect any errors during parsing.
1355f7b8a70eSRiver Riddle   std::string errorMsg;
1356f7b8a70eSRiver Riddle   ScopedDiagnosticHandler diagHandler(
1357f7b8a70eSRiver Riddle       &tempContext,
1358f7b8a70eSRiver Riddle       [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1359f7b8a70eSRiver Riddle 
136034300ee3SRiver Riddle   // Handling for external resources, which we want to propagate up to the user.
136134300ee3SRiver Riddle   FallbackAsmResourceMap fallbackResourceMap;
136234300ee3SRiver Riddle 
136334300ee3SRiver Riddle   // Setup the parser config.
13641ae60e04SRiver Riddle   ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
13651ae60e04SRiver Riddle                             &fallbackResourceMap);
136634300ee3SRiver Riddle 
1367f7b8a70eSRiver Riddle   // Try to parse the given source file.
1368f7b8a70eSRiver Riddle   Block parsedBlock;
136934300ee3SRiver Riddle   if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
1370f7b8a70eSRiver Riddle     return llvm::make_error<lsp::LSPError>(
1371f7b8a70eSRiver Riddle         "failed to parse bytecode source file: " + errorMsg,
1372f7b8a70eSRiver Riddle         lsp::ErrorCode::RequestFailed);
1373f7b8a70eSRiver Riddle   }
1374f7b8a70eSRiver Riddle 
1375f7b8a70eSRiver Riddle   // TODO: We currently expect a single top-level operation, but this could
1376f7b8a70eSRiver Riddle   // conceptually be relaxed.
1377f7b8a70eSRiver Riddle   if (!llvm::hasSingleElement(parsedBlock)) {
1378f7b8a70eSRiver Riddle     return llvm::make_error<lsp::LSPError>(
1379f7b8a70eSRiver Riddle         "expected bytecode to contain a single top-level operation",
1380f7b8a70eSRiver Riddle         lsp::ErrorCode::RequestFailed);
1381f7b8a70eSRiver Riddle   }
1382f7b8a70eSRiver Riddle 
1383f7b8a70eSRiver Riddle   // Print the module to a buffer.
1384f7b8a70eSRiver Riddle   lsp::MLIRConvertBytecodeResult result;
1385f7b8a70eSRiver Riddle   {
1386f7b8a70eSRiver Riddle     // Extract the top-level op so that aliases get printed.
1387f7b8a70eSRiver Riddle     // FIXME: We should be able to enable aliases without having to do this!
1388f7b8a70eSRiver Riddle     OwningOpRef<Operation *> topOp = &parsedBlock.front();
1389e84589c9SMehdi Amini     topOp->remove();
1390f7b8a70eSRiver Riddle 
139134300ee3SRiver Riddle     AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
139234300ee3SRiver Riddle                    /*locationMap=*/nullptr, &fallbackResourceMap);
139334300ee3SRiver Riddle 
1394f7b8a70eSRiver Riddle     llvm::raw_string_ostream os(result.output);
1395e84589c9SMehdi Amini     topOp->print(os, state);
1396f7b8a70eSRiver Riddle   }
1397f7b8a70eSRiver Riddle   return std::move(result);
1398f7b8a70eSRiver Riddle }
1399f7b8a70eSRiver Riddle 
1400f7b8a70eSRiver Riddle llvm::Expected<lsp::MLIRConvertBytecodeResult>
convertToBytecode(const URIForFile & uri)1401f7b8a70eSRiver Riddle lsp::MLIRServer::convertToBytecode(const URIForFile &uri) {
1402f7b8a70eSRiver Riddle   auto fileIt = impl->files.find(uri.file());
1403f7b8a70eSRiver Riddle   if (fileIt == impl->files.end()) {
1404f7b8a70eSRiver Riddle     return llvm::make_error<lsp::LSPError>(
1405f7b8a70eSRiver Riddle         "language server does not contain an entry for this source file",
1406f7b8a70eSRiver Riddle         lsp::ErrorCode::RequestFailed);
1407f7b8a70eSRiver Riddle   }
1408f7b8a70eSRiver Riddle   return fileIt->second->convertToBytecode();
1409f7b8a70eSRiver Riddle }
1410