1 //===- PassCrashRecovery.cpp - Pass Crash Recovery Implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "mlir/IR/Dialect.h" 12 #include "mlir/IR/SymbolTable.h" 13 #include "mlir/IR/Verifier.h" 14 #include "mlir/Parser/Parser.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Support/FileUtilities.h" 17 #include "llvm/ADT/STLExtras.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/Support/CommandLine.h" 21 #include "llvm/Support/CrashRecoveryContext.h" 22 #include "llvm/Support/ManagedStatic.h" 23 #include "llvm/Support/Mutex.h" 24 #include "llvm/Support/Signals.h" 25 #include "llvm/Support/Threading.h" 26 #include "llvm/Support/ToolOutputFile.h" 27 28 using namespace mlir; 29 using namespace mlir::detail; 30 31 //===----------------------------------------------------------------------===// 32 // RecoveryReproducerContext 33 //===----------------------------------------------------------------------===// 34 35 namespace mlir { 36 namespace detail { 37 /// This class contains all of the context for generating a recovery reproducer. 38 /// Each recovery context is registered globally to allow for generating 39 /// reproducers when a signal is raised, such as a segfault. 40 struct RecoveryReproducerContext { 41 RecoveryReproducerContext(std::string passPipelineStr, Operation *op, 42 ReproducerStreamFactory &streamFactory, 43 bool verifyPasses); 44 ~RecoveryReproducerContext(); 45 46 /// Generate a reproducer with the current context. 47 void generate(std::string &description); 48 49 /// Disable this reproducer context. This prevents the context from generating 50 /// a reproducer in the result of a crash. 51 void disable(); 52 53 /// Enable a previously disabled reproducer context. 54 void enable(); 55 56 private: 57 /// This function is invoked in the event of a crash. 58 static void crashHandler(void *); 59 60 /// Register a signal handler to run in the event of a crash. 61 static void registerSignalHandler(); 62 63 /// The textual description of the currently executing pipeline. 64 std::string pipelineElements; 65 66 /// The MLIR operation representing the IR before the crash. 67 Operation *preCrashOperation; 68 69 /// The factory for the reproducer output stream to use when generating the 70 /// reproducer. 71 ReproducerStreamFactory &streamFactory; 72 73 /// Various pass manager and context flags. 74 bool disableThreads; 75 bool verifyPasses; 76 77 /// The current set of active reproducer contexts. This is used in the event 78 /// of a crash. This is not thread_local as the pass manager may produce any 79 /// number of child threads. This uses a set to allow for multiple MLIR pass 80 /// managers to be running at the same time. 81 static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex; 82 static llvm::ManagedStatic< 83 llvm::SmallSetVector<RecoveryReproducerContext *, 1>> 84 reproducerSet; 85 }; 86 } // namespace detail 87 } // namespace mlir 88 89 llvm::ManagedStatic<llvm::sys::SmartMutex<true>> 90 RecoveryReproducerContext::reproducerMutex; 91 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>> 92 RecoveryReproducerContext::reproducerSet; 93 94 RecoveryReproducerContext::RecoveryReproducerContext( 95 std::string passPipelineStr, Operation *op, 96 ReproducerStreamFactory &streamFactory, bool verifyPasses) 97 : pipelineElements(std::move(passPipelineStr)), 98 preCrashOperation(op->clone()), streamFactory(streamFactory), 99 disableThreads(!op->getContext()->isMultithreadingEnabled()), 100 verifyPasses(verifyPasses) { 101 enable(); 102 } 103 104 RecoveryReproducerContext::~RecoveryReproducerContext() { 105 // Erase the cloned preCrash IR that we cached. 106 preCrashOperation->erase(); 107 disable(); 108 } 109 110 static void appendReproducer(std::string &description, Operation *op, 111 const ReproducerStreamFactory &factory, 112 const std::string &pipelineElements, 113 bool disableThreads, bool verifyPasses) { 114 llvm::raw_string_ostream descOS(description); 115 116 // Try to create a new output stream for this crash reproducer. 117 std::string error; 118 std::unique_ptr<ReproducerStream> stream = factory(error); 119 if (!stream) { 120 descOS << "failed to create output stream: " << error; 121 return; 122 } 123 descOS << "reproducer generated at `" << stream->description() << "`"; 124 125 std::string pipeline = 126 (op->getName().getStringRef() + "(" + pipelineElements + ")").str(); 127 AsmState state(op); 128 state.attachResourcePrinter( 129 "mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) { 130 builder.buildString("pipeline", pipeline); 131 builder.buildBool("disable_threading", disableThreads); 132 builder.buildBool("verify_each", verifyPasses); 133 }); 134 135 // Output the .mlir module. 136 op->print(stream->os(), state); 137 } 138 139 void RecoveryReproducerContext::generate(std::string &description) { 140 appendReproducer(description, preCrashOperation, streamFactory, 141 pipelineElements, disableThreads, verifyPasses); 142 } 143 144 void RecoveryReproducerContext::disable() { 145 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); 146 reproducerSet->remove(this); 147 if (reproducerSet->empty()) 148 llvm::CrashRecoveryContext::Disable(); 149 } 150 151 void RecoveryReproducerContext::enable() { 152 llvm::sys::SmartScopedLock<true> lock(*reproducerMutex); 153 if (reproducerSet->empty()) 154 llvm::CrashRecoveryContext::Enable(); 155 registerSignalHandler(); 156 reproducerSet->insert(this); 157 } 158 159 void RecoveryReproducerContext::crashHandler(void *) { 160 // Walk the current stack of contexts and generate a reproducer for each one. 161 // We can't know for certain which one was the cause, so we need to generate 162 // a reproducer for all of them. 163 for (RecoveryReproducerContext *context : *reproducerSet) { 164 std::string description; 165 context->generate(description); 166 167 // Emit an error using information only available within the context. 168 emitError(context->preCrashOperation->getLoc()) 169 << "A signal was caught while processing the MLIR module:" 170 << description << "; marking pass as failed"; 171 } 172 } 173 174 void RecoveryReproducerContext::registerSignalHandler() { 175 // Ensure that the handler is only registered once. 176 static bool registered = 177 (llvm::sys::AddSignalHandler(crashHandler, nullptr), false); 178 (void)registered; 179 } 180 181 //===----------------------------------------------------------------------===// 182 // PassCrashReproducerGenerator 183 //===----------------------------------------------------------------------===// 184 185 struct PassCrashReproducerGenerator::Impl { 186 Impl(ReproducerStreamFactory &streamFactory, bool localReproducer) 187 : streamFactory(streamFactory), localReproducer(localReproducer) {} 188 189 /// The factory to use when generating a crash reproducer. 190 ReproducerStreamFactory streamFactory; 191 192 /// Flag indicating if reproducer generation should be localized to the 193 /// failing pass. 194 bool localReproducer = false; 195 196 /// A record of all of the currently active reproducer contexts. 197 SmallVector<std::unique_ptr<RecoveryReproducerContext>> activeContexts; 198 199 /// The set of all currently running passes. Note: This is not populated when 200 /// `localReproducer` is true, as each pass will get its own recovery context. 201 SetVector<std::pair<Pass *, Operation *>> runningPasses; 202 203 /// Various pass manager flags that get emitted when generating a reproducer. 204 bool pmFlagVerifyPasses = false; 205 }; 206 207 PassCrashReproducerGenerator::PassCrashReproducerGenerator( 208 ReproducerStreamFactory &streamFactory, bool localReproducer) 209 : impl(std::make_unique<Impl>(streamFactory, localReproducer)) {} 210 PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default; 211 212 void PassCrashReproducerGenerator::initialize( 213 iterator_range<PassManager::pass_iterator> passes, Operation *op, 214 bool pmFlagVerifyPasses) { 215 assert((!impl->localReproducer || 216 !op->getContext()->isMultithreadingEnabled()) && 217 "expected multi-threading to be disabled when generating a local " 218 "reproducer"); 219 220 llvm::CrashRecoveryContext::Enable(); 221 impl->pmFlagVerifyPasses = pmFlagVerifyPasses; 222 223 // If we aren't generating a local reproducer, prepare a reproducer for the 224 // given top-level operation. 225 if (!impl->localReproducer) 226 prepareReproducerFor(passes, op); 227 } 228 229 static void 230 formatPassOpReproducerMessage(Diagnostic &os, 231 std::pair<Pass *, Operation *> passOpPair) { 232 os << "`" << passOpPair.first->getName() << "` on " 233 << "'" << passOpPair.second->getName() << "' operation"; 234 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(passOpPair.second)) 235 os << ": @" << symbol.getName(); 236 } 237 238 void PassCrashReproducerGenerator::finalize(Operation *rootOp, 239 LogicalResult executionResult) { 240 // Don't generate a reproducer if we have no active contexts. 241 if (impl->activeContexts.empty()) 242 return; 243 244 // If the pass manager execution succeeded, we don't generate any reproducers. 245 if (succeeded(executionResult)) 246 return impl->activeContexts.clear(); 247 248 InFlightDiagnostic diag = emitError(rootOp->getLoc()) 249 << "Failures have been detected while " 250 "processing an MLIR pass pipeline"; 251 252 // If we are generating a global reproducer, we include all of the running 253 // passes in the error message for the only active context. 254 if (!impl->localReproducer) { 255 assert(impl->activeContexts.size() == 1 && "expected one active context"); 256 257 // Generate the reproducer. 258 std::string description; 259 impl->activeContexts.front()->generate(description); 260 261 // Emit an error to the user. 262 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing ["; 263 llvm::interleaveComma(impl->runningPasses, note, 264 [&](const std::pair<Pass *, Operation *> &value) { 265 formatPassOpReproducerMessage(note, value); 266 }); 267 note << "]: " << description; 268 impl->runningPasses.clear(); 269 impl->activeContexts.clear(); 270 return; 271 } 272 273 // If we were generating a local reproducer, we generate a reproducer for the 274 // most recently executing pass using the matching entry from `runningPasses` 275 // to generate a localized diagnostic message. 276 assert(impl->activeContexts.size() == impl->runningPasses.size() && 277 "expected running passes to match active contexts"); 278 279 // Generate the reproducer. 280 RecoveryReproducerContext &reproducerContext = *impl->activeContexts.back(); 281 std::string description; 282 reproducerContext.generate(description); 283 284 // Emit an error to the user. 285 Diagnostic ¬e = diag.attachNote() << "Pipeline failed while executing "; 286 formatPassOpReproducerMessage(note, impl->runningPasses.back()); 287 note << ": " << description; 288 289 impl->activeContexts.clear(); 290 impl->runningPasses.clear(); 291 } 292 293 void PassCrashReproducerGenerator::prepareReproducerFor(Pass *pass, 294 Operation *op) { 295 // If not tracking local reproducers, we simply remember that this pass is 296 // running. 297 impl->runningPasses.insert(std::make_pair(pass, op)); 298 if (!impl->localReproducer) 299 return; 300 301 // Disable the current pass recovery context, if there is one. This may happen 302 // in the case of dynamic pass pipelines. 303 if (!impl->activeContexts.empty()) 304 impl->activeContexts.back()->disable(); 305 306 // Collect all of the parent scopes of this operation. 307 SmallVector<OperationName> scopes; 308 while (Operation *parentOp = op->getParentOp()) { 309 scopes.push_back(op->getName()); 310 op = parentOp; 311 } 312 313 // Emit a pass pipeline string for the current pass running on the current 314 // operation type. 315 std::string passStr; 316 llvm::raw_string_ostream passOS(passStr); 317 for (OperationName scope : llvm::reverse(scopes)) 318 passOS << scope << "("; 319 pass->printAsTextualPipeline(passOS); 320 for (unsigned i = 0, e = scopes.size(); i < e; ++i) 321 passOS << ")"; 322 323 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>( 324 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses)); 325 } 326 void PassCrashReproducerGenerator::prepareReproducerFor( 327 iterator_range<PassManager::pass_iterator> passes, Operation *op) { 328 std::string passStr; 329 llvm::raw_string_ostream passOS(passStr); 330 llvm::interleaveComma( 331 passes, passOS, [&](Pass &pass) { pass.printAsTextualPipeline(passOS); }); 332 333 impl->activeContexts.push_back(std::make_unique<RecoveryReproducerContext>( 334 passStr, op, impl->streamFactory, impl->pmFlagVerifyPasses)); 335 } 336 337 void PassCrashReproducerGenerator::removeLastReproducerFor(Pass *pass, 338 Operation *op) { 339 // We only pop the active context if we are tracking local reproducers. 340 impl->runningPasses.remove(std::make_pair(pass, op)); 341 if (impl->localReproducer) { 342 impl->activeContexts.pop_back(); 343 344 // Re-enable the previous pass recovery context, if there was one. This may 345 // happen in the case of dynamic pass pipelines. 346 if (!impl->activeContexts.empty()) 347 impl->activeContexts.back()->enable(); 348 } 349 } 350 351 //===----------------------------------------------------------------------===// 352 // CrashReproducerInstrumentation 353 //===----------------------------------------------------------------------===// 354 355 namespace { 356 struct CrashReproducerInstrumentation : public PassInstrumentation { 357 CrashReproducerInstrumentation(PassCrashReproducerGenerator &generator) 358 : generator(generator) {} 359 ~CrashReproducerInstrumentation() override = default; 360 361 void runBeforePass(Pass *pass, Operation *op) override { 362 if (!isa<OpToOpPassAdaptor>(pass)) 363 generator.prepareReproducerFor(pass, op); 364 } 365 366 void runAfterPass(Pass *pass, Operation *op) override { 367 if (!isa<OpToOpPassAdaptor>(pass)) 368 generator.removeLastReproducerFor(pass, op); 369 } 370 371 void runAfterPassFailed(Pass *pass, Operation *op) override { 372 // Only generate one reproducer per crash reproducer instrumentation. 373 if (alreadyFailed) 374 return; 375 376 alreadyFailed = true; 377 generator.finalize(op, /*executionResult=*/failure()); 378 } 379 380 private: 381 /// The generator used to create crash reproducers. 382 PassCrashReproducerGenerator &generator; 383 bool alreadyFailed = false; 384 }; 385 } // namespace 386 387 //===----------------------------------------------------------------------===// 388 // FileReproducerStream 389 //===----------------------------------------------------------------------===// 390 391 namespace { 392 /// This class represents a default instance of mlir::ReproducerStream 393 /// that is backed by a file. 394 struct FileReproducerStream : public mlir::ReproducerStream { 395 FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile) 396 : outputFile(std::move(outputFile)) {} 397 ~FileReproducerStream() override { outputFile->keep(); } 398 399 /// Returns a description of the reproducer stream. 400 StringRef description() override { return outputFile->getFilename(); } 401 402 /// Returns the stream on which to output the reproducer. 403 raw_ostream &os() override { return outputFile->os(); } 404 405 private: 406 /// ToolOutputFile corresponding to opened `filename`. 407 std::unique_ptr<llvm::ToolOutputFile> outputFile = nullptr; 408 }; 409 } // namespace 410 411 //===----------------------------------------------------------------------===// 412 // PassManager 413 //===----------------------------------------------------------------------===// 414 415 LogicalResult PassManager::runWithCrashRecovery(Operation *op, 416 AnalysisManager am) { 417 crashReproGenerator->initialize(getPasses(), op, verifyPasses); 418 419 // Safely invoke the passes within a recovery context. 420 LogicalResult passManagerResult = failure(); 421 llvm::CrashRecoveryContext recoveryContext; 422 recoveryContext.RunSafelyOnThread( 423 [&] { passManagerResult = runPasses(op, am); }); 424 crashReproGenerator->finalize(op, passManagerResult); 425 return passManagerResult; 426 } 427 428 static ReproducerStreamFactory 429 makeReproducerStreamFactory(StringRef outputFile) { 430 // Capture the filename by value in case outputFile is out of scope when 431 // invoked. 432 std::string filename = outputFile.str(); 433 return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> { 434 std::unique_ptr<llvm::ToolOutputFile> outputFile = 435 mlir::openOutputFile(filename, &error); 436 if (!outputFile) { 437 error = "Failed to create reproducer stream: " + error; 438 return nullptr; 439 } 440 return std::make_unique<FileReproducerStream>(std::move(outputFile)); 441 }; 442 } 443 444 void printAsTextualPipeline( 445 raw_ostream &os, StringRef anchorName, 446 const llvm::iterator_range<OpPassManager::pass_iterator> &passes); 447 448 std::string mlir::makeReproducer( 449 StringRef anchorName, 450 const llvm::iterator_range<OpPassManager::pass_iterator> &passes, 451 Operation *op, StringRef outputFile, bool disableThreads, 452 bool verifyPasses) { 453 454 std::string description; 455 std::string pipelineStr; 456 llvm::raw_string_ostream passOS(pipelineStr); 457 ::printAsTextualPipeline(passOS, anchorName, passes); 458 appendReproducer(description, op, makeReproducerStreamFactory(outputFile), 459 pipelineStr, disableThreads, verifyPasses); 460 return description; 461 } 462 463 void PassManager::enableCrashReproducerGeneration(StringRef outputFile, 464 bool genLocalReproducer) { 465 enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile), 466 genLocalReproducer); 467 } 468 469 void PassManager::enableCrashReproducerGeneration( 470 ReproducerStreamFactory factory, bool genLocalReproducer) { 471 assert(!crashReproGenerator && 472 "crash reproducer has already been initialized"); 473 if (genLocalReproducer && getContext()->isMultithreadingEnabled()) 474 llvm::report_fatal_error( 475 "Local crash reproduction can't be setup on a " 476 "pass-manager without disabling multi-threading first."); 477 478 crashReproGenerator = std::make_unique<PassCrashReproducerGenerator>( 479 factory, genLocalReproducer); 480 addInstrumentation( 481 std::make_unique<CrashReproducerInstrumentation>(*crashReproGenerator)); 482 } 483 484 //===----------------------------------------------------------------------===// 485 // Asm Resource 486 //===----------------------------------------------------------------------===// 487 488 void PassReproducerOptions::attachResourceParser(ParserConfig &config) { 489 auto parseFn = [this](AsmParsedResourceEntry &entry) -> LogicalResult { 490 if (entry.getKey() == "pipeline") { 491 FailureOr<std::string> value = entry.parseAsString(); 492 if (succeeded(value)) 493 this->pipeline = std::move(*value); 494 return value; 495 } 496 if (entry.getKey() == "disable_threading") { 497 FailureOr<bool> value = entry.parseAsBool(); 498 if (succeeded(value)) 499 this->disableThreading = *value; 500 return value; 501 } 502 if (entry.getKey() == "verify_each") { 503 FailureOr<bool> value = entry.parseAsBool(); 504 if (succeeded(value)) 505 this->verifyEach = *value; 506 return value; 507 } 508 return entry.emitError() << "unknown 'mlir_reproducer' resource key '" 509 << entry.getKey() << "'"; 510 }; 511 config.attachResourceParser("mlir_reproducer", parseFn); 512 } 513 514 LogicalResult PassReproducerOptions::apply(PassManager &pm) const { 515 if (pipeline.has_value()) { 516 FailureOr<OpPassManager> reproPm = parsePassPipeline(*pipeline); 517 if (failed(reproPm)) 518 return failure(); 519 static_cast<OpPassManager &>(pm) = std::move(*reproPm); 520 } 521 522 if (disableThreading.has_value()) 523 pm.getContext()->disableMultithreading(*disableThreading); 524 525 if (verifyEach.has_value()) 526 pm.enableVerifier(*verifyEach); 527 528 return success(); 529 } 530