1 //===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_MLIRCONTEXT_H 10 #define MLIR_IR_MLIRCONTEXT_H 11 12 #include "mlir/Support/LLVM.h" 13 #include "mlir/Support/TypeID.h" 14 #include "llvm/ADT/ArrayRef.h" 15 #include <functional> 16 #include <memory> 17 #include <vector> 18 19 namespace llvm { 20 class ThreadPoolInterface; 21 } // namespace llvm 22 23 namespace mlir { 24 namespace tracing { 25 class Action; 26 } 27 class DiagnosticEngine; 28 class Dialect; 29 class DialectRegistry; 30 class DynamicDialect; 31 class InFlightDiagnostic; 32 class Location; 33 class MLIRContextImpl; 34 class RegisteredOperationName; 35 class StorageUniquer; 36 class IRUnit; 37 38 /// MLIRContext is the top-level object for a collection of MLIR operations. It 39 /// holds immortal uniqued objects like types, and the tables used to unique 40 /// them. 41 /// 42 /// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with 43 /// a very generic name ("Context") and because it is uncommon for clients to 44 /// interact with it. 45 /// 46 /// The context wrap some multi-threading facilities, and in particular by 47 /// default it will implicitly create a thread pool. 48 /// This can be undesirable if multiple context exists at the same time or if a 49 /// process will be long-lived and create and destroy contexts. 50 /// To control better thread spawning, an externally owned ThreadPool can be 51 /// injected in the context. For example: 52 /// 53 /// llvm::DefaultThreadPool myThreadPool; 54 /// while (auto *request = nextCompilationRequests()) { 55 /// MLIRContext ctx(registry, MLIRContext::Threading::DISABLED); 56 /// ctx.setThreadPool(myThreadPool); 57 /// processRequest(request, cxt); 58 /// } 59 /// 60 class MLIRContext { 61 public: 62 enum class Threading { DISABLED, ENABLED }; 63 /// Create a new Context. 64 explicit MLIRContext(Threading multithreading = Threading::ENABLED); 65 explicit MLIRContext(const DialectRegistry ®istry, 66 Threading multithreading = Threading::ENABLED); 67 ~MLIRContext(); 68 69 /// Return information about all IR dialects loaded in the context. 70 std::vector<Dialect *> getLoadedDialects(); 71 72 /// Return the dialect registry associated with this context. 73 const DialectRegistry &getDialectRegistry(); 74 75 /// Append the contents of the given dialect registry to the registry 76 /// associated with this context. 77 void appendDialectRegistry(const DialectRegistry ®istry); 78 79 /// Return information about all available dialects in the registry in this 80 /// context. 81 std::vector<StringRef> getAvailableDialects(); 82 83 /// Get a registered IR dialect with the given namespace. If an exact match is 84 /// not found, then return nullptr. 85 Dialect *getLoadedDialect(StringRef name); 86 87 /// Get a registered IR dialect for the given derived dialect type. The 88 /// derived type must provide a static 'getDialectNamespace' method. 89 template <typename T> 90 T *getLoadedDialect() { 91 return static_cast<T *>(getLoadedDialect(T::getDialectNamespace())); 92 } 93 94 /// Get (or create) a dialect for the given derived dialect type. The derived 95 /// type must provide a static 'getDialectNamespace' method. 96 template <typename T> 97 T *getOrLoadDialect() { 98 return static_cast<T *>( 99 getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() { 100 std::unique_ptr<T> dialect(new T(this)); 101 return dialect; 102 })); 103 } 104 105 /// Load a dialect in the context. 106 template <typename Dialect> 107 void loadDialect() { 108 // Do not load the dialect if it is currently loading. This can happen if a 109 // dialect initializer triggers loading the same dialect recursively. 110 if (!isDialectLoading(Dialect::getDialectNamespace())) 111 getOrLoadDialect<Dialect>(); 112 } 113 114 /// Load a list dialects in the context. 115 template <typename Dialect, typename OtherDialect, typename... MoreDialects> 116 void loadDialect() { 117 loadDialect<Dialect>(); 118 loadDialect<OtherDialect, MoreDialects...>(); 119 } 120 121 /// Get (or create) a dynamic dialect for the given name. 122 DynamicDialect * 123 getOrLoadDynamicDialect(StringRef dialectNamespace, 124 function_ref<void(DynamicDialect *)> ctor); 125 126 /// Load all dialects available in the registry in this context. 127 void loadAllAvailableDialects(); 128 129 /// Get (or create) a dialect for the given derived dialect name. 130 /// The dialect will be loaded from the registry if no dialect is found. 131 /// If no dialect is loaded for this name and none is available in the 132 /// registry, returns nullptr. 133 Dialect *getOrLoadDialect(StringRef name); 134 135 /// Return true if we allow to create operation for unregistered dialects. 136 [[nodiscard]] bool allowsUnregisteredDialects(); 137 138 /// Enables creating operations in unregistered dialects. 139 /// This option is **heavily discouraged**: it is convenient during testing 140 /// but it is not a good practice to use it in production code. Some system 141 /// invariants can be broken (like loading a dialect after creating 142 /// operations) without being caught by assertions or other means. 143 void allowUnregisteredDialects(bool allow = true); 144 145 /// Return true if multi-threading is enabled by the context. 146 bool isMultithreadingEnabled(); 147 148 /// Set the flag specifying if multi-threading is disabled by the context. 149 /// The command line debugging flag `--mlir-disable-threading` is overriding 150 /// this call and making it a no-op! 151 void disableMultithreading(bool disable = true); 152 void enableMultithreading(bool enable = true) { 153 disableMultithreading(!enable); 154 } 155 156 /// Set a new thread pool to be used in this context. This method requires 157 /// that multithreading is disabled for this context prior to the call. This 158 /// allows to share a thread pool across multiple contexts, as well as 159 /// decoupling the lifetime of the threads from the contexts. The thread pool 160 /// must outlive the context. Multi-threading will be enabled as part of this 161 /// method. 162 /// The command line debugging flag `--mlir-disable-threading` will still 163 /// prevent threading from being enabled and threading won't be enabled after 164 /// this call in this case. 165 void setThreadPool(llvm::ThreadPoolInterface &pool); 166 167 /// Return the number of threads used by the thread pool in this context. The 168 /// number of computed hardware threads can change over the lifetime of a 169 /// process based on affinity changes, so users should use the number of 170 /// threads actually in the thread pool for dispatching work. Returns 1 if 171 /// multithreading is disabled. 172 unsigned getNumThreads(); 173 174 /// Return the thread pool used by this context. This method requires that 175 /// multithreading be enabled within the context, and should generally not be 176 /// used directly. Users should instead prefer the threading utilities within 177 /// Threading.h. 178 llvm::ThreadPoolInterface &getThreadPool(); 179 180 /// Return true if we should attach the operation to diagnostics emitted via 181 /// Operation::emit. 182 bool shouldPrintOpOnDiagnostic(); 183 184 /// Set the flag specifying if we should attach the operation to diagnostics 185 /// emitted via Operation::emit. 186 void printOpOnDiagnostic(bool enable); 187 188 /// Return true if we should attach the current stacktrace to diagnostics when 189 /// emitted. 190 bool shouldPrintStackTraceOnDiagnostic(); 191 192 /// Set the flag specifying if we should attach the current stacktrace when 193 /// emitting diagnostics. 194 void printStackTraceOnDiagnostic(bool enable); 195 196 /// Return a sorted array containing the information about all registered 197 /// operations. 198 ArrayRef<RegisteredOperationName> getRegisteredOperations(); 199 200 /// Return a sorted array containing the information for registered operations 201 /// filtered by dialect name. 202 ArrayRef<RegisteredOperationName> 203 getRegisteredOperationsByDialect(StringRef dialectName); 204 205 /// Return true if this operation name is registered in this context. 206 bool isOperationRegistered(StringRef name); 207 208 // This is effectively private given that only MLIRContext.cpp can see the 209 // MLIRContextImpl type. 210 MLIRContextImpl &getImpl() { return *impl; } 211 212 /// Returns the diagnostic engine for this context. 213 DiagnosticEngine &getDiagEngine(); 214 215 /// Returns the storage uniquer used for creating affine constructs. 216 StorageUniquer &getAffineUniquer(); 217 218 /// Returns the storage uniquer used for constructing type storage instances. 219 /// This should not be used directly. 220 StorageUniquer &getTypeUniquer(); 221 222 /// Returns the storage uniquer used for constructing attribute storage 223 /// instances. This should not be used directly. 224 StorageUniquer &getAttributeUniquer(); 225 226 /// These APIs are tracking whether the context will be used in a 227 /// multithreading environment: this has no effect other than enabling 228 /// assertions on misuses of some APIs. 229 void enterMultiThreadedExecution(); 230 void exitMultiThreadedExecution(); 231 232 /// Get a dialect for the provided namespace and TypeID: abort the program if 233 /// a dialect exist for this namespace with different TypeID. If a dialect has 234 /// not been loaded for this namespace/TypeID yet, use the provided ctor to 235 /// create one on the fly and load it. Returns a pointer to the dialect owned 236 /// by the context. 237 /// The use of this method is in general discouraged in favor of 238 /// 'getOrLoadDialect<DialectClass>()'. 239 Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 240 function_ref<std::unique_ptr<Dialect>()> ctor); 241 242 /// Returns a hash of the registry of the context that may be used to give 243 /// a rough indicator of if the state of the context registry has changed. The 244 /// context registry correlates to loaded dialects and their entities 245 /// (attributes, operations, types, etc.). 246 llvm::hash_code getRegistryHash(); 247 248 //===--------------------------------------------------------------------===// 249 // Action API 250 //===--------------------------------------------------------------------===// 251 252 /// Signatures for the action handler that can be registered with the context. 253 using HandlerTy = 254 std::function<void(function_ref<void()>, const tracing::Action &)>; 255 256 /// Register a handler for handling actions that are dispatched through this 257 /// context. A nullptr handler can be set to disable a previously set handler. 258 void registerActionHandler(HandlerTy handler); 259 260 /// Return true if a valid ActionHandler is set. 261 bool hasActionHandler(); 262 263 /// Dispatch the provided action to the handler if any, or just execute it. 264 void executeAction(function_ref<void()> actionFn, 265 const tracing::Action &action) { 266 if (LLVM_UNLIKELY(hasActionHandler())) 267 executeActionInternal(actionFn, action); 268 else 269 actionFn(); 270 } 271 272 /// Dispatch the provided action to the handler if any, or just execute it. 273 template <typename ActionTy, typename... Args> 274 void executeAction(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits, 275 Args &&...args) { 276 if (LLVM_UNLIKELY(hasActionHandler())) 277 executeActionInternal<ActionTy, Args...>(actionFn, irUnits, 278 std::forward<Args>(args)...); 279 else 280 actionFn(); 281 } 282 283 private: 284 /// Return true if the given dialect is currently loading. 285 bool isDialectLoading(StringRef dialectNamespace); 286 287 /// Internal helper for the dispatch method. 288 void executeActionInternal(function_ref<void()> actionFn, 289 const tracing::Action &action); 290 291 /// Internal helper for the dispatch method. We get here after checking that 292 /// there is a handler, for the purpose of keeping this code out-of-line. and 293 /// avoid calling the ctor for the Action unnecessarily. 294 template <typename ActionTy, typename... Args> 295 LLVM_ATTRIBUTE_NOINLINE void 296 executeActionInternal(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits, 297 Args &&...args) { 298 executeActionInternal(actionFn, 299 ActionTy(irUnits, std::forward<Args>(args)...)); 300 } 301 302 const std::unique_ptr<MLIRContextImpl> impl; 303 304 MLIRContext(const MLIRContext &) = delete; 305 void operator=(const MLIRContext &) = delete; 306 }; 307 308 //===----------------------------------------------------------------------===// 309 // MLIRContext CommandLine Options 310 //===----------------------------------------------------------------------===// 311 312 /// Register a set of useful command-line options that can be used to configure 313 /// various flags within the MLIRContext. These flags are used when constructing 314 /// an MLIR context for initialization. 315 void registerMLIRContextCLOptions(); 316 317 } // namespace mlir 318 319 #endif // MLIR_IR_MLIRCONTEXT_H 320