1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Pass/PassRegistry.h" 10 11 #include "mlir/Pass/Pass.h" 12 #include "mlir/Pass/PassManager.h" 13 #include "llvm/ADT/DenseMap.h" 14 #include "llvm/ADT/ScopeExit.h" 15 #include "llvm/ADT/StringRef.h" 16 #include "llvm/Support/Format.h" 17 #include "llvm/Support/ManagedStatic.h" 18 #include "llvm/Support/MemoryBuffer.h" 19 #include "llvm/Support/SourceMgr.h" 20 21 #include <optional> 22 #include <utility> 23 24 using namespace mlir; 25 using namespace detail; 26 27 /// Static mapping of all of the registered passes. 28 static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry; 29 30 /// A mapping of the above pass registry entries to the corresponding TypeID 31 /// of the pass that they generate. 32 static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs; 33 34 /// Static mapping of all of the registered pass pipelines. 35 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>> 36 passPipelineRegistry; 37 38 /// Utility to create a default registry function from a pass instance. 39 static PassRegistryFunction 40 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) { 41 return [=](OpPassManager &pm, StringRef options, 42 function_ref<LogicalResult(const Twine &)> errorHandler) { 43 std::unique_ptr<Pass> pass = allocator(); 44 LogicalResult result = pass->initializeOptions(options, errorHandler); 45 46 std::optional<StringRef> pmOpName = pm.getOpName(); 47 std::optional<StringRef> passOpName = pass->getOpName(); 48 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) && pmOpName && 49 passOpName && *pmOpName != *passOpName) { 50 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() + 51 "' restricted to '" + *pass->getOpName() + 52 "' on a PassManager intended to run on '" + 53 pm.getOpAnchorName() + "', did you intend to nest?"); 54 } 55 pm.addPass(std::move(pass)); 56 return result; 57 }; 58 } 59 60 /// Utility to print the help string for a specific option. 61 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent, 62 size_t descIndent, bool isTopLevel) { 63 size_t numSpaces = descIndent - indent - 4; 64 llvm::outs().indent(indent) 65 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n'; 66 } 67 68 //===----------------------------------------------------------------------===// 69 // PassRegistry 70 //===----------------------------------------------------------------------===// 71 72 /// Prints the passes that were previously registered and stored in passRegistry 73 void mlir::printRegisteredPasses() { 74 size_t maxWidth = 0; 75 for (auto &entry : *passRegistry) 76 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 77 78 // Functor used to print the ordered entries of a registration map. 79 auto printOrderedEntries = [&](StringRef header, auto &map) { 80 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; 81 for (auto &kv : map) 82 orderedEntries.push_back(&kv.second); 83 llvm::array_pod_sort( 84 orderedEntries.begin(), orderedEntries.end(), 85 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { 86 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument()); 87 }); 88 89 llvm::outs().indent(0) << header << ":\n"; 90 for (PassRegistryEntry *entry : orderedEntries) 91 entry->printHelpStr(/*indent=*/2, maxWidth); 92 }; 93 94 // Print the available passes. 95 printOrderedEntries("Passes", *passRegistry); 96 } 97 98 /// Print the help information for this pass. This includes the argument, 99 /// description, and any pass options. `descIndent` is the indent that the 100 /// descriptions should be aligned. 101 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const { 102 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent, 103 /*isTopLevel=*/true); 104 // If this entry has options, print the help for those as well. 105 optHandler([=](const PassOptions &options) { 106 options.printHelp(indent, descIndent); 107 }); 108 } 109 110 /// Return the maximum width required when printing the options of this 111 /// entry. 112 size_t PassRegistryEntry::getOptionWidth() const { 113 size_t maxLen = 0; 114 optHandler([&](const PassOptions &options) mutable { 115 maxLen = options.getOptionWidth() + 2; 116 }); 117 return maxLen; 118 } 119 120 //===----------------------------------------------------------------------===// 121 // PassPipelineInfo 122 //===----------------------------------------------------------------------===// 123 124 void mlir::registerPassPipeline( 125 StringRef arg, StringRef description, const PassRegistryFunction &function, 126 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) { 127 PassPipelineInfo pipelineInfo(arg, description, function, 128 std::move(optHandler)); 129 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second; 130 #ifndef NDEBUG 131 if (!inserted) 132 report_fatal_error("Pass pipeline " + arg + " registered multiple times"); 133 #endif 134 (void)inserted; 135 } 136 137 //===----------------------------------------------------------------------===// 138 // PassInfo 139 //===----------------------------------------------------------------------===// 140 141 PassInfo::PassInfo(StringRef arg, StringRef description, 142 const PassAllocatorFunction &allocator) 143 : PassRegistryEntry( 144 arg, description, buildDefaultRegistryFn(allocator), 145 // Use a temporary pass to provide an options instance. 146 [=](function_ref<void(const PassOptions &)> optHandler) { 147 optHandler(allocator()->passOptions); 148 }) {} 149 150 void mlir::registerPass(const PassAllocatorFunction &function) { 151 std::unique_ptr<Pass> pass = function(); 152 StringRef arg = pass->getArgument(); 153 if (arg.empty()) 154 llvm::report_fatal_error(llvm::Twine("Trying to register '") + 155 pass->getName() + 156 "' pass that does not override `getArgument()`"); 157 StringRef description = pass->getDescription(); 158 PassInfo passInfo(arg, description, function); 159 passRegistry->try_emplace(arg, passInfo); 160 161 // Verify that the registered pass has the same ID as any registered to this 162 // arg before it. 163 TypeID entryTypeID = pass->getTypeID(); 164 auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first; 165 if (it->second != entryTypeID) 166 llvm::report_fatal_error( 167 "pass allocator creates a different pass than previously " 168 "registered for pass " + 169 arg); 170 } 171 172 /// Returns the pass info for the specified pass argument or null if unknown. 173 const PassInfo *mlir::PassInfo::lookup(StringRef passArg) { 174 auto it = passRegistry->find(passArg); 175 return it == passRegistry->end() ? nullptr : &it->second; 176 } 177 178 /// Returns the pass pipeline info for the specified pass pipeline argument or 179 /// null if unknown. 180 const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) { 181 auto it = passPipelineRegistry->find(pipelineArg); 182 return it == passPipelineRegistry->end() ? nullptr : &it->second; 183 } 184 185 //===----------------------------------------------------------------------===// 186 // PassOptions 187 //===----------------------------------------------------------------------===// 188 189 /// Attempt to find the next occurance of character 'c' in the string starting 190 /// from the `index`-th position , omitting any occurances that appear within 191 /// intervening ranges or literals. 192 static size_t findChar(StringRef str, size_t index, char c) { 193 for (size_t i = index, e = str.size(); i < e; ++i) { 194 if (str[i] == c) 195 return i; 196 // Check for various range characters. 197 if (str[i] == '{') 198 i = findChar(str, i + 1, '}'); 199 else if (str[i] == '(') 200 i = findChar(str, i + 1, ')'); 201 else if (str[i] == '[') 202 i = findChar(str, i + 1, ']'); 203 else if (str[i] == '\"') 204 i = str.find_first_of('\"', i + 1); 205 else if (str[i] == '\'') 206 i = str.find_first_of('\'', i + 1); 207 if (i == StringRef::npos) 208 return StringRef::npos; 209 } 210 return StringRef::npos; 211 } 212 213 /// Extract an argument from 'options' and update it to point after the arg. 214 /// Returns the cleaned argument string. 215 static StringRef extractArgAndUpdateOptions(StringRef &options, 216 size_t argSize) { 217 StringRef str = options.take_front(argSize).trim(); 218 options = options.drop_front(argSize).ltrim(); 219 220 // Early exit if there's no escape sequence. 221 if (str.size() <= 1) 222 return str; 223 224 const auto escapePairs = {std::make_pair('\'', '\''), 225 std::make_pair('"', '"')}; 226 for (const auto &escape : escapePairs) { 227 if (str.front() == escape.first && str.back() == escape.second) { 228 // Drop the escape characters and trim. 229 // Don't process additional escape sequences. 230 return str.drop_front().drop_back().trim(); 231 } 232 } 233 234 // Arguments may be wrapped in `{...}`. Unlike the quotation markers that 235 // denote literals, we respect scoping here. The outer `{...}` should not 236 // be stripped in cases such as "arg={...},{...}", which can be used to denote 237 // lists of nested option structs. 238 if (str.front() == '{') { 239 unsigned match = findChar(str, 1, '}'); 240 if (match == str.size() - 1) 241 str = str.drop_front().drop_back().trim(); 242 } 243 244 return str; 245 } 246 247 LogicalResult detail::pass_options::parseCommaSeparatedList( 248 llvm::cl::Option &opt, StringRef argName, StringRef optionStr, 249 function_ref<LogicalResult(StringRef)> elementParseFn) { 250 if (optionStr.empty()) 251 return success(); 252 253 size_t nextElePos = findChar(optionStr, 0, ','); 254 while (nextElePos != StringRef::npos) { 255 // Process the portion before the comma. 256 if (failed( 257 elementParseFn(extractArgAndUpdateOptions(optionStr, nextElePos)))) 258 return failure(); 259 260 // Drop the leading ',' 261 optionStr = optionStr.drop_front(); 262 nextElePos = findChar(optionStr, 0, ','); 263 } 264 return elementParseFn( 265 extractArgAndUpdateOptions(optionStr, optionStr.size())); 266 } 267 268 /// Out of line virtual function to provide home for the class. 269 void detail::PassOptions::OptionBase::anchor() {} 270 271 /// Copy the option values from 'other'. 272 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) { 273 assert(options.size() == other.options.size()); 274 if (options.empty()) 275 return; 276 for (auto optionsIt : llvm::zip(options, other.options)) 277 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt)); 278 } 279 280 /// Parse in the next argument from the given options string. Returns a tuple 281 /// containing [the key of the option, the value of the option, updated 282 /// `options` string pointing after the parsed option]. 283 static std::tuple<StringRef, StringRef, StringRef> 284 parseNextArg(StringRef options) { 285 // Try to process the given punctuation, properly escaping any contained 286 // characters. 287 auto tryProcessPunct = [&](size_t ¤tPos, char punct) { 288 if (options[currentPos] != punct) 289 return false; 290 size_t nextIt = options.find_first_of(punct, currentPos + 1); 291 if (nextIt != StringRef::npos) 292 currentPos = nextIt; 293 return true; 294 }; 295 296 // Parse the argument name of the option. 297 StringRef argName; 298 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { 299 // Check for the end of the full option. 300 if (argEndIt == optionsE || options[argEndIt] == ' ') { 301 argName = extractArgAndUpdateOptions(options, argEndIt); 302 return std::make_tuple(argName, StringRef(), options); 303 } 304 305 // Check for the end of the name and the start of the value. 306 if (options[argEndIt] == '=') { 307 argName = extractArgAndUpdateOptions(options, argEndIt); 308 options = options.drop_front(); 309 break; 310 } 311 } 312 313 // Parse the value of the option. 314 for (size_t argEndIt = 0, optionsE = options.size();; ++argEndIt) { 315 // Handle the end of the options string. 316 if (argEndIt == optionsE || options[argEndIt] == ' ') { 317 StringRef value = extractArgAndUpdateOptions(options, argEndIt); 318 return std::make_tuple(argName, value, options); 319 } 320 321 // Skip over escaped sequences. 322 char c = options[argEndIt]; 323 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) 324 continue; 325 // '{...}' is used to specify options to passes, properly escape it so 326 // that we don't accidentally split any nested options. 327 if (c == '{') { 328 size_t braceCount = 1; 329 for (++argEndIt; argEndIt != optionsE; ++argEndIt) { 330 // Allow nested punctuation. 331 if (tryProcessPunct(argEndIt, '\'') || tryProcessPunct(argEndIt, '"')) 332 continue; 333 if (options[argEndIt] == '{') 334 ++braceCount; 335 else if (options[argEndIt] == '}' && --braceCount == 0) 336 break; 337 } 338 // Account for the increment at the top of the loop. 339 --argEndIt; 340 } 341 } 342 llvm_unreachable("unexpected control flow in pass option parsing"); 343 } 344 345 LogicalResult detail::PassOptions::parseFromString(StringRef options, 346 raw_ostream &errorStream) { 347 // NOTE: `options` is modified in place to always refer to the unprocessed 348 // part of the string. 349 while (!options.empty()) { 350 StringRef key, value; 351 std::tie(key, value, options) = parseNextArg(options); 352 if (key.empty()) 353 continue; 354 355 auto it = OptionsMap.find(key); 356 if (it == OptionsMap.end()) { 357 errorStream << "<Pass-Options-Parser>: no such option " << key << "\n"; 358 return failure(); 359 } 360 if (llvm::cl::ProvidePositionalOption(it->second, value, 0)) 361 return failure(); 362 } 363 364 return success(); 365 } 366 367 /// Print the options held by this struct in a form that can be parsed via 368 /// 'parseFromString'. 369 void detail::PassOptions::print(raw_ostream &os) const { 370 // If there are no options, there is nothing left to do. 371 if (OptionsMap.empty()) 372 return; 373 374 // Sort the options to make the ordering deterministic. 375 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 376 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 377 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 378 }; 379 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 380 381 // Interleave the options with ' '. 382 os << '{'; 383 llvm::interleave( 384 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " "); 385 os << '}'; 386 } 387 388 /// Print the help string for the options held by this struct. `descIndent` is 389 /// the indent within the stream that the descriptions should be aligned. 390 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const { 391 // Sort the options to make the ordering deterministic. 392 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end()); 393 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) { 394 return (*lhs)->getArgStr().compare((*rhs)->getArgStr()); 395 }; 396 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs); 397 for (OptionBase *option : orderedOps) { 398 // TODO: printOptionInfo assumes a specific indent and will 399 // print options with values with incorrect indentation. We should add 400 // support to llvm::cl::Option for passing in a base indent to use when 401 // printing. 402 llvm::outs().indent(indent); 403 option->getOption()->printOptionInfo(descIndent - indent); 404 } 405 } 406 407 /// Return the maximum width required when printing the help string. 408 size_t detail::PassOptions::getOptionWidth() const { 409 size_t max = 0; 410 for (auto *option : options) 411 max = std::max(max, option->getOption()->getOptionWidth()); 412 return max; 413 } 414 415 //===----------------------------------------------------------------------===// 416 // MLIR Options 417 //===----------------------------------------------------------------------===// 418 419 //===----------------------------------------------------------------------===// 420 // OpPassManager: OptionValue 421 422 llvm::cl::OptionValue<OpPassManager>::OptionValue() = default; 423 llvm::cl::OptionValue<OpPassManager>::OptionValue( 424 const mlir::OpPassManager &value) { 425 setValue(value); 426 } 427 llvm::cl::OptionValue<OpPassManager>::OptionValue( 428 const llvm::cl::OptionValue<mlir::OpPassManager> &rhs) { 429 if (rhs.hasValue()) 430 setValue(rhs.getValue()); 431 } 432 llvm::cl::OptionValue<OpPassManager> & 433 llvm::cl::OptionValue<OpPassManager>::operator=( 434 const mlir::OpPassManager &rhs) { 435 setValue(rhs); 436 return *this; 437 } 438 439 llvm::cl::OptionValue<OpPassManager>::~OptionValue<OpPassManager>() = default; 440 441 void llvm::cl::OptionValue<OpPassManager>::setValue( 442 const OpPassManager &newValue) { 443 if (hasValue()) 444 *value = newValue; 445 else 446 value = std::make_unique<mlir::OpPassManager>(newValue); 447 } 448 void llvm::cl::OptionValue<OpPassManager>::setValue(StringRef pipelineStr) { 449 FailureOr<OpPassManager> pipeline = parsePassPipeline(pipelineStr); 450 assert(succeeded(pipeline) && "invalid pass pipeline"); 451 setValue(*pipeline); 452 } 453 454 bool llvm::cl::OptionValue<OpPassManager>::compare( 455 const mlir::OpPassManager &rhs) const { 456 std::string lhsStr, rhsStr; 457 { 458 raw_string_ostream lhsStream(lhsStr); 459 value->printAsTextualPipeline(lhsStream); 460 461 raw_string_ostream rhsStream(rhsStr); 462 rhs.printAsTextualPipeline(rhsStream); 463 } 464 465 // Use the textual format for pipeline comparisons. 466 return lhsStr == rhsStr; 467 } 468 469 void llvm::cl::OptionValue<OpPassManager>::anchor() {} 470 471 //===----------------------------------------------------------------------===// 472 // OpPassManager: Parser 473 474 namespace llvm { 475 namespace cl { 476 template class basic_parser<OpPassManager>; 477 } // namespace cl 478 } // namespace llvm 479 480 bool llvm::cl::parser<OpPassManager>::parse(Option &, StringRef, StringRef arg, 481 ParsedPassManager &value) { 482 FailureOr<OpPassManager> pipeline = parsePassPipeline(arg); 483 if (failed(pipeline)) 484 return true; 485 value.value = std::make_unique<OpPassManager>(std::move(*pipeline)); 486 return false; 487 } 488 489 void llvm::cl::parser<OpPassManager>::print(raw_ostream &os, 490 const OpPassManager &value) { 491 value.printAsTextualPipeline(os); 492 } 493 494 void llvm::cl::parser<OpPassManager>::printOptionDiff( 495 const Option &opt, OpPassManager &pm, const OptVal &defaultValue, 496 size_t globalWidth) const { 497 printOptionName(opt, globalWidth); 498 outs() << "= "; 499 pm.printAsTextualPipeline(outs()); 500 501 if (defaultValue.hasValue()) { 502 outs().indent(2) << " (default: "; 503 defaultValue.getValue().printAsTextualPipeline(outs()); 504 outs() << ")"; 505 } 506 outs() << "\n"; 507 } 508 509 void llvm::cl::parser<OpPassManager>::anchor() {} 510 511 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager() = 512 default; 513 llvm::cl::parser<OpPassManager>::ParsedPassManager::ParsedPassManager( 514 ParsedPassManager &&) = default; 515 llvm::cl::parser<OpPassManager>::ParsedPassManager::~ParsedPassManager() = 516 default; 517 518 //===----------------------------------------------------------------------===// 519 // TextualPassPipeline Parser 520 //===----------------------------------------------------------------------===// 521 522 namespace { 523 /// This class represents a textual description of a pass pipeline. 524 class TextualPipeline { 525 public: 526 /// Try to initialize this pipeline with the given pipeline text. 527 /// `errorStream` is the output stream to emit errors to. 528 LogicalResult initialize(StringRef text, raw_ostream &errorStream); 529 530 /// Add the internal pipeline elements to the provided pass manager. 531 LogicalResult 532 addToPipeline(OpPassManager &pm, 533 function_ref<LogicalResult(const Twine &)> errorHandler) const; 534 535 private: 536 /// A functor used to emit errors found during pipeline handling. The first 537 /// parameter corresponds to the raw location within the pipeline string. This 538 /// should always return failure. 539 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>; 540 541 /// A struct to capture parsed pass pipeline names. 542 /// 543 /// A pipeline is defined as a series of names, each of which may in itself 544 /// recursively contain a nested pipeline. A name is either the name of a pass 545 /// (e.g. "cse") or the name of an operation type (e.g. "buitin.module"). If 546 /// the name is the name of a pass, the InnerPipeline is empty, since passes 547 /// cannot contain inner pipelines. 548 struct PipelineElement { 549 PipelineElement(StringRef name) : name(name) {} 550 551 StringRef name; 552 StringRef options; 553 const PassRegistryEntry *registryEntry = nullptr; 554 std::vector<PipelineElement> innerPipeline; 555 }; 556 557 /// Parse the given pipeline text into the internal pipeline vector. This 558 /// function only parses the structure of the pipeline, and does not resolve 559 /// its elements. 560 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler); 561 562 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 563 /// the corresponding registry entry. 564 LogicalResult 565 resolvePipelineElements(MutableArrayRef<PipelineElement> elements, 566 ErrorHandlerT errorHandler); 567 568 /// Resolve a single element of the pipeline. 569 LogicalResult resolvePipelineElement(PipelineElement &element, 570 ErrorHandlerT errorHandler); 571 572 /// Add the given pipeline elements to the provided pass manager. 573 LogicalResult 574 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm, 575 function_ref<LogicalResult(const Twine &)> errorHandler) const; 576 577 std::vector<PipelineElement> pipeline; 578 }; 579 580 } // namespace 581 582 /// Try to initialize this pipeline with the given pipeline text. An option is 583 /// given to enable accurate error reporting. 584 LogicalResult TextualPipeline::initialize(StringRef text, 585 raw_ostream &errorStream) { 586 if (text.empty()) 587 return success(); 588 589 // Build a source manager to use for error reporting. 590 llvm::SourceMgr pipelineMgr; 591 pipelineMgr.AddNewSourceBuffer( 592 llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser", 593 /*RequiresNullTerminator=*/false), 594 SMLoc()); 595 auto errorHandler = [&](const char *rawLoc, Twine msg) { 596 pipelineMgr.PrintMessage(errorStream, SMLoc::getFromPointer(rawLoc), 597 llvm::SourceMgr::DK_Error, msg); 598 return failure(); 599 }; 600 601 // Parse the provided pipeline string. 602 if (failed(parsePipelineText(text, errorHandler))) 603 return failure(); 604 return resolvePipelineElements(pipeline, errorHandler); 605 } 606 607 /// Add the internal pipeline elements to the provided pass manager. 608 LogicalResult TextualPipeline::addToPipeline( 609 OpPassManager &pm, 610 function_ref<LogicalResult(const Twine &)> errorHandler) const { 611 // Temporarily disable implicit nesting while we append to the pipeline. We 612 // want the created pipeline to exactly match the parsed text pipeline, so 613 // it's preferrable to just error out if implicit nesting would be required. 614 OpPassManager::Nesting nesting = pm.getNesting(); 615 pm.setNesting(OpPassManager::Nesting::Explicit); 616 auto restore = llvm::make_scope_exit([&]() { pm.setNesting(nesting); }); 617 618 return addToPipeline(pipeline, pm, errorHandler); 619 } 620 621 /// Parse the given pipeline text into the internal pipeline vector. This 622 /// function only parses the structure of the pipeline, and does not resolve 623 /// its elements. 624 LogicalResult TextualPipeline::parsePipelineText(StringRef text, 625 ErrorHandlerT errorHandler) { 626 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline}; 627 for (;;) { 628 std::vector<PipelineElement> &pipeline = *pipelineStack.back(); 629 size_t pos = text.find_first_of(",(){"); 630 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim()); 631 632 // If we have a single terminating name, we're done. 633 if (pos == StringRef::npos) 634 break; 635 636 text = text.substr(pos); 637 char sep = text[0]; 638 639 // Handle pulling ... from 'pass{...}' out as PipelineElement.options. 640 if (sep == '{') { 641 text = text.substr(1); 642 643 // Skip over everything until the closing '}' and store as options. 644 size_t close = StringRef::npos; 645 for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) { 646 if (text[i] == '{') { 647 ++braceCount; 648 continue; 649 } 650 if (text[i] == '}' && --braceCount == 0) { 651 close = i; 652 break; 653 } 654 } 655 656 // Check to see if a closing options brace was found. 657 if (close == StringRef::npos) { 658 return errorHandler( 659 /*rawLoc=*/text.data() - 1, 660 "missing closing '}' while processing pass options"); 661 } 662 pipeline.back().options = text.substr(0, close); 663 text = text.substr(close + 1); 664 665 // Consume space characters that an user might add for readability. 666 text = text.ltrim(); 667 668 // Skip checking for '(' because nested pipelines cannot have options. 669 } else if (sep == '(') { 670 text = text.substr(1); 671 672 // Push the inner pipeline onto the stack to continue processing. 673 pipelineStack.push_back(&pipeline.back().innerPipeline); 674 continue; 675 } 676 677 // When handling the close parenthesis, we greedily consume them to avoid 678 // empty strings in the pipeline. 679 while (text.consume_front(")")) { 680 // If we try to pop the outer pipeline we have unbalanced parentheses. 681 if (pipelineStack.size() == 1) 682 return errorHandler(/*rawLoc=*/text.data() - 1, 683 "encountered extra closing ')' creating unbalanced " 684 "parentheses while parsing pipeline"); 685 686 pipelineStack.pop_back(); 687 // Consume space characters that an user might add for readability. 688 text = text.ltrim(); 689 } 690 691 // Check if we've finished parsing. 692 if (text.empty()) 693 break; 694 695 // Otherwise, the end of an inner pipeline always has to be followed by 696 // a comma, and then we can continue. 697 if (!text.consume_front(",")) 698 return errorHandler(text.data(), "expected ',' after parsing pipeline"); 699 } 700 701 // Check for unbalanced parentheses. 702 if (pipelineStack.size() > 1) 703 return errorHandler( 704 text.data(), 705 "encountered unbalanced parentheses while parsing pipeline"); 706 707 assert(pipelineStack.back() == &pipeline && 708 "wrong pipeline at the bottom of the stack"); 709 return success(); 710 } 711 712 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to 713 /// the corresponding registry entry. 714 LogicalResult TextualPipeline::resolvePipelineElements( 715 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) { 716 for (auto &elt : elements) 717 if (failed(resolvePipelineElement(elt, errorHandler))) 718 return failure(); 719 return success(); 720 } 721 722 /// Resolve a single element of the pipeline. 723 LogicalResult 724 TextualPipeline::resolvePipelineElement(PipelineElement &element, 725 ErrorHandlerT errorHandler) { 726 // If the inner pipeline of this element is not empty, this is an operation 727 // pipeline. 728 if (!element.innerPipeline.empty()) 729 return resolvePipelineElements(element.innerPipeline, errorHandler); 730 731 // Otherwise, this must be a pass or pass pipeline. 732 // Check to see if a pipeline was registered with this name. 733 if ((element.registryEntry = PassPipelineInfo::lookup(element.name))) 734 return success(); 735 736 // If not, then this must be a specific pass name. 737 if ((element.registryEntry = PassInfo::lookup(element.name))) 738 return success(); 739 740 // Emit an error for the unknown pass. 741 auto *rawLoc = element.name.data(); 742 return errorHandler(rawLoc, "'" + element.name + 743 "' does not refer to a " 744 "registered pass or pass pipeline"); 745 } 746 747 /// Add the given pipeline elements to the provided pass manager. 748 LogicalResult TextualPipeline::addToPipeline( 749 ArrayRef<PipelineElement> elements, OpPassManager &pm, 750 function_ref<LogicalResult(const Twine &)> errorHandler) const { 751 for (auto &elt : elements) { 752 if (elt.registryEntry) { 753 if (failed(elt.registryEntry->addToPipeline(pm, elt.options, 754 errorHandler))) { 755 return errorHandler("failed to add `" + elt.name + "` with options `" + 756 elt.options + "`"); 757 } 758 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name), 759 errorHandler))) { 760 return errorHandler("failed to add `" + elt.name + "` with options `" + 761 elt.options + "` to inner pipeline"); 762 } 763 } 764 return success(); 765 } 766 767 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm, 768 raw_ostream &errorStream) { 769 TextualPipeline pipelineParser; 770 if (failed(pipelineParser.initialize(pipeline, errorStream))) 771 return failure(); 772 auto errorHandler = [&](Twine msg) { 773 errorStream << msg << "\n"; 774 return failure(); 775 }; 776 if (failed(pipelineParser.addToPipeline(pm, errorHandler))) 777 return failure(); 778 return success(); 779 } 780 781 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline, 782 raw_ostream &errorStream) { 783 pipeline = pipeline.trim(); 784 // Pipelines are expected to be of the form `<op-name>(<pipeline>)`. 785 size_t pipelineStart = pipeline.find_first_of('('); 786 if (pipelineStart == 0 || pipelineStart == StringRef::npos || 787 !pipeline.consume_back(")")) { 788 errorStream << "expected pass pipeline to be wrapped with the anchor " 789 "operation type, e.g. 'builtin.module(...)'"; 790 return failure(); 791 } 792 793 StringRef opName = pipeline.take_front(pipelineStart).rtrim(); 794 OpPassManager pm(opName); 795 if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm, 796 errorStream))) 797 return failure(); 798 return pm; 799 } 800 801 //===----------------------------------------------------------------------===// 802 // PassNameParser 803 //===----------------------------------------------------------------------===// 804 805 namespace { 806 /// This struct represents the possible data entries in a parsed pass pipeline 807 /// list. 808 struct PassArgData { 809 PassArgData() = default; 810 PassArgData(const PassRegistryEntry *registryEntry) 811 : registryEntry(registryEntry) {} 812 813 /// This field is used when the parsed option corresponds to a registered pass 814 /// or pass pipeline. 815 const PassRegistryEntry *registryEntry{nullptr}; 816 817 /// This field is set when instance specific pass options have been provided 818 /// on the command line. 819 StringRef options; 820 }; 821 } // namespace 822 823 namespace llvm { 824 namespace cl { 825 /// Define a valid OptionValue for the command line pass argument. 826 template <> 827 struct OptionValue<PassArgData> final 828 : OptionValueBase<PassArgData, /*isClass=*/true> { 829 OptionValue(const PassArgData &value) { this->setValue(value); } 830 OptionValue() = default; 831 void anchor() override {} 832 833 bool hasValue() const { return true; } 834 const PassArgData &getValue() const { return value; } 835 void setValue(const PassArgData &value) { this->value = value; } 836 837 PassArgData value; 838 }; 839 } // namespace cl 840 } // namespace llvm 841 842 namespace { 843 844 /// The name for the command line option used for parsing the textual pass 845 /// pipeline. 846 #define PASS_PIPELINE_ARG "pass-pipeline" 847 848 /// Adds command line option for each registered pass or pass pipeline, as well 849 /// as textual pass pipelines. 850 struct PassNameParser : public llvm::cl::parser<PassArgData> { 851 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {} 852 853 void initialize(); 854 void printOptionInfo(const llvm::cl::Option &opt, 855 size_t globalWidth) const override; 856 size_t getOptionWidth(const llvm::cl::Option &opt) const override; 857 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, 858 PassArgData &value); 859 860 /// If true, this parser only parses entries that correspond to a concrete 861 /// pass registry entry, and does not include pipeline entries or the options 862 /// for pass entries. 863 bool passNamesOnly = false; 864 }; 865 } // namespace 866 867 void PassNameParser::initialize() { 868 llvm::cl::parser<PassArgData>::initialize(); 869 870 /// Add the pass entries. 871 for (const auto &kv : *passRegistry) { 872 addLiteralOption(kv.second.getPassArgument(), &kv.second, 873 kv.second.getPassDescription()); 874 } 875 /// Add the pass pipeline entries. 876 if (!passNamesOnly) { 877 for (const auto &kv : *passPipelineRegistry) { 878 addLiteralOption(kv.second.getPassArgument(), &kv.second, 879 kv.second.getPassDescription()); 880 } 881 } 882 } 883 884 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt, 885 size_t globalWidth) const { 886 // If this parser is just parsing pass names, print a simplified option 887 // string. 888 if (passNamesOnly) { 889 llvm::outs() << " --" << opt.ArgStr << "=<pass-arg>"; 890 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 18); 891 return; 892 } 893 894 // Print the information for the top-level option. 895 if (opt.hasArgStr()) { 896 llvm::outs() << " --" << opt.ArgStr; 897 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7); 898 } else { 899 llvm::outs() << " " << opt.HelpStr << '\n'; 900 } 901 902 // Functor used to print the ordered entries of a registration map. 903 auto printOrderedEntries = [&](StringRef header, auto &map) { 904 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries; 905 for (auto &kv : map) 906 orderedEntries.push_back(&kv.second); 907 llvm::array_pod_sort( 908 orderedEntries.begin(), orderedEntries.end(), 909 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) { 910 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument()); 911 }); 912 913 llvm::outs().indent(4) << header << ":\n"; 914 for (PassRegistryEntry *entry : orderedEntries) 915 entry->printHelpStr(/*indent=*/6, globalWidth); 916 }; 917 918 // Print the available passes. 919 printOrderedEntries("Passes", *passRegistry); 920 921 // Print the available pass pipelines. 922 if (!passPipelineRegistry->empty()) 923 printOrderedEntries("Pass Pipelines", *passPipelineRegistry); 924 } 925 926 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const { 927 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2; 928 929 // Check for any wider pass or pipeline options. 930 for (auto &entry : *passRegistry) 931 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 932 for (auto &entry : *passPipelineRegistry) 933 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4); 934 return maxWidth; 935 } 936 937 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName, 938 StringRef arg, PassArgData &value) { 939 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value)) 940 return true; 941 value.options = arg; 942 return false; 943 } 944 945 //===----------------------------------------------------------------------===// 946 // PassPipelineCLParser 947 //===----------------------------------------------------------------------===// 948 949 namespace mlir { 950 namespace detail { 951 struct PassPipelineCLParserImpl { 952 PassPipelineCLParserImpl(StringRef arg, StringRef description, 953 bool passNamesOnly) 954 : passList(arg, llvm::cl::desc(description)) { 955 passList.getParser().passNamesOnly = passNamesOnly; 956 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional); 957 } 958 959 /// Returns true if the given pass registry entry was registered at the 960 /// top-level of the parser, i.e. not within an explicit textual pipeline. 961 bool contains(const PassRegistryEntry *entry) const { 962 return llvm::any_of(passList, [&](const PassArgData &data) { 963 return data.registryEntry == entry; 964 }); 965 } 966 967 /// The set of passes and pass pipelines to run. 968 llvm::cl::list<PassArgData, bool, PassNameParser> passList; 969 }; 970 } // namespace detail 971 } // namespace mlir 972 973 /// Construct a pass pipeline parser with the given command line description. 974 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description) 975 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 976 arg, description, /*passNamesOnly=*/false)), 977 passPipeline( 978 PASS_PIPELINE_ARG, 979 llvm::cl::desc("Textual description of the pass pipeline to run")) {} 980 981 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description, 982 StringRef alias) 983 : PassPipelineCLParser(arg, description) { 984 passPipelineAlias.emplace(alias, 985 llvm::cl::desc("Alias for --" PASS_PIPELINE_ARG), 986 llvm::cl::aliasopt(passPipeline)); 987 } 988 989 PassPipelineCLParser::~PassPipelineCLParser() = default; 990 991 /// Returns true if this parser contains any valid options to add. 992 bool PassPipelineCLParser::hasAnyOccurrences() const { 993 return passPipeline.getNumOccurrences() != 0 || 994 impl->passList.getNumOccurrences() != 0; 995 } 996 997 /// Returns true if the given pass registry entry was registered at the 998 /// top-level of the parser, i.e. not within an explicit textual pipeline. 999 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const { 1000 return impl->contains(entry); 1001 } 1002 1003 /// Adds the passes defined by this parser entry to the given pass manager. 1004 LogicalResult PassPipelineCLParser::addToPipeline( 1005 OpPassManager &pm, 1006 function_ref<LogicalResult(const Twine &)> errorHandler) const { 1007 if (passPipeline.getNumOccurrences()) { 1008 if (impl->passList.getNumOccurrences()) 1009 return errorHandler( 1010 "'-" PASS_PIPELINE_ARG 1011 "' option can't be used with individual pass options"); 1012 std::string errMsg; 1013 llvm::raw_string_ostream os(errMsg); 1014 FailureOr<OpPassManager> parsed = parsePassPipeline(passPipeline, os); 1015 if (failed(parsed)) 1016 return errorHandler(errMsg); 1017 pm = std::move(*parsed); 1018 return success(); 1019 } 1020 1021 for (auto &passIt : impl->passList) { 1022 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options, 1023 errorHandler))) 1024 return failure(); 1025 } 1026 return success(); 1027 } 1028 1029 //===----------------------------------------------------------------------===// 1030 // PassNameCLParser 1031 1032 /// Construct a pass pipeline parser with the given command line description. 1033 PassNameCLParser::PassNameCLParser(StringRef arg, StringRef description) 1034 : impl(std::make_unique<detail::PassPipelineCLParserImpl>( 1035 arg, description, /*passNamesOnly=*/true)) { 1036 impl->passList.setMiscFlag(llvm::cl::CommaSeparated); 1037 } 1038 PassNameCLParser::~PassNameCLParser() = default; 1039 1040 /// Returns true if this parser contains any valid options to add. 1041 bool PassNameCLParser::hasAnyOccurrences() const { 1042 return impl->passList.getNumOccurrences() != 0; 1043 } 1044 1045 /// Returns true if the given pass registry entry was registered at the 1046 /// top-level of the parser, i.e. not within an explicit textual pipeline. 1047 bool PassNameCLParser::contains(const PassRegistryEntry *entry) const { 1048 return impl->contains(entry); 1049 } 1050