xref: /llvm-project/mlir/include/mlir/Pass/Pass.h (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_PASS_PASS_H
10 #define MLIR_PASS_PASS_H
11 
12 #include "mlir/IR/Action.h"
13 #include "mlir/Pass/AnalysisManager.h"
14 #include "mlir/Pass/PassRegistry.h"
15 #include "llvm/ADT/PointerIntPair.h"
16 #include "llvm/ADT/Statistic.h"
17 #include <optional>
18 
19 namespace mlir {
20 namespace detail {
21 class OpToOpPassAdaptor;
22 struct OpPassManagerImpl;
23 
24 /// The state for a single execution of a pass. This provides a unified
25 /// interface for accessing and initializing necessary state for pass execution.
26 struct PassExecutionState {
PassExecutionStatePassExecutionState27   PassExecutionState(Operation *ir, AnalysisManager analysisManager,
28                      function_ref<LogicalResult(OpPassManager &, Operation *)>
29                          pipelineExecutor)
30       : irAndPassFailed(ir, false), analysisManager(analysisManager),
31         pipelineExecutor(pipelineExecutor) {}
32 
33   /// The current operation being transformed and a bool for if the pass
34   /// signaled a failure.
35   llvm::PointerIntPair<Operation *, 1, bool> irAndPassFailed;
36 
37   /// The analysis manager for the operation.
38   AnalysisManager analysisManager;
39 
40   /// The set of preserved analyses for the current execution.
41   detail::PreservedAnalyses preservedAnalyses;
42 
43   /// This is a callback in the PassManager that allows to schedule dynamic
44   /// pipelines that will be rooted at the provided operation.
45   function_ref<LogicalResult(OpPassManager &, Operation *)> pipelineExecutor;
46 };
47 } // namespace detail
48 
49 /// The abstract base pass class. This class contains information describing the
50 /// derived pass object, e.g its kind and abstract TypeID.
51 class Pass {
52 public:
53   virtual ~Pass() = default;
54 
55   /// Returns the unique identifier that corresponds to this pass.
getTypeID()56   TypeID getTypeID() const { return passID; }
57 
58   /// Returns the pass info for this pass, or null if unknown.
lookupPassInfo()59   const PassInfo *lookupPassInfo() const {
60     return PassInfo::lookup(getArgument());
61   }
62 
63   /// Returns the derived pass name.
64   virtual StringRef getName() const = 0;
65 
66   /// Register dependent dialects for the current pass.
67   /// A pass is expected to register the dialects it will create entities for
68   /// (Operations, Types, Attributes), other than dialect that exists in the
69   /// input. For example, a pass that converts from Linalg to Affine would
70   /// register the Affine dialect but does not need to register Linalg.
getDependentDialects(DialectRegistry & registry)71   virtual void getDependentDialects(DialectRegistry &registry) const {}
72 
73   /// Return the command line argument used when registering this pass. Return
74   /// an empty string if one does not exist.
getArgument()75   virtual StringRef getArgument() const { return ""; }
76 
77   /// Return the command line description used when registering this pass.
78   /// Return an empty string if one does not exist.
getDescription()79   virtual StringRef getDescription() const { return ""; }
80 
81   /// Returns the name of the operation that this pass operates on, or
82   /// std::nullopt if this is a generic OperationPass.
getOpName()83   std::optional<StringRef> getOpName() const { return opName; }
84 
85   //===--------------------------------------------------------------------===//
86   // Options
87   //===--------------------------------------------------------------------===//
88 
89   /// This class represents a specific pass option, with a provided data type.
90   template <typename DataType,
91             typename OptionParser = detail::PassOptions::OptionParser<DataType>>
92   struct Option : public detail::PassOptions::Option<DataType, OptionParser> {
93     template <typename... Args>
OptionOption94     Option(Pass &parent, StringRef arg, Args &&...args)
95         : detail::PassOptions::Option<DataType, OptionParser>(
96               parent.passOptions, arg, std::forward<Args>(args)...) {}
97     using detail::PassOptions::Option<DataType, OptionParser>::operator=;
98   };
99   /// This class represents a specific pass option that contains a list of
100   /// values of the provided data type.
101   template <typename DataType,
102             typename OptionParser = detail::PassOptions::OptionParser<DataType>>
103   struct ListOption
104       : public detail::PassOptions::ListOption<DataType, OptionParser> {
105     template <typename... Args>
ListOptionListOption106     ListOption(Pass &parent, StringRef arg, Args &&...args)
107         : detail::PassOptions::ListOption<DataType, OptionParser>(
108               parent.passOptions, arg, std::forward<Args>(args)...) {}
109     using detail::PassOptions::ListOption<DataType, OptionParser>::operator=;
110   };
111 
112   /// Attempt to initialize the options of this pass from the given string.
113   /// Derived classes may override this method to hook into the point at which
114   /// options are initialized, but should generally always invoke this base
115   /// class variant.
116   virtual LogicalResult
117   initializeOptions(StringRef options,
118                     function_ref<LogicalResult(const Twine &)> errorHandler);
119 
120   /// Prints out the pass in the textual representation of pipelines. If this is
121   /// an adaptor pass, print its pass managers.
122   void printAsTextualPipeline(raw_ostream &os);
123 
124   //===--------------------------------------------------------------------===//
125   // Statistics
126   //===--------------------------------------------------------------------===//
127 
128   /// This class represents a single pass statistic. This statistic functions
129   /// similarly to an unsigned integer value, and may be updated and incremented
130   /// accordingly. This class can be used to provide additional information
131   /// about the transformations and analyses performed by a pass.
132   class Statistic : public llvm::Statistic {
133   public:
134     /// The statistic is initialized by the pass owner, a name, and a
135     /// description.
136     Statistic(Pass *owner, const char *name, const char *description);
137 
138     /// Assign the statistic to the given value.
139     Statistic &operator=(unsigned value);
140   };
141 
142   /// Returns the main statistics for this pass instance.
getStatistics()143   ArrayRef<Statistic *> getStatistics() const { return statistics; }
getStatistics()144   MutableArrayRef<Statistic *> getStatistics() { return statistics; }
145 
146   /// Returns the thread sibling of this pass.
147   ///
148   /// If this pass was cloned by the pass manager for the sake of
149   /// multi-threading, this function returns the original pass it was cloned
150   /// from. This is useful for diagnostic purposes to distinguish passes that
151   /// were replicated for threading purposes from passes instantiated by the
152   /// user. Used to collapse passes in timing statistics.
getThreadingSibling()153   const Pass *getThreadingSibling() const { return threadingSibling; }
154 
155   /// Returns the thread sibling of this pass, or the pass itself it has no
156   /// sibling. See `getThreadingSibling()` for details.
getThreadingSiblingOrThis()157   const Pass *getThreadingSiblingOrThis() const {
158     return threadingSibling ? threadingSibling : this;
159   }
160 
161 protected:
162   explicit Pass(TypeID passID, std::optional<StringRef> opName = std::nullopt)
passID(passID)163       : passID(passID), opName(opName) {}
Pass(const Pass & other)164   Pass(const Pass &other) : Pass(other.passID, other.opName) {}
165   Pass &operator=(const Pass &) = delete;
166   Pass(Pass &&) = delete;
167   Pass &operator=(Pass &&) = delete;
168 
169   /// Returns the current pass state.
getPassState()170   detail::PassExecutionState &getPassState() {
171     assert(passState && "pass state was never initialized");
172     return *passState;
173   }
174 
175   /// Return the MLIR context for the current operation being transformed.
getContext()176   MLIRContext &getContext() { return *getOperation()->getContext(); }
177 
178   /// The polymorphic API that runs the pass over the currently held operation.
179   virtual void runOnOperation() = 0;
180 
181   /// Initialize any complex state necessary for running this pass. This hook
182   /// should not rely on any state accessible during the execution of a pass.
183   /// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be
184   /// invoked within this hook.
185   /// This method is invoked after all dependent dialects for the pipeline are
186   /// loaded, and is not allowed to load any further dialects (override the
187   /// `getDependentDialects()` for this purpose instead). Returns a LogicalResult
188   /// to indicate failure, in which case the pass pipeline won't execute.
initialize(MLIRContext * context)189   virtual LogicalResult initialize(MLIRContext *context) { return success(); }
190 
191   /// Indicate if the current pass can be scheduled on the given operation type.
192   /// This is useful for generic operation passes to add restrictions on the
193   /// operations they operate on.
194   virtual bool canScheduleOn(RegisteredOperationName opName) const = 0;
195 
196   /// Schedule an arbitrary pass pipeline on the provided operation.
197   /// This can be invoke any time in a pass to dynamic schedule more passes.
198   /// The provided operation must be the current one or one nested below.
runPipeline(OpPassManager & pipeline,Operation * op)199   LogicalResult runPipeline(OpPassManager &pipeline, Operation *op) {
200     return passState->pipelineExecutor(pipeline, op);
201   }
202 
203   /// A clone method to create a copy of this pass.
clone()204   std::unique_ptr<Pass> clone() const {
205     auto newInst = clonePass();
206     newInst->copyOptionValuesFrom(this);
207     return newInst;
208   }
209 
210   /// Return the current operation being transformed.
getOperation()211   Operation *getOperation() {
212     return getPassState().irAndPassFailed.getPointer();
213   }
214 
215   /// Signal that some invariant was broken when running. The IR is allowed to
216   /// be in an invalid state.
signalPassFailure()217   void signalPassFailure() { getPassState().irAndPassFailed.setInt(true); }
218 
219   /// Query an analysis for the current ir unit.
220   template <typename AnalysisT>
getAnalysis()221   AnalysisT &getAnalysis() {
222     return getAnalysisManager().getAnalysis<AnalysisT>();
223   }
224 
225   /// Query an analysis for the current ir unit of a specific derived operation
226   /// type.
227   template <typename AnalysisT, typename OpT>
getAnalysis()228   AnalysisT &getAnalysis() {
229     return getAnalysisManager().getAnalysis<AnalysisT, OpT>();
230   }
231 
232   /// Query a cached instance of an analysis for the current ir unit if one
233   /// exists.
234   template <typename AnalysisT>
getCachedAnalysis()235   std::optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() {
236     return getAnalysisManager().getCachedAnalysis<AnalysisT>();
237   }
238 
239   /// Mark all analyses as preserved.
markAllAnalysesPreserved()240   void markAllAnalysesPreserved() {
241     getPassState().preservedAnalyses.preserveAll();
242   }
243 
244   /// Mark the provided analyses as preserved.
245   template <typename... AnalysesT>
markAnalysesPreserved()246   void markAnalysesPreserved() {
247     getPassState().preservedAnalyses.preserve<AnalysesT...>();
248   }
markAnalysesPreserved(TypeID id)249   void markAnalysesPreserved(TypeID id) {
250     getPassState().preservedAnalyses.preserve(id);
251   }
252 
253   /// Returns the analysis for the given parent operation if it exists.
254   template <typename AnalysisT>
255   std::optional<std::reference_wrapper<AnalysisT>>
getCachedParentAnalysis(Operation * parent)256   getCachedParentAnalysis(Operation *parent) {
257     return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(parent);
258   }
259 
260   /// Returns the analysis for the parent operation if it exists.
261   template <typename AnalysisT>
getCachedParentAnalysis()262   std::optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() {
263     return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(
264         getOperation()->getParentOp());
265   }
266 
267   /// Returns the analysis for the given child operation if it exists.
268   template <typename AnalysisT>
269   std::optional<std::reference_wrapper<AnalysisT>>
getCachedChildAnalysis(Operation * child)270   getCachedChildAnalysis(Operation *child) {
271     return getAnalysisManager().getCachedChildAnalysis<AnalysisT>(child);
272   }
273 
274   /// Returns the analysis for the given child operation, or creates it if it
275   /// doesn't exist.
276   template <typename AnalysisT>
getChildAnalysis(Operation * child)277   AnalysisT &getChildAnalysis(Operation *child) {
278     return getAnalysisManager().getChildAnalysis<AnalysisT>(child);
279   }
280 
281   /// Returns the analysis for the given child operation of specific derived
282   /// operation type, or creates it if it doesn't exist.
283   template <typename AnalysisT, typename OpTy>
getChildAnalysis(OpTy child)284   AnalysisT &getChildAnalysis(OpTy child) {
285     return getAnalysisManager().getChildAnalysis<AnalysisT>(child);
286   }
287 
288   /// Returns the current analysis manager.
getAnalysisManager()289   AnalysisManager getAnalysisManager() {
290     return getPassState().analysisManager;
291   }
292 
293   /// Create a copy of this pass, ignoring statistics and options.
294   virtual std::unique_ptr<Pass> clonePass() const = 0;
295 
296   /// Copy the option values from 'other', which is another instance of this
297   /// pass.
298   void copyOptionValuesFrom(const Pass *other);
299 
300 private:
301   /// Out of line virtual method to ensure vtables and metadata are emitted to a
302   /// single .o file.
303   virtual void anchor();
304 
305   /// Represents a unique identifier for the pass.
306   TypeID passID;
307 
308   /// The name of the operation that this pass operates on, or std::nullopt if
309   /// this is a generic OperationPass.
310   std::optional<StringRef> opName;
311 
312   /// The current execution state for the pass.
313   std::optional<detail::PassExecutionState> passState;
314 
315   /// The set of statistics held by this pass.
316   std::vector<Statistic *> statistics;
317 
318   /// The pass options registered to this pass instance.
319   detail::PassOptions passOptions;
320 
321   /// A pointer to the pass this pass was cloned from, if the clone was made by
322   /// the pass manager for the sake of multi-threading.
323   const Pass *threadingSibling = nullptr;
324 
325   /// Allow access to 'clone'.
326   friend class OpPassManager;
327 
328   /// Allow access to 'canScheduleOn'.
329   friend detail::OpPassManagerImpl;
330 
331   /// Allow access to 'passState'.
332   friend detail::OpToOpPassAdaptor;
333 
334   /// Allow access to 'passOptions'.
335   friend class PassInfo;
336 };
337 
338 //===----------------------------------------------------------------------===//
339 // Pass Model Definitions
340 //===----------------------------------------------------------------------===//
341 
342 /// Pass to transform an operation of a specific type.
343 ///
344 /// Operation passes must not:
345 ///   - modify any other operations within the parent region, as other threads
346 ///     may be manipulating them concurrently.
347 ///   - modify any state within the parent operation, this includes adding
348 ///     additional operations.
349 ///
350 /// Derived operation passes are expected to provide the following:
351 ///   - A 'void runOnOperation()' method.
352 ///   - A 'StringRef getName() const' method.
353 ///   - A 'std::unique_ptr<Pass> clonePass() const' method.
354 template <typename OpT = void>
355 class OperationPass : public Pass {
356 public:
357   ~OperationPass() override = default;
358 
359 protected:
OperationPass(TypeID passID)360   OperationPass(TypeID passID) : Pass(passID, OpT::getOperationName()) {}
361   OperationPass(const OperationPass &) = default;
362   OperationPass &operator=(const OperationPass &) = delete;
363   OperationPass(OperationPass &&) = delete;
364   OperationPass &operator=(OperationPass &&) = delete;
365 
366   /// Support isa/dyn_cast functionality.
classof(const Pass * pass)367   static bool classof(const Pass *pass) {
368     return pass->getOpName() == OpT::getOperationName();
369   }
370 
371   /// Indicate if the current pass can be scheduled on the given operation type.
canScheduleOn(RegisteredOperationName opName)372   bool canScheduleOn(RegisteredOperationName opName) const final {
373     return opName.getStringRef() == getOpName();
374   }
375 
376   /// Return the current operation being transformed.
getOperation()377   OpT getOperation() { return cast<OpT>(Pass::getOperation()); }
378 
379   /// Query an analysis for the current operation of the specific derived
380   /// operation type.
381   template <typename AnalysisT>
getAnalysis()382   AnalysisT &getAnalysis() {
383     return Pass::getAnalysis<AnalysisT, OpT>();
384   }
385 };
386 
387 /// Pass to transform an operation.
388 ///
389 /// Operation passes must not:
390 ///   - modify any other operations within the parent region, as other threads
391 ///     may be manipulating them concurrently.
392 ///   - modify any state within the parent operation, this includes adding
393 ///     additional operations.
394 ///
395 /// Derived operation passes are expected to provide the following:
396 ///   - A 'void runOnOperation()' method.
397 ///   - A 'StringRef getName() const' method.
398 ///   - A 'std::unique_ptr<Pass> clonePass() const' method.
399 template <>
400 class OperationPass<void> : public Pass {
401 public:
402   ~OperationPass() override = default;
403 
404 protected:
OperationPass(TypeID passID)405   OperationPass(TypeID passID) : Pass(passID) {}
406   OperationPass(const OperationPass &) = default;
407   OperationPass &operator=(const OperationPass &) = delete;
408   OperationPass(OperationPass &&) = delete;
409   OperationPass &operator=(OperationPass &&) = delete;
410 
411   /// Indicate if the current pass can be scheduled on the given operation type.
412   /// By default, generic operation passes can be scheduled on any operation.
canScheduleOn(RegisteredOperationName opName)413   bool canScheduleOn(RegisteredOperationName opName) const override {
414     return true;
415   }
416 };
417 
418 /// Pass to transform an operation that implements the given interface.
419 ///
420 /// Interface passes must not:
421 ///   - modify any other operations within the parent region, as other threads
422 ///     may be manipulating them concurrently.
423 ///   - modify any state within the parent operation, this includes adding
424 ///     additional operations.
425 ///
426 /// Derived interface passes are expected to provide the following:
427 ///   - A 'void runOnOperation()' method.
428 ///   - A 'StringRef getName() const' method.
429 ///   - A 'std::unique_ptr<Pass> clonePass() const' method.
430 template <typename InterfaceT>
431 class InterfacePass : public OperationPass<> {
432 protected:
433   using OperationPass::OperationPass;
434 
435   /// Indicate if the current pass can be scheduled on the given operation type.
436   /// For an InterfacePass, this checks if the operation implements the given
437   /// interface.
canScheduleOn(RegisteredOperationName opName)438   bool canScheduleOn(RegisteredOperationName opName) const final {
439     return opName.hasInterface<InterfaceT>();
440   }
441 
442   /// Return the current operation being transformed.
getOperation()443   InterfaceT getOperation() { return cast<InterfaceT>(Pass::getOperation()); }
444 
445   /// Query an analysis for the current operation.
446   template <typename AnalysisT>
getAnalysis()447   AnalysisT &getAnalysis() {
448     return Pass::getAnalysis<AnalysisT, InterfaceT>();
449   }
450 };
451 
452 /// This class provides a CRTP wrapper around a base pass class to define
453 /// several necessary utility methods. This should only be used for passes that
454 /// are not suitably represented using the declarative pass specification(i.e.
455 /// tablegen backend).
456 template <typename PassT, typename BaseT>
457 class PassWrapper : public BaseT {
458 public:
459   /// Support isa/dyn_cast functionality for the derived pass class.
classof(const Pass * pass)460   static bool classof(const Pass *pass) {
461     return pass->getTypeID() == TypeID::get<PassT>();
462   }
463   ~PassWrapper() override = default;
464 
465 protected:
PassWrapper()466   PassWrapper() : BaseT(TypeID::get<PassT>()) {}
467   PassWrapper(const PassWrapper &) = default;
468   PassWrapper &operator=(const PassWrapper &) = delete;
469   PassWrapper(PassWrapper &&) = delete;
470   PassWrapper &operator=(PassWrapper &&) = delete;
471 
472   /// Returns the derived pass name.
getName()473   StringRef getName() const override { return llvm::getTypeName<PassT>(); }
474 
475   /// A clone method to create a copy of this pass.
clonePass()476   std::unique_ptr<Pass> clonePass() const override {
477     return std::make_unique<PassT>(*static_cast<const PassT *>(this));
478   }
479 };
480 
481 /// This class encapsulates the "action" of executing a single pass. This allows
482 /// a user of the Action infrastructure to query information about an action in
483 /// (for example) a breakpoint context. You could use it like this:
484 ///
485 ///  auto onBreakpoint = [&](const ActionActiveStack *backtrace) {
486 ///    if (auto passExec = dyn_cast<PassExecutionAction>(anAction))
487 ///      record(passExec.getPass());
488 ///    return ExecutionContext::Apply;
489 ///  };
490 ///  ExecutionContext exeCtx(onBreakpoint);
491 ///
492 class PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
493   using Base = tracing::ActionImpl<PassExecutionAction>;
494 
495 public:
496   /// Define a TypeID for this PassExecutionAction.
497   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PassExecutionAction)
498   /// Construct a PassExecutionAction. This is called by the OpToOpPassAdaptor
499   /// when it calls `executeAction`.
500   PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass);
501 
502   /// The tag required by ActionImpl to identify this action.
503   static constexpr StringLiteral tag = "pass-execution";
504 
505   /// Print a textual version of this action to `os`.
506   void print(raw_ostream &os) const override;
507 
508   /// Get the pass that will be executed by this action. This is not a class of
509   /// passes, or all instances of a pass kind, this is a single pass.
getPass()510   const Pass &getPass() const { return pass; }
511 
512   /// Get the operation that is the base of this pass. For example, an
513   /// OperationPass<ModuleOp> would return a ModuleOp.
514   Operation *getOp() const;
515 
516 public:
517   /// Reference to the pass being run. Notice that this will *not* extend the
518   /// lifetime of the pass, and so this class is therefore unsafe to keep past
519   /// the lifetime of the `executeAction` call.
520   const Pass &pass;
521 
522   /// The base op for this pass. For an OperationPass<ModuleOp>, we would have a
523   /// ModuleOp here.
524   Operation *op;
525 };
526 
527 } // namespace mlir
528 
529 #endif // MLIR_PASS_PASS_H
530