xref: /llvm-project/mlir/lib/Bindings/Python/IRModule.h (revision acde3f722ff3766f6f793884108d342b78623fe4)
1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
9 
10 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
11 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
12 
13 #include <optional>
14 #include <utility>
15 #include <vector>
16 
17 #include "Globals.h"
18 #include "NanobindUtils.h"
19 #include "mlir-c/AffineExpr.h"
20 #include "mlir-c/AffineMap.h"
21 #include "mlir-c/Diagnostics.h"
22 #include "mlir-c/IR.h"
23 #include "mlir-c/IntegerSet.h"
24 #include "mlir-c/Transforms.h"
25 #include "mlir/Bindings/Python/NanobindAdaptors.h"
26 #include "mlir/Bindings/Python/Nanobind.h"
27 #include "llvm/ADT/DenseMap.h"
28 
29 namespace mlir {
30 namespace python {
31 
32 class PyBlock;
33 class PyDiagnostic;
34 class PyDiagnosticHandler;
35 class PyInsertionPoint;
36 class PyLocation;
37 class DefaultingPyLocation;
38 class PyMlirContext;
39 class DefaultingPyMlirContext;
40 class PyModule;
41 class PyOperation;
42 class PyOperationBase;
43 class PyType;
44 class PySymbolTable;
45 class PyValue;
46 
47 /// Template for a reference to a concrete type which captures a python
48 /// reference to its underlying python object.
49 template <typename T>
50 class PyObjectRef {
51 public:
52   PyObjectRef(T *referrent, nanobind::object object)
53       : referrent(referrent), object(std::move(object)) {
54     assert(this->referrent &&
55            "cannot construct PyObjectRef with null referrent");
56     assert(this->object && "cannot construct PyObjectRef with null object");
57   }
58   PyObjectRef(PyObjectRef &&other) noexcept
59       : referrent(other.referrent), object(std::move(other.object)) {
60     other.referrent = nullptr;
61     assert(!other.object);
62   }
63   PyObjectRef(const PyObjectRef &other)
64       : referrent(other.referrent), object(other.object /* copies */) {}
65   ~PyObjectRef() = default;
66 
67   int getRefCount() {
68     if (!object)
69       return 0;
70     return Py_REFCNT(object.ptr());
71   }
72 
73   /// Releases the object held by this instance, returning it.
74   /// This is the proper thing to return from a function that wants to return
75   /// the reference. Note that this does not work from initializers.
76   nanobind::object releaseObject() {
77     assert(referrent && object);
78     referrent = nullptr;
79     auto stolen = std::move(object);
80     return stolen;
81   }
82 
83   T *get() { return referrent; }
84   T *operator->() {
85     assert(referrent && object);
86     return referrent;
87   }
88   nanobind::object getObject() {
89     assert(referrent && object);
90     return object;
91   }
92   operator bool() const { return referrent && object; }
93 
94 private:
95   T *referrent;
96   nanobind::object object;
97 };
98 
99 /// Tracks an entry in the thread context stack. New entries are pushed onto
100 /// here for each with block that activates a new InsertionPoint, Context or
101 /// Location.
102 ///
103 /// Pushing either a Location or InsertionPoint also pushes its associated
104 /// Context. Pushing a Context will not modify the Location or InsertionPoint
105 /// unless if they are from a different context, in which case, they are
106 /// cleared.
107 class PyThreadContextEntry {
108 public:
109   enum class FrameKind {
110     Context,
111     InsertionPoint,
112     Location,
113   };
114 
115   PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
116                        nanobind::object insertionPoint,
117                        nanobind::object location)
118       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
119         location(std::move(location)), frameKind(frameKind) {}
120 
121   /// Gets the top of stack context and return nullptr if not defined.
122   static PyMlirContext *getDefaultContext();
123 
124   /// Gets the top of stack insertion point and return nullptr if not defined.
125   static PyInsertionPoint *getDefaultInsertionPoint();
126 
127   /// Gets the top of stack location and returns nullptr if not defined.
128   static PyLocation *getDefaultLocation();
129 
130   PyMlirContext *getContext();
131   PyInsertionPoint *getInsertionPoint();
132   PyLocation *getLocation();
133   FrameKind getFrameKind() { return frameKind; }
134 
135   /// Stack management.
136   static PyThreadContextEntry *getTopOfStack();
137   static nanobind::object pushContext(nanobind::object context);
138   static void popContext(PyMlirContext &context);
139   static nanobind::object pushInsertionPoint(nanobind::object insertionPoint);
140   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
141   static nanobind::object pushLocation(nanobind::object location);
142   static void popLocation(PyLocation &location);
143 
144   /// Gets the thread local stack.
145   static std::vector<PyThreadContextEntry> &getStack();
146 
147 private:
148   static void push(FrameKind frameKind, nanobind::object context,
149                    nanobind::object insertionPoint, nanobind::object location);
150 
151   /// An object reference to the PyContext.
152   nanobind::object context;
153   /// An object reference to the current insertion point.
154   nanobind::object insertionPoint;
155   /// An object reference to the current location.
156   nanobind::object location;
157   // The kind of push that was performed.
158   FrameKind frameKind;
159 };
160 
161 /// Wrapper around MlirContext.
162 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
163 class PyMlirContext {
164 public:
165   PyMlirContext() = delete;
166   PyMlirContext(MlirContext context);
167   PyMlirContext(const PyMlirContext &) = delete;
168   PyMlirContext(PyMlirContext &&) = delete;
169 
170   /// For the case of a python __init__ (nanobind::init) method, pybind11 is
171   /// quite strict about needing to return a pointer that is not yet associated
172   /// to an nanobind::object. Since the forContext() method acts like a pool,
173   /// possibly returning a recycled context, it does not satisfy this need. The
174   /// usual way in python to accomplish such a thing is to override __new__, but
175   /// that is also not supported by pybind11. Instead, we use this entry
176   /// point which always constructs a fresh context (which cannot alias an
177   /// existing one because it is fresh).
178   static PyMlirContext *createNewContextForInit();
179 
180   /// Returns a context reference for the singleton PyMlirContext wrapper for
181   /// the given context.
182   static PyMlirContextRef forContext(MlirContext context);
183   ~PyMlirContext();
184 
185   /// Accesses the underlying MlirContext.
186   MlirContext get() { return context; }
187 
188   /// Gets a strong reference to this context, which will ensure it is kept
189   /// alive for the life of the reference.
190   PyMlirContextRef getRef() {
191     return PyMlirContextRef(this, nanobind::cast(this));
192   }
193 
194   /// Gets a capsule wrapping the void* within the MlirContext.
195   nanobind::object getCapsule();
196 
197   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
198   /// Note that PyMlirContext instances are uniqued, so the returned object
199   /// may be a pre-existing object. Ownership of the underlying MlirContext
200   /// is taken by calling this function.
201   static nanobind::object createFromCapsule(nanobind::object capsule);
202 
203   /// Gets the count of live context objects. Used for testing.
204   static size_t getLiveCount();
205 
206   /// Get a list of Python objects which are still in the live context map.
207   std::vector<PyOperation *> getLiveOperationObjects();
208 
209   /// Gets the count of live operations associated with this context.
210   /// Used for testing.
211   size_t getLiveOperationCount();
212 
213   /// Clears the live operations map, returning the number of entries which were
214   /// invalidated. To be used as a safety mechanism so that API end-users can't
215   /// corrupt by holding references they shouldn't have accessed in the first
216   /// place.
217   size_t clearLiveOperations();
218 
219   /// Removes an operation from the live operations map and sets it invalid.
220   /// This is useful for when some non-bindings code destroys the operation and
221   /// the bindings need to made aware. For example, in the case when pass
222   /// manager is run.
223   ///
224   /// Note that this does *NOT* clear the nested operations.
225   void clearOperation(MlirOperation op);
226 
227   /// Clears all operations nested inside the given op using
228   /// `clearOperation(MlirOperation)`.
229   void clearOperationsInside(PyOperationBase &op);
230   void clearOperationsInside(MlirOperation op);
231 
232   /// Clears the operaiton _and_ all operations inside using
233   /// `clearOperation(MlirOperation)`.
234   void clearOperationAndInside(PyOperationBase &op);
235 
236   /// Gets the count of live modules associated with this context.
237   /// Used for testing.
238   size_t getLiveModuleCount();
239 
240   /// Enter and exit the context manager.
241   static nanobind::object contextEnter(nanobind::object context);
242   void contextExit(const nanobind::object &excType,
243                    const nanobind::object &excVal,
244                    const nanobind::object &excTb);
245 
246   /// Attaches a Python callback as a diagnostic handler, returning a
247   /// registration object (internally a PyDiagnosticHandler).
248   nanobind::object attachDiagnosticHandler(nanobind::object callback);
249 
250   /// Controls whether error diagnostics should be propagated to diagnostic
251   /// handlers, instead of being captured by `ErrorCapture`.
252   void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; }
253   struct ErrorCapture;
254 
255 private:
256   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
257   // preserving the relationship that an MlirContext maps to a single
258   // PyMlirContext wrapper. This could be replaced in the future with an
259   // extension mechanism on the MlirContext for stashing user pointers.
260   // Note that this holds a handle, which does not imply ownership.
261   // Mappings will be removed when the context is destructed.
262   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
263   static nanobind::ft_mutex live_contexts_mutex;
264   static LiveContextMap &getLiveContexts();
265 
266   // Interns all live modules associated with this context. Modules tracked
267   // in this map are valid. When a module is invalidated, it is removed
268   // from this map, and while it still exists as an instance, any
269   // attempt to access it will raise an error.
270   using LiveModuleMap =
271       llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
272   LiveModuleMap liveModules;
273 
274   // Interns all live operations associated with this context. Operations
275   // tracked in this map are valid. When an operation is invalidated, it is
276   // removed from this map, and while it still exists as an instance, any
277   // attempt to access it will raise an error.
278   using LiveOperationMap =
279       llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
280   nanobind::ft_mutex liveOperationsMutex;
281 
282   // Guarded by liveOperationsMutex in free-threading mode.
283   LiveOperationMap liveOperations;
284 
285   bool emitErrorDiagnostics = false;
286 
287   MlirContext context;
288   friend class PyModule;
289   friend class PyOperation;
290 };
291 
292 /// Used in function arguments when None should resolve to the current context
293 /// manager set instance.
294 class DefaultingPyMlirContext
295     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
296 public:
297   using Defaulting::Defaulting;
298   static constexpr const char kTypeDescription[] = "mlir.ir.Context";
299   static PyMlirContext &resolve();
300 };
301 
302 /// Base class for all objects that directly or indirectly depend on an
303 /// MlirContext. The lifetime of the context will extend at least to the
304 /// lifetime of these instances.
305 /// Immutable objects that depend on a context extend this directly.
306 class BaseContextObject {
307 public:
308   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
309     assert(this->contextRef &&
310            "context object constructed with null context ref");
311   }
312 
313   /// Accesses the context reference.
314   PyMlirContextRef &getContext() { return contextRef; }
315 
316 private:
317   PyMlirContextRef contextRef;
318 };
319 
320 /// Wrapper around an MlirLocation.
321 class PyLocation : public BaseContextObject {
322 public:
323   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
324       : BaseContextObject(std::move(contextRef)), loc(loc) {}
325 
326   operator MlirLocation() const { return loc; }
327   MlirLocation get() const { return loc; }
328 
329   /// Enter and exit the context manager.
330   static nanobind::object contextEnter(nanobind::object location);
331   void contextExit(const nanobind::object &excType,
332                    const nanobind::object &excVal,
333                    const nanobind::object &excTb);
334 
335   /// Gets a capsule wrapping the void* within the MlirLocation.
336   nanobind::object getCapsule();
337 
338   /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
339   /// Note that PyLocation instances are uniqued, so the returned object
340   /// may be a pre-existing object. Ownership of the underlying MlirLocation
341   /// is taken by calling this function.
342   static PyLocation createFromCapsule(nanobind::object capsule);
343 
344 private:
345   MlirLocation loc;
346 };
347 
348 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs
349 /// are only valid for the duration of a diagnostic callback and attempting
350 /// to access them outside of that will raise an exception. This applies to
351 /// nested diagnostics (in the notes) as well.
352 class PyDiagnostic {
353 public:
354   PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
355   void invalidate();
356   bool isValid() { return valid; }
357   MlirDiagnosticSeverity getSeverity();
358   PyLocation getLocation();
359   nanobind::str getMessage();
360   nanobind::tuple getNotes();
361 
362   /// Materialized diagnostic information. This is safe to access outside the
363   /// diagnostic callback.
364   struct DiagnosticInfo {
365     MlirDiagnosticSeverity severity;
366     PyLocation location;
367     std::string message;
368     std::vector<DiagnosticInfo> notes;
369   };
370   DiagnosticInfo getInfo();
371 
372 private:
373   MlirDiagnostic diagnostic;
374 
375   void checkValid();
376   /// If notes have been materialized from the diagnostic, then this will
377   /// be populated with the corresponding objects (all castable to
378   /// PyDiagnostic).
379   std::optional<nanobind::tuple> materializedNotes;
380   bool valid = true;
381 };
382 
383 /// Represents a diagnostic handler attached to the context. The handler's
384 /// callback will be invoked with PyDiagnostic instances until the detach()
385 /// method is called or the context is destroyed. A diagnostic handler can be
386 /// the subject of a `with` block, which will detach it when the block exits.
387 ///
388 /// Since diagnostic handlers can call back into Python code which can do
389 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
390 /// etc), this is generally not deemed to be a great user-level API. Users
391 /// should generally use some form of DiagnosticCollector. If the handler raises
392 /// any exceptions, they will just be emitted to stderr and dropped.
393 ///
394 /// The unique usage of this class means that its lifetime management is
395 /// different from most other parts of the API. Instances are always created
396 /// in an attached state and can transition to a detached state by either:
397 ///   a) The context being destroyed and unregistering all handlers.
398 ///   b) An explicit call to detach().
399 /// The object may remain live from a Python perspective for an arbitrary time
400 /// after detachment, but there is nothing the user can do with it (since there
401 /// is no way to attach an existing handler object).
402 class PyDiagnosticHandler {
403 public:
404   PyDiagnosticHandler(MlirContext context, nanobind::object callback);
405   ~PyDiagnosticHandler();
406 
407   bool isAttached() { return registeredID.has_value(); }
408   bool getHadError() { return hadError; }
409 
410   /// Detaches the handler. Does nothing if not attached.
411   void detach();
412 
413   nanobind::object contextEnter() { return nanobind::cast(this); }
414   void contextExit(const nanobind::object &excType,
415                    const nanobind::object &excVal,
416                    const nanobind::object &excTb) {
417     detach();
418   }
419 
420 private:
421   MlirContext context;
422   nanobind::object callback;
423   std::optional<MlirDiagnosticHandlerID> registeredID;
424   bool hadError = false;
425   friend class PyMlirContext;
426 };
427 
428 /// RAII object that captures any error diagnostics emitted to the provided
429 /// context.
430 struct PyMlirContext::ErrorCapture {
431   ErrorCapture(PyMlirContextRef ctx)
432       : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
433                       ctx->get(), handler, /*userData=*/this,
434                       /*deleteUserData=*/nullptr)) {}
435   ~ErrorCapture() {
436     mlirContextDetachDiagnosticHandler(ctx->get(), handlerID);
437     assert(errors.empty() && "unhandled captured errors");
438   }
439 
440   std::vector<PyDiagnostic::DiagnosticInfo> take() {
441     return std::move(errors);
442   };
443 
444 private:
445   PyMlirContextRef ctx;
446   MlirDiagnosticHandlerID handlerID;
447   std::vector<PyDiagnostic::DiagnosticInfo> errors;
448 
449   static MlirLogicalResult handler(MlirDiagnostic diag, void *userData);
450 };
451 
452 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
453 /// order to differentiate it from the `Dialect` base class which is extended by
454 /// plugins which extend dialect functionality through extension python code.
455 /// This should be seen as the "low-level" object and `Dialect` as the
456 /// high-level, user facing object.
457 class PyDialectDescriptor : public BaseContextObject {
458 public:
459   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
460       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
461 
462   MlirDialect get() { return dialect; }
463 
464 private:
465   MlirDialect dialect;
466 };
467 
468 /// User-level object for accessing dialects with dotted syntax such as:
469 ///   ctx.dialect.std
470 class PyDialects : public BaseContextObject {
471 public:
472   PyDialects(PyMlirContextRef contextRef)
473       : BaseContextObject(std::move(contextRef)) {}
474 
475   MlirDialect getDialectForKey(const std::string &key, bool attrError);
476 };
477 
478 /// User-level dialect object. For dialects that have a registered extension,
479 /// this will be the base class of the extension dialect type. For un-extended,
480 /// objects of this type will be returned directly.
481 class PyDialect {
482 public:
483   PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {}
484 
485   nanobind::object getDescriptor() { return descriptor; }
486 
487 private:
488   nanobind::object descriptor;
489 };
490 
491 /// Wrapper around an MlirDialectRegistry.
492 /// Upon construction, the Python wrapper takes ownership of the
493 /// underlying MlirDialectRegistry.
494 class PyDialectRegistry {
495 public:
496   PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {}
497   PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {}
498   ~PyDialectRegistry() {
499     if (!mlirDialectRegistryIsNull(registry))
500       mlirDialectRegistryDestroy(registry);
501   }
502   PyDialectRegistry(PyDialectRegistry &) = delete;
503   PyDialectRegistry(PyDialectRegistry &&other) noexcept
504       : registry(other.registry) {
505     other.registry = {nullptr};
506   }
507 
508   operator MlirDialectRegistry() const { return registry; }
509   MlirDialectRegistry get() const { return registry; }
510 
511   nanobind::object getCapsule();
512   static PyDialectRegistry createFromCapsule(nanobind::object capsule);
513 
514 private:
515   MlirDialectRegistry registry;
516 };
517 
518 /// Used in function arguments when None should resolve to the current context
519 /// manager set instance.
520 class DefaultingPyLocation
521     : public Defaulting<DefaultingPyLocation, PyLocation> {
522 public:
523   using Defaulting::Defaulting;
524   static constexpr const char kTypeDescription[] = "mlir.ir.Location";
525   static PyLocation &resolve();
526 
527   operator MlirLocation() const { return *get(); }
528 };
529 
530 /// Wrapper around MlirModule.
531 /// This is the top-level, user-owned object that contains regions/ops/blocks.
532 class PyModule;
533 using PyModuleRef = PyObjectRef<PyModule>;
534 class PyModule : public BaseContextObject {
535 public:
536   /// Returns a PyModule reference for the given MlirModule. This may return
537   /// a pre-existing or new object.
538   static PyModuleRef forModule(MlirModule module);
539   PyModule(PyModule &) = delete;
540   PyModule(PyMlirContext &&) = delete;
541   ~PyModule();
542 
543   /// Gets the backing MlirModule.
544   MlirModule get() { return module; }
545 
546   /// Gets a strong reference to this module.
547   PyModuleRef getRef() {
548     return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle));
549   }
550 
551   /// Gets a capsule wrapping the void* within the MlirModule.
552   /// Note that the module does not (yet) provide a corresponding factory for
553   /// constructing from a capsule as that would require uniquing PyModule
554   /// instances, which is not currently done.
555   nanobind::object getCapsule();
556 
557   /// Creates a PyModule from the MlirModule wrapped by a capsule.
558   /// Note that PyModule instances are uniqued, so the returned object
559   /// may be a pre-existing object. Ownership of the underlying MlirModule
560   /// is taken by calling this function.
561   static nanobind::object createFromCapsule(nanobind::object capsule);
562 
563 private:
564   PyModule(PyMlirContextRef contextRef, MlirModule module);
565   MlirModule module;
566   nanobind::handle handle;
567 };
568 
569 class PyAsmState;
570 
571 /// Base class for PyOperation and PyOpView which exposes the primary, user
572 /// visible methods for manipulating it.
573 class PyOperationBase {
574 public:
575   virtual ~PyOperationBase() = default;
576   /// Implements the bound 'print' method and helps with others.
577   void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
578              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
579              bool assumeVerified, nanobind::object fileObject, bool binary,
580              bool skipRegions);
581   void print(PyAsmState &state, nanobind::object fileObject, bool binary);
582 
583   nanobind::object getAsm(bool binary,
584                           std::optional<int64_t> largeElementsLimit,
585                           bool enableDebugInfo, bool prettyDebugInfo,
586                           bool printGenericOpForm, bool useLocalScope,
587                           bool assumeVerified, bool skipRegions);
588 
589   // Implement the bound 'writeBytecode' method.
590   void writeBytecode(const nanobind::object &fileObject,
591                      std::optional<int64_t> bytecodeVersion);
592 
593   // Implement the walk method.
594   void walk(std::function<MlirWalkResult(MlirOperation)> callback,
595             MlirWalkOrder walkOrder);
596 
597   /// Moves the operation before or after the other operation.
598   void moveAfter(PyOperationBase &other);
599   void moveBefore(PyOperationBase &other);
600 
601   /// Verify the operation. Throws `MLIRError` if verification fails, and
602   /// returns `true` otherwise.
603   bool verify();
604 
605   /// Each must provide access to the raw Operation.
606   virtual PyOperation &getOperation() = 0;
607 };
608 
609 /// Wrapper around PyOperation.
610 /// Operations exist in either an attached (dependent) or detached (top-level)
611 /// state. In the detached state (as on creation), an operation is owned by
612 /// the creator and its lifetime extends either until its reference count
613 /// drops to zero or it is attached to a parent, at which point its lifetime
614 /// is bounded by its top-level parent reference.
615 class PyOperation;
616 using PyOperationRef = PyObjectRef<PyOperation>;
617 class PyOperation : public PyOperationBase, public BaseContextObject {
618 public:
619   ~PyOperation() override;
620   PyOperation &getOperation() override { return *this; }
621 
622   /// Returns a PyOperation for the given MlirOperation, optionally associating
623   /// it with a parentKeepAlive.
624   static PyOperationRef
625   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
626                nanobind::object parentKeepAlive = nanobind::object());
627 
628   /// Creates a detached operation. The operation must not be associated with
629   /// any existing live operation.
630   static PyOperationRef
631   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
632                  nanobind::object parentKeepAlive = nanobind::object());
633 
634   /// Parses a source string (either text assembly or bytecode), creating a
635   /// detached operation.
636   static PyOperationRef parse(PyMlirContextRef contextRef,
637                               const std::string &sourceStr,
638                               const std::string &sourceName);
639 
640   /// Detaches the operation from its parent block and updates its state
641   /// accordingly.
642   void detachFromParent() {
643     mlirOperationRemoveFromParent(getOperation());
644     setDetached();
645     parentKeepAlive = nanobind::object();
646   }
647 
648   /// Gets the backing operation.
649   operator MlirOperation() const { return get(); }
650   MlirOperation get() const {
651     checkValid();
652     return operation;
653   }
654 
655   PyOperationRef getRef() {
656     return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
657   }
658 
659   bool isAttached() { return attached; }
660   void setAttached(const nanobind::object &parent = nanobind::object()) {
661     assert(!attached && "operation already attached");
662     attached = true;
663   }
664   void setDetached() {
665     assert(attached && "operation already detached");
666     attached = false;
667   }
668   void checkValid() const;
669 
670   /// Gets the owning block or raises an exception if the operation has no
671   /// owning block.
672   PyBlock getBlock();
673 
674   /// Gets the parent operation or raises an exception if the operation has
675   /// no parent.
676   std::optional<PyOperationRef> getParentOperation();
677 
678   /// Gets a capsule wrapping the void* within the MlirOperation.
679   nanobind::object getCapsule();
680 
681   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
682   /// Ownership of the underlying MlirOperation is taken by calling this
683   /// function.
684   static nanobind::object createFromCapsule(nanobind::object capsule);
685 
686   /// Creates an operation. See corresponding python docstring.
687   static nanobind::object
688   create(std::string_view name, std::optional<std::vector<PyType *>> results,
689          llvm::ArrayRef<MlirValue> operands,
690          std::optional<nanobind::dict> attributes,
691          std::optional<std::vector<PyBlock *>> successors, int regions,
692          DefaultingPyLocation location, const nanobind::object &ip,
693          bool inferType);
694 
695   /// Creates an OpView suitable for this operation.
696   nanobind::object createOpView();
697 
698   /// Erases the underlying MlirOperation, removes its pointer from the
699   /// parent context's live operations map, and sets the valid bit false.
700   void erase();
701 
702   /// Invalidate the operation.
703   void setInvalid() { valid = false; }
704 
705   /// Clones this operation.
706   nanobind::object clone(const nanobind::object &ip);
707 
708   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
709 
710 private:
711   static PyOperationRef createInstance(PyMlirContextRef contextRef,
712                                        MlirOperation operation,
713                                        nanobind::object parentKeepAlive);
714 
715   MlirOperation operation;
716   nanobind::handle handle;
717   // Keeps the parent alive, regardless of whether it is an Operation or
718   // Module.
719   // TODO: As implemented, this facility is only sufficient for modeling the
720   // trivial module parent back-reference. Generalize this to also account for
721   // transitions from detached to attached and address TODOs in the
722   // ir_operation.py regarding testing corresponding lifetime guarantees.
723   nanobind::object parentKeepAlive;
724   bool attached = true;
725   bool valid = true;
726 
727   friend class PyOperationBase;
728   friend class PySymbolTable;
729 };
730 
731 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
732 /// providing more instance-specific accessors and serve as the base class for
733 /// custom ODS-style operation classes. Since this class is subclass on the
734 /// python side, it must present an __init__ method that operates in pure
735 /// python types.
736 class PyOpView : public PyOperationBase {
737 public:
738   PyOpView(const nanobind::object &operationObject);
739   PyOperation &getOperation() override { return operation; }
740 
741   nanobind::object getOperationObject() { return operationObject; }
742 
743   static nanobind::object
744   buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec,
745                nanobind::object operandSegmentSpecObj,
746                nanobind::object resultSegmentSpecObj,
747                std::optional<nanobind::list> resultTypeList,
748                nanobind::list operandList,
749                std::optional<nanobind::dict> attributes,
750                std::optional<std::vector<PyBlock *>> successors,
751                std::optional<int> regions, DefaultingPyLocation location,
752                const nanobind::object &maybeIp);
753 
754   /// Construct an instance of a class deriving from OpView, bypassing its
755   /// `__init__` method. The derived class will typically define a constructor
756   /// that provides a convenient builder, but we need to side-step this when
757   /// constructing an `OpView` for an already-built operation.
758   ///
759   /// The caller is responsible for verifying that `operation` is a valid
760   /// operation to construct `cls` with.
761   static nanobind::object constructDerived(const nanobind::object &cls,
762                                            const nanobind::object &operation);
763 
764 private:
765   PyOperation &operation;           // For efficient, cast-free access from C++
766   nanobind::object operationObject; // Holds the reference.
767 };
768 
769 /// Wrapper around an MlirRegion.
770 /// Regions are managed completely by their containing operation. Unlike the
771 /// C++ API, the python API does not support detached regions.
772 class PyRegion {
773 public:
774   PyRegion(PyOperationRef parentOperation, MlirRegion region)
775       : parentOperation(std::move(parentOperation)), region(region) {
776     assert(!mlirRegionIsNull(region) && "python region cannot be null");
777   }
778   operator MlirRegion() const { return region; }
779 
780   MlirRegion get() { return region; }
781   PyOperationRef &getParentOperation() { return parentOperation; }
782 
783   void checkValid() { return parentOperation->checkValid(); }
784 
785 private:
786   PyOperationRef parentOperation;
787   MlirRegion region;
788 };
789 
790 /// Wrapper around an MlirAsmState.
791 class PyAsmState {
792 public:
793   PyAsmState(MlirValue value, bool useLocalScope) {
794     flags = mlirOpPrintingFlagsCreate();
795     // The OpPrintingFlags are not exposed Python side, create locally and
796     // associate lifetime with the state.
797     if (useLocalScope)
798       mlirOpPrintingFlagsUseLocalScope(flags);
799     state = mlirAsmStateCreateForValue(value, flags);
800   }
801 
802   PyAsmState(PyOperationBase &operation, bool useLocalScope) {
803     flags = mlirOpPrintingFlagsCreate();
804     // The OpPrintingFlags are not exposed Python side, create locally and
805     // associate lifetime with the state.
806     if (useLocalScope)
807       mlirOpPrintingFlagsUseLocalScope(flags);
808     state =
809         mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
810   }
811   ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
812   // Delete copy constructors.
813   PyAsmState(PyAsmState &other) = delete;
814   PyAsmState(const PyAsmState &other) = delete;
815 
816   MlirAsmState get() { return state; }
817 
818 private:
819   MlirAsmState state;
820   MlirOpPrintingFlags flags;
821 };
822 
823 /// Wrapper around an MlirBlock.
824 /// Blocks are managed completely by their containing operation. Unlike the
825 /// C++ API, the python API does not support detached blocks.
826 class PyBlock {
827 public:
828   PyBlock(PyOperationRef parentOperation, MlirBlock block)
829       : parentOperation(std::move(parentOperation)), block(block) {
830     assert(!mlirBlockIsNull(block) && "python block cannot be null");
831   }
832 
833   MlirBlock get() { return block; }
834   PyOperationRef &getParentOperation() { return parentOperation; }
835 
836   void checkValid() { return parentOperation->checkValid(); }
837 
838   /// Gets a capsule wrapping the void* within the MlirBlock.
839   nanobind::object getCapsule();
840 
841 private:
842   PyOperationRef parentOperation;
843   MlirBlock block;
844 };
845 
846 /// An insertion point maintains a pointer to a Block and a reference operation.
847 /// Calls to insert() will insert a new operation before the
848 /// reference operation. If the reference operation is null, then appends to
849 /// the end of the block.
850 class PyInsertionPoint {
851 public:
852   /// Creates an insertion point positioned after the last operation in the
853   /// block, but still inside the block.
854   PyInsertionPoint(PyBlock &block);
855   /// Creates an insertion point positioned before a reference operation.
856   PyInsertionPoint(PyOperationBase &beforeOperationBase);
857 
858   /// Shortcut to create an insertion point at the beginning of the block.
859   static PyInsertionPoint atBlockBegin(PyBlock &block);
860   /// Shortcut to create an insertion point before the block terminator.
861   static PyInsertionPoint atBlockTerminator(PyBlock &block);
862 
863   /// Inserts an operation.
864   void insert(PyOperationBase &operationBase);
865 
866   /// Enter and exit the context manager.
867   static nanobind::object contextEnter(nanobind::object insertionPoint);
868   void contextExit(const nanobind::object &excType,
869                    const nanobind::object &excVal,
870                    const nanobind::object &excTb);
871 
872   PyBlock &getBlock() { return block; }
873   std::optional<PyOperationRef> &getRefOperation() { return refOperation; }
874 
875 private:
876   // Trampoline constructor that avoids null initializing members while
877   // looking up parents.
878   PyInsertionPoint(PyBlock block, std::optional<PyOperationRef> refOperation)
879       : refOperation(std::move(refOperation)), block(std::move(block)) {}
880 
881   std::optional<PyOperationRef> refOperation;
882   PyBlock block;
883 };
884 /// Wrapper around the generic MlirType.
885 /// The lifetime of a type is bound by the PyContext that created it.
886 class PyType : public BaseContextObject {
887 public:
888   PyType(PyMlirContextRef contextRef, MlirType type)
889       : BaseContextObject(std::move(contextRef)), type(type) {}
890   bool operator==(const PyType &other) const;
891   operator MlirType() const { return type; }
892   MlirType get() const { return type; }
893 
894   /// Gets a capsule wrapping the void* within the MlirType.
895   nanobind::object getCapsule();
896 
897   /// Creates a PyType from the MlirType wrapped by a capsule.
898   /// Note that PyType instances are uniqued, so the returned object
899   /// may be a pre-existing object. Ownership of the underlying MlirType
900   /// is taken by calling this function.
901   static PyType createFromCapsule(nanobind::object capsule);
902 
903 private:
904   MlirType type;
905 };
906 
907 /// A TypeID provides an efficient and unique identifier for a specific C++
908 /// type. This allows for a C++ type to be compared, hashed, and stored in an
909 /// opaque context. This class wraps around the generic MlirTypeID.
910 class PyTypeID {
911 public:
912   PyTypeID(MlirTypeID typeID) : typeID(typeID) {}
913   // Note, this tests whether the underlying TypeIDs are the same,
914   // not whether the wrapper MlirTypeIDs are the same, nor whether
915   // the PyTypeID objects are the same (i.e., PyTypeID is a value type).
916   bool operator==(const PyTypeID &other) const;
917   operator MlirTypeID() const { return typeID; }
918   MlirTypeID get() { return typeID; }
919 
920   /// Gets a capsule wrapping the void* within the MlirTypeID.
921   nanobind::object getCapsule();
922 
923   /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
924   static PyTypeID createFromCapsule(nanobind::object capsule);
925 
926 private:
927   MlirTypeID typeID;
928 };
929 
930 /// CRTP base classes for Python types that subclass Type and should be
931 /// castable from it (i.e. via something like IntegerType(t)).
932 /// By default, type class hierarchies are one level deep (i.e. a
933 /// concrete type class extends PyType); however, intermediate python-visible
934 /// base classes can be modeled by specifying a BaseTy.
935 template <typename DerivedTy, typename BaseTy = PyType>
936 class PyConcreteType : public BaseTy {
937 public:
938   // Derived classes must define statics for:
939   //   IsAFunctionTy isaFunction
940   //   const char *pyClassName
941   using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
942   using IsAFunctionTy = bool (*)(MlirType);
943   using GetTypeIDFunctionTy = MlirTypeID (*)();
944   static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
945 
946   PyConcreteType() = default;
947   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
948       : BaseTy(std::move(contextRef), t) {}
949   PyConcreteType(PyType &orig)
950       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
951 
952   static MlirType castFrom(PyType &orig) {
953     if (!DerivedTy::isaFunction(orig)) {
954       auto origRepr =
955           nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
956       throw nanobind::value_error((llvm::Twine("Cannot cast type to ") +
957                                    DerivedTy::pyClassName + " (from " +
958                                    origRepr + ")")
959                                       .str()
960                                       .c_str());
961     }
962     return orig;
963   }
964 
965   static void bind(nanobind::module_ &m) {
966     auto cls = ClassTy(m, DerivedTy::pyClassName);
967     cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
968             nanobind::arg("cast_from_type"));
969     cls.def_static(
970         "isinstance",
971         [](PyType &otherType) -> bool {
972           return DerivedTy::isaFunction(otherType);
973         },
974         nanobind::arg("other"));
975     cls.def_prop_ro_static(
976         "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
977           if (DerivedTy::getTypeIdFunction)
978             return DerivedTy::getTypeIdFunction();
979           throw nanobind::attribute_error(
980               (DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
981                   .str()
982                   .c_str());
983         });
984     cls.def_prop_ro("typeid", [](PyType &self) {
985       return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
986     });
987     cls.def("__repr__", [](DerivedTy &self) {
988       PyPrintAccumulator printAccum;
989       printAccum.parts.append(DerivedTy::pyClassName);
990       printAccum.parts.append("(");
991       mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
992       printAccum.parts.append(")");
993       return printAccum.join();
994     });
995 
996     if (DerivedTy::getTypeIdFunction) {
997       PyGlobals::get().registerTypeCaster(
998           DerivedTy::getTypeIdFunction(),
999           nanobind::cast<nanobind::callable>(nanobind::cpp_function(
1000               [](PyType pyType) -> DerivedTy { return pyType; })));
1001     }
1002 
1003     DerivedTy::bindDerived(cls);
1004   }
1005 
1006   /// Implemented by derived classes to add methods to the Python subclass.
1007   static void bindDerived(ClassTy &m) {}
1008 };
1009 
1010 /// Wrapper around the generic MlirAttribute.
1011 /// The lifetime of a type is bound by the PyContext that created it.
1012 class PyAttribute : public BaseContextObject {
1013 public:
1014   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
1015       : BaseContextObject(std::move(contextRef)), attr(attr) {}
1016   bool operator==(const PyAttribute &other) const;
1017   operator MlirAttribute() const { return attr; }
1018   MlirAttribute get() const { return attr; }
1019 
1020   /// Gets a capsule wrapping the void* within the MlirAttribute.
1021   nanobind::object getCapsule();
1022 
1023   /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
1024   /// Note that PyAttribute instances are uniqued, so the returned object
1025   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
1026   /// is taken by calling this function.
1027   static PyAttribute createFromCapsule(nanobind::object capsule);
1028 
1029 private:
1030   MlirAttribute attr;
1031 };
1032 
1033 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
1034 /// TODO: Refactor this and the C-API to be based on an Identifier owned
1035 /// by the context so as to avoid ownership issues here.
1036 class PyNamedAttribute {
1037 public:
1038   /// Constructs a PyNamedAttr that retains an owned name. This should be
1039   /// used in any code that originates an MlirNamedAttribute from a python
1040   /// string.
1041   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
1042   /// passed attribute.
1043   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
1044 
1045   MlirNamedAttribute namedAttr;
1046 
1047 private:
1048   // Since the MlirNamedAttr contains an internal pointer to the actual
1049   // memory of the owned string, it must be heap allocated to remain valid.
1050   // Otherwise, strings that fit within the small object optimization threshold
1051   // will have their memory address change as the containing object is moved,
1052   // resulting in an invalid aliased pointer.
1053   std::unique_ptr<std::string> ownedName;
1054 };
1055 
1056 /// CRTP base classes for Python attributes that subclass Attribute and should
1057 /// be castable from it (i.e. via something like StringAttr(attr)).
1058 /// By default, attribute class hierarchies are one level deep (i.e. a
1059 /// concrete attribute class extends PyAttribute); however, intermediate
1060 /// python-visible base classes can be modeled by specifying a BaseTy.
1061 template <typename DerivedTy, typename BaseTy = PyAttribute>
1062 class PyConcreteAttribute : public BaseTy {
1063 public:
1064   // Derived classes must define statics for:
1065   //   IsAFunctionTy isaFunction
1066   //   const char *pyClassName
1067   using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
1068   using IsAFunctionTy = bool (*)(MlirAttribute);
1069   using GetTypeIDFunctionTy = MlirTypeID (*)();
1070   static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
1071 
1072   PyConcreteAttribute() = default;
1073   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
1074       : BaseTy(std::move(contextRef), attr) {}
1075   PyConcreteAttribute(PyAttribute &orig)
1076       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
1077 
1078   static MlirAttribute castFrom(PyAttribute &orig) {
1079     if (!DerivedTy::isaFunction(orig)) {
1080       auto origRepr =
1081           nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
1082       throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") +
1083                                    DerivedTy::pyClassName + " (from " +
1084                                    origRepr + ")")
1085                                       .str()
1086                                       .c_str());
1087     }
1088     return orig;
1089   }
1090 
1091   static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) {
1092     ClassTy cls;
1093     if (slots) {
1094       cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots));
1095     } else {
1096       cls = ClassTy(m, DerivedTy::pyClassName);
1097     }
1098     cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
1099             nanobind::arg("cast_from_attr"));
1100     cls.def_static(
1101         "isinstance",
1102         [](PyAttribute &otherAttr) -> bool {
1103           return DerivedTy::isaFunction(otherAttr);
1104         },
1105         nanobind::arg("other"));
1106     cls.def_prop_ro(
1107         "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
1108     cls.def_prop_ro_static(
1109         "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
1110           if (DerivedTy::getTypeIdFunction)
1111             return DerivedTy::getTypeIdFunction();
1112           throw nanobind::attribute_error(
1113               (DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
1114                   .str()
1115                   .c_str());
1116         });
1117     cls.def_prop_ro("typeid", [](PyAttribute &self) {
1118       return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
1119     });
1120     cls.def("__repr__", [](DerivedTy &self) {
1121       PyPrintAccumulator printAccum;
1122       printAccum.parts.append(DerivedTy::pyClassName);
1123       printAccum.parts.append("(");
1124       mlirAttributePrint(self, printAccum.getCallback(),
1125                          printAccum.getUserData());
1126       printAccum.parts.append(")");
1127       return printAccum.join();
1128     });
1129 
1130     if (DerivedTy::getTypeIdFunction) {
1131       PyGlobals::get().registerTypeCaster(
1132           DerivedTy::getTypeIdFunction(),
1133           nanobind::cast<nanobind::callable>(
1134               nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
1135                 return pyAttribute;
1136               })));
1137     }
1138 
1139     DerivedTy::bindDerived(cls);
1140   }
1141 
1142   /// Implemented by derived classes to add methods to the Python subclass.
1143   static void bindDerived(ClassTy &m) {}
1144 };
1145 
1146 /// Wrapper around the generic MlirValue.
1147 /// Values are managed completely by the operation that resulted in their
1148 /// definition. For op result value, this is the operation that defines the
1149 /// value. For block argument values, this is the operation that contains the
1150 /// block to which the value is an argument (blocks cannot be detached in Python
1151 /// bindings so such operation always exists).
1152 class PyValue {
1153 public:
1154   // The virtual here is "load bearing" in that it enables RTTI
1155   // for PyConcreteValue CRTP classes that support maybeDownCast.
1156   // See PyValue::maybeDownCast.
1157   virtual ~PyValue() = default;
1158   PyValue(PyOperationRef parentOperation, MlirValue value)
1159       : parentOperation(std::move(parentOperation)), value(value) {}
1160   operator MlirValue() const { return value; }
1161 
1162   MlirValue get() { return value; }
1163   PyOperationRef &getParentOperation() { return parentOperation; }
1164 
1165   void checkValid() { return parentOperation->checkValid(); }
1166 
1167   /// Gets a capsule wrapping the void* within the MlirValue.
1168   nanobind::object getCapsule();
1169 
1170   nanobind::object maybeDownCast();
1171 
1172   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
1173   /// the underlying MlirValue is still tied to the owning operation.
1174   static PyValue createFromCapsule(nanobind::object capsule);
1175 
1176 private:
1177   PyOperationRef parentOperation;
1178   MlirValue value;
1179 };
1180 
1181 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
1182 class PyAffineExpr : public BaseContextObject {
1183 public:
1184   PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
1185       : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
1186   bool operator==(const PyAffineExpr &other) const;
1187   operator MlirAffineExpr() const { return affineExpr; }
1188   MlirAffineExpr get() const { return affineExpr; }
1189 
1190   /// Gets a capsule wrapping the void* within the MlirAffineExpr.
1191   nanobind::object getCapsule();
1192 
1193   /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
1194   /// Note that PyAffineExpr instances are uniqued, so the returned object
1195   /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
1196   /// is taken by calling this function.
1197   static PyAffineExpr createFromCapsule(nanobind::object capsule);
1198 
1199   PyAffineExpr add(const PyAffineExpr &other) const;
1200   PyAffineExpr mul(const PyAffineExpr &other) const;
1201   PyAffineExpr floorDiv(const PyAffineExpr &other) const;
1202   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
1203   PyAffineExpr mod(const PyAffineExpr &other) const;
1204 
1205 private:
1206   MlirAffineExpr affineExpr;
1207 };
1208 
1209 class PyAffineMap : public BaseContextObject {
1210 public:
1211   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
1212       : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
1213   bool operator==(const PyAffineMap &other) const;
1214   operator MlirAffineMap() const { return affineMap; }
1215   MlirAffineMap get() const { return affineMap; }
1216 
1217   /// Gets a capsule wrapping the void* within the MlirAffineMap.
1218   nanobind::object getCapsule();
1219 
1220   /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
1221   /// Note that PyAffineMap instances are uniqued, so the returned object
1222   /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
1223   /// is taken by calling this function.
1224   static PyAffineMap createFromCapsule(nanobind::object capsule);
1225 
1226 private:
1227   MlirAffineMap affineMap;
1228 };
1229 
1230 class PyIntegerSet : public BaseContextObject {
1231 public:
1232   PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
1233       : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
1234   bool operator==(const PyIntegerSet &other) const;
1235   operator MlirIntegerSet() const { return integerSet; }
1236   MlirIntegerSet get() const { return integerSet; }
1237 
1238   /// Gets a capsule wrapping the void* within the MlirIntegerSet.
1239   nanobind::object getCapsule();
1240 
1241   /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
1242   /// Note that PyIntegerSet instances may be uniqued, so the returned object
1243   /// may be a pre-existing object. Integer sets are owned by the context.
1244   static PyIntegerSet createFromCapsule(nanobind::object capsule);
1245 
1246 private:
1247   MlirIntegerSet integerSet;
1248 };
1249 
1250 /// Bindings for MLIR symbol tables.
1251 class PySymbolTable {
1252 public:
1253   /// Constructs a symbol table for the given operation.
1254   explicit PySymbolTable(PyOperationBase &operation);
1255 
1256   /// Destroys the symbol table.
1257   ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
1258 
1259   /// Returns the symbol (opview) with the given name, throws if there is no
1260   /// such symbol in the table.
1261   nanobind::object dunderGetItem(const std::string &name);
1262 
1263   /// Removes the given operation from the symbol table and erases it.
1264   void erase(PyOperationBase &symbol);
1265 
1266   /// Removes the operation with the given name from the symbol table and erases
1267   /// it, throws if there is no such symbol in the table.
1268   void dunderDel(const std::string &name);
1269 
1270   /// Inserts the given operation into the symbol table. The operation must have
1271   /// the symbol trait.
1272   MlirAttribute insert(PyOperationBase &symbol);
1273 
1274   /// Gets and sets the name of a symbol op.
1275   static MlirAttribute getSymbolName(PyOperationBase &symbol);
1276   static void setSymbolName(PyOperationBase &symbol, const std::string &name);
1277 
1278   /// Gets and sets the visibility of a symbol op.
1279   static MlirAttribute getVisibility(PyOperationBase &symbol);
1280   static void setVisibility(PyOperationBase &symbol,
1281                             const std::string &visibility);
1282 
1283   /// Replaces all symbol uses within an operation. See the API
1284   /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
1285   static void replaceAllSymbolUses(const std::string &oldSymbol,
1286                                    const std::string &newSymbol,
1287                                    PyOperationBase &from);
1288 
1289   /// Walks all symbol tables under and including 'from'.
1290   static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
1291                                nanobind::object callback);
1292 
1293   /// Casts the bindings class into the C API structure.
1294   operator MlirSymbolTable() { return symbolTable; }
1295 
1296 private:
1297   PyOperationRef operation;
1298   MlirSymbolTable symbolTable;
1299 };
1300 
1301 /// Custom exception that allows access to error diagnostic information. This is
1302 /// converted to the `ir.MLIRError` python exception when thrown.
1303 struct MLIRError {
1304   MLIRError(llvm::Twine message,
1305             std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
1306       : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
1307   std::string message;
1308   std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
1309 };
1310 
1311 void populateIRAffine(nanobind::module_ &m);
1312 void populateIRAttributes(nanobind::module_ &m);
1313 void populateIRCore(nanobind::module_ &m);
1314 void populateIRInterfaces(nanobind::module_ &m);
1315 void populateIRTypes(nanobind::module_ &m);
1316 
1317 } // namespace python
1318 } // namespace mlir
1319 
1320 namespace nanobind {
1321 namespace detail {
1322 
1323 template <>
1324 struct type_caster<mlir::python::DefaultingPyMlirContext>
1325     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
1326 template <>
1327 struct type_caster<mlir::python::DefaultingPyLocation>
1328     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
1329 
1330 } // namespace detail
1331 } // namespace nanobind
1332 
1333 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
1334