xref: /llvm-project/mlir/include/mlir/Pass/PassOptions.h (revision 1a70420ff3b972b3d9bbc1c4d1e98bfa12bfb73a)
1 //===- PassOptions.h - Pass Option Utilities --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains utilities for registering options with compiler passes and
10 // pipelines.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_PASS_PASSOPTIONS_H_
15 #define MLIR_PASS_PASSOPTIONS_H_
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/FunctionExtras.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Compiler.h"
22 #include <memory>
23 
24 namespace mlir {
25 class OpPassManager;
26 
27 namespace detail {
28 namespace pass_options {
29 /// Parse a string containing a list of comma-delimited elements, invoking the
30 /// given parser for each sub-element and passing them to the provided
31 /// element-append functor.
32 LogicalResult
33 parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
34                         StringRef optionStr,
35                         function_ref<LogicalResult(StringRef)> elementParseFn);
36 template <typename ElementParser, typename ElementAppendFn>
37 LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
38                                       StringRef optionStr,
39                                       ElementParser &elementParser,
40                                       ElementAppendFn &&appendFn) {
41   return parseCommaSeparatedList(
42       opt, argName, optionStr, [&](StringRef valueStr) {
43         typename ElementParser::parser_data_type value = {};
44         if (elementParser.parse(opt, argName, valueStr, value))
45           return failure();
46         appendFn(value);
47         return success();
48       });
49 }
50 
51 /// Trait used to detect if a type has a operator<< method.
52 template <typename T>
53 using has_stream_operator_trait =
54     decltype(std::declval<raw_ostream &>() << std::declval<T>());
55 template <typename T>
56 using has_stream_operator = llvm::is_detected<has_stream_operator_trait, T>;
57 
58 /// Utility methods for printing option values.
59 template <typename ParserT>
60 static void printOptionValue(raw_ostream &os, const bool &value) {
61   os << (value ? StringRef("true") : StringRef("false"));
62 }
63 template <typename ParserT>
64 static void printOptionValue(raw_ostream &os, const std::string &str) {
65   // Check if the string needs to be escaped before writing it to the ostream.
66   const size_t spaceIndex = str.find_first_of(' ');
67   const size_t escapeIndex =
68       std::min({str.find_first_of('{'), str.find_first_of('\''),
69                 str.find_first_of('"')});
70   const bool requiresEscape = spaceIndex < escapeIndex;
71   if (requiresEscape)
72     os << "{";
73   os << str;
74   if (requiresEscape)
75     os << "}";
76 }
77 template <typename ParserT, typename DataT>
78 static std::enable_if_t<has_stream_operator<DataT>::value>
79 printOptionValue(raw_ostream &os, const DataT &value) {
80   os << value;
81 }
82 template <typename ParserT, typename DataT>
83 static std::enable_if_t<!has_stream_operator<DataT>::value>
84 printOptionValue(raw_ostream &os, const DataT &value) {
85   // If the value can't be streamed, fallback to checking for a print in the
86   // parser.
87   ParserT::print(os, value);
88 }
89 } // namespace pass_options
90 
91 /// Base container class and manager for all pass options.
92 class PassOptions : protected llvm::cl::SubCommand {
93 private:
94   /// This is the type-erased option base class. This provides some additional
95   /// hooks into the options that are not available via llvm::cl::Option.
96   class OptionBase {
97   public:
98     virtual ~OptionBase() = default;
99 
100     /// Out of line virtual function to provide home for the class.
101     virtual void anchor();
102 
103     /// Print the name and value of this option to the given stream.
104     virtual void print(raw_ostream &os) = 0;
105 
106     /// Return the argument string of this option.
107     StringRef getArgStr() const { return getOption()->ArgStr; }
108 
109     /// Returns true if this option has any value assigned to it.
110     bool hasValue() const { return optHasValue; }
111 
112   protected:
113     /// Return the main option instance.
114     virtual const llvm::cl::Option *getOption() const = 0;
115 
116     /// Copy the value from the given option into this one.
117     virtual void copyValueFrom(const OptionBase &other) = 0;
118 
119     /// Flag indicating if this option has a value.
120     bool optHasValue = false;
121 
122     /// Allow access to private methods.
123     friend PassOptions;
124   };
125 
126   /// This is the parser that is used by pass options that use literal options.
127   /// This is a thin wrapper around the llvm::cl::parser, that exposes some
128   /// additional methods.
129   template <typename DataType>
130   struct GenericOptionParser : public llvm::cl::parser<DataType> {
131     using llvm::cl::parser<DataType>::parser;
132 
133     /// Returns an argument name that maps to the specified value.
134     std::optional<StringRef> findArgStrForValue(const DataType &value) {
135       for (auto &it : this->Values)
136         if (it.V.compare(value))
137           return it.Name;
138       return std::nullopt;
139     }
140   };
141 
142   /// This is the parser that is used by pass options that wrap PassOptions
143   /// instances. Like GenericOptionParser, this is a thin wrapper around
144   /// llvm::cl::basic_parser.
145   template <typename PassOptionsT>
146   struct PassOptionsParser : public llvm::cl::basic_parser<PassOptionsT> {
147     using llvm::cl::basic_parser<PassOptionsT>::basic_parser;
148     // Parse the options object by delegating to
149     // `PassOptionsT::parseFromString`.
150     bool parse(llvm::cl::Option &, StringRef, StringRef arg,
151                PassOptionsT &value) {
152       return failed(value.parseFromString(arg));
153     }
154 
155     // Print the options object by delegating to `PassOptionsT::print`.
156     static void print(llvm::raw_ostream &os, const PassOptionsT &value) {
157       value.print(os);
158     }
159   };
160 
161   /// Utility methods for printing option values.
162   template <typename DataT>
163   static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
164                          const DataT &value) {
165     if (std::optional<StringRef> argStr = parser.findArgStrForValue(value))
166       os << *argStr;
167     else
168       llvm_unreachable("unknown data value for option");
169   }
170   template <typename DataT, typename ParserT>
171   static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
172     detail::pass_options::printOptionValue<ParserT>(os, value);
173   }
174 
175 public:
176   /// The specific parser to use. This is necessary because we need to provide
177   /// additional methods for certain data type parsers.
178   template <typename DataType>
179   using OptionParser = std::conditional_t<
180       // If the data type is derived from PassOptions, use the
181       // PassOptionsParser.
182       std::is_base_of_v<PassOptions, DataType>, PassOptionsParser<DataType>,
183       // Otherwise, use GenericOptionParser where it is well formed, and fall
184       // back to llvm::cl::parser otherwise.
185       // TODO: We should upstream the methods in GenericOptionParser to avoid
186       // the  need to do this.
187       std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
188                                          llvm::cl::parser<DataType>>::value,
189                          GenericOptionParser<DataType>,
190                          llvm::cl::parser<DataType>>>;
191 
192   /// This class represents a specific pass option, with a provided
193   /// data type.
194   template <typename DataType, typename OptionParser = OptionParser<DataType>>
195   class Option
196       : public llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>,
197         public OptionBase {
198   public:
199     template <typename... Args>
200     Option(PassOptions &parent, StringRef arg, Args &&...args)
201         : llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>(
202               arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
203       assert(!this->isPositional() && !this->isSink() &&
204              "sink and positional options are not supported");
205       parent.options.push_back(this);
206 
207       // Set a callback to track if this option has a value.
208       this->setCallback([this](const auto &) { this->optHasValue = true; });
209     }
210     ~Option() override = default;
211     using llvm::cl::opt<DataType, /*ExternalStorage=*/false,
212                         OptionParser>::operator=;
213     Option &operator=(const Option &other) {
214       *this = other.getValue();
215       return *this;
216     }
217 
218   private:
219     /// Return the main option instance.
220     const llvm::cl::Option *getOption() const final { return this; }
221 
222     /// Print the name and value of this option to the given stream.
223     void print(raw_ostream &os) final {
224       os << this->ArgStr << '=';
225       printValue(os, this->getParser(), this->getValue());
226     }
227 
228     /// Copy the value from the given option into this one.
229     void copyValueFrom(const OptionBase &other) final {
230       this->setValue(static_cast<const Option<DataType, OptionParser> &>(other)
231                          .getValue());
232       optHasValue = other.optHasValue;
233     }
234   };
235 
236   /// This class represents a specific pass option that contains a list of
237   /// values of the provided data type. The elements within the textual form of
238   /// this option are parsed assuming they are comma-separated. Delimited
239   /// sub-ranges within individual elements of the list may contain commas that
240   /// are not treated as separators for the top-level list.
241   template <typename DataType, typename OptionParser = OptionParser<DataType>>
242   class ListOption
243       : public llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>,
244         public OptionBase {
245   public:
246     template <typename... Args>
247     ListOption(PassOptions &parent, StringRef arg, Args &&...args)
248         : llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>(
249               arg, llvm::cl::sub(parent), std::forward<Args>(args)...),
250           elementParser(*this) {
251       assert(!this->isPositional() && !this->isSink() &&
252              "sink and positional options are not supported");
253       assert(!(this->getMiscFlags() & llvm::cl::MiscFlags::CommaSeparated) &&
254              "ListOption is implicitly comma separated, specifying "
255              "CommaSeparated is extraneous");
256 
257       // Make the default explicitly "empty" if no default was given.
258       if (!this->isDefaultAssigned())
259         this->setInitialValues({});
260 
261       parent.options.push_back(this);
262       elementParser.initialize();
263     }
264     ~ListOption() override = default;
265     ListOption<DataType, OptionParser> &
266     operator=(const ListOption<DataType, OptionParser> &other) {
267       *this = ArrayRef<DataType>(other);
268       this->optHasValue = other.optHasValue;
269       return *this;
270     }
271 
272     bool handleOccurrence(unsigned pos, StringRef argName,
273                           StringRef arg) override {
274       if (this->isDefaultAssigned()) {
275         this->clear();
276         this->overwriteDefault();
277       }
278       this->optHasValue = true;
279       return failed(detail::pass_options::parseCommaSeparatedList(
280           *this, argName, arg, elementParser,
281           [&](const DataType &value) { this->addValue(value); }));
282     }
283 
284     /// Allow assigning from an ArrayRef.
285     ListOption<DataType, OptionParser> &operator=(ArrayRef<DataType> values) {
286       ((std::vector<DataType> &)*this).assign(values.begin(), values.end());
287       optHasValue = true;
288       return *this;
289     }
290 
291     /// Allow accessing the data held by this option.
292     MutableArrayRef<DataType> operator*() {
293       return static_cast<std::vector<DataType> &>(*this);
294     }
295     ArrayRef<DataType> operator*() const {
296       return static_cast<const std::vector<DataType> &>(*this);
297     }
298 
299   private:
300     /// Return the main option instance.
301     const llvm::cl::Option *getOption() const final { return this; }
302 
303     /// Print the name and value of this option to the given stream.
304     /// Note that there is currently a limitation with regards to
305     /// `ListOption<string>`: parsing 'option=""` will result in `option` being
306     /// set to the empty list, not to a size-1 list containing an empty string.
307     void print(raw_ostream &os) final {
308       // Don't print the list if the value is the default value.
309       if (this->isDefaultAssigned() &&
310           this->getDefault().size() == (**this).size()) {
311         unsigned i = 0;
312         for (unsigned e = (**this).size(); i < e; i++) {
313           if (!this->getDefault()[i].compare((**this)[i]))
314             break;
315         }
316         if (i == (**this).size())
317           return;
318       }
319 
320       os << this->ArgStr << "={";
321       auto printElementFn = [&](const DataType &value) {
322         printValue(os, this->getParser(), value);
323       };
324       llvm::interleave(*this, os, printElementFn, ",");
325       os << "}";
326     }
327 
328     /// Copy the value from the given option into this one.
329     void copyValueFrom(const OptionBase &other) final {
330       *this = static_cast<const ListOption<DataType, OptionParser> &>(other);
331     }
332 
333     /// The parser to use for parsing the list elements.
334     OptionParser elementParser;
335   };
336 
337   PassOptions() = default;
338   /// Delete the copy constructor to avoid copying the internal options map.
339   PassOptions(const PassOptions &) = delete;
340   PassOptions(PassOptions &&) = delete;
341 
342   /// Copy the option values from 'other' into 'this', where 'other' has the
343   /// same options as 'this'.
344   void copyOptionValuesFrom(const PassOptions &other);
345 
346   /// Parse options out as key=value pairs that can then be handed off to the
347   /// `llvm::cl` command line passing infrastructure. Everything is space
348   /// separated.
349   LogicalResult parseFromString(StringRef options,
350                                 raw_ostream &errorStream = llvm::errs());
351 
352   /// Print the options held by this struct in a form that can be parsed via
353   /// 'parseFromString'.
354   void print(raw_ostream &os) const;
355 
356   /// Print the help string for the options held by this struct. `descIndent` is
357   /// the indent that the descriptions should be aligned.
358   void printHelp(size_t indent, size_t descIndent) const;
359 
360   /// Return the maximum width required when printing the help string.
361   size_t getOptionWidth() const;
362 
363 private:
364   /// A list of all of the opaque options.
365   std::vector<OptionBase *> options;
366 };
367 } // namespace detail
368 
369 //===----------------------------------------------------------------------===//
370 // PassPipelineOptions
371 //===----------------------------------------------------------------------===//
372 
373 /// Subclasses of PassPipelineOptions provide a set of options that can be used
374 /// to initialize a pass pipeline. See PassPipelineRegistration for usage
375 /// details.
376 ///
377 /// Usage:
378 ///
379 /// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
380 ///   ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
381 /// };
382 template <typename T>
383 class PassPipelineOptions : public detail::PassOptions {
384 public:
385   /// Factory that parses the provided options and returns a unique_ptr to the
386   /// struct.
387   static std::unique_ptr<T> createFromString(StringRef options) {
388     auto result = std::make_unique<T>();
389     if (failed(result->parseFromString(options)))
390       return nullptr;
391     return result;
392   }
393 };
394 
395 /// A default empty option struct to be used for passes that do not need to take
396 /// any options.
397 struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
398 };
399 } // namespace mlir
400 
401 //===----------------------------------------------------------------------===//
402 // MLIR Options
403 //===----------------------------------------------------------------------===//
404 
405 namespace llvm {
406 namespace cl {
407 //===----------------------------------------------------------------------===//
408 // std::vector+SmallVector
409 
410 namespace detail {
411 template <typename VectorT, typename ElementT>
412 class VectorParserBase : public basic_parser_impl {
413 public:
414   VectorParserBase(Option &opt) : basic_parser_impl(opt), elementParser(opt) {}
415 
416   using parser_data_type = VectorT;
417 
418   bool parse(Option &opt, StringRef argName, StringRef arg,
419              parser_data_type &vector) {
420     if (!arg.consume_front("[") || !arg.consume_back("]")) {
421       return opt.error("expected vector option to be wrapped with '[]'",
422                        argName);
423     }
424 
425     return failed(mlir::detail::pass_options::parseCommaSeparatedList(
426         opt, argName, arg, elementParser,
427         [&](const ElementT &value) { vector.push_back(value); }));
428   }
429 
430   static void print(raw_ostream &os, const VectorT &vector) {
431     llvm::interleave(
432         vector, os,
433         [&](const ElementT &value) {
434           mlir::detail::pass_options::printOptionValue<
435               llvm::cl::parser<ElementT>>(os, value);
436         },
437         ",");
438   }
439 
440   void printOptionInfo(const Option &opt, size_t globalWidth) const {
441     // Add the `vector<>` qualifier to the option info.
442     outs() << "  --" << opt.ArgStr;
443     outs() << "=<vector<" << elementParser.getValueName() << ">>";
444     Option::printHelpStr(opt.HelpStr, globalWidth, getOptionWidth(opt));
445   }
446 
447   size_t getOptionWidth(const Option &opt) const {
448     // Add the `vector<>` qualifier to the option width.
449     StringRef vectorExt("vector<>");
450     return elementParser.getOptionWidth(opt) + vectorExt.size();
451   }
452 
453 private:
454   llvm::cl::parser<ElementT> elementParser;
455 };
456 } // namespace detail
457 
458 template <typename T>
459 class parser<std::vector<T>>
460     : public detail::VectorParserBase<std::vector<T>, T> {
461 public:
462   parser(Option &opt) : detail::VectorParserBase<std::vector<T>, T>(opt) {}
463 };
464 template <typename T, unsigned N>
465 class parser<SmallVector<T, N>>
466     : public detail::VectorParserBase<SmallVector<T, N>, T> {
467 public:
468   parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
469 };
470 
471 //===----------------------------------------------------------------------===//
472 // OpPassManager: OptionValue
473 
474 template <>
475 struct OptionValue<mlir::OpPassManager> final : GenericOptionValue {
476   using WrapperType = mlir::OpPassManager;
477 
478   OptionValue();
479   OptionValue(const OptionValue<mlir::OpPassManager> &rhs);
480   OptionValue(const mlir::OpPassManager &value);
481   OptionValue<mlir::OpPassManager> &operator=(const mlir::OpPassManager &rhs);
482   ~OptionValue();
483 
484   /// Returns if the current option has a value.
485   bool hasValue() const { return value.get(); }
486 
487   /// Returns the current value of the option.
488   mlir::OpPassManager &getValue() const {
489     assert(hasValue() && "invalid option value");
490     return *value;
491   }
492 
493   /// Set the value of the option.
494   void setValue(const mlir::OpPassManager &newValue);
495   void setValue(StringRef pipelineStr);
496 
497   /// Compare the option with the provided value.
498   bool compare(const mlir::OpPassManager &rhs) const;
499   bool compare(const GenericOptionValue &rhs) const override {
500     const auto &rhsOV =
501         static_cast<const OptionValue<mlir::OpPassManager> &>(rhs);
502     if (!rhsOV.hasValue())
503       return false;
504     return compare(rhsOV.getValue());
505   }
506 
507 private:
508   void anchor() override;
509 
510   /// The underlying pass manager. We use a unique_ptr to avoid the need for the
511   /// full type definition.
512   std::unique_ptr<mlir::OpPassManager> value;
513 };
514 
515 //===----------------------------------------------------------------------===//
516 // OpPassManager: Parser
517 
518 extern template class basic_parser<mlir::OpPassManager>;
519 
520 template <>
521 class parser<mlir::OpPassManager> : public basic_parser<mlir::OpPassManager> {
522 public:
523   /// A utility struct used when parsing a pass manager that prevents the need
524   /// for a default constructor on OpPassManager.
525   struct ParsedPassManager {
526     ParsedPassManager();
527     ParsedPassManager(ParsedPassManager &&);
528     ~ParsedPassManager();
529     operator const mlir::OpPassManager &() const {
530       assert(value && "parsed value was invalid");
531       return *value;
532     }
533 
534     std::unique_ptr<mlir::OpPassManager> value;
535   };
536   using parser_data_type = ParsedPassManager;
537   using OptVal = OptionValue<mlir::OpPassManager>;
538 
539   parser(Option &opt) : basic_parser(opt) {}
540 
541   bool parse(Option &, StringRef, StringRef arg, ParsedPassManager &value);
542 
543   /// Print an instance of the underling option value to the given stream.
544   static void print(raw_ostream &os, const mlir::OpPassManager &value);
545 
546   // Overload in subclass to provide a better default value.
547   StringRef getValueName() const override { return "pass-manager"; }
548 
549   void printOptionDiff(const Option &opt, mlir::OpPassManager &pm,
550                        const OptVal &defaultValue, size_t globalWidth) const;
551 
552   // An out-of-line virtual method to provide a 'home' for this class.
553   void anchor() override;
554 };
555 
556 } // namespace cl
557 } // namespace llvm
558 
559 #endif // MLIR_PASS_PASSOPTIONS_H_
560