xref: /llvm-project/mlir/include/mlir/IR/MLIRContext.h (revision b091701d0190912578ac3fe91ee8fd29e9b6de6e)
1 //===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_MLIRCONTEXT_H
10 #define MLIR_IR_MLIRCONTEXT_H
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Support/TypeID.h"
14 #include "llvm/ADT/ArrayRef.h"
15 #include <functional>
16 #include <memory>
17 #include <vector>
18 
19 namespace llvm {
20 class ThreadPoolInterface;
21 } // namespace llvm
22 
23 namespace mlir {
24 namespace tracing {
25 class Action;
26 }
27 class DiagnosticEngine;
28 class Dialect;
29 class DialectRegistry;
30 class DynamicDialect;
31 class InFlightDiagnostic;
32 class Location;
33 class MLIRContextImpl;
34 class RegisteredOperationName;
35 class StorageUniquer;
36 class IRUnit;
37 
38 /// MLIRContext is the top-level object for a collection of MLIR operations. It
39 /// holds immortal uniqued objects like types, and the tables used to unique
40 /// them.
41 ///
42 /// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with
43 /// a very generic name ("Context") and because it is uncommon for clients to
44 /// interact with it.
45 ///
46 /// The context wrap some multi-threading facilities, and in particular by
47 /// default it will implicitly create a thread pool.
48 /// This can be undesirable if multiple context exists at the same time or if a
49 /// process will be long-lived and create and destroy contexts.
50 /// To control better thread spawning, an externally owned ThreadPool can be
51 /// injected in the context. For example:
52 ///
53 ///  llvm::DefaultThreadPool myThreadPool;
54 ///  while (auto *request = nextCompilationRequests()) {
55 ///    MLIRContext ctx(registry, MLIRContext::Threading::DISABLED);
56 ///    ctx.setThreadPool(myThreadPool);
57 ///    processRequest(request, cxt);
58 ///  }
59 ///
60 class MLIRContext {
61 public:
62   enum class Threading { DISABLED, ENABLED };
63   /// Create a new Context.
64   explicit MLIRContext(Threading multithreading = Threading::ENABLED);
65   explicit MLIRContext(const DialectRegistry &registry,
66                        Threading multithreading = Threading::ENABLED);
67   ~MLIRContext();
68 
69   /// Return information about all IR dialects loaded in the context.
70   std::vector<Dialect *> getLoadedDialects();
71 
72   /// Return the dialect registry associated with this context.
73   const DialectRegistry &getDialectRegistry();
74 
75   /// Append the contents of the given dialect registry to the registry
76   /// associated with this context.
77   void appendDialectRegistry(const DialectRegistry &registry);
78 
79   /// Return information about all available dialects in the registry in this
80   /// context.
81   std::vector<StringRef> getAvailableDialects();
82 
83   /// Get a registered IR dialect with the given namespace. If an exact match is
84   /// not found, then return nullptr.
85   Dialect *getLoadedDialect(StringRef name);
86 
87   /// Get a registered IR dialect for the given derived dialect type. The
88   /// derived type must provide a static 'getDialectNamespace' method.
89   template <typename T>
90   T *getLoadedDialect() {
91     return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
92   }
93 
94   /// Get (or create) a dialect for the given derived dialect type. The derived
95   /// type must provide a static 'getDialectNamespace' method.
96   template <typename T>
97   T *getOrLoadDialect() {
98     return static_cast<T *>(
99         getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
100           std::unique_ptr<T> dialect(new T(this));
101           return dialect;
102         }));
103   }
104 
105   /// Load a dialect in the context.
106   template <typename Dialect>
107   void loadDialect() {
108     // Do not load the dialect if it is currently loading. This can happen if a
109     // dialect initializer triggers loading the same dialect recursively.
110     if (!isDialectLoading(Dialect::getDialectNamespace()))
111       getOrLoadDialect<Dialect>();
112   }
113 
114   /// Load a list dialects in the context.
115   template <typename Dialect, typename OtherDialect, typename... MoreDialects>
116   void loadDialect() {
117     loadDialect<Dialect>();
118     loadDialect<OtherDialect, MoreDialects...>();
119   }
120 
121   /// Get (or create) a dynamic dialect for the given name.
122   DynamicDialect *
123   getOrLoadDynamicDialect(StringRef dialectNamespace,
124                           function_ref<void(DynamicDialect *)> ctor);
125 
126   /// Load all dialects available in the registry in this context.
127   void loadAllAvailableDialects();
128 
129   /// Get (or create) a dialect for the given derived dialect name.
130   /// The dialect will be loaded from the registry if no dialect is found.
131   /// If no dialect is loaded for this name and none is available in the
132   /// registry, returns nullptr.
133   Dialect *getOrLoadDialect(StringRef name);
134 
135   /// Return true if we allow to create operation for unregistered dialects.
136   [[nodiscard]] bool allowsUnregisteredDialects();
137 
138   /// Enables creating operations in unregistered dialects.
139   /// This option is **heavily discouraged**: it is convenient during testing
140   /// but it is not a good practice to use it in production code. Some system
141   /// invariants can be broken (like loading a dialect after creating
142   ///  operations) without being caught by assertions or other means.
143   void allowUnregisteredDialects(bool allow = true);
144 
145   /// Return true if multi-threading is enabled by the context.
146   bool isMultithreadingEnabled();
147 
148   /// Set the flag specifying if multi-threading is disabled by the context.
149   /// The command line debugging flag `--mlir-disable-threading` is overriding
150   /// this call and making it a no-op!
151   void disableMultithreading(bool disable = true);
152   void enableMultithreading(bool enable = true) {
153     disableMultithreading(!enable);
154   }
155 
156   /// Set a new thread pool to be used in this context. This method requires
157   /// that multithreading is disabled for this context prior to the call. This
158   /// allows to share a thread pool across multiple contexts, as well as
159   /// decoupling the lifetime of the threads from the contexts. The thread pool
160   /// must outlive the context. Multi-threading will be enabled as part of this
161   /// method.
162   /// The command line debugging flag `--mlir-disable-threading` will still
163   /// prevent threading from being enabled and threading won't be enabled after
164   /// this call in this case.
165   void setThreadPool(llvm::ThreadPoolInterface &pool);
166 
167   /// Return the number of threads used by the thread pool in this context. The
168   /// number of computed hardware threads can change over the lifetime of a
169   /// process based on affinity changes, so users should use the number of
170   /// threads actually in the thread pool for dispatching work. Returns 1 if
171   /// multithreading is disabled.
172   unsigned getNumThreads();
173 
174   /// Return the thread pool used by this context. This method requires that
175   /// multithreading be enabled within the context, and should generally not be
176   /// used directly. Users should instead prefer the threading utilities within
177   /// Threading.h.
178   llvm::ThreadPoolInterface &getThreadPool();
179 
180   /// Return true if we should attach the operation to diagnostics emitted via
181   /// Operation::emit.
182   bool shouldPrintOpOnDiagnostic();
183 
184   /// Set the flag specifying if we should attach the operation to diagnostics
185   /// emitted via Operation::emit.
186   void printOpOnDiagnostic(bool enable);
187 
188   /// Return true if we should attach the current stacktrace to diagnostics when
189   /// emitted.
190   bool shouldPrintStackTraceOnDiagnostic();
191 
192   /// Set the flag specifying if we should attach the current stacktrace when
193   /// emitting diagnostics.
194   void printStackTraceOnDiagnostic(bool enable);
195 
196   /// Return a sorted array containing the information about all registered
197   /// operations.
198   ArrayRef<RegisteredOperationName> getRegisteredOperations();
199 
200   /// Return a sorted array containing the information for registered operations
201   /// filtered by dialect name.
202   ArrayRef<RegisteredOperationName>
203   getRegisteredOperationsByDialect(StringRef dialectName);
204 
205   /// Return true if this operation name is registered in this context.
206   bool isOperationRegistered(StringRef name);
207 
208   // This is effectively private given that only MLIRContext.cpp can see the
209   // MLIRContextImpl type.
210   MLIRContextImpl &getImpl() { return *impl; }
211 
212   /// Returns the diagnostic engine for this context.
213   DiagnosticEngine &getDiagEngine();
214 
215   /// Returns the storage uniquer used for creating affine constructs.
216   StorageUniquer &getAffineUniquer();
217 
218   /// Returns the storage uniquer used for constructing type storage instances.
219   /// This should not be used directly.
220   StorageUniquer &getTypeUniquer();
221 
222   /// Returns the storage uniquer used for constructing attribute storage
223   /// instances. This should not be used directly.
224   StorageUniquer &getAttributeUniquer();
225 
226   /// These APIs are tracking whether the context will be used in a
227   /// multithreading environment: this has no effect other than enabling
228   /// assertions on misuses of some APIs.
229   void enterMultiThreadedExecution();
230   void exitMultiThreadedExecution();
231 
232   /// Get a dialect for the provided namespace and TypeID: abort the program if
233   /// a dialect exist for this namespace with different TypeID. If a dialect has
234   /// not been loaded for this namespace/TypeID yet, use the provided ctor to
235   /// create one on the fly and load it. Returns a pointer to the dialect owned
236   /// by the context.
237   /// The use of this method is in general discouraged in favor of
238   /// 'getOrLoadDialect<DialectClass>()'.
239   Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
240                             function_ref<std::unique_ptr<Dialect>()> ctor);
241 
242   /// Returns a hash of the registry of the context that may be used to give
243   /// a rough indicator of if the state of the context registry has changed. The
244   /// context registry correlates to loaded dialects and their entities
245   /// (attributes, operations, types, etc.).
246   llvm::hash_code getRegistryHash();
247 
248   //===--------------------------------------------------------------------===//
249   // Action API
250   //===--------------------------------------------------------------------===//
251 
252   /// Signatures for the action handler that can be registered with the context.
253   using HandlerTy =
254       std::function<void(function_ref<void()>, const tracing::Action &)>;
255 
256   /// Register a handler for handling actions that are dispatched through this
257   /// context. A nullptr handler can be set to disable a previously set handler.
258   void registerActionHandler(HandlerTy handler);
259 
260   /// Return true if a valid ActionHandler is set.
261   bool hasActionHandler();
262 
263   /// Dispatch the provided action to the handler if any, or just execute it.
264   void executeAction(function_ref<void()> actionFn,
265                      const tracing::Action &action) {
266     if (LLVM_UNLIKELY(hasActionHandler()))
267       executeActionInternal(actionFn, action);
268     else
269       actionFn();
270   }
271 
272   /// Dispatch the provided action to the handler if any, or just execute it.
273   template <typename ActionTy, typename... Args>
274   void executeAction(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
275                      Args &&...args) {
276     if (LLVM_UNLIKELY(hasActionHandler()))
277       executeActionInternal<ActionTy, Args...>(actionFn, irUnits,
278                                                std::forward<Args>(args)...);
279     else
280       actionFn();
281   }
282 
283 private:
284   /// Return true if the given dialect is currently loading.
285   bool isDialectLoading(StringRef dialectNamespace);
286 
287   /// Internal helper for the dispatch method.
288   void executeActionInternal(function_ref<void()> actionFn,
289                              const tracing::Action &action);
290 
291   /// Internal helper for the dispatch method. We get here after checking that
292   /// there is a handler, for the purpose of keeping this code out-of-line. and
293   /// avoid calling the ctor for the Action unnecessarily.
294   template <typename ActionTy, typename... Args>
295   LLVM_ATTRIBUTE_NOINLINE void
296   executeActionInternal(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
297                         Args &&...args) {
298     executeActionInternal(actionFn,
299                           ActionTy(irUnits, std::forward<Args>(args)...));
300   }
301 
302   const std::unique_ptr<MLIRContextImpl> impl;
303 
304   MLIRContext(const MLIRContext &) = delete;
305   void operator=(const MLIRContext &) = delete;
306 };
307 
308 //===----------------------------------------------------------------------===//
309 // MLIRContext CommandLine Options
310 //===----------------------------------------------------------------------===//
311 
312 /// Register a set of useful command-line options that can be used to configure
313 /// various flags within the MLIRContext. These flags are used when constructing
314 /// an MLIR context for initialization.
315 void registerMLIRContextCLOptions();
316 
317 } // namespace mlir
318 
319 #endif // MLIR_IR_MLIRCONTEXT_H
320