xref: /llvm-project/mlir/lib/Pass/PassCrashRecovery.cpp (revision 884221eddb9d395830704fac79fd04008e02e368)
1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/SymbolTable.h"
13 #include "mlir/IR/Verifier.h"
14 #include "mlir/Parser/Parser.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Support/FileUtilities.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/CrashRecoveryContext.h"
22 #include "llvm/Support/ManagedStatic.h"
23 #include "llvm/Support/Mutex.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/Support/Threading.h"
26 #include "llvm/Support/ToolOutputFile.h"
27 
28 using namespace mlir;
29 using namespace mlir::detail;
30 
31 //===----------------------------------------------------------------------===//
32 // RecoveryReproducerContext
33 //===----------------------------------------------------------------------===//
34 
35 namespace mlir {
36 namespace detail {
37 /// This class contains all of the context for generating a recovery reproducer.
38 /// Each recovery context is registered globally to allow for generating
39 /// reproducers when a signal is raised, such as a segfault.
40 struct RecoveryReproducerContext {
41   RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
42                             ReproducerStreamFactory &streamFactory,
43                             bool verifyPasses);
44   ~RecoveryReproducerContext();
45 
46   /// Generate a reproducer with the current context.
47   void generate(std::string &description);
48 
49   /// Disable this reproducer context. This prevents the context from generating
50   /// a reproducer in the result of a crash.
51   void disable();
52 
53   /// Enable a previously disabled reproducer context.
54   void enable();
55 
56 private:
57   /// This function is invoked in the event of a crash.
58   static void crashHandler(void *);
59 
60   /// Register a signal handler to run in the event of a crash.
61   static void registerSignalHandler();
62 
63   /// The textual description of the currently executing pipeline.
64   std::string pipelineElements;
65 
66   /// The MLIR operation representing the IR before the crash.
67   Operation *preCrashOperation;
68 
69   /// The factory for the reproducer output stream to use when generating the
70   /// reproducer.
71   ReproducerStreamFactory &streamFactory;
72 
73   /// Various pass manager and context flags.
74   bool disableThreads;
75   bool verifyPasses;
76 
77   /// The current set of active reproducer contexts. This is used in the event
78   /// of a crash. This is not thread_local as the pass manager may produce any
79   /// number of child threads. This uses a set to allow for multiple MLIR pass
80   /// managers to be running at the same time.
81   static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
82   static llvm::ManagedStatic<
83       llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
84       reproducerSet;
85 };
86 } // namespace detail
87 } // namespace mlir
88 
89 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
90     RecoveryReproducerContext::reproducerMutex;
91 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
92     RecoveryReproducerContext::reproducerSet;
93 
94 RecoveryReproducerContext::RecoveryReproducerContext(
95     std::string passPipelineStr, Operation *op,
96     ReproducerStreamFactory &streamFactory, bool verifyPasses)
97     : pipelineElements(std::move(passPipelineStr)),
98       preCrashOperation(op->clone()), streamFactory(streamFactory),
99       disableThreads(!op->getContext()->isMultithreadingEnabled()),
100       verifyPasses(verifyPasses) {
101   enable();
102 }
103 
104 RecoveryReproducerContext::~RecoveryReproducerContext() {
105   // Erase the cloned preCrash IR that we cached.
106   preCrashOperation->erase();
107   disable();
108 }
109 
110 static void appendReproducer(std::string &description, Operation *op,
111                              const ReproducerStreamFactory &factory,
112                              const std::string &pipelineElements,
113                              bool disableThreads, bool verifyPasses) {
114   llvm::raw_string_ostream descOS(description);
115 
116   // Try to create a new output stream for this crash reproducer.
117   std::string error;
118   std::unique_ptr<ReproducerStream> stream = factory(error);
119   if (!stream) {
120     descOS << "failed to create output stream: " << error;
121     return;
122   }
123   descOS << "reproducer generated at `" << stream->description() << "`";
124 
125   std::string pipeline =
126       (op->getName().getStringRef() + "(" + pipelineElements + ")").str();
127   AsmState state(op);
128   state.attachResourcePrinter(
129       "mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
130         builder.buildString("pipeline", pipeline);
131         builder.buildBool("disable_threading", disableThreads);
132         builder.buildBool("verify_each", verifyPasses);
133       });
134 
135   // Output the .mlir module.
136   op->print(stream->os(), state);
137 }
138 
139 void RecoveryReproducerContext::generate(std::string &description) {
140   appendReproducer(description, preCrashOperation, streamFactory,
141                    pipelineElements, disableThreads, verifyPasses);
142 }
143 
144 void RecoveryReproducerContext::disable() {
145   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
146   reproducerSet->remove(this);
147   if (reproducerSet->empty())
148     llvm::CrashRecoveryContext::Disable();
149 }
150 
151 void RecoveryReproducerContext::enable() {
152   llvm::sys::SmartScopedLock<true> lock(*reproducerMutex);
153   if (reproducerSet->empty())
154     llvm::CrashRecoveryContext::Enable();
155   registerSignalHandler();
156   reproducerSet->insert(this);
157 }
158 
159 void RecoveryReproducerContext::crashHandler(void *) {
160   // Walk the current stack of contexts and generate a reproducer for each one.
161   // We can't know for certain which one was the cause, so we need to generate
162   // a reproducer for all of them.
163   for (RecoveryReproducerContext *context : *reproducerSet) {
164     std::string description;
165     context->generate(description);
166 
167     // Emit an error using information only available within the context.
168     emitError(context->preCrashOperation->getLoc())
169         << "A signal was caught while processing the MLIR module:"
170         << description << "; marking pass as failed";
171   }
172 }
173 
174 void RecoveryReproducerContext::registerSignalHandler() {
175   // Ensure that the handler is only registered once.
176   static bool registered =
177       (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
178   (void)registered;
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // PassCrashReproducerGenerator
183 //===----------------------------------------------------------------------===//
184 
185 struct PassCrashReproducerGenerator::Impl {
186   Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
187       : streamFactory(streamFactory), localReproducer(localReproducer) {}
188 
189   /// The factory to use when generating a crash reproducer.
190   ReproducerStreamFactory streamFactory;
191 
192   /// Flag indicating if reproducer generation should be localized to the
193   /// failing pass.
194   bool localReproducer = false;
195 
196   /// A record of all of the currently active reproducer contexts.
197   SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts;
198 
199   /// The set of all currently running passes. Note: This is not populated when
200   /// `localReproducer` is true, as each pass will get its own recovery context.
201   SetVector<std::pair<Pass *, Operation *>> runningPasses;
202 
203   /// Various pass manager flags that get emitted when generating a reproducer.
204   bool pmFlagVerifyPasses = false;
205 };
206 
207 PassCrashReproducerGenerator::PassCrashReproducerGenerator(
208     ReproducerStreamFactory &streamFactory, bool localReproducer)
209     : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
210 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
211 
212 void PassCrashReproducerGenerator::initialize(
213     iterator_range<PassManager::pass_iterator> passes, Operation *op,
214     bool pmFlagVerifyPasses) {
215   assert((!impl->localReproducer ||
216           !op->getContext()->isMultithreadingEnabled()) &&
217          "expected multi-threading to be disabled when generating a local "
218          "reproducer");
219 
220   llvm::CrashRecoveryContext::Enable();
221   impl->pmFlagVerifyPasses = pmFlagVerifyPasses;
222 
223   // If we aren't generating a local reproducer, prepare a reproducer for the
224   // given top-level operation.
225   if (!impl->localReproducer)
226     prepareReproducerFor(passes, op);
227 }
228 
229 static void
230 formatPassOpReproducerMessage(Diagnostic &os,
231                               std::pair<Pass *, Operation *> passOpPair) {
232   os << "`" << passOpPair.first->getName() << "` on "
233      << "'" << passOpPair.second->getName() << "' operation";
234   if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second))
235     os << ": @" << symbol.getName();
236 }
237 
238 void PassCrashReproducerGenerator::finalize(Operation *rootOp,
239                                             LogicalResult executionResult) {
240   // Don't generate a reproducer if we have no active contexts.
241   if (impl->activeContexts.empty())
242     return;
243 
244   // If the pass manager execution succeeded, we don't generate any reproducers.
245   if (succeeded(executionResult))
246     return impl->activeContexts.clear();
247 
248   InFlightDiagnostic diag = emitError(rootOp->getLoc())
249                             << "Failures have been detected while "
250                                "processing an MLIR pass pipeline";
251 
252   // If we are generating a global reproducer, we include all of the running
253   // passes in the error message for the only active context.
254   if (!impl->localReproducer) {
255     assert(impl->activeContexts.size() == 1 && "expected one active context");
256 
257     // Generate the reproducer.
258     std::string description;
259     impl->activeContexts.front()->generate(description);
260 
261     // Emit an error to the user.
262     Diagnostic &note = diag.attachNote() << "Pipeline failed while executing [";
263     llvm::interleaveComma(impl->runningPasses, note,
264                           [&](const std::pair<Pass *, Operation *> &value) {
265                             formatPassOpReproducerMessage(note, value);
266                           });
267     note << "]: " << description;
268     impl->runningPasses.clear();
269     impl->activeContexts.clear();
270     return;
271   }
272 
273   // If we were generating a local reproducer, we generate a reproducer for the
274   // most recently executing pass using the matching entry from  `runningPasses`
275   // to generate a localized diagnostic message.
276   assert(impl->activeContexts.size() == impl->runningPasses.size() &&
277          "expected running passes to match active contexts");
278 
279   // Generate the reproducer.
280   RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back();
281   std::string description;
282   reproducerContext.generate(description);
283 
284   // Emit an error to the user.
285   Diagnostic &note = diag.attachNote() << "Pipeline failed while executing ";
286   formatPassOpReproducerMessage(note, impl->runningPasses.back());
287   note << ": " << description;
288 
289   impl->activeContexts.clear();
290   impl->runningPasses.clear();
291 }
292 
293 void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass,
294                                                         Operation *op) {
295   // If not tracking local reproducers, we simply remember that this pass is
296   // running.
297   impl->runningPasses.insert(std::make_pair(pass, op));
298   if (!impl->localReproducer)
299     return;
300 
301   // Disable the current pass recovery context, if there is one. This may happen
302   // in the case of dynamic pass pipelines.
303   if (!impl->activeContexts.empty())
304     impl->activeContexts.back()->disable();
305 
306   // Collect all of the parent scopes of this operation.
307   SmallVector<OperationName> scopes;
308   while (Operation *parentOp = op->getParentOp()) {
309     scopes.push_back(op->getName());
310     op = parentOp;
311   }
312 
313   // Emit a pass pipeline string for the current pass running on the current
314   // operation type.
315   std::string passStr;
316   llvm::raw_string_ostream passOS(passStr);
317   for (OperationName scope : llvm::reverse(scopes))
318     passOS << scope << "(";
319   pass->printAsTextualPipeline(passOS);
320   for (unsigned i = 0, e = scopes.size(); i < e; ++i)
321     passOS << ")";
322 
323   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
324       passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));
325 }
326 void PassCrashReproducerGenerator::prepareReproducerFor(
327     iterator_range<PassManager::pass_iterator> passes, Operation *op) {
328   std::string passStr;
329   llvm::raw_string_ostream passOS(passStr);
330   llvm::interleaveComma(
331       passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); });
332 
333   impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>(
334       passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses));
335 }
336 
337 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass,
338                                                            Operation *op) {
339   // We only pop the active context if we are tracking local reproducers.
340   impl->runningPasses.remove(std::make_pair(pass, op));
341   if (impl->localReproducer) {
342     impl->activeContexts.pop_back();
343 
344     // Re-enable the previous pass recovery context, if there was one. This may
345     // happen in the case of dynamic pass pipelines.
346     if (!impl->activeContexts.empty())
347       impl->activeContexts.back()->enable();
348   }
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // CrashReproducerInstrumentation
353 //===----------------------------------------------------------------------===//
354 
355 namespace {
356 struct CrashReproducerInstrumentation : public PassInstrumentation {
357   CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator)
358       : generator(generator) {}
359   ~CrashReproducerInstrumentation() override = default;
360 
361   void runBeforePass(Pass *pass, Operation *op) override {
362     if (!isa<OpToOpPassAdaptor>(pass))
363       generator.prepareReproducerFor(pass, op);
364   }
365 
366   void runAfterPass(Pass *pass, Operation *op) override {
367     if (!isa<OpToOpPassAdaptor>(pass))
368       generator.removeLastReproducerFor(pass, op);
369   }
370 
371   void runAfterPassFailed(Pass *pass, Operation *op) override {
372     // Only generate one reproducer per crash reproducer instrumentation.
373     if (alreadyFailed)
374       return;
375 
376     alreadyFailed = true;
377     generator.finalize(op, /*executionResult=*/failure());
378   }
379 
380 private:
381   /// The generator used to create crash reproducers.
382   PassCrashReproducerGenerator &generator;
383   bool alreadyFailed = false;
384 };
385 } // namespace
386 
387 //===----------------------------------------------------------------------===//
388 // FileReproducerStream
389 //===----------------------------------------------------------------------===//
390 
391 namespace {
392 /// This class represents a default instance of mlir::ReproducerStream
393 /// that is backed by a file.
394 struct FileReproducerStream : public mlir::ReproducerStream {
395   FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
396       : outputFile(std::move(outputFile)) {}
397   ~FileReproducerStream() override { outputFile->keep(); }
398 
399   /// Returns a description of the reproducer stream.
400   StringRef description() override { return outputFile->getFilename(); }
401 
402   /// Returns the stream on which to output the reproducer.
403   raw_ostream &os() override { return outputFile->os(); }
404 
405 private:
406   /// ToolOutputFile corresponding to opened `filename`.
407   std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr;
408 };
409 } // namespace
410 
411 //===----------------------------------------------------------------------===//
412 // PassManager
413 //===----------------------------------------------------------------------===//
414 
415 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
416                                                 AnalysisManager am) {
417   crashReproGenerator->initialize(getPasses(), op, verifyPasses);
418 
419   // Safely invoke the passes within a recovery context.
420   LogicalResult passManagerResult = failure();
421   llvm::CrashRecoveryContext recoveryContext;
422   recoveryContext.RunSafelyOnThread(
423       [&] { passManagerResult = runPasses(op, am); });
424   crashReproGenerator->finalize(op, passManagerResult);
425   return passManagerResult;
426 }
427 
428 static ReproducerStreamFactory
429 makeReproducerStreamFactory(StringRef outputFile) {
430   // Capture the filename by value in case outputFile is out of scope when
431   // invoked.
432   std::string filename = outputFile.str();
433   return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
434     std::unique_ptr<llvm::ToolOutputFile> outputFile =
435         mlir::openOutputFile(filename, &error);
436     if (!outputFile) {
437       error = "Failed to create reproducer stream: " + error;
438       return nullptr;
439     }
440     return std::make_unique<FileReproducerStream>(std::move(outputFile));
441   };
442 }
443 
444 void printAsTextualPipeline(
445     raw_ostream &os, StringRef anchorName,
446     const llvm::iterator_range<OpPassManager::pass_iterator> &passes);
447 
448 std::string mlir::makeReproducer(
449     StringRef anchorName,
450     const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
451     Operation *op, StringRef outputFile, bool disableThreads,
452     bool verifyPasses) {
453 
454   std::string description;
455   std::string pipelineStr;
456   llvm::raw_string_ostream passOS(pipelineStr);
457   ::printAsTextualPipeline(passOS, anchorName, passes);
458   appendReproducer(description, op, makeReproducerStreamFactory(outputFile),
459                    pipelineStr, disableThreads, verifyPasses);
460   return description;
461 }
462 
463 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
464                                                   bool genLocalReproducer) {
465   enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile),
466                                   genLocalReproducer);
467 }
468 
469 void PassManager::enableCrashReproducerGeneration(
470     ReproducerStreamFactory factory, bool genLocalReproducer) {
471   assert(!crashReproGenerator &&
472          "crash reproducer has already been initialized");
473   if (genLocalReproducer && getContext()->isMultithreadingEnabled())
474     llvm::report_fatal_error(
475         "Local crash reproduction can't be setup on a "
476         "pass-manager without disabling multi-threading first.");
477 
478   crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>(
479       factory, genLocalReproducer);
480   addInstrumentation(
481       std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator));
482 }
483 
484 //===----------------------------------------------------------------------===//
485 // Asm Resource
486 //===----------------------------------------------------------------------===//
487 
488 void PassReproducerOptions::attachResourceParser(ParserConfig &config) {
489   auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult {
490     if (entry.getKey() == "pipeline") {
491       FailureOr<std::string> value = entry.parseAsString();
492       if (succeeded(value))
493         this->pipeline = std::move(*value);
494       return value;
495     }
496     if (entry.getKey() == "disable_threading") {
497       FailureOr<bool> value = entry.parseAsBool();
498       if (succeeded(value))
499         this->disableThreading = *value;
500       return value;
501     }
502     if (entry.getKey() == "verify_each") {
503       FailureOr<bool> value = entry.parseAsBool();
504       if (succeeded(value))
505         this->verifyEach = *value;
506       return value;
507     }
508     return entry.emitError() << "unknown 'mlir_reproducer' resource key '"
509                              << entry.getKey() << "'";
510   };
511   config.attachResourceParser("mlir_reproducer", parseFn);
512 }
513 
514 LogicalResult PassReproducerOptions::apply(PassManager &pm) const {
515   if (pipeline.has_value()) {
516     FailureOr<OpPassManager> reproPm = parsePassPipeline(*pipeline);
517     if (failed(reproPm))
518       return failure();
519     static_cast<OpPassManager &>(pm) = std::move(*reproPm);
520   }
521 
522   if (disableThreading.has_value())
523     pm.getContext()->disableMultithreading(*disableThreading);
524 
525   if (verifyEach.has_value())
526     pm.enableVerifier(*verifyEach);
527 
528   return success();
529 }
530