xref: /llvm-project/mlir/lib/IR/Diagnostics.cpp (revision 8815c505be90edf0168e931d77f2b68e393031d3)
1 //===- Diagnostics.cpp - MLIR Diagnostics ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Diagnostics.h"
10 #include "mlir/IR/Attributes.h"
11 #include "mlir/IR/Location.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Mutex.h"
20 #include "llvm/Support/PrettyStackTrace.h"
21 #include "llvm/Support/Regex.h"
22 #include "llvm/Support/Signals.h"
23 #include "llvm/Support/SourceMgr.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include <optional>
26 
27 using namespace mlir;
28 using namespace mlir::detail;
29 
30 //===----------------------------------------------------------------------===//
31 // DiagnosticArgument
32 //===----------------------------------------------------------------------===//
33 
34 /// Construct from an Attribute.
35 DiagnosticArgument::DiagnosticArgument(Attribute attr)
36     : kind(DiagnosticArgumentKind::Attribute),
37       opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
38 
39 /// Construct from a Type.
40 DiagnosticArgument::DiagnosticArgument(Type val)
41     : kind(DiagnosticArgumentKind::Type),
42       opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
43 
44 /// Returns this argument as an Attribute.
45 Attribute DiagnosticArgument::getAsAttribute() const {
46   assert(getKind() == DiagnosticArgumentKind::Attribute);
47   return Attribute::getFromOpaquePointer(
48       reinterpret_cast<const void *>(opaqueVal));
49 }
50 
51 /// Returns this argument as a Type.
52 Type DiagnosticArgument::getAsType() const {
53   assert(getKind() == DiagnosticArgumentKind::Type);
54   return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
55 }
56 
57 /// Outputs this argument to a stream.
58 void DiagnosticArgument::print(raw_ostream &os) const {
59   switch (kind) {
60   case DiagnosticArgumentKind::Attribute:
61     os << getAsAttribute();
62     break;
63   case DiagnosticArgumentKind::Double:
64     os << getAsDouble();
65     break;
66   case DiagnosticArgumentKind::Integer:
67     os << getAsInteger();
68     break;
69   case DiagnosticArgumentKind::String:
70     os << getAsString();
71     break;
72   case DiagnosticArgumentKind::Type:
73     os << '\'' << getAsType() << '\'';
74     break;
75   case DiagnosticArgumentKind::Unsigned:
76     os << getAsUnsigned();
77     break;
78   }
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // Diagnostic
83 //===----------------------------------------------------------------------===//
84 
85 /// Convert a Twine to a StringRef. Memory used for generating the StringRef is
86 /// stored in 'strings'.
87 static StringRef twineToStrRef(const Twine &val,
88                                std::vector<std::unique_ptr<char[]>> &strings) {
89   // Allocate memory to hold this string.
90   SmallString<64> data;
91   auto strRef = val.toStringRef(data);
92   if (strRef.empty())
93     return strRef;
94 
95   strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
96   memcpy(&strings.back()[0], strRef.data(), strRef.size());
97   // Return a reference to the new string.
98   return StringRef(&strings.back()[0], strRef.size());
99 }
100 
101 /// Stream in a Twine argument.
102 Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
103 Diagnostic &Diagnostic::operator<<(const Twine &val) {
104   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
105   return *this;
106 }
107 Diagnostic &Diagnostic::operator<<(Twine &&val) {
108   arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
109   return *this;
110 }
111 
112 Diagnostic &Diagnostic::operator<<(StringAttr val) {
113   arguments.push_back(DiagnosticArgument(val));
114   return *this;
115 }
116 
117 /// Stream in an OperationName.
118 Diagnostic &Diagnostic::operator<<(OperationName val) {
119   // An OperationName is stored in the context, so we don't need to worry about
120   // the lifetime of its data.
121   arguments.push_back(DiagnosticArgument(val.getStringRef()));
122   return *this;
123 }
124 
125 /// Adjusts operation printing flags used in diagnostics for the given severity
126 /// level.
127 static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
128                                            DiagnosticSeverity severity) {
129   flags.useLocalScope();
130   flags.elideLargeElementsAttrs();
131   if (severity == DiagnosticSeverity::Error)
132     flags.printGenericOpForm();
133   return flags;
134 }
135 
136 /// Stream in an Operation.
137 Diagnostic &Diagnostic::operator<<(Operation &op) {
138   return appendOp(op, OpPrintingFlags());
139 }
140 
141 Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
142   std::string str;
143   llvm::raw_string_ostream os(str);
144   op.print(os, adjustPrintingFlags(flags, severity));
145   // Print on a new line for better readability if the op will be printed on
146   // multiple lines.
147   if (str.find('\n') != std::string::npos)
148     *this << '\n';
149   return *this << str;
150 }
151 
152 /// Stream in a Value.
153 Diagnostic &Diagnostic::operator<<(Value val) {
154   std::string str;
155   llvm::raw_string_ostream os(str);
156   val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
157   return *this << str;
158 }
159 
160 /// Outputs this diagnostic to a stream.
161 void Diagnostic::print(raw_ostream &os) const {
162   for (auto &arg : getArguments())
163     arg.print(os);
164 }
165 
166 /// Convert the diagnostic to a string.
167 std::string Diagnostic::str() const {
168   std::string str;
169   llvm::raw_string_ostream os(str);
170   print(os);
171   return str;
172 }
173 
174 /// Attaches a note to this diagnostic. A new location may be optionally
175 /// provided, if not, then the location defaults to the one specified for this
176 /// diagnostic. Notes may not be attached to other notes.
177 Diagnostic &Diagnostic::attachNote(std::optional<Location> noteLoc) {
178   // We don't allow attaching notes to notes.
179   assert(severity != DiagnosticSeverity::Note &&
180          "cannot attach a note to a note");
181 
182   // If a location wasn't provided then reuse our location.
183   if (!noteLoc)
184     noteLoc = loc;
185 
186   /// Append and return a new note.
187   notes.push_back(
188       std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
189   return *notes.back();
190 }
191 
192 /// Allow a diagnostic to be converted to 'failure'.
193 Diagnostic::operator LogicalResult() const { return failure(); }
194 
195 //===----------------------------------------------------------------------===//
196 // InFlightDiagnostic
197 //===----------------------------------------------------------------------===//
198 
199 /// Allow an inflight diagnostic to be converted to 'failure', otherwise
200 /// 'success' if this is an empty diagnostic.
201 InFlightDiagnostic::operator LogicalResult() const {
202   return failure(isActive());
203 }
204 
205 /// Reports the diagnostic to the engine.
206 void InFlightDiagnostic::report() {
207   // If this diagnostic is still inflight and it hasn't been abandoned, then
208   // report it.
209   if (isInFlight()) {
210     owner->emit(std::move(*impl));
211     owner = nullptr;
212   }
213   impl.reset();
214 }
215 
216 /// Abandons this diagnostic.
217 void InFlightDiagnostic::abandon() { owner = nullptr; }
218 
219 //===----------------------------------------------------------------------===//
220 // DiagnosticEngineImpl
221 //===----------------------------------------------------------------------===//
222 
223 namespace mlir {
224 namespace detail {
225 struct DiagnosticEngineImpl {
226   /// Emit a diagnostic using the registered issue handle if present, or with
227   /// the default behavior if not.
228   void emit(Diagnostic &&diag);
229 
230   /// A mutex to ensure that diagnostics emission is thread-safe.
231   llvm::sys::SmartMutex<true> mutex;
232 
233   /// These are the handlers used to report diagnostics.
234   llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
235                        2>
236       handlers;
237 
238   /// This is a unique identifier counter for diagnostic handlers in the
239   /// context. This id starts at 1 to allow for 0 to be used as a sentinel.
240   DiagnosticEngine::HandlerID uniqueHandlerId = 1;
241 };
242 } // namespace detail
243 } // namespace mlir
244 
245 /// Emit a diagnostic using the registered issue handle if present, or with
246 /// the default behavior if not.
247 void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
248   llvm::sys::SmartScopedLock<true> lock(mutex);
249 
250   // Try to process the given diagnostic on one of the registered handlers.
251   // Handlers are walked in reverse order, so that the most recent handler is
252   // processed first.
253   for (auto &handlerIt : llvm::reverse(handlers))
254     if (succeeded(handlerIt.second(diag)))
255       return;
256 
257   // Otherwise, if this is an error we emit it to stderr.
258   if (diag.getSeverity() != DiagnosticSeverity::Error)
259     return;
260 
261   auto &os = llvm::errs();
262   if (!llvm::isa<UnknownLoc>(diag.getLocation()))
263     os << diag.getLocation() << ": ";
264   os << "error: ";
265 
266   // The default behavior for errors is to emit them to stderr.
267   os << diag << '\n';
268   os.flush();
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // DiagnosticEngine
273 //===----------------------------------------------------------------------===//
274 
275 DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
276 DiagnosticEngine::~DiagnosticEngine() = default;
277 
278 /// Register a new handler for diagnostics to the engine. This function returns
279 /// a unique identifier for the registered handler, which can be used to
280 /// unregister this handler at a later time.
281 auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID {
282   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
283   auto uniqueID = impl->uniqueHandlerId++;
284   impl->handlers.insert({uniqueID, std::move(handler)});
285   return uniqueID;
286 }
287 
288 /// Erase the registered diagnostic handler with the given identifier.
289 void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
290   llvm::sys::SmartScopedLock<true> lock(impl->mutex);
291   impl->handlers.erase(handlerID);
292 }
293 
294 /// Emit a diagnostic using the registered issue handler if present, or with
295 /// the default behavior if not.
296 void DiagnosticEngine::emit(Diagnostic &&diag) {
297   assert(diag.getSeverity() != DiagnosticSeverity::Note &&
298          "notes should not be emitted directly");
299   impl->emit(std::move(diag));
300 }
301 
302 /// Helper function used to emit a diagnostic with an optionally empty twine
303 /// message. If the message is empty, then it is not inserted into the
304 /// diagnostic.
305 static InFlightDiagnostic
306 emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
307   MLIRContext *ctx = location->getContext();
308   auto &diagEngine = ctx->getDiagEngine();
309   auto diag = diagEngine.emit(location, severity);
310   if (!message.isTriviallyEmpty())
311     diag << message;
312 
313   // Add the stack trace as a note if necessary.
314   if (ctx->shouldPrintStackTraceOnDiagnostic()) {
315     std::string bt;
316     {
317       llvm::raw_string_ostream stream(bt);
318       llvm::sys::PrintStackTrace(stream);
319     }
320     if (!bt.empty())
321       diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
322   }
323 
324   return diag;
325 }
326 
327 /// Emit an error message using this location.
328 InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
329 InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
330   return emitDiag(loc, DiagnosticSeverity::Error, message);
331 }
332 
333 /// Emit a warning message using this location.
334 InFlightDiagnostic mlir::emitWarning(Location loc) {
335   return emitWarning(loc, {});
336 }
337 InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
338   return emitDiag(loc, DiagnosticSeverity::Warning, message);
339 }
340 
341 /// Emit a remark message using this location.
342 InFlightDiagnostic mlir::emitRemark(Location loc) {
343   return emitRemark(loc, {});
344 }
345 InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
346   return emitDiag(loc, DiagnosticSeverity::Remark, message);
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // ScopedDiagnosticHandler
351 //===----------------------------------------------------------------------===//
352 
353 ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
354   if (handlerID)
355     ctx->getDiagEngine().eraseHandler(handlerID);
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // SourceMgrDiagnosticHandler
360 //===----------------------------------------------------------------------===//
361 namespace mlir {
362 namespace detail {
363 struct SourceMgrDiagnosticHandlerImpl {
364   /// Return the SrcManager buffer id for the specified file, or zero if none
365   /// can be found.
366   unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
367                                        StringRef filename) {
368     // Check for an existing mapping to the buffer id for this file.
369     auto bufferIt = filenameToBufId.find(filename);
370     if (bufferIt != filenameToBufId.end())
371       return bufferIt->second;
372 
373     // Look for a buffer in the manager that has this filename.
374     for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
375       auto *buf = mgr.getMemoryBuffer(i);
376       if (buf->getBufferIdentifier() == filename)
377         return filenameToBufId[filename] = i;
378     }
379 
380     // Otherwise, try to load the source file.
381     std::string ignored;
382     unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
383     filenameToBufId[filename] = id;
384     return id;
385   }
386 
387   /// Mapping between file name and buffer ID's.
388   llvm::StringMap<unsigned> filenameToBufId;
389 };
390 } // namespace detail
391 } // namespace mlir
392 
393 /// Return a processable CallSiteLoc from the given location.
394 static std::optional<CallSiteLoc> getCallSiteLoc(Location loc) {
395   if (dyn_cast<NameLoc>(loc))
396     return getCallSiteLoc(cast<NameLoc>(loc).getChildLoc());
397   if (auto callLoc = dyn_cast<CallSiteLoc>(loc))
398     return callLoc;
399   if (dyn_cast<FusedLoc>(loc)) {
400     for (auto subLoc : cast<FusedLoc>(loc).getLocations()) {
401       if (auto callLoc = getCallSiteLoc(subLoc)) {
402         return callLoc;
403       }
404     }
405     return std::nullopt;
406   }
407   return std::nullopt;
408 }
409 
410 /// Given a diagnostic kind, returns the LLVM DiagKind.
411 static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
412   switch (kind) {
413   case DiagnosticSeverity::Note:
414     return llvm::SourceMgr::DK_Note;
415   case DiagnosticSeverity::Warning:
416     return llvm::SourceMgr::DK_Warning;
417   case DiagnosticSeverity::Error:
418     return llvm::SourceMgr::DK_Error;
419   case DiagnosticSeverity::Remark:
420     return llvm::SourceMgr::DK_Remark;
421   }
422   llvm_unreachable("Unknown DiagnosticSeverity");
423 }
424 
425 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
426     llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
427     ShouldShowLocFn &&shouldShowLocFn)
428     : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
429       shouldShowLocFn(std::move(shouldShowLocFn)),
430       impl(new SourceMgrDiagnosticHandlerImpl()) {
431   setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
432 }
433 
434 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
435     llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
436     : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
437                                  std::move(shouldShowLocFn)) {}
438 
439 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
440 
441 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
442                                                 DiagnosticSeverity kind,
443                                                 bool displaySourceLine) {
444   // Extract a file location from this loc.
445   auto fileLoc = loc->findInstanceOf<FileLineColLoc>();
446 
447   // If one doesn't exist, then print the raw message without a source location.
448   if (!fileLoc) {
449     std::string str;
450     llvm::raw_string_ostream strOS(str);
451     if (!llvm::isa<UnknownLoc>(loc))
452       strOS << loc << ": ";
453     strOS << message;
454     return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), str);
455   }
456 
457   // Otherwise if we are displaying the source line, try to convert the file
458   // location to an SMLoc.
459   if (displaySourceLine) {
460     auto smloc = convertLocToSMLoc(fileLoc);
461     if (smloc.isValid())
462       return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
463   }
464 
465   // If the conversion was unsuccessful, create a diagnostic with the file
466   // information. We manually combine the line and column to avoid asserts in
467   // the constructor of SMDiagnostic that takes a location.
468   std::string locStr;
469   llvm::raw_string_ostream locOS(locStr);
470   locOS << fileLoc.getFilename().getValue() << ":" << fileLoc.getLine() << ":"
471         << fileLoc.getColumn();
472   llvm::SMDiagnostic diag(locStr, getDiagKind(kind), message.str());
473   diag.print(nullptr, os);
474 }
475 
476 /// Emit the given diagnostic with the held source manager.
477 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
478   SmallVector<std::pair<Location, StringRef>> locationStack;
479   auto addLocToStack = [&](Location loc, StringRef locContext) {
480     if (std::optional<Location> showableLoc = findLocToShow(loc))
481       locationStack.emplace_back(*showableLoc, locContext);
482   };
483 
484   // Add locations to display for this diagnostic.
485   Location loc = diag.getLocation();
486   addLocToStack(loc, /*locContext=*/{});
487 
488   // If the diagnostic location was a call site location, add the call stack as
489   // well.
490   if (auto callLoc = getCallSiteLoc(loc)) {
491     // Print the call stack while valid, or until the limit is reached.
492     loc = callLoc->getCaller();
493     for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
494       addLocToStack(loc, "called from");
495       if ((callLoc = getCallSiteLoc(loc)))
496         loc = callLoc->getCaller();
497       else
498         break;
499     }
500   }
501 
502   // If the location stack is empty, use the initial location.
503   if (locationStack.empty()) {
504     emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
505 
506     // Otherwise, use the location stack.
507   } else {
508     emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
509     for (auto &it : llvm::drop_begin(locationStack))
510       emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
511   }
512 
513   // Emit each of the notes. Only display the source code if the location is
514   // different from the previous location.
515   for (auto &note : diag.getNotes()) {
516     emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
517                    /*displaySourceLine=*/loc != note.getLocation());
518     loc = note.getLocation();
519   }
520 }
521 
522 void SourceMgrDiagnosticHandler::setCallStackLimit(unsigned limit) {
523   callStackLimit = limit;
524 }
525 
526 /// Get a memory buffer for the given file, or nullptr if one is not found.
527 const llvm::MemoryBuffer *
528 SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
529   if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
530     return mgr.getMemoryBuffer(id);
531   return nullptr;
532 }
533 
534 std::optional<Location>
535 SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
536   if (!shouldShowLocFn)
537     return loc;
538   if (!shouldShowLocFn(loc))
539     return std::nullopt;
540 
541   // Recurse into the child locations of some of location types.
542   return TypeSwitch<LocationAttr, std::optional<Location>>(loc)
543       .Case([&](CallSiteLoc callLoc) -> std::optional<Location> {
544         // We recurse into the callee of a call site, as the caller will be
545         // emitted in a different note on the main diagnostic.
546         return findLocToShow(callLoc.getCallee());
547       })
548       .Case([&](FileLineColLoc) -> std::optional<Location> { return loc; })
549       .Case([&](FusedLoc fusedLoc) -> std::optional<Location> {
550         // Fused location is unique in that we try to find a sub-location to
551         // show, rather than the top-level location itself.
552         for (Location childLoc : fusedLoc.getLocations())
553           if (std::optional<Location> showableLoc = findLocToShow(childLoc))
554             return showableLoc;
555         return std::nullopt;
556       })
557       .Case([&](NameLoc nameLoc) -> std::optional<Location> {
558         return findLocToShow(nameLoc.getChildLoc());
559       })
560       .Case([&](OpaqueLoc opaqueLoc) -> std::optional<Location> {
561         // OpaqueLoc always falls back to a different source location.
562         return findLocToShow(opaqueLoc.getFallbackLocation());
563       })
564       .Case([](UnknownLoc) -> std::optional<Location> {
565         // Prefer not to show unknown locations.
566         return std::nullopt;
567       });
568 }
569 
570 /// Get a memory buffer for the given file, or the main file of the source
571 /// manager if one doesn't exist. This always returns non-null.
572 SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
573   // The column and line may be zero to represent unknown column and/or unknown
574   /// line/column information.
575   if (loc.getLine() == 0 || loc.getColumn() == 0)
576     return SMLoc();
577 
578   unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
579   if (!bufferId)
580     return SMLoc();
581   return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // SourceMgrDiagnosticVerifierHandler
586 //===----------------------------------------------------------------------===//
587 
588 namespace mlir {
589 namespace detail {
590 /// This class represents an expected output diagnostic.
591 struct ExpectedDiag {
592   ExpectedDiag(DiagnosticSeverity kind, unsigned lineNo, SMLoc fileLoc,
593                StringRef substring)
594       : kind(kind), lineNo(lineNo), fileLoc(fileLoc), substring(substring) {}
595 
596   /// Emit an error at the location referenced by this diagnostic.
597   LogicalResult emitError(raw_ostream &os, llvm::SourceMgr &mgr,
598                           const Twine &msg) {
599     SMRange range(fileLoc, SMLoc::getFromPointer(fileLoc.getPointer() +
600                                                  substring.size()));
601     mgr.PrintMessage(os, fileLoc, llvm::SourceMgr::DK_Error, msg, range);
602     return failure();
603   }
604 
605   /// Returns true if this diagnostic matches the given string.
606   bool match(StringRef str) const {
607     // If this isn't a regex diagnostic, we simply check if the string was
608     // contained.
609     if (substringRegex)
610       return substringRegex->match(str);
611     return str.contains(substring);
612   }
613 
614   /// Compute the regex matcher for this diagnostic, using the provided stream
615   /// and manager to emit diagnostics as necessary.
616   LogicalResult computeRegex(raw_ostream &os, llvm::SourceMgr &mgr) {
617     std::string regexStr;
618     llvm::raw_string_ostream regexOS(regexStr);
619     StringRef strToProcess = substring;
620     while (!strToProcess.empty()) {
621       // Find the next regex block.
622       size_t regexIt = strToProcess.find("{{");
623       if (regexIt == StringRef::npos) {
624         regexOS << llvm::Regex::escape(strToProcess);
625         break;
626       }
627       regexOS << llvm::Regex::escape(strToProcess.take_front(regexIt));
628       strToProcess = strToProcess.drop_front(regexIt + 2);
629 
630       // Find the end of the regex block.
631       size_t regexEndIt = strToProcess.find("}}");
632       if (regexEndIt == StringRef::npos)
633         return emitError(os, mgr, "found start of regex with no end '}}'");
634       StringRef regexStr = strToProcess.take_front(regexEndIt);
635 
636       // Validate that the regex is actually valid.
637       std::string regexError;
638       if (!llvm::Regex(regexStr).isValid(regexError))
639         return emitError(os, mgr, "invalid regex: " + regexError);
640 
641       regexOS << '(' << regexStr << ')';
642       strToProcess = strToProcess.drop_front(regexEndIt + 2);
643     }
644     substringRegex = llvm::Regex(regexStr);
645     return success();
646   }
647 
648   /// The severity of the diagnosic expected.
649   DiagnosticSeverity kind;
650   /// The line number the expected diagnostic should be on.
651   unsigned lineNo;
652   /// The location of the expected diagnostic within the input file.
653   SMLoc fileLoc;
654   /// A flag indicating if the expected diagnostic has been matched yet.
655   bool matched = false;
656   /// The substring that is expected to be within the diagnostic.
657   StringRef substring;
658   /// An optional regex matcher, if the expected diagnostic sub-string was a
659   /// regex string.
660   std::optional<llvm::Regex> substringRegex;
661 };
662 
663 struct SourceMgrDiagnosticVerifierHandlerImpl {
664   SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
665 
666   /// Returns the expected diagnostics for the given source file.
667   std::optional<MutableArrayRef<ExpectedDiag>>
668   getExpectedDiags(StringRef bufName);
669 
670   /// Computes the expected diagnostics for the given source buffer.
671   MutableArrayRef<ExpectedDiag>
672   computeExpectedDiags(raw_ostream &os, llvm::SourceMgr &mgr,
673                        const llvm::MemoryBuffer *buf);
674 
675   /// The current status of the verifier.
676   LogicalResult status;
677 
678   /// A list of expected diagnostics for each buffer of the source manager.
679   llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
680 
681   /// Regex to match the expected diagnostics format.
682   llvm::Regex expected =
683       llvm::Regex("expected-(error|note|remark|warning)(-re)? "
684                   "*(@([+-][0-9]+|above|below))? *{{(.*)}}$");
685 };
686 } // namespace detail
687 } // namespace mlir
688 
689 /// Given a diagnostic kind, return a human readable string for it.
690 static StringRef getDiagKindStr(DiagnosticSeverity kind) {
691   switch (kind) {
692   case DiagnosticSeverity::Note:
693     return "note";
694   case DiagnosticSeverity::Warning:
695     return "warning";
696   case DiagnosticSeverity::Error:
697     return "error";
698   case DiagnosticSeverity::Remark:
699     return "remark";
700   }
701   llvm_unreachable("Unknown DiagnosticSeverity");
702 }
703 
704 std::optional<MutableArrayRef<ExpectedDiag>>
705 SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
706   auto expectedDiags = expectedDiagsPerFile.find(bufName);
707   if (expectedDiags != expectedDiagsPerFile.end())
708     return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
709   return std::nullopt;
710 }
711 
712 MutableArrayRef<ExpectedDiag>
713 SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
714     raw_ostream &os, llvm::SourceMgr &mgr, const llvm::MemoryBuffer *buf) {
715   // If the buffer is invalid, return an empty list.
716   if (!buf)
717     return std::nullopt;
718   auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
719 
720   // The number of the last line that did not correlate to a designator.
721   unsigned lastNonDesignatorLine = 0;
722 
723   // The indices of designators that apply to the next non designator line.
724   SmallVector<unsigned, 1> designatorsForNextLine;
725 
726   // Scan the file for expected-* designators.
727   SmallVector<StringRef, 100> lines;
728   buf->getBuffer().split(lines, '\n');
729   for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
730     SmallVector<StringRef, 4> matches;
731     if (!expected.match(lines[lineNo].rtrim(), &matches)) {
732       // Check for designators that apply to this line.
733       if (!designatorsForNextLine.empty()) {
734         for (unsigned diagIndex : designatorsForNextLine)
735           expectedDiags[diagIndex].lineNo = lineNo + 1;
736         designatorsForNextLine.clear();
737       }
738       lastNonDesignatorLine = lineNo;
739       continue;
740     }
741 
742     // Point to the start of expected-*.
743     SMLoc expectedStart = SMLoc::getFromPointer(matches[0].data());
744 
745     DiagnosticSeverity kind;
746     if (matches[1] == "error")
747       kind = DiagnosticSeverity::Error;
748     else if (matches[1] == "warning")
749       kind = DiagnosticSeverity::Warning;
750     else if (matches[1] == "remark")
751       kind = DiagnosticSeverity::Remark;
752     else {
753       assert(matches[1] == "note");
754       kind = DiagnosticSeverity::Note;
755     }
756     ExpectedDiag record(kind, lineNo + 1, expectedStart, matches[5]);
757 
758     // Check to see if this is a regex match, i.e. it includes the `-re`.
759     if (!matches[2].empty() && failed(record.computeRegex(os, mgr))) {
760       status = failure();
761       continue;
762     }
763 
764     StringRef offsetMatch = matches[3];
765     if (!offsetMatch.empty()) {
766       offsetMatch = offsetMatch.drop_front(1);
767 
768       // Get the integer value without the @ and +/- prefix.
769       if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
770         int offset;
771         offsetMatch.drop_front().getAsInteger(0, offset);
772 
773         if (offsetMatch.front() == '+')
774           record.lineNo += offset;
775         else
776           record.lineNo -= offset;
777       } else if (offsetMatch.consume_front("above")) {
778         // If the designator applies 'above' we add it to the last non
779         // designator line.
780         record.lineNo = lastNonDesignatorLine + 1;
781       } else {
782         // Otherwise, this is a 'below' designator and applies to the next
783         // non-designator line.
784         assert(offsetMatch.consume_front("below"));
785         designatorsForNextLine.push_back(expectedDiags.size());
786 
787         // Set the line number to the last in the case that this designator ends
788         // up dangling.
789         record.lineNo = e;
790       }
791     }
792     expectedDiags.emplace_back(std::move(record));
793   }
794   return expectedDiags;
795 }
796 
797 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
798     llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
799     : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
800       impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
801   // Compute the expected diagnostics for each of the current files in the
802   // source manager.
803   for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
804     (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
805 
806   // Register a handler to verify the diagnostics.
807   setHandler([&](Diagnostic &diag) {
808     // Process the main diagnostics.
809     process(diag);
810 
811     // Process each of the notes.
812     for (auto &note : diag.getNotes())
813       process(note);
814   });
815 }
816 
817 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
818     llvm::SourceMgr &srcMgr, MLIRContext *ctx)
819     : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
820 
821 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
822   // Ensure that all expected diagnostics were handled.
823   (void)verify();
824 }
825 
826 /// Returns the status of the verifier and verifies that all expected
827 /// diagnostics were emitted. This return success if all diagnostics were
828 /// verified correctly, failure otherwise.
829 LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
830   // Verify that all expected errors were seen.
831   for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
832     for (auto &err : expectedDiagsPair.second) {
833       if (err.matched)
834         continue;
835       impl->status =
836           err.emitError(os, mgr,
837                         "expected " + getDiagKindStr(err.kind) + " \"" +
838                             err.substring + "\" was not produced");
839     }
840   }
841   impl->expectedDiagsPerFile.clear();
842   return impl->status;
843 }
844 
845 /// Process a single diagnostic.
846 void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
847   auto kind = diag.getSeverity();
848 
849   // Process a FileLineColLoc.
850   if (auto fileLoc = diag.getLocation()->findInstanceOf<FileLineColLoc>())
851     return process(fileLoc, diag.str(), kind);
852 
853   emitDiagnostic(diag.getLocation(),
854                  "unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
855                  DiagnosticSeverity::Error);
856   impl->status = failure();
857 }
858 
859 /// Process a FileLineColLoc diagnostic.
860 void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
861                                                  StringRef msg,
862                                                  DiagnosticSeverity kind) {
863   // Get the expected diagnostics for this file.
864   auto diags = impl->getExpectedDiags(loc.getFilename());
865   if (!diags) {
866     diags = impl->computeExpectedDiags(os, mgr,
867                                        getBufferForFile(loc.getFilename()));
868   }
869 
870   // Search for a matching expected diagnostic.
871   // If we find something that is close then emit a more specific error.
872   ExpectedDiag *nearMiss = nullptr;
873 
874   // If this was an expected error, remember that we saw it and return.
875   unsigned line = loc.getLine();
876   for (auto &e : *diags) {
877     if (line == e.lineNo && e.match(msg)) {
878       if (e.kind == kind) {
879         e.matched = true;
880         return;
881       }
882 
883       // If this only differs based on the diagnostic kind, then consider it
884       // to be a near miss.
885       nearMiss = &e;
886     }
887   }
888 
889   // Otherwise, emit an error for the near miss.
890   if (nearMiss)
891     mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
892                      "'" + getDiagKindStr(kind) +
893                          "' diagnostic emitted when expecting a '" +
894                          getDiagKindStr(nearMiss->kind) + "'");
895   else
896     emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
897                    DiagnosticSeverity::Error);
898   impl->status = failure();
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // ParallelDiagnosticHandler
903 //===----------------------------------------------------------------------===//
904 
905 namespace mlir {
906 namespace detail {
907 struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
908   struct ThreadDiagnostic {
909     ThreadDiagnostic(size_t id, Diagnostic diag)
910         : id(id), diag(std::move(diag)) {}
911     bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
912 
913     /// The id for this diagnostic, this is used for ordering.
914     /// Note: This id corresponds to the ordered position of the current element
915     ///       being processed by a given thread.
916     size_t id;
917 
918     /// The diagnostic.
919     Diagnostic diag;
920   };
921 
922   ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : context(ctx) {
923     handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
924       uint64_t tid = llvm::get_threadid();
925       llvm::sys::SmartScopedLock<true> lock(mutex);
926 
927       // If this thread is not tracked, then return failure to let another
928       // handler process this diagnostic.
929       if (!threadToOrderID.count(tid))
930         return failure();
931 
932       // Append a new diagnostic.
933       diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
934       return success();
935     });
936   }
937 
938   ~ParallelDiagnosticHandlerImpl() override {
939     // Erase this handler from the context.
940     context->getDiagEngine().eraseHandler(handlerID);
941 
942     // Early exit if there are no diagnostics, this is the common case.
943     if (diagnostics.empty())
944       return;
945 
946     // Emit the diagnostics back to the context.
947     emitDiagnostics([&](Diagnostic &diag) {
948       return context->getDiagEngine().emit(std::move(diag));
949     });
950   }
951 
952   /// Utility method to emit any held diagnostics.
953   void emitDiagnostics(llvm::function_ref<void(Diagnostic &)> emitFn) const {
954     // Stable sort all of the diagnostics that were emitted. This creates a
955     // deterministic ordering for the diagnostics based upon which order id they
956     // were emitted for.
957     std::stable_sort(diagnostics.begin(), diagnostics.end());
958 
959     // Emit each diagnostic to the context again.
960     for (ThreadDiagnostic &diag : diagnostics)
961       emitFn(diag.diag);
962   }
963 
964   /// Set the order id for the current thread.
965   void setOrderIDForThread(size_t orderID) {
966     uint64_t tid = llvm::get_threadid();
967     llvm::sys::SmartScopedLock<true> lock(mutex);
968     threadToOrderID[tid] = orderID;
969   }
970 
971   /// Remove the order id for the current thread.
972   void eraseOrderIDForThread() {
973     uint64_t tid = llvm::get_threadid();
974     llvm::sys::SmartScopedLock<true> lock(mutex);
975     threadToOrderID.erase(tid);
976   }
977 
978   /// Dump the current diagnostics that were inflight.
979   void print(raw_ostream &os) const override {
980     // Early exit if there are no diagnostics, this is the common case.
981     if (diagnostics.empty())
982       return;
983 
984     os << "In-Flight Diagnostics:\n";
985     emitDiagnostics([&](const Diagnostic &diag) {
986       os.indent(4);
987 
988       // Print each diagnostic with the format:
989       //   "<location>: <kind>: <msg>"
990       if (!llvm::isa<UnknownLoc>(diag.getLocation()))
991         os << diag.getLocation() << ": ";
992       switch (diag.getSeverity()) {
993       case DiagnosticSeverity::Error:
994         os << "error: ";
995         break;
996       case DiagnosticSeverity::Warning:
997         os << "warning: ";
998         break;
999       case DiagnosticSeverity::Note:
1000         os << "note: ";
1001         break;
1002       case DiagnosticSeverity::Remark:
1003         os << "remark: ";
1004         break;
1005       }
1006       os << diag << '\n';
1007     });
1008   }
1009 
1010   /// A smart mutex to lock access to the internal state.
1011   llvm::sys::SmartMutex<true> mutex;
1012 
1013   /// A mapping between the thread id and the current order id.
1014   DenseMap<uint64_t, size_t> threadToOrderID;
1015 
1016   /// An unordered list of diagnostics that were emitted.
1017   mutable std::vector<ThreadDiagnostic> diagnostics;
1018 
1019   /// The unique id for the parallel handler.
1020   DiagnosticEngine::HandlerID handlerID = 0;
1021 
1022   /// The context to emit the diagnostics to.
1023   MLIRContext *context;
1024 };
1025 } // namespace detail
1026 } // namespace mlir
1027 
1028 ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
1029     : impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
1030 ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
1031 
1032 /// Set the order id for the current thread.
1033 void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
1034   impl->setOrderIDForThread(orderID);
1035 }
1036 
1037 /// Remove the order id for the current thread. This removes the thread from
1038 /// diagnostics tracking.
1039 void ParallelDiagnosticHandler::eraseOrderIDForThread() {
1040   impl->eraseOrderIDForThread();
1041 }
1042