1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements common pass infrastructure. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Pass/Pass.h" 14 #include "PassDetail.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/IR/Threading.h" 19 #include "mlir/IR/Verifier.h" 20 #include "mlir/Support/FileUtilities.h" 21 #include "llvm/ADT/Hashing.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/ADT/ScopeExit.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/CrashRecoveryContext.h" 26 #include "llvm/Support/Mutex.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/Support/Threading.h" 29 #include "llvm/Support/ToolOutputFile.h" 30 #include <optional> 31 32 using namespace mlir; 33 using namespace mlir::detail; 34 35 //===----------------------------------------------------------------------===// 36 // PassExecutionAction 37 //===----------------------------------------------------------------------===// 38 39 PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits, 40 const Pass &pass) 41 : Base(irUnits), pass(pass) {} 42 43 void PassExecutionAction::print(raw_ostream &os) const { 44 os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`", tag, 45 pass.getName(), getOp()->getName()); 46 } 47 48 Operation *PassExecutionAction::getOp() const { 49 ArrayRef<IRUnit> irUnits = getContextIRUnits(); 50 return irUnits.empty() ? nullptr 51 : llvm::dyn_cast_if_present<Operation *>(irUnits[0]); 52 } 53 54 //===----------------------------------------------------------------------===// 55 // Pass 56 //===----------------------------------------------------------------------===// 57 58 /// Out of line virtual method to ensure vtables and metadata are emitted to a 59 /// single .o file. 60 void Pass::anchor() {} 61 62 /// Attempt to initialize the options of this pass from the given string. 63 LogicalResult Pass::initializeOptions( 64 StringRef options, 65 function_ref<LogicalResult(const Twine &)> errorHandler) { 66 std::string errStr; 67 llvm::raw_string_ostream os(errStr); 68 if (failed(passOptions.parseFromString(options, os))) { 69 return errorHandler(errStr); 70 } 71 return success(); 72 } 73 74 /// Copy the option values from 'other', which is another instance of this 75 /// pass. 76 void Pass::copyOptionValuesFrom(const Pass *other) { 77 passOptions.copyOptionValuesFrom(other->passOptions); 78 } 79 80 /// Prints out the pass in the textual representation of pipelines. If this is 81 /// an adaptor pass, print its pass managers. 82 void Pass::printAsTextualPipeline(raw_ostream &os) { 83 // Special case for adaptors to print its pass managers. 84 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) { 85 llvm::interleave( 86 adaptor->getPassManagers(), 87 [&](OpPassManager &pm) { pm.printAsTextualPipeline(os); }, 88 [&] { os << ","; }); 89 return; 90 } 91 // Otherwise, print the pass argument followed by its options. If the pass 92 // doesn't have an argument, print the name of the pass to give some indicator 93 // of what pass was run. 94 StringRef argument = getArgument(); 95 if (!argument.empty()) 96 os << argument; 97 else 98 os << "unknown<" << getName() << ">"; 99 passOptions.print(os); 100 } 101 102 //===----------------------------------------------------------------------===// 103 // OpPassManagerImpl 104 //===----------------------------------------------------------------------===// 105 106 namespace mlir { 107 namespace detail { 108 struct OpPassManagerImpl { 109 OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting) 110 : name(opName.getStringRef().str()), opName(opName), 111 initializationGeneration(0), nesting(nesting) {} 112 OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting) 113 : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()), 114 initializationGeneration(0), nesting(nesting) {} 115 OpPassManagerImpl(OpPassManager::Nesting nesting) 116 : initializationGeneration(0), nesting(nesting) {} 117 OpPassManagerImpl(const OpPassManagerImpl &rhs) 118 : name(rhs.name), opName(rhs.opName), 119 initializationGeneration(rhs.initializationGeneration), 120 nesting(rhs.nesting) { 121 for (const std::unique_ptr<Pass> &pass : rhs.passes) { 122 std::unique_ptr<Pass> newPass = pass->clone(); 123 newPass->threadingSibling = pass.get(); 124 passes.push_back(std::move(newPass)); 125 } 126 } 127 128 /// Merge the passes of this pass manager into the one provided. 129 void mergeInto(OpPassManagerImpl &rhs); 130 131 /// Nest a new operation pass manager for the given operation kind under this 132 /// pass manager. 133 OpPassManager &nest(OperationName nestedName) { 134 return nest(OpPassManager(nestedName, nesting)); 135 } 136 OpPassManager &nest(StringRef nestedName) { 137 return nest(OpPassManager(nestedName, nesting)); 138 } 139 OpPassManager &nestAny() { return nest(OpPassManager(nesting)); } 140 141 /// Nest the given pass manager under this pass manager. 142 OpPassManager &nest(OpPassManager &&nested); 143 144 /// Add the given pass to this pass manager. If this pass has a concrete 145 /// operation type, it must be the same type as this pass manager. 146 void addPass(std::unique_ptr<Pass> pass); 147 148 /// Clear the list of passes in this pass manager, other options are 149 /// preserved. 150 void clear(); 151 152 /// Finalize the pass list in preparation for execution. This includes 153 /// coalescing adjacent pass managers when possible, verifying scheduled 154 /// passes, etc. 155 LogicalResult finalizePassList(MLIRContext *ctx); 156 157 /// Return the operation name of this pass manager. 158 std::optional<OperationName> getOpName(MLIRContext &context) { 159 if (!name.empty() && !opName) 160 opName = OperationName(name, &context); 161 return opName; 162 } 163 std::optional<StringRef> getOpName() const { 164 return name.empty() ? std::optional<StringRef>() 165 : std::optional<StringRef>(name); 166 } 167 168 /// Return the name used to anchor this pass manager. This is either the name 169 /// of an operation, or the result of `getAnyOpAnchorName()` in the case of an 170 /// op-agnostic pass manager. 171 StringRef getOpAnchorName() const { 172 return getOpName().value_or(OpPassManager::getAnyOpAnchorName()); 173 } 174 175 /// Indicate if the current pass manager can be scheduled on the given 176 /// operation type. 177 bool canScheduleOn(MLIRContext &context, OperationName opName); 178 179 /// The name of the operation that passes of this pass manager operate on. 180 std::string name; 181 182 /// The cached OperationName (internalized in the context) for the name of the 183 /// operation that passes of this pass manager operate on. 184 std::optional<OperationName> opName; 185 186 /// The set of passes to run as part of this pass manager. 187 std::vector<std::unique_ptr<Pass>> passes; 188 189 /// The current initialization generation of this pass manager. This is used 190 /// to indicate when a pass manager should be reinitialized. 191 unsigned initializationGeneration; 192 193 /// Control the implicit nesting of passes that mismatch the name set for this 194 /// OpPassManager. 195 OpPassManager::Nesting nesting; 196 }; 197 } // namespace detail 198 } // namespace mlir 199 200 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) { 201 assert(name == rhs.name && "merging unrelated pass managers"); 202 for (auto &pass : passes) 203 rhs.passes.push_back(std::move(pass)); 204 passes.clear(); 205 } 206 207 OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) { 208 auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); 209 addPass(std::unique_ptr<Pass>(adaptor)); 210 return adaptor->getPassManagers().front(); 211 } 212 213 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) { 214 // If this pass runs on a different operation than this pass manager, then 215 // implicitly nest a pass manager for this operation if enabled. 216 std::optional<StringRef> pmOpName = getOpName(); 217 std::optional<StringRef> passOpName = pass->getOpName(); 218 if (pmOpName && passOpName && *pmOpName != *passOpName) { 219 if (nesting == OpPassManager::Nesting::Implicit) 220 return nest(*passOpName).addPass(std::move(pass)); 221 llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() + 222 "' restricted to '" + *passOpName + 223 "' on a PassManager intended to run on '" + 224 getOpAnchorName() + "', did you intend to nest?"); 225 } 226 227 passes.emplace_back(std::move(pass)); 228 } 229 230 void OpPassManagerImpl::clear() { passes.clear(); } 231 232 LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) { 233 auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) { 234 for (auto &pm : adaptor->getPassManagers()) 235 if (failed(pm.getImpl().finalizePassList(ctx))) 236 return failure(); 237 return success(); 238 }; 239 240 // Walk the pass list and merge adjacent adaptors. 241 OpToOpPassAdaptor *lastAdaptor = nullptr; 242 for (auto &pass : passes) { 243 // Check to see if this pass is an adaptor. 244 if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get())) { 245 // If it is the first adaptor in a possible chain, remember it and 246 // continue. 247 if (!lastAdaptor) { 248 lastAdaptor = currentAdaptor; 249 continue; 250 } 251 252 // Otherwise, try to merge into the existing adaptor and delete the 253 // current one. If merging fails, just remember this as the last adaptor. 254 if (succeeded(currentAdaptor->tryMergeInto(ctx, *lastAdaptor))) 255 pass.reset(); 256 else 257 lastAdaptor = currentAdaptor; 258 } else if (lastAdaptor) { 259 // If this pass isn't an adaptor, finalize it and forget the last adaptor. 260 if (failed(finalizeAdaptor(lastAdaptor))) 261 return failure(); 262 lastAdaptor = nullptr; 263 } 264 } 265 266 // If there was an adaptor at the end of the manager, finalize it as well. 267 if (lastAdaptor && failed(finalizeAdaptor(lastAdaptor))) 268 return failure(); 269 270 // Now that the adaptors have been merged, erase any empty slots corresponding 271 // to the merged adaptors that were nulled-out in the loop above. 272 llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>()); 273 274 // If this is a op-agnostic pass manager, there is nothing left to do. 275 std::optional<OperationName> rawOpName = getOpName(*ctx); 276 if (!rawOpName) 277 return success(); 278 279 // Otherwise, verify that all of the passes are valid for the current 280 // operation anchor. 281 std::optional<RegisteredOperationName> opName = 282 rawOpName->getRegisteredInfo(); 283 for (std::unique_ptr<Pass> &pass : passes) { 284 if (opName && !pass->canScheduleOn(*opName)) { 285 return emitError(UnknownLoc::get(ctx)) 286 << "unable to schedule pass '" << pass->getName() 287 << "' on a PassManager intended to run on '" << getOpAnchorName() 288 << "'!"; 289 } 290 } 291 return success(); 292 } 293 294 bool OpPassManagerImpl::canScheduleOn(MLIRContext &context, 295 OperationName opName) { 296 // If this pass manager is op-specific, we simply check if the provided 297 // operation name is the same as this one. 298 std::optional<OperationName> pmOpName = getOpName(context); 299 if (pmOpName) 300 return pmOpName == opName; 301 302 // Otherwise, this is an op-agnostic pass manager. Check that the operation 303 // can be scheduled on all passes within the manager. 304 std::optional<RegisteredOperationName> registeredInfo = 305 opName.getRegisteredInfo(); 306 if (!registeredInfo || 307 !registeredInfo->hasTrait<OpTrait::IsIsolatedFromAbove>()) 308 return false; 309 return llvm::all_of(passes, [&](const std::unique_ptr<Pass> &pass) { 310 return pass->canScheduleOn(*registeredInfo); 311 }); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // OpPassManager 316 //===----------------------------------------------------------------------===// 317 318 OpPassManager::OpPassManager(Nesting nesting) 319 : impl(new OpPassManagerImpl(nesting)) {} 320 OpPassManager::OpPassManager(StringRef name, Nesting nesting) 321 : impl(new OpPassManagerImpl(name, nesting)) {} 322 OpPassManager::OpPassManager(OperationName name, Nesting nesting) 323 : impl(new OpPassManagerImpl(name, nesting)) {} 324 OpPassManager::OpPassManager(OpPassManager &&rhs) { *this = std::move(rhs); } 325 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; } 326 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) { 327 impl = std::make_unique<OpPassManagerImpl>(*rhs.impl); 328 return *this; 329 } 330 OpPassManager &OpPassManager::operator=(OpPassManager &&rhs) { 331 impl = std::move(rhs.impl); 332 return *this; 333 } 334 335 OpPassManager::~OpPassManager() = default; 336 337 OpPassManager::pass_iterator OpPassManager::begin() { 338 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(); 339 } 340 OpPassManager::pass_iterator OpPassManager::end() { 341 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end(); 342 } 343 344 OpPassManager::const_pass_iterator OpPassManager::begin() const { 345 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(); 346 } 347 OpPassManager::const_pass_iterator OpPassManager::end() const { 348 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end(); 349 } 350 351 /// Nest a new operation pass manager for the given operation kind under this 352 /// pass manager. 353 OpPassManager &OpPassManager::nest(OperationName nestedName) { 354 return impl->nest(nestedName); 355 } 356 OpPassManager &OpPassManager::nest(StringRef nestedName) { 357 return impl->nest(nestedName); 358 } 359 OpPassManager &OpPassManager::nestAny() { return impl->nestAny(); } 360 361 /// Add the given pass to this pass manager. If this pass has a concrete 362 /// operation type, it must be the same type as this pass manager. 363 void OpPassManager::addPass(std::unique_ptr<Pass> pass) { 364 impl->addPass(std::move(pass)); 365 } 366 367 void OpPassManager::clear() { impl->clear(); } 368 369 /// Returns the number of passes held by this manager. 370 size_t OpPassManager::size() const { return impl->passes.size(); } 371 372 /// Returns the internal implementation instance. 373 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; } 374 375 /// Return the operation name that this pass manager operates on. 376 std::optional<StringRef> OpPassManager::getOpName() const { 377 return impl->getOpName(); 378 } 379 380 /// Return the operation name that this pass manager operates on. 381 std::optional<OperationName> 382 OpPassManager::getOpName(MLIRContext &context) const { 383 return impl->getOpName(context); 384 } 385 386 StringRef OpPassManager::getOpAnchorName() const { 387 return impl->getOpAnchorName(); 388 } 389 390 /// Prints out the passes of the pass manager as the textual representation 391 /// of pipelines. 392 void printAsTextualPipeline( 393 raw_ostream &os, StringRef anchorName, 394 const llvm::iterator_range<OpPassManager::pass_iterator> &passes) { 395 os << anchorName << "("; 396 llvm::interleave( 397 passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); }, 398 [&]() { os << ","; }); 399 os << ")"; 400 } 401 void OpPassManager::printAsTextualPipeline(raw_ostream &os) const { 402 StringRef anchorName = getOpAnchorName(); 403 ::printAsTextualPipeline( 404 os, anchorName, 405 {MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(), 406 MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()}); 407 } 408 409 void OpPassManager::dump() { 410 llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes:\n"; 411 printAsTextualPipeline(llvm::errs()); 412 llvm::errs() << "\n"; 413 } 414 415 static void registerDialectsForPipeline(const OpPassManager &pm, 416 DialectRegistry &dialects) { 417 for (const Pass &pass : pm.getPasses()) 418 pass.getDependentDialects(dialects); 419 } 420 421 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const { 422 registerDialectsForPipeline(*this, dialects); 423 } 424 425 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; } 426 427 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; } 428 429 LogicalResult OpPassManager::initialize(MLIRContext *context, 430 unsigned newInitGeneration) { 431 if (impl->initializationGeneration == newInitGeneration) 432 return success(); 433 impl->initializationGeneration = newInitGeneration; 434 for (Pass &pass : getPasses()) { 435 // If this pass isn't an adaptor, directly initialize it. 436 auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass); 437 if (!adaptor) { 438 if (failed(pass.initialize(context))) 439 return failure(); 440 continue; 441 } 442 443 // Otherwise, initialize each of the adaptors pass managers. 444 for (OpPassManager &adaptorPM : adaptor->getPassManagers()) 445 if (failed(adaptorPM.initialize(context, newInitGeneration))) 446 return failure(); 447 } 448 return success(); 449 } 450 451 llvm::hash_code OpPassManager::hash() { 452 llvm::hash_code hashCode{}; 453 for (Pass &pass : getPasses()) { 454 // If this pass isn't an adaptor, directly hash it. 455 auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass); 456 if (!adaptor) { 457 hashCode = llvm::hash_combine(hashCode, &pass); 458 continue; 459 } 460 // Otherwise, hash recursively each of the adaptors pass managers. 461 for (OpPassManager &adaptorPM : adaptor->getPassManagers()) 462 llvm::hash_combine(hashCode, adaptorPM.hash()); 463 } 464 return hashCode; 465 } 466 467 468 //===----------------------------------------------------------------------===// 469 // OpToOpPassAdaptor 470 //===----------------------------------------------------------------------===// 471 472 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, 473 AnalysisManager am, bool verifyPasses, 474 unsigned parentInitGeneration) { 475 std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo(); 476 if (!opInfo) 477 return op->emitOpError() 478 << "trying to schedule a pass on an unregistered operation"; 479 if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>()) 480 return op->emitOpError() << "trying to schedule a pass on an operation not " 481 "marked as 'IsolatedFromAbove'"; 482 if (!pass->canScheduleOn(*op->getName().getRegisteredInfo())) 483 return op->emitOpError() 484 << "trying to schedule a pass on an unsupported operation"; 485 486 // Initialize the pass state with a callback for the pass to dynamically 487 // execute a pipeline on the currently visited operation. 488 PassInstrumentor *pi = am.getPassInstrumentor(); 489 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), 490 pass}; 491 auto dynamicPipelineCallback = [&](OpPassManager &pipeline, 492 Operation *root) -> LogicalResult { 493 if (!op->isAncestor(root)) 494 return root->emitOpError() 495 << "Trying to schedule a dynamic pipeline on an " 496 "operation that isn't " 497 "nested under the current operation the pass is processing"; 498 assert( 499 pipeline.getImpl().canScheduleOn(*op->getContext(), root->getName())); 500 501 // Before running, finalize the passes held by the pipeline. 502 if (failed(pipeline.getImpl().finalizePassList(root->getContext()))) 503 return failure(); 504 505 // Initialize the user provided pipeline and execute the pipeline. 506 if (failed(pipeline.initialize(root->getContext(), parentInitGeneration))) 507 return failure(); 508 AnalysisManager nestedAm = root == op ? am : am.nest(root); 509 return OpToOpPassAdaptor::runPipeline(pipeline, root, nestedAm, 510 verifyPasses, parentInitGeneration, 511 pi, &parentInfo); 512 }; 513 pass->passState.emplace(op, am, dynamicPipelineCallback); 514 515 // Instrument before the pass has run. 516 if (pi) 517 pi->runBeforePass(pass, op); 518 519 bool passFailed = false; 520 op->getContext()->executeAction<PassExecutionAction>( 521 [&]() { 522 // Invoke the virtual runOnOperation method. 523 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass)) 524 adaptor->runOnOperation(verifyPasses); 525 else 526 pass->runOnOperation(); 527 passFailed = pass->passState->irAndPassFailed.getInt(); 528 }, 529 {op}, *pass); 530 531 // Invalidate any non preserved analyses. 532 am.invalidate(pass->passState->preservedAnalyses); 533 534 // When verifyPasses is specified, we run the verifier (unless the pass 535 // failed). 536 if (!passFailed && verifyPasses) { 537 bool runVerifierNow = true; 538 539 // If the pass is an adaptor pass, we don't run the verifier recursively 540 // because the nested operations should have already been verified after 541 // nested passes had run. 542 bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass); 543 544 // Reduce compile time by avoiding running the verifier if the pass didn't 545 // change the IR since the last time the verifier was run: 546 // 547 // 1) If the pass said that it preserved all analyses then it can't have 548 // permuted the IR. 549 // 550 // We run these checks in EXPENSIVE_CHECKS mode out of caution. 551 #ifndef EXPENSIVE_CHECKS 552 runVerifierNow = !pass->passState->preservedAnalyses.isAll(); 553 #endif 554 if (runVerifierNow) 555 passFailed = failed(verify(op, runVerifierRecursively)); 556 } 557 558 // Instrument after the pass has run. 559 if (pi) { 560 if (passFailed) 561 pi->runAfterPassFailed(pass, op); 562 else 563 pi->runAfterPass(pass, op); 564 } 565 566 // Return if the pass signaled a failure. 567 return failure(passFailed); 568 } 569 570 /// Run the given operation and analysis manager on a provided op pass manager. 571 LogicalResult OpToOpPassAdaptor::runPipeline( 572 OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses, 573 unsigned parentInitGeneration, PassInstrumentor *instrumentor, 574 const PassInstrumentation::PipelineParentInfo *parentInfo) { 575 assert((!instrumentor || parentInfo) && 576 "expected parent info if instrumentor is provided"); 577 auto scopeExit = llvm::make_scope_exit([&] { 578 // Clear out any computed operation analyses. These analyses won't be used 579 // any more in this pipeline, and this helps reduce the current working set 580 // of memory. If preserving these analyses becomes important in the future 581 // we can re-evaluate this. 582 am.clear(); 583 }); 584 585 // Run the pipeline over the provided operation. 586 if (instrumentor) { 587 instrumentor->runBeforePipeline(pm.getOpName(*op->getContext()), 588 *parentInfo); 589 } 590 591 for (Pass &pass : pm.getPasses()) 592 if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration))) 593 return failure(); 594 595 if (instrumentor) { 596 instrumentor->runAfterPipeline(pm.getOpName(*op->getContext()), 597 *parentInfo); 598 } 599 return success(); 600 } 601 602 /// Find an operation pass manager with the given anchor name, or nullptr if one 603 /// does not exist. 604 static OpPassManager * 605 findPassManagerWithAnchor(MutableArrayRef<OpPassManager> mgrs, StringRef name) { 606 auto *it = llvm::find_if( 607 mgrs, [&](OpPassManager &mgr) { return mgr.getOpAnchorName() == name; }); 608 return it == mgrs.end() ? nullptr : &*it; 609 } 610 611 /// Find an operation pass manager that can operate on an operation of the given 612 /// type, or nullptr if one does not exist. 613 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs, 614 OperationName name, 615 MLIRContext &context) { 616 auto *it = llvm::find_if(mgrs, [&](OpPassManager &mgr) { 617 return mgr.getImpl().canScheduleOn(context, name); 618 }); 619 return it == mgrs.end() ? nullptr : &*it; 620 } 621 622 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) { 623 mgrs.emplace_back(std::move(mgr)); 624 } 625 626 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const { 627 for (auto &pm : mgrs) 628 pm.getDependentDialects(dialects); 629 } 630 631 LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx, 632 OpToOpPassAdaptor &rhs) { 633 // Functor used to check if a pass manager is generic, i.e. op-agnostic. 634 auto isGenericPM = [&](OpPassManager &pm) { return !pm.getOpName(); }; 635 636 // Functor used to detect if the given generic pass manager will have a 637 // potential schedule conflict with the given `otherPMs`. 638 auto hasScheduleConflictWith = [&](OpPassManager &genericPM, 639 MutableArrayRef<OpPassManager> otherPMs) { 640 return llvm::any_of(otherPMs, [&](OpPassManager &pm) { 641 // If this is a non-generic pass manager, a conflict will arise if a 642 // non-generic pass manager's operation name can be scheduled on the 643 // generic passmanager. 644 if (std::optional<OperationName> pmOpName = pm.getOpName(*ctx)) 645 return genericPM.getImpl().canScheduleOn(*ctx, *pmOpName); 646 // Otherwise, this is a generic pass manager. We current can't determine 647 // when generic pass managers can be merged, so conservatively assume they 648 // conflict. 649 return true; 650 }); 651 }; 652 653 // Check that if either adaptor has a generic pass manager, that pm is 654 // compatible within any non-generic pass managers. 655 // 656 // Check the current adaptor. 657 auto *lhsGenericPMIt = llvm::find_if(mgrs, isGenericPM); 658 if (lhsGenericPMIt != mgrs.end() && 659 hasScheduleConflictWith(*lhsGenericPMIt, rhs.mgrs)) 660 return failure(); 661 // Check the rhs adaptor. 662 auto *rhsGenericPMIt = llvm::find_if(rhs.mgrs, isGenericPM); 663 if (rhsGenericPMIt != rhs.mgrs.end() && 664 hasScheduleConflictWith(*rhsGenericPMIt, mgrs)) 665 return failure(); 666 667 for (auto &pm : mgrs) { 668 // If an existing pass manager exists, then merge the given pass manager 669 // into it. 670 if (auto *existingPM = 671 findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) { 672 pm.getImpl().mergeInto(existingPM->getImpl()); 673 } else { 674 // Otherwise, add the given pass manager to the list. 675 rhs.mgrs.emplace_back(std::move(pm)); 676 } 677 } 678 mgrs.clear(); 679 680 // After coalescing, sort the pass managers within rhs by name. 681 auto compareFn = [](const OpPassManager *lhs, const OpPassManager *rhs) { 682 // Order op-specific pass managers first and op-agnostic pass managers last. 683 if (std::optional<StringRef> lhsName = lhs->getOpName()) { 684 if (std::optional<StringRef> rhsName = rhs->getOpName()) 685 return lhsName->compare(*rhsName); 686 return -1; // lhs(op-specific) < rhs(op-agnostic) 687 } 688 return 1; // lhs(op-agnostic) > rhs(op-specific) 689 }; 690 llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), compareFn); 691 return success(); 692 } 693 694 /// Returns the adaptor pass name. 695 std::string OpToOpPassAdaptor::getAdaptorName() { 696 std::string name = "Pipeline Collection : ["; 697 llvm::raw_string_ostream os(name); 698 llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) { 699 os << '\'' << pm.getOpAnchorName() << '\''; 700 }); 701 os << ']'; 702 return name; 703 } 704 705 void OpToOpPassAdaptor::runOnOperation() { 706 llvm_unreachable( 707 "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor"); 708 } 709 710 /// Run the held pipeline over all nested operations. 711 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) { 712 if (getContext().isMultithreadingEnabled()) 713 runOnOperationAsyncImpl(verifyPasses); 714 else 715 runOnOperationImpl(verifyPasses); 716 } 717 718 /// Run this pass adaptor synchronously. 719 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) { 720 auto am = getAnalysisManager(); 721 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), 722 this}; 723 auto *instrumentor = am.getPassInstrumentor(); 724 for (auto ®ion : getOperation()->getRegions()) { 725 for (auto &block : region) { 726 for (auto &op : block) { 727 auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext()); 728 if (!mgr) 729 continue; 730 731 // Run the held pipeline over the current operation. 732 unsigned initGeneration = mgr->impl->initializationGeneration; 733 if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses, 734 initGeneration, instrumentor, &parentInfo))) 735 signalPassFailure(); 736 } 737 } 738 } 739 } 740 741 /// Utility functor that checks if the two ranges of pass managers have a size 742 /// mismatch. 743 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs, 744 ArrayRef<OpPassManager> rhs) { 745 return lhs.size() != rhs.size() || 746 llvm::any_of(llvm::seq<size_t>(0, lhs.size()), 747 [&](size_t i) { return lhs[i].size() != rhs[i].size(); }); 748 } 749 750 /// Run this pass adaptor synchronously. 751 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) { 752 AnalysisManager am = getAnalysisManager(); 753 MLIRContext *context = &getContext(); 754 755 // Create the async executors if they haven't been created, or if the main 756 // pipeline has changed. 757 if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs)) 758 asyncExecutors.assign(context->getThreadPool().getMaxConcurrency(), mgrs); 759 760 // This struct represents the information for a single operation to be 761 // scheduled on a pass manager. 762 struct OpPMInfo { 763 OpPMInfo(unsigned passManagerIdx, Operation *op, AnalysisManager am) 764 : passManagerIdx(passManagerIdx), op(op), am(am) {} 765 766 /// The index of the pass manager to schedule the operation on. 767 unsigned passManagerIdx; 768 /// The operation to schedule. 769 Operation *op; 770 /// The analysis manager for the operation. 771 AnalysisManager am; 772 }; 773 774 // Run a prepass over the operation to collect the nested operations to 775 // execute over. This ensures that an analysis manager exists for each 776 // operation, as well as providing a queue of operations to execute over. 777 std::vector<OpPMInfo> opInfos; 778 DenseMap<OperationName, std::optional<unsigned>> knownOpPMIdx; 779 for (auto ®ion : getOperation()->getRegions()) { 780 for (Operation &op : region.getOps()) { 781 // Get the pass manager index for this operation type. 782 auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), std::nullopt); 783 if (pmIdxIt.second) { 784 if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context)) 785 pmIdxIt.first->second = std::distance(mgrs.begin(), mgr); 786 } 787 788 // If this operation can be scheduled, add it to the list. 789 if (pmIdxIt.first->second) 790 opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op)); 791 } 792 } 793 794 // Get the current thread for this adaptor. 795 PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), 796 this}; 797 auto *instrumentor = am.getPassInstrumentor(); 798 799 // An atomic failure variable for the async executors. 800 std::vector<std::atomic<bool>> activePMs(asyncExecutors.size()); 801 std::fill(activePMs.begin(), activePMs.end(), false); 802 std::atomic<bool> hasFailure = false; 803 parallelForEach(context, opInfos, [&](OpPMInfo &opInfo) { 804 // Find an executor for this operation. 805 auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) { 806 bool expectedInactive = false; 807 return isActive.compare_exchange_strong(expectedInactive, true); 808 }); 809 unsigned pmIndex = it - activePMs.begin(); 810 811 // Get the pass manager for this operation and execute it. 812 OpPassManager &pm = asyncExecutors[pmIndex][opInfo.passManagerIdx]; 813 LogicalResult pipelineResult = runPipeline( 814 pm, opInfo.op, opInfo.am, verifyPasses, 815 pm.impl->initializationGeneration, instrumentor, &parentInfo); 816 if (failed(pipelineResult)) 817 hasFailure.store(true); 818 819 // Reset the active bit for this pass manager. 820 activePMs[pmIndex].store(false); 821 }); 822 823 // Signal a failure if any of the executors failed. 824 if (hasFailure) 825 signalPassFailure(); 826 } 827 828 //===----------------------------------------------------------------------===// 829 // PassManager 830 //===----------------------------------------------------------------------===// 831 832 PassManager::PassManager(MLIRContext *ctx, StringRef operationName, 833 Nesting nesting) 834 : OpPassManager(operationName, nesting), context(ctx), passTiming(false), 835 verifyPasses(true) {} 836 837 PassManager::PassManager(OperationName operationName, Nesting nesting) 838 : OpPassManager(operationName, nesting), 839 context(operationName.getContext()), passTiming(false), 840 verifyPasses(true) {} 841 842 PassManager::~PassManager() = default; 843 844 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; } 845 846 /// Run the passes within this manager on the provided operation. 847 LogicalResult PassManager::run(Operation *op) { 848 MLIRContext *context = getContext(); 849 std::optional<OperationName> anchorOp = getOpName(*context); 850 if (anchorOp && anchorOp != op->getName()) 851 return emitError(op->getLoc()) 852 << "can't run '" << getOpAnchorName() << "' pass manager on '" 853 << op->getName() << "' op"; 854 855 // Register all dialects for the current pipeline. 856 DialectRegistry dependentDialects; 857 getDependentDialects(dependentDialects); 858 context->appendDialectRegistry(dependentDialects); 859 for (StringRef name : dependentDialects.getDialectNames()) 860 context->getOrLoadDialect(name); 861 862 // Before running, make sure to finalize the pipeline pass list. 863 if (failed(getImpl().finalizePassList(context))) 864 return failure(); 865 866 // Notify the context that we start running a pipeline for bookkeeping. 867 context->enterMultiThreadedExecution(); 868 869 // Initialize all of the passes within the pass manager with a new generation. 870 llvm::hash_code newInitKey = context->getRegistryHash(); 871 llvm::hash_code pipelineKey = hash(); 872 if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) { 873 if (failed(initialize(context, impl->initializationGeneration + 1))) 874 return failure(); 875 initializationKey = newInitKey; 876 pipelineKey = pipelineInitializationKey; 877 } 878 879 // Construct a top level analysis manager for the pipeline. 880 ModuleAnalysisManager am(op, instrumentor.get()); 881 882 // If reproducer generation is enabled, run the pass manager with crash 883 // handling enabled. 884 LogicalResult result = 885 crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am); 886 887 // Notify the context that the run is done. 888 context->exitMultiThreadedExecution(); 889 890 // Dump all of the pass statistics if necessary. 891 if (passStatisticsMode) 892 dumpStatistics(); 893 return result; 894 } 895 896 /// Add the provided instrumentation to the pass manager. 897 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) { 898 if (!instrumentor) 899 instrumentor = std::make_unique<PassInstrumentor>(); 900 901 instrumentor->addInstrumentation(std::move(pi)); 902 } 903 904 LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) { 905 return OpToOpPassAdaptor::runPipeline(*this, op, am, verifyPasses, 906 impl->initializationGeneration); 907 } 908 909 //===----------------------------------------------------------------------===// 910 // AnalysisManager 911 //===----------------------------------------------------------------------===// 912 913 /// Get an analysis manager for the given operation, which must be a proper 914 /// descendant of the current operation represented by this analysis manager. 915 AnalysisManager AnalysisManager::nest(Operation *op) { 916 Operation *currentOp = impl->getOperation(); 917 assert(currentOp->isProperAncestor(op) && 918 "expected valid descendant operation"); 919 920 // Check for the base case where the provided operation is immediately nested. 921 if (currentOp == op->getParentOp()) 922 return nestImmediate(op); 923 924 // Otherwise, we need to collect all ancestors up to the current operation. 925 SmallVector<Operation *, 4> opAncestors; 926 do { 927 opAncestors.push_back(op); 928 op = op->getParentOp(); 929 } while (op != currentOp); 930 931 AnalysisManager result = *this; 932 for (Operation *op : llvm::reverse(opAncestors)) 933 result = result.nestImmediate(op); 934 return result; 935 } 936 937 /// Get an analysis manager for the given immediately nested child operation. 938 AnalysisManager AnalysisManager::nestImmediate(Operation *op) { 939 assert(impl->getOperation() == op->getParentOp() && 940 "expected immediate child operation"); 941 942 auto [it, inserted] = impl->childAnalyses.try_emplace(op); 943 if (inserted) 944 it->second = std::make_unique<NestedAnalysisMap>(op, impl); 945 return {it->second.get()}; 946 } 947 948 /// Invalidate any non preserved analyses. 949 void detail::NestedAnalysisMap::invalidate( 950 const detail::PreservedAnalyses &pa) { 951 // If all analyses were preserved, then there is nothing to do here. 952 if (pa.isAll()) 953 return; 954 955 // Invalidate the analyses for the current operation directly. 956 analyses.invalidate(pa); 957 958 // If no analyses were preserved, then just simply clear out the child 959 // analysis results. 960 if (pa.isNone()) { 961 childAnalyses.clear(); 962 return; 963 } 964 965 // Otherwise, invalidate each child analysis map. 966 SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this); 967 while (!mapsToInvalidate.empty()) { 968 auto *map = mapsToInvalidate.pop_back_val(); 969 for (auto &analysisPair : map->childAnalyses) { 970 analysisPair.second->invalidate(pa); 971 if (!analysisPair.second->childAnalyses.empty()) 972 mapsToInvalidate.push_back(analysisPair.second.get()); 973 } 974 } 975 } 976 977 //===----------------------------------------------------------------------===// 978 // PassInstrumentation 979 //===----------------------------------------------------------------------===// 980 981 PassInstrumentation::~PassInstrumentation() = default; 982 983 void PassInstrumentation::runBeforePipeline( 984 std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {} 985 986 void PassInstrumentation::runAfterPipeline( 987 std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {} 988 989 //===----------------------------------------------------------------------===// 990 // PassInstrumentor 991 //===----------------------------------------------------------------------===// 992 993 namespace mlir { 994 namespace detail { 995 struct PassInstrumentorImpl { 996 /// Mutex to keep instrumentation access thread-safe. 997 llvm::sys::SmartMutex<true> mutex; 998 999 /// Set of registered instrumentations. 1000 std::vector<std::unique_ptr<PassInstrumentation>> instrumentations; 1001 }; 1002 } // namespace detail 1003 } // namespace mlir 1004 1005 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {} 1006 PassInstrumentor::~PassInstrumentor() = default; 1007 1008 /// See PassInstrumentation::runBeforePipeline for details. 1009 void PassInstrumentor::runBeforePipeline( 1010 std::optional<OperationName> name, 1011 const PassInstrumentation::PipelineParentInfo &parentInfo) { 1012 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1013 for (auto &instr : impl->instrumentations) 1014 instr->runBeforePipeline(name, parentInfo); 1015 } 1016 1017 /// See PassInstrumentation::runAfterPipeline for details. 1018 void PassInstrumentor::runAfterPipeline( 1019 std::optional<OperationName> name, 1020 const PassInstrumentation::PipelineParentInfo &parentInfo) { 1021 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1022 for (auto &instr : llvm::reverse(impl->instrumentations)) 1023 instr->runAfterPipeline(name, parentInfo); 1024 } 1025 1026 /// See PassInstrumentation::runBeforePass for details. 1027 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) { 1028 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1029 for (auto &instr : impl->instrumentations) 1030 instr->runBeforePass(pass, op); 1031 } 1032 1033 /// See PassInstrumentation::runAfterPass for details. 1034 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) { 1035 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1036 for (auto &instr : llvm::reverse(impl->instrumentations)) 1037 instr->runAfterPass(pass, op); 1038 } 1039 1040 /// See PassInstrumentation::runAfterPassFailed for details. 1041 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) { 1042 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1043 for (auto &instr : llvm::reverse(impl->instrumentations)) 1044 instr->runAfterPassFailed(pass, op); 1045 } 1046 1047 /// See PassInstrumentation::runBeforeAnalysis for details. 1048 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id, 1049 Operation *op) { 1050 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1051 for (auto &instr : impl->instrumentations) 1052 instr->runBeforeAnalysis(name, id, op); 1053 } 1054 1055 /// See PassInstrumentation::runAfterAnalysis for details. 1056 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id, 1057 Operation *op) { 1058 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1059 for (auto &instr : llvm::reverse(impl->instrumentations)) 1060 instr->runAfterAnalysis(name, id, op); 1061 } 1062 1063 /// Add the given instrumentation to the collection. 1064 void PassInstrumentor::addInstrumentation( 1065 std::unique_ptr<PassInstrumentation> pi) { 1066 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex); 1067 impl->instrumentations.emplace_back(std::move(pi)); 1068 } 1069