1 //===- Pass.h - Base classes for compiler passes ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_PASS_PASS_H 10 #define MLIR_PASS_PASS_H 11 12 #include "mlir/IR/Action.h" 13 #include "mlir/Pass/AnalysisManager.h" 14 #include "mlir/Pass/PassRegistry.h" 15 #include "llvm/ADT/PointerIntPair.h" 16 #include "llvm/ADT/Statistic.h" 17 #include <optional> 18 19 namespace mlir { 20 namespace detail { 21 class OpToOpPassAdaptor; 22 struct OpPassManagerImpl; 23 24 /// The state for a single execution of a pass. This provides a unified 25 /// interface for accessing and initializing necessary state for pass execution. 26 struct PassExecutionState { PassExecutionStatePassExecutionState27 PassExecutionState(Operation *ir, AnalysisManager analysisManager, 28 function_ref<LogicalResult(OpPassManager &, Operation *)> 29 pipelineExecutor) 30 : irAndPassFailed(ir, false), analysisManager(analysisManager), 31 pipelineExecutor(pipelineExecutor) {} 32 33 /// The current operation being transformed and a bool for if the pass 34 /// signaled a failure. 35 llvm::PointerIntPair<Operation *, 1, bool> irAndPassFailed; 36 37 /// The analysis manager for the operation. 38 AnalysisManager analysisManager; 39 40 /// The set of preserved analyses for the current execution. 41 detail::PreservedAnalyses preservedAnalyses; 42 43 /// This is a callback in the PassManager that allows to schedule dynamic 44 /// pipelines that will be rooted at the provided operation. 45 function_ref<LogicalResult(OpPassManager &, Operation *)> pipelineExecutor; 46 }; 47 } // namespace detail 48 49 /// The abstract base pass class. This class contains information describing the 50 /// derived pass object, e.g its kind and abstract TypeID. 51 class Pass { 52 public: 53 virtual ~Pass() = default; 54 55 /// Returns the unique identifier that corresponds to this pass. getTypeID()56 TypeID getTypeID() const { return passID; } 57 58 /// Returns the pass info for this pass, or null if unknown. lookupPassInfo()59 const PassInfo *lookupPassInfo() const { 60 return PassInfo::lookup(getArgument()); 61 } 62 63 /// Returns the derived pass name. 64 virtual StringRef getName() const = 0; 65 66 /// Register dependent dialects for the current pass. 67 /// A pass is expected to register the dialects it will create entities for 68 /// (Operations, Types, Attributes), other than dialect that exists in the 69 /// input. For example, a pass that converts from Linalg to Affine would 70 /// register the Affine dialect but does not need to register Linalg. getDependentDialects(DialectRegistry & registry)71 virtual void getDependentDialects(DialectRegistry ®istry) const {} 72 73 /// Return the command line argument used when registering this pass. Return 74 /// an empty string if one does not exist. getArgument()75 virtual StringRef getArgument() const { return ""; } 76 77 /// Return the command line description used when registering this pass. 78 /// Return an empty string if one does not exist. getDescription()79 virtual StringRef getDescription() const { return ""; } 80 81 /// Returns the name of the operation that this pass operates on, or 82 /// std::nullopt if this is a generic OperationPass. getOpName()83 std::optional<StringRef> getOpName() const { return opName; } 84 85 //===--------------------------------------------------------------------===// 86 // Options 87 //===--------------------------------------------------------------------===// 88 89 /// This class represents a specific pass option, with a provided data type. 90 template <typename DataType, 91 typename OptionParser = detail::PassOptions::OptionParser<DataType>> 92 struct Option : public detail::PassOptions::Option<DataType, OptionParser> { 93 template <typename... Args> OptionOption94 Option(Pass &parent, StringRef arg, Args &&...args) 95 : detail::PassOptions::Option<DataType, OptionParser>( 96 parent.passOptions, arg, std::forward<Args>(args)...) {} 97 using detail::PassOptions::Option<DataType, OptionParser>::operator=; 98 }; 99 /// This class represents a specific pass option that contains a list of 100 /// values of the provided data type. 101 template <typename DataType, 102 typename OptionParser = detail::PassOptions::OptionParser<DataType>> 103 struct ListOption 104 : public detail::PassOptions::ListOption<DataType, OptionParser> { 105 template <typename... Args> ListOptionListOption106 ListOption(Pass &parent, StringRef arg, Args &&...args) 107 : detail::PassOptions::ListOption<DataType, OptionParser>( 108 parent.passOptions, arg, std::forward<Args>(args)...) {} 109 using detail::PassOptions::ListOption<DataType, OptionParser>::operator=; 110 }; 111 112 /// Attempt to initialize the options of this pass from the given string. 113 /// Derived classes may override this method to hook into the point at which 114 /// options are initialized, but should generally always invoke this base 115 /// class variant. 116 virtual LogicalResult 117 initializeOptions(StringRef options, 118 function_ref<LogicalResult(const Twine &)> errorHandler); 119 120 /// Prints out the pass in the textual representation of pipelines. If this is 121 /// an adaptor pass, print its pass managers. 122 void printAsTextualPipeline(raw_ostream &os); 123 124 //===--------------------------------------------------------------------===// 125 // Statistics 126 //===--------------------------------------------------------------------===// 127 128 /// This class represents a single pass statistic. This statistic functions 129 /// similarly to an unsigned integer value, and may be updated and incremented 130 /// accordingly. This class can be used to provide additional information 131 /// about the transformations and analyses performed by a pass. 132 class Statistic : public llvm::Statistic { 133 public: 134 /// The statistic is initialized by the pass owner, a name, and a 135 /// description. 136 Statistic(Pass *owner, const char *name, const char *description); 137 138 /// Assign the statistic to the given value. 139 Statistic &operator=(unsigned value); 140 }; 141 142 /// Returns the main statistics for this pass instance. getStatistics()143 ArrayRef<Statistic *> getStatistics() const { return statistics; } getStatistics()144 MutableArrayRef<Statistic *> getStatistics() { return statistics; } 145 146 /// Returns the thread sibling of this pass. 147 /// 148 /// If this pass was cloned by the pass manager for the sake of 149 /// multi-threading, this function returns the original pass it was cloned 150 /// from. This is useful for diagnostic purposes to distinguish passes that 151 /// were replicated for threading purposes from passes instantiated by the 152 /// user. Used to collapse passes in timing statistics. getThreadingSibling()153 const Pass *getThreadingSibling() const { return threadingSibling; } 154 155 /// Returns the thread sibling of this pass, or the pass itself it has no 156 /// sibling. See `getThreadingSibling()` for details. getThreadingSiblingOrThis()157 const Pass *getThreadingSiblingOrThis() const { 158 return threadingSibling ? threadingSibling : this; 159 } 160 161 protected: 162 explicit Pass(TypeID passID, std::optional<StringRef> opName = std::nullopt) passID(passID)163 : passID(passID), opName(opName) {} Pass(const Pass & other)164 Pass(const Pass &other) : Pass(other.passID, other.opName) {} 165 Pass &operator=(const Pass &) = delete; 166 Pass(Pass &&) = delete; 167 Pass &operator=(Pass &&) = delete; 168 169 /// Returns the current pass state. getPassState()170 detail::PassExecutionState &getPassState() { 171 assert(passState && "pass state was never initialized"); 172 return *passState; 173 } 174 175 /// Return the MLIR context for the current operation being transformed. getContext()176 MLIRContext &getContext() { return *getOperation()->getContext(); } 177 178 /// The polymorphic API that runs the pass over the currently held operation. 179 virtual void runOnOperation() = 0; 180 181 /// Initialize any complex state necessary for running this pass. This hook 182 /// should not rely on any state accessible during the execution of a pass. 183 /// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be 184 /// invoked within this hook. 185 /// This method is invoked after all dependent dialects for the pipeline are 186 /// loaded, and is not allowed to load any further dialects (override the 187 /// `getDependentDialects()` for this purpose instead). Returns a LogicalResult 188 /// to indicate failure, in which case the pass pipeline won't execute. initialize(MLIRContext * context)189 virtual LogicalResult initialize(MLIRContext *context) { return success(); } 190 191 /// Indicate if the current pass can be scheduled on the given operation type. 192 /// This is useful for generic operation passes to add restrictions on the 193 /// operations they operate on. 194 virtual bool canScheduleOn(RegisteredOperationName opName) const = 0; 195 196 /// Schedule an arbitrary pass pipeline on the provided operation. 197 /// This can be invoke any time in a pass to dynamic schedule more passes. 198 /// The provided operation must be the current one or one nested below. runPipeline(OpPassManager & pipeline,Operation * op)199 LogicalResult runPipeline(OpPassManager &pipeline, Operation *op) { 200 return passState->pipelineExecutor(pipeline, op); 201 } 202 203 /// A clone method to create a copy of this pass. clone()204 std::unique_ptr<Pass> clone() const { 205 auto newInst = clonePass(); 206 newInst->copyOptionValuesFrom(this); 207 return newInst; 208 } 209 210 /// Return the current operation being transformed. getOperation()211 Operation *getOperation() { 212 return getPassState().irAndPassFailed.getPointer(); 213 } 214 215 /// Signal that some invariant was broken when running. The IR is allowed to 216 /// be in an invalid state. signalPassFailure()217 void signalPassFailure() { getPassState().irAndPassFailed.setInt(true); } 218 219 /// Query an analysis for the current ir unit. 220 template <typename AnalysisT> getAnalysis()221 AnalysisT &getAnalysis() { 222 return getAnalysisManager().getAnalysis<AnalysisT>(); 223 } 224 225 /// Query an analysis for the current ir unit of a specific derived operation 226 /// type. 227 template <typename AnalysisT, typename OpT> getAnalysis()228 AnalysisT &getAnalysis() { 229 return getAnalysisManager().getAnalysis<AnalysisT, OpT>(); 230 } 231 232 /// Query a cached instance of an analysis for the current ir unit if one 233 /// exists. 234 template <typename AnalysisT> getCachedAnalysis()235 std::optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() { 236 return getAnalysisManager().getCachedAnalysis<AnalysisT>(); 237 } 238 239 /// Mark all analyses as preserved. markAllAnalysesPreserved()240 void markAllAnalysesPreserved() { 241 getPassState().preservedAnalyses.preserveAll(); 242 } 243 244 /// Mark the provided analyses as preserved. 245 template <typename... AnalysesT> markAnalysesPreserved()246 void markAnalysesPreserved() { 247 getPassState().preservedAnalyses.preserve<AnalysesT...>(); 248 } markAnalysesPreserved(TypeID id)249 void markAnalysesPreserved(TypeID id) { 250 getPassState().preservedAnalyses.preserve(id); 251 } 252 253 /// Returns the analysis for the given parent operation if it exists. 254 template <typename AnalysisT> 255 std::optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis(Operation * parent)256 getCachedParentAnalysis(Operation *parent) { 257 return getAnalysisManager().getCachedParentAnalysis<AnalysisT>(parent); 258 } 259 260 /// Returns the analysis for the parent operation if it exists. 261 template <typename AnalysisT> getCachedParentAnalysis()262 std::optional<std::reference_wrapper<AnalysisT>> getCachedParentAnalysis() { 263 return getAnalysisManager().getCachedParentAnalysis<AnalysisT>( 264 getOperation()->getParentOp()); 265 } 266 267 /// Returns the analysis for the given child operation if it exists. 268 template <typename AnalysisT> 269 std::optional<std::reference_wrapper<AnalysisT>> getCachedChildAnalysis(Operation * child)270 getCachedChildAnalysis(Operation *child) { 271 return getAnalysisManager().getCachedChildAnalysis<AnalysisT>(child); 272 } 273 274 /// Returns the analysis for the given child operation, or creates it if it 275 /// doesn't exist. 276 template <typename AnalysisT> getChildAnalysis(Operation * child)277 AnalysisT &getChildAnalysis(Operation *child) { 278 return getAnalysisManager().getChildAnalysis<AnalysisT>(child); 279 } 280 281 /// Returns the analysis for the given child operation of specific derived 282 /// operation type, or creates it if it doesn't exist. 283 template <typename AnalysisT, typename OpTy> getChildAnalysis(OpTy child)284 AnalysisT &getChildAnalysis(OpTy child) { 285 return getAnalysisManager().getChildAnalysis<AnalysisT>(child); 286 } 287 288 /// Returns the current analysis manager. getAnalysisManager()289 AnalysisManager getAnalysisManager() { 290 return getPassState().analysisManager; 291 } 292 293 /// Create a copy of this pass, ignoring statistics and options. 294 virtual std::unique_ptr<Pass> clonePass() const = 0; 295 296 /// Copy the option values from 'other', which is another instance of this 297 /// pass. 298 void copyOptionValuesFrom(const Pass *other); 299 300 private: 301 /// Out of line virtual method to ensure vtables and metadata are emitted to a 302 /// single .o file. 303 virtual void anchor(); 304 305 /// Represents a unique identifier for the pass. 306 TypeID passID; 307 308 /// The name of the operation that this pass operates on, or std::nullopt if 309 /// this is a generic OperationPass. 310 std::optional<StringRef> opName; 311 312 /// The current execution state for the pass. 313 std::optional<detail::PassExecutionState> passState; 314 315 /// The set of statistics held by this pass. 316 std::vector<Statistic *> statistics; 317 318 /// The pass options registered to this pass instance. 319 detail::PassOptions passOptions; 320 321 /// A pointer to the pass this pass was cloned from, if the clone was made by 322 /// the pass manager for the sake of multi-threading. 323 const Pass *threadingSibling = nullptr; 324 325 /// Allow access to 'clone'. 326 friend class OpPassManager; 327 328 /// Allow access to 'canScheduleOn'. 329 friend detail::OpPassManagerImpl; 330 331 /// Allow access to 'passState'. 332 friend detail::OpToOpPassAdaptor; 333 334 /// Allow access to 'passOptions'. 335 friend class PassInfo; 336 }; 337 338 //===----------------------------------------------------------------------===// 339 // Pass Model Definitions 340 //===----------------------------------------------------------------------===// 341 342 /// Pass to transform an operation of a specific type. 343 /// 344 /// Operation passes must not: 345 /// - modify any other operations within the parent region, as other threads 346 /// may be manipulating them concurrently. 347 /// - modify any state within the parent operation, this includes adding 348 /// additional operations. 349 /// 350 /// Derived operation passes are expected to provide the following: 351 /// - A 'void runOnOperation()' method. 352 /// - A 'StringRef getName() const' method. 353 /// - A 'std::unique_ptr<Pass> clonePass() const' method. 354 template <typename OpT = void> 355 class OperationPass : public Pass { 356 public: 357 ~OperationPass() override = default; 358 359 protected: OperationPass(TypeID passID)360 OperationPass(TypeID passID) : Pass(passID, OpT::getOperationName()) {} 361 OperationPass(const OperationPass &) = default; 362 OperationPass &operator=(const OperationPass &) = delete; 363 OperationPass(OperationPass &&) = delete; 364 OperationPass &operator=(OperationPass &&) = delete; 365 366 /// Support isa/dyn_cast functionality. classof(const Pass * pass)367 static bool classof(const Pass *pass) { 368 return pass->getOpName() == OpT::getOperationName(); 369 } 370 371 /// Indicate if the current pass can be scheduled on the given operation type. canScheduleOn(RegisteredOperationName opName)372 bool canScheduleOn(RegisteredOperationName opName) const final { 373 return opName.getStringRef() == getOpName(); 374 } 375 376 /// Return the current operation being transformed. getOperation()377 OpT getOperation() { return cast<OpT>(Pass::getOperation()); } 378 379 /// Query an analysis for the current operation of the specific derived 380 /// operation type. 381 template <typename AnalysisT> getAnalysis()382 AnalysisT &getAnalysis() { 383 return Pass::getAnalysis<AnalysisT, OpT>(); 384 } 385 }; 386 387 /// Pass to transform an operation. 388 /// 389 /// Operation passes must not: 390 /// - modify any other operations within the parent region, as other threads 391 /// may be manipulating them concurrently. 392 /// - modify any state within the parent operation, this includes adding 393 /// additional operations. 394 /// 395 /// Derived operation passes are expected to provide the following: 396 /// - A 'void runOnOperation()' method. 397 /// - A 'StringRef getName() const' method. 398 /// - A 'std::unique_ptr<Pass> clonePass() const' method. 399 template <> 400 class OperationPass<void> : public Pass { 401 public: 402 ~OperationPass() override = default; 403 404 protected: OperationPass(TypeID passID)405 OperationPass(TypeID passID) : Pass(passID) {} 406 OperationPass(const OperationPass &) = default; 407 OperationPass &operator=(const OperationPass &) = delete; 408 OperationPass(OperationPass &&) = delete; 409 OperationPass &operator=(OperationPass &&) = delete; 410 411 /// Indicate if the current pass can be scheduled on the given operation type. 412 /// By default, generic operation passes can be scheduled on any operation. canScheduleOn(RegisteredOperationName opName)413 bool canScheduleOn(RegisteredOperationName opName) const override { 414 return true; 415 } 416 }; 417 418 /// Pass to transform an operation that implements the given interface. 419 /// 420 /// Interface passes must not: 421 /// - modify any other operations within the parent region, as other threads 422 /// may be manipulating them concurrently. 423 /// - modify any state within the parent operation, this includes adding 424 /// additional operations. 425 /// 426 /// Derived interface passes are expected to provide the following: 427 /// - A 'void runOnOperation()' method. 428 /// - A 'StringRef getName() const' method. 429 /// - A 'std::unique_ptr<Pass> clonePass() const' method. 430 template <typename InterfaceT> 431 class InterfacePass : public OperationPass<> { 432 protected: 433 using OperationPass::OperationPass; 434 435 /// Indicate if the current pass can be scheduled on the given operation type. 436 /// For an InterfacePass, this checks if the operation implements the given 437 /// interface. canScheduleOn(RegisteredOperationName opName)438 bool canScheduleOn(RegisteredOperationName opName) const final { 439 return opName.hasInterface<InterfaceT>(); 440 } 441 442 /// Return the current operation being transformed. getOperation()443 InterfaceT getOperation() { return cast<InterfaceT>(Pass::getOperation()); } 444 445 /// Query an analysis for the current operation. 446 template <typename AnalysisT> getAnalysis()447 AnalysisT &getAnalysis() { 448 return Pass::getAnalysis<AnalysisT, InterfaceT>(); 449 } 450 }; 451 452 /// This class provides a CRTP wrapper around a base pass class to define 453 /// several necessary utility methods. This should only be used for passes that 454 /// are not suitably represented using the declarative pass specification(i.e. 455 /// tablegen backend). 456 template <typename PassT, typename BaseT> 457 class PassWrapper : public BaseT { 458 public: 459 /// Support isa/dyn_cast functionality for the derived pass class. classof(const Pass * pass)460 static bool classof(const Pass *pass) { 461 return pass->getTypeID() == TypeID::get<PassT>(); 462 } 463 ~PassWrapper() override = default; 464 465 protected: PassWrapper()466 PassWrapper() : BaseT(TypeID::get<PassT>()) {} 467 PassWrapper(const PassWrapper &) = default; 468 PassWrapper &operator=(const PassWrapper &) = delete; 469 PassWrapper(PassWrapper &&) = delete; 470 PassWrapper &operator=(PassWrapper &&) = delete; 471 472 /// Returns the derived pass name. getName()473 StringRef getName() const override { return llvm::getTypeName<PassT>(); } 474 475 /// A clone method to create a copy of this pass. clonePass()476 std::unique_ptr<Pass> clonePass() const override { 477 return std::make_unique<PassT>(*static_cast<const PassT *>(this)); 478 } 479 }; 480 481 /// This class encapsulates the "action" of executing a single pass. This allows 482 /// a user of the Action infrastructure to query information about an action in 483 /// (for example) a breakpoint context. You could use it like this: 484 /// 485 /// auto onBreakpoint = [&](const ActionActiveStack *backtrace) { 486 /// if (auto passExec = dyn_cast<PassExecutionAction>(anAction)) 487 /// record(passExec.getPass()); 488 /// return ExecutionContext::Apply; 489 /// }; 490 /// ExecutionContext exeCtx(onBreakpoint); 491 /// 492 class PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> { 493 using Base = tracing::ActionImpl<PassExecutionAction>; 494 495 public: 496 /// Define a TypeID for this PassExecutionAction. 497 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PassExecutionAction) 498 /// Construct a PassExecutionAction. This is called by the OpToOpPassAdaptor 499 /// when it calls `executeAction`. 500 PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass); 501 502 /// The tag required by ActionImpl to identify this action. 503 static constexpr StringLiteral tag = "pass-execution"; 504 505 /// Print a textual version of this action to `os`. 506 void print(raw_ostream &os) const override; 507 508 /// Get the pass that will be executed by this action. This is not a class of 509 /// passes, or all instances of a pass kind, this is a single pass. getPass()510 const Pass &getPass() const { return pass; } 511 512 /// Get the operation that is the base of this pass. For example, an 513 /// OperationPass<ModuleOp> would return a ModuleOp. 514 Operation *getOp() const; 515 516 public: 517 /// Reference to the pass being run. Notice that this will *not* extend the 518 /// lifetime of the pass, and so this class is therefore unsafe to keep past 519 /// the lifetime of the `executeAction` call. 520 const Pass &pass; 521 522 /// The base op for this pass. For an OperationPass<ModuleOp>, we would have a 523 /// ModuleOp here. 524 Operation *op; 525 }; 526 527 } // namespace mlir 528 529 #endif // MLIR_PASS_PASS_H 530