xref: /llvm-project/mlir/lib/Pass/PassRegistry.cpp (revision 1a70420ff3b972b3d9bbc1c4d1e98bfa12bfb73a)
1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassRegistry.h"
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Pass/PassManager.h"
13 #include "llvm/ADT/DenseMap.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/SourceMgr.h"
20 
21 #include <optional>
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace detail;
26 
27 /// Static mapping of all of the registered passes.
28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
29 
30 /// A mapping of the above pass registry entries to the corresponding TypeID
31 /// of the pass that they generate.
32 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
33 
34 /// Static mapping of all of the registered pass pipelines.
35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
36     passPipelineRegistry;
37 
38 /// Utility to create a default registry function from a pass instance.
39 static PassRegistryFunction
40 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
41   return [=](OpPassManager &pm, StringRef options,
42              function_ref<LogicalResult(const Twine &)> errorHandler) {
43     std::unique_ptr<Pass> pass = allocator();
44     LogicalResult result = pass->initializeOptions(options, errorHandler);
45 
46     std::optional<StringRef> pmOpName = pm.getOpName();
47     std::optional<StringRef> passOpName = pass->getOpName();
48     if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName &&
49         passOpName && *pmOpName != *passOpName) {
50       return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
51                           "' restricted to '" + *pass->getOpName() +
52                           "' on a PassManager intended to run on '" +
53                           pm.getOpAnchorName() + "', did you intend to nest?");
54     }
55     pm.addPass(std::move(pass));
56     return result;
57   };
58 }
59 
60 /// Utility to print the help string for a specific option.
61 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
62                             size_t descIndent, bool isTopLevel) {
63   size_t numSpaces = descIndent - indent - 4;
64   llvm::outs().indent(indent)
65       << "--" << llvm::left_justify(arg, numSpaces) << "-   " << desc << '\n';
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // PassRegistry
70 //===----------------------------------------------------------------------===//
71 
72 /// Prints the passes that were previously registered and stored in passRegistry
73 void mlir::printRegisteredPasses() {
74   size_t maxWidth = 0;
75   for (auto &entry : *passRegistry)
76     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
77 
78   // Functor used to print the ordered entries of a registration map.
79   auto printOrderedEntries = [&](StringRef header, auto &map) {
80     llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
81     for (auto &kv : map)
82       orderedEntries.push_back(&kv.second);
83     llvm::array_pod_sort(
84         orderedEntries.begin(), orderedEntries.end(),
85         [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
86           return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
87         });
88 
89     llvm::outs().indent(0) << header << ":\n";
90     for (PassRegistryEntry *entry : orderedEntries)
91       entry->printHelpStr(/*indent=*/2, maxWidth);
92   };
93 
94   // Print the available passes.
95   printOrderedEntries("Passes", *passRegistry);
96 }
97 
98 /// Print the help information for this pass. This includes the argument,
99 /// description, and any pass options. `descIndent` is the indent that the
100 /// descriptions should be aligned.
101 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
102   printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
103                   /*isTopLevel=*/true);
104   // If this entry has options, print the help for those as well.
105   optHandler([=](const PassOptions &options) {
106     options.printHelp(indent, descIndent);
107   });
108 }
109 
110 /// Return the maximum width required when printing the options of this
111 /// entry.
112 size_t PassRegistryEntry::getOptionWidth() const {
113   size_t maxLen = 0;
114   optHandler([&](const PassOptions &options) mutable {
115     maxLen = options.getOptionWidth() + 2;
116   });
117   return maxLen;
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // PassPipelineInfo
122 //===----------------------------------------------------------------------===//
123 
124 void mlir::registerPassPipeline(
125     StringRef arg, StringRef description, const PassRegistryFunction &function,
126     std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
127   PassPipelineInfo pipelineInfo(arg, description, function,
128                                 std::move(optHandler));
129   bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
130 #ifndef NDEBUG
131   if (!inserted)
132     report_fatal_error("Pass pipeline " + arg + " registered multiple times");
133 #endif
134   (void)inserted;
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // PassInfo
139 //===----------------------------------------------------------------------===//
140 
141 PassInfo::PassInfo(StringRef arg, StringRef description,
142                    const PassAllocatorFunction &allocator)
143     : PassRegistryEntry(
144           arg, description, buildDefaultRegistryFn(allocator),
145           // Use a temporary pass to provide an options instance.
146           [=](function_ref<void(const PassOptions &)> optHandler) {
147             optHandler(allocator()->passOptions);
148           }) {}
149 
150 void mlir::registerPass(const PassAllocatorFunction &function) {
151   std::unique_ptr<Pass> pass = function();
152   StringRef arg = pass->getArgument();
153   if (arg.empty())
154     llvm::report_fatal_error(llvm::Twine("Trying to register '") +
155                              pass->getName() +
156                              "' pass that does not override `getArgument()`");
157   StringRef description = pass->getDescription();
158   PassInfo passInfo(arg, description, function);
159   passRegistry->try_emplace(arg, passInfo);
160 
161   // Verify that the registered pass has the same ID as any registered to this
162   // arg before it.
163   TypeID entryTypeID = pass->getTypeID();
164   auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
165   if (it->second != entryTypeID)
166     llvm::report_fatal_error(
167         "pass allocator creates a different pass than previously "
168         "registered for pass " +
169         arg);
170 }
171 
172 /// Returns the pass info for the specified pass argument or null if unknown.
173 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
174   auto it = passRegistry->find(passArg);
175   return it == passRegistry->end() ? nullptr : &it->second;
176 }
177 
178 /// Returns the pass pipeline info for the specified pass pipeline argument or
179 /// null if unknown.
180 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
181   auto it = passPipelineRegistry->find(pipelineArg);
182   return it == passPipelineRegistry->end() ? nullptr : &it->second;
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // PassOptions
187 //===----------------------------------------------------------------------===//
188 
189 /// Attempt to find the next occurance of character 'c' in the string starting
190 /// from the `index`-th position , omitting any occurances that appear within
191 /// intervening ranges or literals.
192 static size_t findChar(StringRef str, size_t index, char c) {
193   for (size_t i = index, e = str.size(); i < e; ++i) {
194     if (str[i] == c)
195       return i;
196     // Check for various range characters.
197     if (str[i] == '{')
198       i = findChar(str, i + 1, '}');
199     else if (str[i] == '(')
200       i = findChar(str, i + 1, ')');
201     else if (str[i] == '[')
202       i = findChar(str, i + 1, ']');
203     else if (str[i] == '\"')
204       i = str.find_first_of('\"', i + 1);
205     else if (str[i] == '\'')
206       i = str.find_first_of('\'', i + 1);
207     if (i == StringRef::npos)
208       return StringRef::npos;
209   }
210   return StringRef::npos;
211 }
212 
213 /// Extract an argument from 'options' and update it to point after the arg.
214 /// Returns the cleaned argument string.
215 static StringRef extractArgAndUpdateOptions(StringRef &options,
216                                             size_t argSize) {
217   StringRef str = options.take_front(argSize).trim();
218   options = options.drop_front(argSize).ltrim();
219 
220   // Early exit if there's no escape sequence.
221   if (str.size() <= 1)
222     return str;
223 
224   const auto escapePairs = {std::make_pair('\'', '\''),
225                             std::make_pair('"', '"')};
226   for (const auto &escape : escapePairs) {
227     if (str.front() == escape.first && str.back() == escape.second) {
228       // Drop the escape characters and trim.
229       // Don't process additional escape sequences.
230       return str.drop_front().drop_back().trim();
231     }
232   }
233 
234   // Arguments may be wrapped in `{...}`. Unlike the quotation markers that
235   // denote literals, we respect scoping here. The outer `{...}` should not
236   // be stripped in cases such as "arg={...},{...}", which can be used to denote
237   // lists of nested option structs.
238   if (str.front() == '{') {
239     unsigned match = findChar(str, 1, '}');
240     if (match == str.size() - 1)
241       str = str.drop_front().drop_back().trim();
242   }
243 
244   return str;
245 }
246 
247 LogicalResult detail::pass_options::parseCommaSeparatedList(
248     llvm::cl::Option &opt, StringRef argName, StringRef optionStr,
249     function_ref<LogicalResult(StringRef)> elementParseFn) {
250   if (optionStr.empty())
251     return success();
252 
253   size_t nextElePos = findChar(optionStr, 0, ',');
254   while (nextElePos != StringRef::npos) {
255     // Process the portion before the comma.
256     if (failed(
257             elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos))))
258       return failure();
259 
260     // Drop the leading ','
261     optionStr = optionStr.drop_front();
262     nextElePos = findChar(optionStr, 0, ',');
263   }
264   return elementParseFn(
265       extractArgAndUpdateOptions(optionStr, optionStr.size()));
266 }
267 
268 /// Out of line virtual function to provide home for the class.
269 void detail::PassOptions::OptionBase::anchor() {}
270 
271 /// Copy the option values from 'other'.
272 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
273   assert(options.size() == other.options.size());
274   if (options.empty())
275     return;
276   for (auto optionsIt : llvm::zip(options, other.options))
277     std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
278 }
279 
280 /// Parse in the next argument from the given options string. Returns a tuple
281 /// containing [the key of the option, the value of the option, updated
282 /// `options` string pointing after the parsed option].
283 static std::tuple<StringRef, StringRef, StringRef>
284 parseNextArg(StringRef options) {
285   // Try to process the given punctuation, properly escaping any contained
286   // characters.
287   auto tryProcessPunct = [&](size_t &currentPos, char punct) {
288     if (options[currentPos] != punct)
289       return false;
290     size_t nextIt = options.find_first_of(punct, currentPos + 1);
291     if (nextIt != StringRef::npos)
292       currentPos = nextIt;
293     return true;
294   };
295 
296   // Parse the argument name of the option.
297   StringRef argName;
298   for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
299     // Check for the end of the full option.
300     if (argEndIt == optionsE || options[argEndIt] == ' ') {
301       argName = extractArgAndUpdateOptions(options, argEndIt);
302       return std::make_tuple(argName, StringRef(), options);
303     }
304 
305     // Check for the end of the name and the start of the value.
306     if (options[argEndIt] == '=') {
307       argName = extractArgAndUpdateOptions(options, argEndIt);
308       options = options.drop_front();
309       break;
310     }
311   }
312 
313   // Parse the value of the option.
314   for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) {
315     // Handle the end of the options string.
316     if (argEndIt == optionsE || options[argEndIt] == ' ') {
317       StringRef value = extractArgAndUpdateOptions(options, argEndIt);
318       return std::make_tuple(argName, value, options);
319     }
320 
321     // Skip over escaped sequences.
322     char c = options[argEndIt];
323     if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
324       continue;
325     // '{...}' is used to specify options to passes, properly escape it so
326     // that we don't accidentally split any nested options.
327     if (c == '{') {
328       size_t braceCount = 1;
329       for (++argEndIt; argEndIt != optionsE; ++argEndIt) {
330         // Allow nested punctuation.
331         if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"'))
332           continue;
333         if (options[argEndIt] == '{')
334           ++braceCount;
335         else if (options[argEndIt] == '}' && --braceCount == 0)
336           break;
337       }
338       // Account for the increment at the top of the loop.
339       --argEndIt;
340     }
341   }
342   llvm_unreachable("unexpected control flow in pass option parsing");
343 }
344 
345 LogicalResult detail::PassOptions::parseFromString(StringRef options,
346                                                    raw_ostream &errorStream) {
347   // NOTE: `options` is modified in place to always refer to the unprocessed
348   // part of the string.
349   while (!options.empty()) {
350     StringRef key, value;
351     std::tie(key, value, options) = parseNextArg(options);
352     if (key.empty())
353       continue;
354 
355     auto it = OptionsMap.find(key);
356     if (it == OptionsMap.end()) {
357       errorStream << "<Pass-Options-Parser>: no such option " << key << "\n";
358       return failure();
359     }
360     if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
361       return failure();
362   }
363 
364   return success();
365 }
366 
367 /// Print the options held by this struct in a form that can be parsed via
368 /// 'parseFromString'.
369 void detail::PassOptions::print(raw_ostream &os) const {
370   // If there are no options, there is nothing left to do.
371   if (OptionsMap.empty())
372     return;
373 
374   // Sort the options to make the ordering deterministic.
375   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
376   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
377     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
378   };
379   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
380 
381   // Interleave the options with ' '.
382   os << '{';
383   llvm::interleave(
384       orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
385   os << '}';
386 }
387 
388 /// Print the help string for the options held by this struct. `descIndent` is
389 /// the indent within the stream that the descriptions should be aligned.
390 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
391   // Sort the options to make the ordering deterministic.
392   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
393   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
394     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
395   };
396   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
397   for (OptionBase *option : orderedOps) {
398     // TODO: printOptionInfo assumes a specific indent and will
399     // print options with values with incorrect indentation. We should add
400     // support to llvm::cl::Option for passing in a base indent to use when
401     // printing.
402     llvm::outs().indent(indent);
403     option->getOption()->printOptionInfo(descIndent - indent);
404   }
405 }
406 
407 /// Return the maximum width required when printing the help string.
408 size_t detail::PassOptions::getOptionWidth() const {
409   size_t max = 0;
410   for (auto *option : options)
411     max = std::max(max, option->getOption()->getOptionWidth());
412   return max;
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // MLIR Options
417 //===----------------------------------------------------------------------===//
418 
419 //===----------------------------------------------------------------------===//
420 // OpPassManager: OptionValue
421 
422 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default;
423 llvm::cl::OptionValue<OpPassManager>::OptionValue(
424     const mlir::OpPassManager &value) {
425   setValue(value);
426 }
427 llvm::cl::OptionValue<OpPassManager>::OptionValue(
428     const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) {
429   if (rhs.hasValue())
430     setValue(rhs.getValue());
431 }
432 llvm::cl::OptionValue<OpPassManager> &
433 llvm::cl::OptionValue<OpPassManager>::operator=(
434     const mlir::OpPassManager &rhs) {
435   setValue(rhs);
436   return *this;
437 }
438 
439 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default;
440 
441 void llvm::cl::OptionValue<OpPassManager>::setValue(
442     const OpPassManager &newValue) {
443   if (hasValue())
444     *value = newValue;
445   else
446     value = std::make_unique<mlir::OpPassManager>(newValue);
447 }
448 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) {
449   FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr);
450   assert(succeeded(pipeline) && "invalid pass pipeline");
451   setValue(*pipeline);
452 }
453 
454 bool llvm::cl::OptionValue<OpPassManager>::compare(
455     const mlir::OpPassManager &rhs) const {
456   std::string lhsStr, rhsStr;
457   {
458     raw_string_ostream lhsStream(lhsStr);
459     value->printAsTextualPipeline(lhsStream);
460 
461     raw_string_ostream rhsStream(rhsStr);
462     rhs.printAsTextualPipeline(rhsStream);
463   }
464 
465   // Use the textual format for pipeline comparisons.
466   return lhsStr == rhsStr;
467 }
468 
469 void llvm::cl::OptionValue<OpPassManager>::anchor() {}
470 
471 //===----------------------------------------------------------------------===//
472 // OpPassManager: Parser
473 
474 namespace llvm {
475 namespace cl {
476 template class basic_parser<OpPassManager>;
477 } // namespace cl
478 } // namespace llvm
479 
480 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg,
481                                             ParsedPassManager &value) {
482   FailureOr<OpPassManager> pipeline = parsePassPipeline(arg);
483   if (failed(pipeline))
484     return true;
485   value.value = std::make_unique<OpPassManager>(std::move(*pipeline));
486   return false;
487 }
488 
489 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os,
490                                             const OpPassManager &value) {
491   value.printAsTextualPipeline(os);
492 }
493 
494 void llvm::cl::parser<OpPassManager>::printOptionDiff(
495     const Option &opt, OpPassManager &pm, const OptVal &defaultValue,
496     size_t globalWidth) const {
497   printOptionName(opt, globalWidth);
498   outs() << "= ";
499   pm.printAsTextualPipeline(outs());
500 
501   if (defaultValue.hasValue()) {
502     outs().indent(2) << " (default: ";
503     defaultValue.getValue().printAsTextualPipeline(outs());
504     outs() << ")";
505   }
506   outs() << "\n";
507 }
508 
509 void llvm::cl::parser<OpPassManager>::anchor() {}
510 
511 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() =
512     default;
513 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager(
514     ParsedPassManager &&) = default;
515 llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() =
516     default;
517 
518 //===----------------------------------------------------------------------===//
519 // TextualPassPipeline Parser
520 //===----------------------------------------------------------------------===//
521 
522 namespace {
523 /// This class represents a textual description of a pass pipeline.
524 class TextualPipeline {
525 public:
526   /// Try to initialize this pipeline with the given pipeline text.
527   /// `errorStream` is the output stream to emit errors to.
528   LogicalResult initialize(StringRef text, raw_ostream &errorStream);
529 
530   /// Add the internal pipeline elements to the provided pass manager.
531   LogicalResult
532   addToPipeline(OpPassManager &pm,
533                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
534 
535 private:
536   /// A functor used to emit errors found during pipeline handling. The first
537   /// parameter corresponds to the raw location within the pipeline string. This
538   /// should always return failure.
539   using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
540 
541   /// A struct to capture parsed pass pipeline names.
542   ///
543   /// A pipeline is defined as a series of names, each of which may in itself
544   /// recursively contain a nested pipeline. A name is either the name of a pass
545   /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If
546   /// the name is the name of a pass, the InnerPipeline is empty, since passes
547   /// cannot contain inner pipelines.
548   struct PipelineElement {
549     PipelineElement(StringRef name) : name(name) {}
550 
551     StringRef name;
552     StringRef options;
553     const PassRegistryEntry *registryEntry = nullptr;
554     std::vector<PipelineElement> innerPipeline;
555   };
556 
557   /// Parse the given pipeline text into the internal pipeline vector. This
558   /// function only parses the structure of the pipeline, and does not resolve
559   /// its elements.
560   LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
561 
562   /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
563   /// the corresponding registry entry.
564   LogicalResult
565   resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
566                           ErrorHandlerT errorHandler);
567 
568   /// Resolve a single element of the pipeline.
569   LogicalResult resolvePipelineElement(PipelineElement &element,
570                                        ErrorHandlerT errorHandler);
571 
572   /// Add the given pipeline elements to the provided pass manager.
573   LogicalResult
574   addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
575                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
576 
577   std::vector<PipelineElement> pipeline;
578 };
579 
580 } // namespace
581 
582 /// Try to initialize this pipeline with the given pipeline text. An option is
583 /// given to enable accurate error reporting.
584 LogicalResult TextualPipeline::initialize(StringRef text,
585                                           raw_ostream &errorStream) {
586   if (text.empty())
587     return success();
588 
589   // Build a source manager to use for error reporting.
590   llvm::SourceMgr pipelineMgr;
591   pipelineMgr.AddNewSourceBuffer(
592       llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
593                                        /*RequiresNullTerminator=*/false),
594       SMLoc());
595   auto errorHandler = [&](const char *rawLoc, Twine msg) {
596     pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc),
597                              llvm::SourceMgr::DK_Error, msg);
598     return failure();
599   };
600 
601   // Parse the provided pipeline string.
602   if (failed(parsePipelineText(text, errorHandler)))
603     return failure();
604   return resolvePipelineElements(pipeline, errorHandler);
605 }
606 
607 /// Add the internal pipeline elements to the provided pass manager.
608 LogicalResult TextualPipeline::addToPipeline(
609     OpPassManager &pm,
610     function_ref<LogicalResult(const Twine &)> errorHandler) const {
611   // Temporarily disable implicit nesting while we append to the pipeline. We
612   // want the created pipeline to exactly match the parsed text pipeline, so
613   // it's preferrable to just error out if implicit nesting would be required.
614   OpPassManager::Nesting nesting = pm.getNesting();
615   pm.setNesting(OpPassManager::Nesting::Explicit);
616   auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); });
617 
618   return addToPipeline(pipeline, pm, errorHandler);
619 }
620 
621 /// Parse the given pipeline text into the internal pipeline vector. This
622 /// function only parses the structure of the pipeline, and does not resolve
623 /// its elements.
624 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
625                                                  ErrorHandlerT errorHandler) {
626   SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
627   for (;;) {
628     std::vector<PipelineElement> &pipeline = *pipelineStack.back();
629     size_t pos = text.find_first_of(",(){");
630     pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
631 
632     // If we have a single terminating name, we're done.
633     if (pos == StringRef::npos)
634       break;
635 
636     text = text.substr(pos);
637     char sep = text[0];
638 
639     // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
640     if (sep == '{') {
641       text = text.substr(1);
642 
643       // Skip over everything until the closing '}' and store as options.
644       size_t close = StringRef::npos;
645       for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
646         if (text[i] == '{') {
647           ++braceCount;
648           continue;
649         }
650         if (text[i] == '}' && --braceCount == 0) {
651           close = i;
652           break;
653         }
654       }
655 
656       // Check to see if a closing options brace was found.
657       if (close == StringRef::npos) {
658         return errorHandler(
659             /*rawLoc=*/text.data() - 1,
660             "missing closing '}' while processing pass options");
661       }
662       pipeline.back().options = text.substr(0, close);
663       text = text.substr(close + 1);
664 
665       // Consume space characters that an user might add for readability.
666       text = text.ltrim();
667 
668       // Skip checking for '(' because nested pipelines cannot have options.
669     } else if (sep == '(') {
670       text = text.substr(1);
671 
672       // Push the inner pipeline onto the stack to continue processing.
673       pipelineStack.push_back(&pipeline.back().innerPipeline);
674       continue;
675     }
676 
677     // When handling the close parenthesis, we greedily consume them to avoid
678     // empty strings in the pipeline.
679     while (text.consume_front(")")) {
680       // If we try to pop the outer pipeline we have unbalanced parentheses.
681       if (pipelineStack.size() == 1)
682         return errorHandler(/*rawLoc=*/text.data() - 1,
683                             "encountered extra closing ')' creating unbalanced "
684                             "parentheses while parsing pipeline");
685 
686       pipelineStack.pop_back();
687       // Consume space characters that an user might add for readability.
688       text = text.ltrim();
689     }
690 
691     // Check if we've finished parsing.
692     if (text.empty())
693       break;
694 
695     // Otherwise, the end of an inner pipeline always has to be followed by
696     // a comma, and then we can continue.
697     if (!text.consume_front(","))
698       return errorHandler(text.data(), "expected ',' after parsing pipeline");
699   }
700 
701   // Check for unbalanced parentheses.
702   if (pipelineStack.size() > 1)
703     return errorHandler(
704         text.data(),
705         "encountered unbalanced parentheses while parsing pipeline");
706 
707   assert(pipelineStack.back() == &pipeline &&
708          "wrong pipeline at the bottom of the stack");
709   return success();
710 }
711 
712 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
713 /// the corresponding registry entry.
714 LogicalResult TextualPipeline::resolvePipelineElements(
715     MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
716   for (auto &elt : elements)
717     if (failed(resolvePipelineElement(elt, errorHandler)))
718       return failure();
719   return success();
720 }
721 
722 /// Resolve a single element of the pipeline.
723 LogicalResult
724 TextualPipeline::resolvePipelineElement(PipelineElement &element,
725                                         ErrorHandlerT errorHandler) {
726   // If the inner pipeline of this element is not empty, this is an operation
727   // pipeline.
728   if (!element.innerPipeline.empty())
729     return resolvePipelineElements(element.innerPipeline, errorHandler);
730 
731   // Otherwise, this must be a pass or pass pipeline.
732   // Check to see if a pipeline was registered with this name.
733   if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
734     return success();
735 
736   // If not, then this must be a specific pass name.
737   if ((element.registryEntry = PassInfo::lookup(element.name)))
738     return success();
739 
740   // Emit an error for the unknown pass.
741   auto *rawLoc = element.name.data();
742   return errorHandler(rawLoc, "'" + element.name +
743                                   "' does not refer to a "
744                                   "registered pass or pass pipeline");
745 }
746 
747 /// Add the given pipeline elements to the provided pass manager.
748 LogicalResult TextualPipeline::addToPipeline(
749     ArrayRef<PipelineElement> elements, OpPassManager &pm,
750     function_ref<LogicalResult(const Twine &)> errorHandler) const {
751   for (auto &elt : elements) {
752     if (elt.registryEntry) {
753       if (failed(elt.registryEntry->addToPipeline(pm, elt.options,
754                                                   errorHandler))) {
755         return errorHandler("failed to add `" + elt.name + "` with options `" +
756                             elt.options + "`");
757       }
758     } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
759                                     errorHandler))) {
760       return errorHandler("failed to add `" + elt.name + "` with options `" +
761                           elt.options + "` to inner pipeline");
762     }
763   }
764   return success();
765 }
766 
767 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
768                                       raw_ostream &errorStream) {
769   TextualPipeline pipelineParser;
770   if (failed(pipelineParser.initialize(pipeline, errorStream)))
771     return failure();
772   auto errorHandler = [&](Twine msg) {
773     errorStream << msg << "\n";
774     return failure();
775   };
776   if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
777     return failure();
778   return success();
779 }
780 
781 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
782                                                  raw_ostream &errorStream) {
783   pipeline = pipeline.trim();
784   // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
785   size_t pipelineStart = pipeline.find_first_of('(');
786   if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
787       !pipeline.consume_back(")")) {
788     errorStream << "expected pass pipeline to be wrapped with the anchor "
789                    "operation type, e.g. 'builtin.module(...)'";
790     return failure();
791   }
792 
793   StringRef opName = pipeline.take_front(pipelineStart).rtrim();
794   OpPassManager pm(opName);
795   if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
796                                errorStream)))
797     return failure();
798   return pm;
799 }
800 
801 //===----------------------------------------------------------------------===//
802 // PassNameParser
803 //===----------------------------------------------------------------------===//
804 
805 namespace {
806 /// This struct represents the possible data entries in a parsed pass pipeline
807 /// list.
808 struct PassArgData {
809   PassArgData() = default;
810   PassArgData(const PassRegistryEntry *registryEntry)
811       : registryEntry(registryEntry) {}
812 
813   /// This field is used when the parsed option corresponds to a registered pass
814   /// or pass pipeline.
815   const PassRegistryEntry *registryEntry{nullptr};
816 
817   /// This field is set when instance specific pass options have been provided
818   /// on the command line.
819   StringRef options;
820 };
821 } // namespace
822 
823 namespace llvm {
824 namespace cl {
825 /// Define a valid OptionValue for the command line pass argument.
826 template <>
827 struct OptionValue<PassArgData> final
828     : OptionValueBase<PassArgData, /*isClass=*/true> {
829   OptionValue(const PassArgData &value) { this->setValue(value); }
830   OptionValue() = default;
831   void anchor() override {}
832 
833   bool hasValue() const { return true; }
834   const PassArgData &getValue() const { return value; }
835   void setValue(const PassArgData &value) { this->value = value; }
836 
837   PassArgData value;
838 };
839 } // namespace cl
840 } // namespace llvm
841 
842 namespace {
843 
844 /// The name for the command line option used for parsing the textual pass
845 /// pipeline.
846 #define PASS_PIPELINE_ARG "pass-pipeline"
847 
848 /// Adds command line option for each registered pass or pass pipeline, as well
849 /// as textual pass pipelines.
850 struct PassNameParser : public llvm::cl::parser<PassArgData> {
851   PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
852 
853   void initialize();
854   void printOptionInfo(const llvm::cl::Option &opt,
855                        size_t globalWidth) const override;
856   size_t getOptionWidth(const llvm::cl::Option &opt) const override;
857   bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
858              PassArgData &value);
859 
860   /// If true, this parser only parses entries that correspond to a concrete
861   /// pass registry entry, and does not include pipeline entries or the options
862   /// for pass entries.
863   bool passNamesOnly = false;
864 };
865 } // namespace
866 
867 void PassNameParser::initialize() {
868   llvm::cl::parser<PassArgData>::initialize();
869 
870   /// Add the pass entries.
871   for (const auto &kv : *passRegistry) {
872     addLiteralOption(kv.second.getPassArgument(), &kv.second,
873                      kv.second.getPassDescription());
874   }
875   /// Add the pass pipeline entries.
876   if (!passNamesOnly) {
877     for (const auto &kv : *passPipelineRegistry) {
878       addLiteralOption(kv.second.getPassArgument(), &kv.second,
879                        kv.second.getPassDescription());
880     }
881   }
882 }
883 
884 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
885                                      size_t globalWidth) const {
886   // If this parser is just parsing pass names, print a simplified option
887   // string.
888   if (passNamesOnly) {
889     llvm::outs() << "  --" << opt.ArgStr << "=<pass-arg>";
890     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18);
891     return;
892   }
893 
894   // Print the information for the top-level option.
895   if (opt.hasArgStr()) {
896     llvm::outs() << "  --" << opt.ArgStr;
897     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
898   } else {
899     llvm::outs() << "  " << opt.HelpStr << '\n';
900   }
901 
902   // Functor used to print the ordered entries of a registration map.
903   auto printOrderedEntries = [&](StringRef header, auto &map) {
904     llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
905     for (auto &kv : map)
906       orderedEntries.push_back(&kv.second);
907     llvm::array_pod_sort(
908         orderedEntries.begin(), orderedEntries.end(),
909         [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
910           return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
911         });
912 
913     llvm::outs().indent(4) << header << ":\n";
914     for (PassRegistryEntry *entry : orderedEntries)
915       entry->printHelpStr(/*indent=*/6, globalWidth);
916   };
917 
918   // Print the available passes.
919   printOrderedEntries("Passes", *passRegistry);
920 
921   // Print the available pass pipelines.
922   if (!passPipelineRegistry->empty())
923     printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
924 }
925 
926 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
927   size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
928 
929   // Check for any wider pass or pipeline options.
930   for (auto &entry : *passRegistry)
931     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
932   for (auto &entry : *passPipelineRegistry)
933     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
934   return maxWidth;
935 }
936 
937 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
938                            StringRef arg, PassArgData &value) {
939   if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
940     return true;
941   value.options = arg;
942   return false;
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // PassPipelineCLParser
947 //===----------------------------------------------------------------------===//
948 
949 namespace mlir {
950 namespace detail {
951 struct PassPipelineCLParserImpl {
952   PassPipelineCLParserImpl(StringRef arg, StringRef description,
953                            bool passNamesOnly)
954       : passList(arg, llvm::cl::desc(description)) {
955     passList.getParser().passNamesOnly = passNamesOnly;
956     passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
957   }
958 
959   /// Returns true if the given pass registry entry was registered at the
960   /// top-level of the parser, i.e. not within an explicit textual pipeline.
961   bool contains(const PassRegistryEntry *entry) const {
962     return llvm::any_of(passList, [&](const PassArgData &data) {
963       return data.registryEntry == entry;
964     });
965   }
966 
967   /// The set of passes and pass pipelines to run.
968   llvm::cl::list<PassArgData, bool, PassNameParser> passList;
969 };
970 } // namespace detail
971 } // namespace mlir
972 
973 /// Construct a pass pipeline parser with the given command line description.
974 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
975     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
976           arg, description, /*passNamesOnly=*/false)),
977       passPipeline(
978           PASS_PIPELINE_ARG,
979           llvm::cl::desc("Textual description of the pass pipeline to run")) {}
980 
981 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description,
982                                            StringRef alias)
983     : PassPipelineCLParser(arg, description) {
984   passPipelineAlias.emplace(alias,
985                             llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG),
986                             llvm::cl::aliasopt(passPipeline));
987 }
988 
989 PassPipelineCLParser::~PassPipelineCLParser() = default;
990 
991 /// Returns true if this parser contains any valid options to add.
992 bool PassPipelineCLParser::hasAnyOccurrences() const {
993   return passPipeline.getNumOccurrences() != 0 ||
994          impl->passList.getNumOccurrences() != 0;
995 }
996 
997 /// Returns true if the given pass registry entry was registered at the
998 /// top-level of the parser, i.e. not within an explicit textual pipeline.
999 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
1000   return impl->contains(entry);
1001 }
1002 
1003 /// Adds the passes defined by this parser entry to the given pass manager.
1004 LogicalResult PassPipelineCLParser::addToPipeline(
1005     OpPassManager &pm,
1006     function_ref<LogicalResult(const Twine &)> errorHandler) const {
1007   if (passPipeline.getNumOccurrences()) {
1008     if (impl->passList.getNumOccurrences())
1009       return errorHandler(
1010           "'-" PASS_PIPELINE_ARG
1011           "' option can't be used with individual pass options");
1012     std::string errMsg;
1013     llvm::raw_string_ostream os(errMsg);
1014     FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os);
1015     if (failed(parsed))
1016       return errorHandler(errMsg);
1017     pm = std::move(*parsed);
1018     return success();
1019   }
1020 
1021   for (auto &passIt : impl->passList) {
1022     if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
1023                                                    errorHandler)))
1024       return failure();
1025   }
1026   return success();
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // PassNameCLParser
1031 
1032 /// Construct a pass pipeline parser with the given command line description.
1033 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description)
1034     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(
1035           arg, description, /*passNamesOnly=*/true)) {
1036   impl->passList.setMiscFlag(llvm::cl::CommaSeparated);
1037 }
1038 PassNameCLParser::~PassNameCLParser() = default;
1039 
1040 /// Returns true if this parser contains any valid options to add.
1041 bool PassNameCLParser::hasAnyOccurrences() const {
1042   return impl->passList.getNumOccurrences() != 0;
1043 }
1044 
1045 /// Returns true if the given pass registry entry was registered at the
1046 /// top-level of the parser, i.e. not within an explicit textual pipeline.
1047 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const {
1048   return impl->contains(entry);
1049 }
1050