xref: /llvm-project/mlir/lib/Pass/Pass.cpp (revision 5b21fd298cb4fc2042a95ffb9284b778f8504e04)
1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements common pass infrastructure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "PassDetail.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/Threading.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/Support/FileUtilities.h"
21 #include "llvm/ADT/Hashing.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/CrashRecoveryContext.h"
26 #include "llvm/Support/Mutex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/Support/Threading.h"
29 #include "llvm/Support/ToolOutputFile.h"
30 #include <optional>
31 
32 using namespace mlir;
33 using namespace mlir::detail;
34 
35 //===----------------------------------------------------------------------===//
36 // PassExecutionAction
37 //===----------------------------------------------------------------------===//
38 
39 PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits,
40                                          const Pass &pass)
41     : Base(irUnits), pass(pass) {}
42 
43 void PassExecutionAction::print(raw_ostream &os) const {
44   os << llvm::formatv("`{0}` running `{1}` on Operation `{2}`", tag,
45                       pass.getName(), getOp()->getName());
46 }
47 
48 Operation *PassExecutionAction::getOp() const {
49   ArrayRef<IRUnit> irUnits = getContextIRUnits();
50   return irUnits.empty() ? nullptr
51                          : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // Pass
56 //===----------------------------------------------------------------------===//
57 
58 /// Out of line virtual method to ensure vtables and metadata are emitted to a
59 /// single .o file.
60 void Pass::anchor() {}
61 
62 /// Attempt to initialize the options of this pass from the given string.
63 LogicalResult Pass::initializeOptions(
64     StringRef options,
65     function_ref<LogicalResult(const Twine &)> errorHandler) {
66   std::string errStr;
67   llvm::raw_string_ostream os(errStr);
68   if (failed(passOptions.parseFromString(options, os))) {
69     return errorHandler(errStr);
70   }
71   return success();
72 }
73 
74 /// Copy the option values from 'other', which is another instance of this
75 /// pass.
76 void Pass::copyOptionValuesFrom(const Pass *other) {
77   passOptions.copyOptionValuesFrom(other->passOptions);
78 }
79 
80 /// Prints out the pass in the textual representation of pipelines. If this is
81 /// an adaptor pass, print its pass managers.
82 void Pass::printAsTextualPipeline(raw_ostream &os) {
83   // Special case for adaptors to print its pass managers.
84   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
85     llvm::interleave(
86         adaptor->getPassManagers(),
87         [&](OpPassManager &pm) { pm.printAsTextualPipeline(os); },
88         [&] { os << ","; });
89     return;
90   }
91   // Otherwise, print the pass argument followed by its options. If the pass
92   // doesn't have an argument, print the name of the pass to give some indicator
93   // of what pass was run.
94   StringRef argument = getArgument();
95   if (!argument.empty())
96     os << argument;
97   else
98     os << "unknown<" << getName() << ">";
99   passOptions.print(os);
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // OpPassManagerImpl
104 //===----------------------------------------------------------------------===//
105 
106 namespace mlir {
107 namespace detail {
108 struct OpPassManagerImpl {
109   OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting)
110       : name(opName.getStringRef().str()), opName(opName),
111         initializationGeneration(0), nesting(nesting) {}
112   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
113       : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()),
114         initializationGeneration(0), nesting(nesting) {}
115   OpPassManagerImpl(OpPassManager::Nesting nesting)
116       : initializationGeneration(0), nesting(nesting) {}
117   OpPassManagerImpl(const OpPassManagerImpl &rhs)
118       : name(rhs.name), opName(rhs.opName),
119         initializationGeneration(rhs.initializationGeneration),
120         nesting(rhs.nesting) {
121     for (const std::unique_ptr<Pass> &pass : rhs.passes) {
122       std::unique_ptr<Pass> newPass = pass->clone();
123       newPass->threadingSibling = pass.get();
124       passes.push_back(std::move(newPass));
125     }
126   }
127 
128   /// Merge the passes of this pass manager into the one provided.
129   void mergeInto(OpPassManagerImpl &rhs);
130 
131   /// Nest a new operation pass manager for the given operation kind under this
132   /// pass manager.
133   OpPassManager &nest(OperationName nestedName) {
134     return nest(OpPassManager(nestedName, nesting));
135   }
136   OpPassManager &nest(StringRef nestedName) {
137     return nest(OpPassManager(nestedName, nesting));
138   }
139   OpPassManager &nestAny() { return nest(OpPassManager(nesting)); }
140 
141   /// Nest the given pass manager under this pass manager.
142   OpPassManager &nest(OpPassManager &&nested);
143 
144   /// Add the given pass to this pass manager. If this pass has a concrete
145   /// operation type, it must be the same type as this pass manager.
146   void addPass(std::unique_ptr<Pass> pass);
147 
148   /// Clear the list of passes in this pass manager, other options are
149   /// preserved.
150   void clear();
151 
152   /// Finalize the pass list in preparation for execution. This includes
153   /// coalescing adjacent pass managers when possible, verifying scheduled
154   /// passes, etc.
155   LogicalResult finalizePassList(MLIRContext *ctx);
156 
157   /// Return the operation name of this pass manager.
158   std::optional<OperationName> getOpName(MLIRContext &context) {
159     if (!name.empty() && !opName)
160       opName = OperationName(name, &context);
161     return opName;
162   }
163   std::optional<StringRef> getOpName() const {
164     return name.empty() ? std::optional<StringRef>()
165                         : std::optional<StringRef>(name);
166   }
167 
168   /// Return the name used to anchor this pass manager. This is either the name
169   /// of an operation, or the result of `getAnyOpAnchorName()` in the case of an
170   /// op-agnostic pass manager.
171   StringRef getOpAnchorName() const {
172     return getOpName().value_or(OpPassManager::getAnyOpAnchorName());
173   }
174 
175   /// Indicate if the current pass manager can be scheduled on the given
176   /// operation type.
177   bool canScheduleOn(MLIRContext &context, OperationName opName);
178 
179   /// The name of the operation that passes of this pass manager operate on.
180   std::string name;
181 
182   /// The cached OperationName (internalized in the context) for the name of the
183   /// operation that passes of this pass manager operate on.
184   std::optional<OperationName> opName;
185 
186   /// The set of passes to run as part of this pass manager.
187   std::vector<std::unique_ptr<Pass>> passes;
188 
189   /// The current initialization generation of this pass manager. This is used
190   /// to indicate when a pass manager should be reinitialized.
191   unsigned initializationGeneration;
192 
193   /// Control the implicit nesting of passes that mismatch the name set for this
194   /// OpPassManager.
195   OpPassManager::Nesting nesting;
196 };
197 } // namespace detail
198 } // namespace mlir
199 
200 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
201   assert(name == rhs.name && "merging unrelated pass managers");
202   for (auto &pass : passes)
203     rhs.passes.push_back(std::move(pass));
204   passes.clear();
205 }
206 
207 OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) {
208   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
209   addPass(std::unique_ptr<Pass>(adaptor));
210   return adaptor->getPassManagers().front();
211 }
212 
213 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
214   // If this pass runs on a different operation than this pass manager, then
215   // implicitly nest a pass manager for this operation if enabled.
216   std::optional<StringRef> pmOpName = getOpName();
217   std::optional<StringRef> passOpName = pass->getOpName();
218   if (pmOpName && passOpName && *pmOpName != *passOpName) {
219     if (nesting == OpPassManager::Nesting::Implicit)
220       return nest(*passOpName).addPass(std::move(pass));
221     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
222                              "' restricted to '" + *passOpName +
223                              "' on a PassManager intended to run on '" +
224                              getOpAnchorName() + "', did you intend to nest?");
225   }
226 
227   passes.emplace_back(std::move(pass));
228 }
229 
230 void OpPassManagerImpl::clear() { passes.clear(); }
231 
232 LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) {
233   auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) {
234     for (auto &pm : adaptor->getPassManagers())
235       if (failed(pm.getImpl().finalizePassList(ctx)))
236         return failure();
237     return success();
238   };
239 
240   // Walk the pass list and merge adjacent adaptors.
241   OpToOpPassAdaptor *lastAdaptor = nullptr;
242   for (auto &pass : passes) {
243     // Check to see if this pass is an adaptor.
244     if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get())) {
245       // If it is the first adaptor in a possible chain, remember it and
246       // continue.
247       if (!lastAdaptor) {
248         lastAdaptor = currentAdaptor;
249         continue;
250       }
251 
252       // Otherwise, try to merge into the existing adaptor and delete the
253       // current one. If merging fails, just remember this as the last adaptor.
254       if (succeeded(currentAdaptor->tryMergeInto(ctx, *lastAdaptor)))
255         pass.reset();
256       else
257         lastAdaptor = currentAdaptor;
258     } else if (lastAdaptor) {
259       // If this pass isn't an adaptor, finalize it and forget the last adaptor.
260       if (failed(finalizeAdaptor(lastAdaptor)))
261         return failure();
262       lastAdaptor = nullptr;
263     }
264   }
265 
266   // If there was an adaptor at the end of the manager, finalize it as well.
267   if (lastAdaptor && failed(finalizeAdaptor(lastAdaptor)))
268     return failure();
269 
270   // Now that the adaptors have been merged, erase any empty slots corresponding
271   // to the merged adaptors that were nulled-out in the loop above.
272   llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
273 
274   // If this is a op-agnostic pass manager, there is nothing left to do.
275   std::optional<OperationName> rawOpName = getOpName(*ctx);
276   if (!rawOpName)
277     return success();
278 
279   // Otherwise, verify that all of the passes are valid for the current
280   // operation anchor.
281   std::optional<RegisteredOperationName> opName =
282       rawOpName->getRegisteredInfo();
283   for (std::unique_ptr<Pass> &pass : passes) {
284     if (opName && !pass->canScheduleOn(*opName)) {
285       return emitError(UnknownLoc::get(ctx))
286              << "unable to schedule pass '" << pass->getName()
287              << "' on a PassManager intended to run on '" << getOpAnchorName()
288              << "'!";
289     }
290   }
291   return success();
292 }
293 
294 bool OpPassManagerImpl::canScheduleOn(MLIRContext &context,
295                                       OperationName opName) {
296   // If this pass manager is op-specific, we simply check if the provided
297   // operation name is the same as this one.
298   std::optional<OperationName> pmOpName = getOpName(context);
299   if (pmOpName)
300     return pmOpName == opName;
301 
302   // Otherwise, this is an op-agnostic pass manager. Check that the operation
303   // can be scheduled on all passes within the manager.
304   std::optional<RegisteredOperationName> registeredInfo =
305       opName.getRegisteredInfo();
306   if (!registeredInfo ||
307       !registeredInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
308     return false;
309   return llvm::all_of(passes, [&](const std::unique_ptr<Pass> &pass) {
310     return pass->canScheduleOn(*registeredInfo);
311   });
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // OpPassManager
316 //===----------------------------------------------------------------------===//
317 
318 OpPassManager::OpPassManager(Nesting nesting)
319     : impl(new OpPassManagerImpl(nesting)) {}
320 OpPassManager::OpPassManager(StringRef name, Nesting nesting)
321     : impl(new OpPassManagerImpl(name, nesting)) {}
322 OpPassManager::OpPassManager(OperationName name, Nesting nesting)
323     : impl(new OpPassManagerImpl(name, nesting)) {}
324 OpPassManager::OpPassManager(OpPassManager &&rhs) { *this = std::move(rhs); }
325 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
326 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
327   impl = std::make_unique<OpPassManagerImpl>(*rhs.impl);
328   return *this;
329 }
330 OpPassManager &OpPassManager::operator=(OpPassManager &&rhs) {
331   impl = std::move(rhs.impl);
332   return *this;
333 }
334 
335 OpPassManager::~OpPassManager() = default;
336 
337 OpPassManager::pass_iterator OpPassManager::begin() {
338   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
339 }
340 OpPassManager::pass_iterator OpPassManager::end() {
341   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
342 }
343 
344 OpPassManager::const_pass_iterator OpPassManager::begin() const {
345   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
346 }
347 OpPassManager::const_pass_iterator OpPassManager::end() const {
348   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
349 }
350 
351 /// Nest a new operation pass manager for the given operation kind under this
352 /// pass manager.
353 OpPassManager &OpPassManager::nest(OperationName nestedName) {
354   return impl->nest(nestedName);
355 }
356 OpPassManager &OpPassManager::nest(StringRef nestedName) {
357   return impl->nest(nestedName);
358 }
359 OpPassManager &OpPassManager::nestAny() { return impl->nestAny(); }
360 
361 /// Add the given pass to this pass manager. If this pass has a concrete
362 /// operation type, it must be the same type as this pass manager.
363 void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
364   impl->addPass(std::move(pass));
365 }
366 
367 void OpPassManager::clear() { impl->clear(); }
368 
369 /// Returns the number of passes held by this manager.
370 size_t OpPassManager::size() const { return impl->passes.size(); }
371 
372 /// Returns the internal implementation instance.
373 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
374 
375 /// Return the operation name that this pass manager operates on.
376 std::optional<StringRef> OpPassManager::getOpName() const {
377   return impl->getOpName();
378 }
379 
380 /// Return the operation name that this pass manager operates on.
381 std::optional<OperationName>
382 OpPassManager::getOpName(MLIRContext &context) const {
383   return impl->getOpName(context);
384 }
385 
386 StringRef OpPassManager::getOpAnchorName() const {
387   return impl->getOpAnchorName();
388 }
389 
390 /// Prints out the passes of the pass manager as the textual representation
391 /// of pipelines.
392 void printAsTextualPipeline(
393     raw_ostream &os, StringRef anchorName,
394     const llvm::iterator_range<OpPassManager::pass_iterator> &passes) {
395   os << anchorName << "(";
396   llvm::interleave(
397       passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
398       [&]() { os << ","; });
399   os << ")";
400 }
401 void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
402   StringRef anchorName = getOpAnchorName();
403   ::printAsTextualPipeline(
404       os, anchorName,
405       {MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(),
406        MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()});
407 }
408 
409 void OpPassManager::dump() {
410   llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes:\n";
411   printAsTextualPipeline(llvm::errs());
412   llvm::errs() << "\n";
413 }
414 
415 static void registerDialectsForPipeline(const OpPassManager &pm,
416                                         DialectRegistry &dialects) {
417   for (const Pass &pass : pm.getPasses())
418     pass.getDependentDialects(dialects);
419 }
420 
421 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
422   registerDialectsForPipeline(*this, dialects);
423 }
424 
425 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
426 
427 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
428 
429 LogicalResult OpPassManager::initialize(MLIRContext *context,
430                                         unsigned newInitGeneration) {
431   if (impl->initializationGeneration == newInitGeneration)
432     return success();
433   impl->initializationGeneration = newInitGeneration;
434   for (Pass &pass : getPasses()) {
435     // If this pass isn't an adaptor, directly initialize it.
436     auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
437     if (!adaptor) {
438       if (failed(pass.initialize(context)))
439         return failure();
440       continue;
441     }
442 
443     // Otherwise, initialize each of the adaptors pass managers.
444     for (OpPassManager &adaptorPM : adaptor->getPassManagers())
445       if (failed(adaptorPM.initialize(context, newInitGeneration)))
446         return failure();
447   }
448   return success();
449 }
450 
451 llvm::hash_code OpPassManager::hash() {
452   llvm::hash_code hashCode{};
453   for (Pass &pass : getPasses()) {
454     // If this pass isn't an adaptor, directly hash it.
455     auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
456     if (!adaptor) {
457       hashCode = llvm::hash_combine(hashCode, &pass);
458       continue;
459     }
460     // Otherwise, hash recursively each of the adaptors pass managers.
461     for (OpPassManager &adaptorPM : adaptor->getPassManagers())
462       llvm::hash_combine(hashCode, adaptorPM.hash());
463   }
464   return hashCode;
465 }
466 
467 
468 //===----------------------------------------------------------------------===//
469 // OpToOpPassAdaptor
470 //===----------------------------------------------------------------------===//
471 
472 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
473                                      AnalysisManager am, bool verifyPasses,
474                                      unsigned parentInitGeneration) {
475   std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
476   if (!opInfo)
477     return op->emitOpError()
478            << "trying to schedule a pass on an unregistered operation";
479   if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
480     return op->emitOpError() << "trying to schedule a pass on an operation not "
481                                 "marked as 'IsolatedFromAbove'";
482   if (!pass->canScheduleOn(*op->getName().getRegisteredInfo()))
483     return op->emitOpError()
484            << "trying to schedule a pass on an unsupported operation";
485 
486   // Initialize the pass state with a callback for the pass to dynamically
487   // execute a pipeline on the currently visited operation.
488   PassInstrumentor *pi = am.getPassInstrumentor();
489   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
490                                                         pass};
491   auto dynamicPipelineCallback = [&](OpPassManager &pipeline,
492                                      Operation *root) -> LogicalResult {
493     if (!op->isAncestor(root))
494       return root->emitOpError()
495              << "Trying to schedule a dynamic pipeline on an "
496                 "operation that isn't "
497                 "nested under the current operation the pass is processing";
498     assert(
499         pipeline.getImpl().canScheduleOn(*op->getContext(), root->getName()));
500 
501     // Before running, finalize the passes held by the pipeline.
502     if (failed(pipeline.getImpl().finalizePassList(root->getContext())))
503       return failure();
504 
505     // Initialize the user provided pipeline and execute the pipeline.
506     if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
507       return failure();
508     AnalysisManager nestedAm = root == op ? am : am.nest(root);
509     return OpToOpPassAdaptor::runPipeline(pipeline, root, nestedAm,
510                                           verifyPasses, parentInitGeneration,
511                                           pi, &parentInfo);
512   };
513   pass->passState.emplace(op, am, dynamicPipelineCallback);
514 
515   // Instrument before the pass has run.
516   if (pi)
517     pi->runBeforePass(pass, op);
518 
519   bool passFailed = false;
520   op->getContext()->executeAction<PassExecutionAction>(
521       [&]() {
522         // Invoke the virtual runOnOperation method.
523         if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
524           adaptor->runOnOperation(verifyPasses);
525         else
526           pass->runOnOperation();
527         passFailed = pass->passState->irAndPassFailed.getInt();
528       },
529       {op}, *pass);
530 
531   // Invalidate any non preserved analyses.
532   am.invalidate(pass->passState->preservedAnalyses);
533 
534   // When verifyPasses is specified, we run the verifier (unless the pass
535   // failed).
536   if (!passFailed && verifyPasses) {
537     bool runVerifierNow = true;
538 
539     // If the pass is an adaptor pass, we don't run the verifier recursively
540     // because the nested operations should have already been verified after
541     // nested passes had run.
542     bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass);
543 
544     // Reduce compile time by avoiding running the verifier if the pass didn't
545     // change the IR since the last time the verifier was run:
546     //
547     //  1) If the pass said that it preserved all analyses then it can't have
548     //     permuted the IR.
549     //
550     // We run these checks in EXPENSIVE_CHECKS mode out of caution.
551 #ifndef EXPENSIVE_CHECKS
552     runVerifierNow = !pass->passState->preservedAnalyses.isAll();
553 #endif
554     if (runVerifierNow)
555       passFailed = failed(verify(op, runVerifierRecursively));
556   }
557 
558   // Instrument after the pass has run.
559   if (pi) {
560     if (passFailed)
561       pi->runAfterPassFailed(pass, op);
562     else
563       pi->runAfterPass(pass, op);
564   }
565 
566   // Return if the pass signaled a failure.
567   return failure(passFailed);
568 }
569 
570 /// Run the given operation and analysis manager on a provided op pass manager.
571 LogicalResult OpToOpPassAdaptor::runPipeline(
572     OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses,
573     unsigned parentInitGeneration, PassInstrumentor *instrumentor,
574     const PassInstrumentation::PipelineParentInfo *parentInfo) {
575   assert((!instrumentor || parentInfo) &&
576          "expected parent info if instrumentor is provided");
577   auto scopeExit = llvm::make_scope_exit([&] {
578     // Clear out any computed operation analyses. These analyses won't be used
579     // any more in this pipeline, and this helps reduce the current working set
580     // of memory. If preserving these analyses becomes important in the future
581     // we can re-evaluate this.
582     am.clear();
583   });
584 
585   // Run the pipeline over the provided operation.
586   if (instrumentor) {
587     instrumentor->runBeforePipeline(pm.getOpName(*op->getContext()),
588                                     *parentInfo);
589   }
590 
591   for (Pass &pass : pm.getPasses())
592     if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration)))
593       return failure();
594 
595   if (instrumentor) {
596     instrumentor->runAfterPipeline(pm.getOpName(*op->getContext()),
597                                    *parentInfo);
598   }
599   return success();
600 }
601 
602 /// Find an operation pass manager with the given anchor name, or nullptr if one
603 /// does not exist.
604 static OpPassManager *
605 findPassManagerWithAnchor(MutableArrayRef<OpPassManager> mgrs, StringRef name) {
606   auto *it = llvm::find_if(
607       mgrs, [&](OpPassManager &mgr) { return mgr.getOpAnchorName() == name; });
608   return it == mgrs.end() ? nullptr : &*it;
609 }
610 
611 /// Find an operation pass manager that can operate on an operation of the given
612 /// type, or nullptr if one does not exist.
613 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
614                                          OperationName name,
615                                          MLIRContext &context) {
616   auto *it = llvm::find_if(mgrs, [&](OpPassManager &mgr) {
617     return mgr.getImpl().canScheduleOn(context, name);
618   });
619   return it == mgrs.end() ? nullptr : &*it;
620 }
621 
622 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
623   mgrs.emplace_back(std::move(mgr));
624 }
625 
626 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
627   for (auto &pm : mgrs)
628     pm.getDependentDialects(dialects);
629 }
630 
631 LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx,
632                                               OpToOpPassAdaptor &rhs) {
633   // Functor used to check if a pass manager is generic, i.e. op-agnostic.
634   auto isGenericPM = [&](OpPassManager &pm) { return !pm.getOpName(); };
635 
636   // Functor used to detect if the given generic pass manager will have a
637   // potential schedule conflict with the given `otherPMs`.
638   auto hasScheduleConflictWith = [&](OpPassManager &genericPM,
639                                      MutableArrayRef<OpPassManager> otherPMs) {
640     return llvm::any_of(otherPMs, [&](OpPassManager &pm) {
641       // If this is a non-generic pass manager, a conflict will arise if a
642       // non-generic pass manager's operation name can be scheduled on the
643       // generic passmanager.
644       if (std::optional<OperationName> pmOpName = pm.getOpName(*ctx))
645         return genericPM.getImpl().canScheduleOn(*ctx, *pmOpName);
646       // Otherwise, this is a generic pass manager. We current can't determine
647       // when generic pass managers can be merged, so conservatively assume they
648       // conflict.
649       return true;
650     });
651   };
652 
653   // Check that if either adaptor has a generic pass manager, that pm is
654   // compatible within any non-generic pass managers.
655   //
656   // Check the current adaptor.
657   auto *lhsGenericPMIt = llvm::find_if(mgrs, isGenericPM);
658   if (lhsGenericPMIt != mgrs.end() &&
659       hasScheduleConflictWith(*lhsGenericPMIt, rhs.mgrs))
660     return failure();
661   // Check the rhs adaptor.
662   auto *rhsGenericPMIt = llvm::find_if(rhs.mgrs, isGenericPM);
663   if (rhsGenericPMIt != rhs.mgrs.end() &&
664       hasScheduleConflictWith(*rhsGenericPMIt, mgrs))
665     return failure();
666 
667   for (auto &pm : mgrs) {
668     // If an existing pass manager exists, then merge the given pass manager
669     // into it.
670     if (auto *existingPM =
671             findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) {
672       pm.getImpl().mergeInto(existingPM->getImpl());
673     } else {
674       // Otherwise, add the given pass manager to the list.
675       rhs.mgrs.emplace_back(std::move(pm));
676     }
677   }
678   mgrs.clear();
679 
680   // After coalescing, sort the pass managers within rhs by name.
681   auto compareFn = [](const OpPassManager *lhs, const OpPassManager *rhs) {
682     // Order op-specific pass managers first and op-agnostic pass managers last.
683     if (std::optional<StringRef> lhsName = lhs->getOpName()) {
684       if (std::optional<StringRef> rhsName = rhs->getOpName())
685         return lhsName->compare(*rhsName);
686       return -1; // lhs(op-specific) < rhs(op-agnostic)
687     }
688     return 1; // lhs(op-agnostic) > rhs(op-specific)
689   };
690   llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), compareFn);
691   return success();
692 }
693 
694 /// Returns the adaptor pass name.
695 std::string OpToOpPassAdaptor::getAdaptorName() {
696   std::string name = "Pipeline Collection : [";
697   llvm::raw_string_ostream os(name);
698   llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
699     os << '\'' << pm.getOpAnchorName() << '\'';
700   });
701   os << ']';
702   return name;
703 }
704 
705 void OpToOpPassAdaptor::runOnOperation() {
706   llvm_unreachable(
707       "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
708 }
709 
710 /// Run the held pipeline over all nested operations.
711 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
712   if (getContext().isMultithreadingEnabled())
713     runOnOperationAsyncImpl(verifyPasses);
714   else
715     runOnOperationImpl(verifyPasses);
716 }
717 
718 /// Run this pass adaptor synchronously.
719 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
720   auto am = getAnalysisManager();
721   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
722                                                         this};
723   auto *instrumentor = am.getPassInstrumentor();
724   for (auto &region : getOperation()->getRegions()) {
725     for (auto &block : region) {
726       for (auto &op : block) {
727         auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext());
728         if (!mgr)
729           continue;
730 
731         // Run the held pipeline over the current operation.
732         unsigned initGeneration = mgr->impl->initializationGeneration;
733         if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses,
734                                initGeneration, instrumentor, &parentInfo)))
735           signalPassFailure();
736       }
737     }
738   }
739 }
740 
741 /// Utility functor that checks if the two ranges of pass managers have a size
742 /// mismatch.
743 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
744                             ArrayRef<OpPassManager> rhs) {
745   return lhs.size() != rhs.size() ||
746          llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
747                       [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
748 }
749 
750 /// Run this pass adaptor synchronously.
751 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
752   AnalysisManager am = getAnalysisManager();
753   MLIRContext *context = &getContext();
754 
755   // Create the async executors if they haven't been created, or if the main
756   // pipeline has changed.
757   if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
758     asyncExecutors.assign(context->getThreadPool().getMaxConcurrency(), mgrs);
759 
760   // This struct represents the information for a single operation to be
761   // scheduled on a pass manager.
762   struct OpPMInfo {
763     OpPMInfo(unsigned passManagerIdx, Operation *op, AnalysisManager am)
764         : passManagerIdx(passManagerIdx), op(op), am(am) {}
765 
766     /// The index of the pass manager to schedule the operation on.
767     unsigned passManagerIdx;
768     /// The operation to schedule.
769     Operation *op;
770     /// The analysis manager for the operation.
771     AnalysisManager am;
772   };
773 
774   // Run a prepass over the operation to collect the nested operations to
775   // execute over. This ensures that an analysis manager exists for each
776   // operation, as well as providing a queue of operations to execute over.
777   std::vector<OpPMInfo> opInfos;
778   DenseMap<OperationName, std::optional<unsigned>> knownOpPMIdx;
779   for (auto &region : getOperation()->getRegions()) {
780     for (Operation &op : region.getOps()) {
781       // Get the pass manager index for this operation type.
782       auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), std::nullopt);
783       if (pmIdxIt.second) {
784         if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context))
785           pmIdxIt.first->second = std::distance(mgrs.begin(), mgr);
786       }
787 
788       // If this operation can be scheduled, add it to the list.
789       if (pmIdxIt.first->second)
790         opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op));
791     }
792   }
793 
794   // Get the current thread for this adaptor.
795   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
796                                                         this};
797   auto *instrumentor = am.getPassInstrumentor();
798 
799   // An atomic failure variable for the async executors.
800   std::vector<std::atomic<bool>> activePMs(asyncExecutors.size());
801   std::fill(activePMs.begin(), activePMs.end(), false);
802   std::atomic<bool> hasFailure = false;
803   parallelForEach(context, opInfos, [&](OpPMInfo &opInfo) {
804     // Find an executor for this operation.
805     auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
806       bool expectedInactive = false;
807       return isActive.compare_exchange_strong(expectedInactive, true);
808     });
809     unsigned pmIndex = it - activePMs.begin();
810 
811     // Get the pass manager for this operation and execute it.
812     OpPassManager &pm = asyncExecutors[pmIndex][opInfo.passManagerIdx];
813     LogicalResult pipelineResult = runPipeline(
814         pm, opInfo.op, opInfo.am, verifyPasses,
815         pm.impl->initializationGeneration, instrumentor, &parentInfo);
816     if (failed(pipelineResult))
817       hasFailure.store(true);
818 
819     // Reset the active bit for this pass manager.
820     activePMs[pmIndex].store(false);
821   });
822 
823   // Signal a failure if any of the executors failed.
824   if (hasFailure)
825     signalPassFailure();
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // PassManager
830 //===----------------------------------------------------------------------===//
831 
832 PassManager::PassManager(MLIRContext *ctx, StringRef operationName,
833                          Nesting nesting)
834     : OpPassManager(operationName, nesting), context(ctx), passTiming(false),
835       verifyPasses(true) {}
836 
837 PassManager::PassManager(OperationName operationName, Nesting nesting)
838     : OpPassManager(operationName, nesting),
839       context(operationName.getContext()), passTiming(false),
840       verifyPasses(true) {}
841 
842 PassManager::~PassManager() = default;
843 
844 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
845 
846 /// Run the passes within this manager on the provided operation.
847 LogicalResult PassManager::run(Operation *op) {
848   MLIRContext *context = getContext();
849   std::optional<OperationName> anchorOp = getOpName(*context);
850   if (anchorOp && anchorOp != op->getName())
851     return emitError(op->getLoc())
852            << "can't run '" << getOpAnchorName() << "' pass manager on '"
853            << op->getName() << "' op";
854 
855   // Register all dialects for the current pipeline.
856   DialectRegistry dependentDialects;
857   getDependentDialects(dependentDialects);
858   context->appendDialectRegistry(dependentDialects);
859   for (StringRef name : dependentDialects.getDialectNames())
860     context->getOrLoadDialect(name);
861 
862   // Before running, make sure to finalize the pipeline pass list.
863   if (failed(getImpl().finalizePassList(context)))
864     return failure();
865 
866   // Notify the context that we start running a pipeline for bookkeeping.
867   context->enterMultiThreadedExecution();
868 
869   // Initialize all of the passes within the pass manager with a new generation.
870   llvm::hash_code newInitKey = context->getRegistryHash();
871   llvm::hash_code pipelineKey = hash();
872   if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
873     if (failed(initialize(context, impl->initializationGeneration + 1)))
874       return failure();
875     initializationKey = newInitKey;
876     pipelineKey = pipelineInitializationKey;
877   }
878 
879   // Construct a top level analysis manager for the pipeline.
880   ModuleAnalysisManager am(op, instrumentor.get());
881 
882   // If reproducer generation is enabled, run the pass manager with crash
883   // handling enabled.
884   LogicalResult result =
885       crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am);
886 
887   // Notify the context that the run is done.
888   context->exitMultiThreadedExecution();
889 
890   // Dump all of the pass statistics if necessary.
891   if (passStatisticsMode)
892     dumpStatistics();
893   return result;
894 }
895 
896 /// Add the provided instrumentation to the pass manager.
897 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
898   if (!instrumentor)
899     instrumentor = std::make_unique<PassInstrumentor>();
900 
901   instrumentor->addInstrumentation(std::move(pi));
902 }
903 
904 LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) {
905   return OpToOpPassAdaptor::runPipeline(*this, op, am, verifyPasses,
906                                         impl->initializationGeneration);
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // AnalysisManager
911 //===----------------------------------------------------------------------===//
912 
913 /// Get an analysis manager for the given operation, which must be a proper
914 /// descendant of the current operation represented by this analysis manager.
915 AnalysisManager AnalysisManager::nest(Operation *op) {
916   Operation *currentOp = impl->getOperation();
917   assert(currentOp->isProperAncestor(op) &&
918          "expected valid descendant operation");
919 
920   // Check for the base case where the provided operation is immediately nested.
921   if (currentOp == op->getParentOp())
922     return nestImmediate(op);
923 
924   // Otherwise, we need to collect all ancestors up to the current operation.
925   SmallVector<Operation *, 4> opAncestors;
926   do {
927     opAncestors.push_back(op);
928     op = op->getParentOp();
929   } while (op != currentOp);
930 
931   AnalysisManager result = *this;
932   for (Operation *op : llvm::reverse(opAncestors))
933     result = result.nestImmediate(op);
934   return result;
935 }
936 
937 /// Get an analysis manager for the given immediately nested child operation.
938 AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
939   assert(impl->getOperation() == op->getParentOp() &&
940          "expected immediate child operation");
941 
942   auto [it, inserted] = impl->childAnalyses.try_emplace(op);
943   if (inserted)
944     it->second = std::make_unique<NestedAnalysisMap>(op, impl);
945   return {it->second.get()};
946 }
947 
948 /// Invalidate any non preserved analyses.
949 void detail::NestedAnalysisMap::invalidate(
950     const detail::PreservedAnalyses &pa) {
951   // If all analyses were preserved, then there is nothing to do here.
952   if (pa.isAll())
953     return;
954 
955   // Invalidate the analyses for the current operation directly.
956   analyses.invalidate(pa);
957 
958   // If no analyses were preserved, then just simply clear out the child
959   // analysis results.
960   if (pa.isNone()) {
961     childAnalyses.clear();
962     return;
963   }
964 
965   // Otherwise, invalidate each child analysis map.
966   SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
967   while (!mapsToInvalidate.empty()) {
968     auto *map = mapsToInvalidate.pop_back_val();
969     for (auto &analysisPair : map->childAnalyses) {
970       analysisPair.second->invalidate(pa);
971       if (!analysisPair.second->childAnalyses.empty())
972         mapsToInvalidate.push_back(analysisPair.second.get());
973     }
974   }
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // PassInstrumentation
979 //===----------------------------------------------------------------------===//
980 
981 PassInstrumentation::~PassInstrumentation() = default;
982 
983 void PassInstrumentation::runBeforePipeline(
984     std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
985 
986 void PassInstrumentation::runAfterPipeline(
987     std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
988 
989 //===----------------------------------------------------------------------===//
990 // PassInstrumentor
991 //===----------------------------------------------------------------------===//
992 
993 namespace mlir {
994 namespace detail {
995 struct PassInstrumentorImpl {
996   /// Mutex to keep instrumentation access thread-safe.
997   llvm::sys::SmartMutex<true> mutex;
998 
999   /// Set of registered instrumentations.
1000   std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
1001 };
1002 } // namespace detail
1003 } // namespace mlir
1004 
1005 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
1006 PassInstrumentor::~PassInstrumentor() = default;
1007 
1008 /// See PassInstrumentation::runBeforePipeline for details.
1009 void PassInstrumentor::runBeforePipeline(
1010     std::optional<OperationName> name,
1011     const PassInstrumentation::PipelineParentInfo &parentInfo) {
1012   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1013   for (auto &instr : impl->instrumentations)
1014     instr->runBeforePipeline(name, parentInfo);
1015 }
1016 
1017 /// See PassInstrumentation::runAfterPipeline for details.
1018 void PassInstrumentor::runAfterPipeline(
1019     std::optional<OperationName> name,
1020     const PassInstrumentation::PipelineParentInfo &parentInfo) {
1021   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1022   for (auto &instr : llvm::reverse(impl->instrumentations))
1023     instr->runAfterPipeline(name, parentInfo);
1024 }
1025 
1026 /// See PassInstrumentation::runBeforePass for details.
1027 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
1028   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1029   for (auto &instr : impl->instrumentations)
1030     instr->runBeforePass(pass, op);
1031 }
1032 
1033 /// See PassInstrumentation::runAfterPass for details.
1034 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
1035   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1036   for (auto &instr : llvm::reverse(impl->instrumentations))
1037     instr->runAfterPass(pass, op);
1038 }
1039 
1040 /// See PassInstrumentation::runAfterPassFailed for details.
1041 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
1042   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1043   for (auto &instr : llvm::reverse(impl->instrumentations))
1044     instr->runAfterPassFailed(pass, op);
1045 }
1046 
1047 /// See PassInstrumentation::runBeforeAnalysis for details.
1048 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
1049                                          Operation *op) {
1050   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1051   for (auto &instr : impl->instrumentations)
1052     instr->runBeforeAnalysis(name, id, op);
1053 }
1054 
1055 /// See PassInstrumentation::runAfterAnalysis for details.
1056 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
1057                                         Operation *op) {
1058   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1059   for (auto &instr : llvm::reverse(impl->instrumentations))
1060     instr->runAfterAnalysis(name, id, op);
1061 }
1062 
1063 /// Add the given instrumentation to the collection.
1064 void PassInstrumentor::addInstrumentation(
1065     std::unique_ptr<PassInstrumentation> pi) {
1066   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1067   impl->instrumentations.emplace_back(std::move(pi));
1068 }
1069