1436c6c9cSStella Laurenzo //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7ea2e83afSAdrian Kuegel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 8436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 9436c6c9cSStella Laurenzo 10436c6c9cSStella Laurenzo #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 11436c6c9cSStella Laurenzo #define MLIR_BINDINGS_PYTHON_IRMODULES_H 12436c6c9cSStella Laurenzo 1337107e17Srkayaith #include <optional> 14e8d07395SMehdi Amini #include <utility> 15436c6c9cSStella Laurenzo #include <vector> 16436c6c9cSStella Laurenzo 17bfb1ba75Smax #include "Globals.h" 18b56d1ec6SPeter Hawkins #include "NanobindUtils.h" 19436c6c9cSStella Laurenzo #include "mlir-c/AffineExpr.h" 20436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h" 217ee25bc5SStella Laurenzo #include "mlir-c/Diagnostics.h" 22436c6c9cSStella Laurenzo #include "mlir-c/IR.h" 23436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h" 2418cf1cd9SJacques Pienaar #include "mlir-c/Transforms.h" 25b56d1ec6SPeter Hawkins #include "mlir/Bindings/Python/NanobindAdaptors.h" 265cd42747SPeter Hawkins #include "mlir/Bindings/Python/Nanobind.h" 27436c6c9cSStella Laurenzo #include "llvm/ADT/DenseMap.h" 28436c6c9cSStella Laurenzo 29436c6c9cSStella Laurenzo namespace mlir { 30436c6c9cSStella Laurenzo namespace python { 31436c6c9cSStella Laurenzo 32436c6c9cSStella Laurenzo class PyBlock; 337ee25bc5SStella Laurenzo class PyDiagnostic; 347ee25bc5SStella Laurenzo class PyDiagnosticHandler; 35436c6c9cSStella Laurenzo class PyInsertionPoint; 36436c6c9cSStella Laurenzo class PyLocation; 37436c6c9cSStella Laurenzo class DefaultingPyLocation; 38436c6c9cSStella Laurenzo class PyMlirContext; 39436c6c9cSStella Laurenzo class DefaultingPyMlirContext; 40436c6c9cSStella Laurenzo class PyModule; 41436c6c9cSStella Laurenzo class PyOperation; 42fa19ef7aSIngo Müller class PyOperationBase; 43436c6c9cSStella Laurenzo class PyType; 4430d61893SAlex Zinenko class PySymbolTable; 45436c6c9cSStella Laurenzo class PyValue; 46436c6c9cSStella Laurenzo 47436c6c9cSStella Laurenzo /// Template for a reference to a concrete type which captures a python 48436c6c9cSStella Laurenzo /// reference to its underlying python object. 49436c6c9cSStella Laurenzo template <typename T> 50436c6c9cSStella Laurenzo class PyObjectRef { 51436c6c9cSStella Laurenzo public: 52b56d1ec6SPeter Hawkins PyObjectRef(T *referrent, nanobind::object object) 53436c6c9cSStella Laurenzo : referrent(referrent), object(std::move(object)) { 54436c6c9cSStella Laurenzo assert(this->referrent && 55436c6c9cSStella Laurenzo "cannot construct PyObjectRef with null referrent"); 56436c6c9cSStella Laurenzo assert(this->object && "cannot construct PyObjectRef with null object"); 57436c6c9cSStella Laurenzo } 58ea2e83afSAdrian Kuegel PyObjectRef(PyObjectRef &&other) noexcept 59436c6c9cSStella Laurenzo : referrent(other.referrent), object(std::move(other.object)) { 60436c6c9cSStella Laurenzo other.referrent = nullptr; 61436c6c9cSStella Laurenzo assert(!other.object); 62436c6c9cSStella Laurenzo } 63436c6c9cSStella Laurenzo PyObjectRef(const PyObjectRef &other) 64436c6c9cSStella Laurenzo : referrent(other.referrent), object(other.object /* copies */) {} 659940dcfaSMehdi Amini ~PyObjectRef() = default; 66436c6c9cSStella Laurenzo 67436c6c9cSStella Laurenzo int getRefCount() { 68436c6c9cSStella Laurenzo if (!object) 69436c6c9cSStella Laurenzo return 0; 70b56d1ec6SPeter Hawkins return Py_REFCNT(object.ptr()); 71436c6c9cSStella Laurenzo } 72436c6c9cSStella Laurenzo 73436c6c9cSStella Laurenzo /// Releases the object held by this instance, returning it. 74436c6c9cSStella Laurenzo /// This is the proper thing to return from a function that wants to return 75436c6c9cSStella Laurenzo /// the reference. Note that this does not work from initializers. 76b56d1ec6SPeter Hawkins nanobind::object releaseObject() { 77436c6c9cSStella Laurenzo assert(referrent && object); 78436c6c9cSStella Laurenzo referrent = nullptr; 79436c6c9cSStella Laurenzo auto stolen = std::move(object); 80436c6c9cSStella Laurenzo return stolen; 81436c6c9cSStella Laurenzo } 82436c6c9cSStella Laurenzo 83436c6c9cSStella Laurenzo T *get() { return referrent; } 84436c6c9cSStella Laurenzo T *operator->() { 85436c6c9cSStella Laurenzo assert(referrent && object); 86436c6c9cSStella Laurenzo return referrent; 87436c6c9cSStella Laurenzo } 88b56d1ec6SPeter Hawkins nanobind::object getObject() { 89436c6c9cSStella Laurenzo assert(referrent && object); 90436c6c9cSStella Laurenzo return object; 91436c6c9cSStella Laurenzo } 92436c6c9cSStella Laurenzo operator bool() const { return referrent && object; } 93436c6c9cSStella Laurenzo 94436c6c9cSStella Laurenzo private: 95436c6c9cSStella Laurenzo T *referrent; 96b56d1ec6SPeter Hawkins nanobind::object object; 97436c6c9cSStella Laurenzo }; 98436c6c9cSStella Laurenzo 99436c6c9cSStella Laurenzo /// Tracks an entry in the thread context stack. New entries are pushed onto 100436c6c9cSStella Laurenzo /// here for each with block that activates a new InsertionPoint, Context or 101436c6c9cSStella Laurenzo /// Location. 102436c6c9cSStella Laurenzo /// 103436c6c9cSStella Laurenzo /// Pushing either a Location or InsertionPoint also pushes its associated 104436c6c9cSStella Laurenzo /// Context. Pushing a Context will not modify the Location or InsertionPoint 105436c6c9cSStella Laurenzo /// unless if they are from a different context, in which case, they are 106436c6c9cSStella Laurenzo /// cleared. 107436c6c9cSStella Laurenzo class PyThreadContextEntry { 108436c6c9cSStella Laurenzo public: 109436c6c9cSStella Laurenzo enum class FrameKind { 110436c6c9cSStella Laurenzo Context, 111436c6c9cSStella Laurenzo InsertionPoint, 112436c6c9cSStella Laurenzo Location, 113436c6c9cSStella Laurenzo }; 114436c6c9cSStella Laurenzo 115b56d1ec6SPeter Hawkins PyThreadContextEntry(FrameKind frameKind, nanobind::object context, 116b56d1ec6SPeter Hawkins nanobind::object insertionPoint, 117b56d1ec6SPeter Hawkins nanobind::object location) 118436c6c9cSStella Laurenzo : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 119436c6c9cSStella Laurenzo location(std::move(location)), frameKind(frameKind) {} 120436c6c9cSStella Laurenzo 121436c6c9cSStella Laurenzo /// Gets the top of stack context and return nullptr if not defined. 122436c6c9cSStella Laurenzo static PyMlirContext *getDefaultContext(); 123436c6c9cSStella Laurenzo 124436c6c9cSStella Laurenzo /// Gets the top of stack insertion point and return nullptr if not defined. 125436c6c9cSStella Laurenzo static PyInsertionPoint *getDefaultInsertionPoint(); 126436c6c9cSStella Laurenzo 127436c6c9cSStella Laurenzo /// Gets the top of stack location and returns nullptr if not defined. 128436c6c9cSStella Laurenzo static PyLocation *getDefaultLocation(); 129436c6c9cSStella Laurenzo 130436c6c9cSStella Laurenzo PyMlirContext *getContext(); 131436c6c9cSStella Laurenzo PyInsertionPoint *getInsertionPoint(); 132436c6c9cSStella Laurenzo PyLocation *getLocation(); 133436c6c9cSStella Laurenzo FrameKind getFrameKind() { return frameKind; } 134436c6c9cSStella Laurenzo 135436c6c9cSStella Laurenzo /// Stack management. 136436c6c9cSStella Laurenzo static PyThreadContextEntry *getTopOfStack(); 137b56d1ec6SPeter Hawkins static nanobind::object pushContext(nanobind::object context); 138436c6c9cSStella Laurenzo static void popContext(PyMlirContext &context); 139b56d1ec6SPeter Hawkins static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); 140436c6c9cSStella Laurenzo static void popInsertionPoint(PyInsertionPoint &insertionPoint); 141b56d1ec6SPeter Hawkins static nanobind::object pushLocation(nanobind::object location); 142436c6c9cSStella Laurenzo static void popLocation(PyLocation &location); 143436c6c9cSStella Laurenzo 144436c6c9cSStella Laurenzo /// Gets the thread local stack. 145436c6c9cSStella Laurenzo static std::vector<PyThreadContextEntry> &getStack(); 146436c6c9cSStella Laurenzo 147436c6c9cSStella Laurenzo private: 148b56d1ec6SPeter Hawkins static void push(FrameKind frameKind, nanobind::object context, 149b56d1ec6SPeter Hawkins nanobind::object insertionPoint, nanobind::object location); 150436c6c9cSStella Laurenzo 151436c6c9cSStella Laurenzo /// An object reference to the PyContext. 152b56d1ec6SPeter Hawkins nanobind::object context; 153436c6c9cSStella Laurenzo /// An object reference to the current insertion point. 154b56d1ec6SPeter Hawkins nanobind::object insertionPoint; 155436c6c9cSStella Laurenzo /// An object reference to the current location. 156b56d1ec6SPeter Hawkins nanobind::object location; 157436c6c9cSStella Laurenzo // The kind of push that was performed. 158436c6c9cSStella Laurenzo FrameKind frameKind; 159436c6c9cSStella Laurenzo }; 160436c6c9cSStella Laurenzo 161436c6c9cSStella Laurenzo /// Wrapper around MlirContext. 162436c6c9cSStella Laurenzo using PyMlirContextRef = PyObjectRef<PyMlirContext>; 163436c6c9cSStella Laurenzo class PyMlirContext { 164436c6c9cSStella Laurenzo public: 165436c6c9cSStella Laurenzo PyMlirContext() = delete; 166b56d1ec6SPeter Hawkins PyMlirContext(MlirContext context); 167436c6c9cSStella Laurenzo PyMlirContext(const PyMlirContext &) = delete; 168436c6c9cSStella Laurenzo PyMlirContext(PyMlirContext &&) = delete; 169436c6c9cSStella Laurenzo 170b56d1ec6SPeter Hawkins /// For the case of a python __init__ (nanobind::init) method, pybind11 is 171b56d1ec6SPeter Hawkins /// quite strict about needing to return a pointer that is not yet associated 172b56d1ec6SPeter Hawkins /// to an nanobind::object. Since the forContext() method acts like a pool, 173b56d1ec6SPeter Hawkins /// possibly returning a recycled context, it does not satisfy this need. The 174b56d1ec6SPeter Hawkins /// usual way in python to accomplish such a thing is to override __new__, but 175436c6c9cSStella Laurenzo /// that is also not supported by pybind11. Instead, we use this entry 176436c6c9cSStella Laurenzo /// point which always constructs a fresh context (which cannot alias an 177436c6c9cSStella Laurenzo /// existing one because it is fresh). 178436c6c9cSStella Laurenzo static PyMlirContext *createNewContextForInit(); 179436c6c9cSStella Laurenzo 180436c6c9cSStella Laurenzo /// Returns a context reference for the singleton PyMlirContext wrapper for 18178bd1246SAlex Zinenko /// the given context. 182436c6c9cSStella Laurenzo static PyMlirContextRef forContext(MlirContext context); 183436c6c9cSStella Laurenzo ~PyMlirContext(); 184436c6c9cSStella Laurenzo 185436c6c9cSStella Laurenzo /// Accesses the underlying MlirContext. 186436c6c9cSStella Laurenzo MlirContext get() { return context; } 187436c6c9cSStella Laurenzo 188436c6c9cSStella Laurenzo /// Gets a strong reference to this context, which will ensure it is kept 189436c6c9cSStella Laurenzo /// alive for the life of the reference. 190436c6c9cSStella Laurenzo PyMlirContextRef getRef() { 191b56d1ec6SPeter Hawkins return PyMlirContextRef(this, nanobind::cast(this)); 192436c6c9cSStella Laurenzo } 193436c6c9cSStella Laurenzo 194436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirContext. 195b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 196436c6c9cSStella Laurenzo 197436c6c9cSStella Laurenzo /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 198436c6c9cSStella Laurenzo /// Note that PyMlirContext instances are uniqued, so the returned object 199436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirContext 200436c6c9cSStella Laurenzo /// is taken by calling this function. 201b56d1ec6SPeter Hawkins static nanobind::object createFromCapsule(nanobind::object capsule); 202436c6c9cSStella Laurenzo 203436c6c9cSStella Laurenzo /// Gets the count of live context objects. Used for testing. 204436c6c9cSStella Laurenzo static size_t getLiveCount(); 205436c6c9cSStella Laurenzo 206d1fdb416SJohn Demme /// Get a list of Python objects which are still in the live context map. 207d1fdb416SJohn Demme std::vector<PyOperation *> getLiveOperationObjects(); 208d1fdb416SJohn Demme 209436c6c9cSStella Laurenzo /// Gets the count of live operations associated with this context. 210436c6c9cSStella Laurenzo /// Used for testing. 211436c6c9cSStella Laurenzo size_t getLiveOperationCount(); 212436c6c9cSStella Laurenzo 2136b0bed7eSJohn Demme /// Clears the live operations map, returning the number of entries which were 2146b0bed7eSJohn Demme /// invalidated. To be used as a safety mechanism so that API end-users can't 2156b0bed7eSJohn Demme /// corrupt by holding references they shouldn't have accessed in the first 2166b0bed7eSJohn Demme /// place. 2176b0bed7eSJohn Demme size_t clearLiveOperations(); 2186b0bed7eSJohn Demme 219fa19ef7aSIngo Müller /// Removes an operation from the live operations map and sets it invalid. 220fa19ef7aSIngo Müller /// This is useful for when some non-bindings code destroys the operation and 221fa19ef7aSIngo Müller /// the bindings need to made aware. For example, in the case when pass 222fa19ef7aSIngo Müller /// manager is run. 22367897d77SOleksandr "Alex" Zinenko /// 22467897d77SOleksandr "Alex" Zinenko /// Note that this does *NOT* clear the nested operations. 225fa19ef7aSIngo Müller void clearOperation(MlirOperation op); 226fa19ef7aSIngo Müller 227fa19ef7aSIngo Müller /// Clears all operations nested inside the given op using 228fa19ef7aSIngo Müller /// `clearOperation(MlirOperation)`. 229fa19ef7aSIngo Müller void clearOperationsInside(PyOperationBase &op); 23091f11611SOleksandr "Alex" Zinenko void clearOperationsInside(MlirOperation op); 231bdc3e6cbSMaksim Levental 23267897d77SOleksandr "Alex" Zinenko /// Clears the operaiton _and_ all operations inside using 23367897d77SOleksandr "Alex" Zinenko /// `clearOperation(MlirOperation)`. 23467897d77SOleksandr "Alex" Zinenko void clearOperationAndInside(PyOperationBase &op); 23567897d77SOleksandr "Alex" Zinenko 236436c6c9cSStella Laurenzo /// Gets the count of live modules associated with this context. 237436c6c9cSStella Laurenzo /// Used for testing. 238436c6c9cSStella Laurenzo size_t getLiveModuleCount(); 239436c6c9cSStella Laurenzo 240436c6c9cSStella Laurenzo /// Enter and exit the context manager. 241b56d1ec6SPeter Hawkins static nanobind::object contextEnter(nanobind::object context); 242b56d1ec6SPeter Hawkins void contextExit(const nanobind::object &excType, 243b56d1ec6SPeter Hawkins const nanobind::object &excVal, 244b56d1ec6SPeter Hawkins const nanobind::object &excTb); 245436c6c9cSStella Laurenzo 2467ee25bc5SStella Laurenzo /// Attaches a Python callback as a diagnostic handler, returning a 2477ee25bc5SStella Laurenzo /// registration object (internally a PyDiagnosticHandler). 248b56d1ec6SPeter Hawkins nanobind::object attachDiagnosticHandler(nanobind::object callback); 2497ee25bc5SStella Laurenzo 2503ea4c501SRahul Kayaith /// Controls whether error diagnostics should be propagated to diagnostic 2513ea4c501SRahul Kayaith /// handlers, instead of being captured by `ErrorCapture`. 2523ea4c501SRahul Kayaith void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } 2533ea4c501SRahul Kayaith struct ErrorCapture; 2543ea4c501SRahul Kayaith 255436c6c9cSStella Laurenzo private: 256436c6c9cSStella Laurenzo // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 257436c6c9cSStella Laurenzo // preserving the relationship that an MlirContext maps to a single 258436c6c9cSStella Laurenzo // PyMlirContext wrapper. This could be replaced in the future with an 259436c6c9cSStella Laurenzo // extension mechanism on the MlirContext for stashing user pointers. 260436c6c9cSStella Laurenzo // Note that this holds a handle, which does not imply ownership. 261436c6c9cSStella Laurenzo // Mappings will be removed when the context is destructed. 262436c6c9cSStella Laurenzo using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 263f136c800Svfdev static nanobind::ft_mutex live_contexts_mutex; 264436c6c9cSStella Laurenzo static LiveContextMap &getLiveContexts(); 265436c6c9cSStella Laurenzo 266436c6c9cSStella Laurenzo // Interns all live modules associated with this context. Modules tracked 267436c6c9cSStella Laurenzo // in this map are valid. When a module is invalidated, it is removed 268436c6c9cSStella Laurenzo // from this map, and while it still exists as an instance, any 269436c6c9cSStella Laurenzo // attempt to access it will raise an error. 270436c6c9cSStella Laurenzo using LiveModuleMap = 271b56d1ec6SPeter Hawkins llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>; 272436c6c9cSStella Laurenzo LiveModuleMap liveModules; 273436c6c9cSStella Laurenzo 274436c6c9cSStella Laurenzo // Interns all live operations associated with this context. Operations 275436c6c9cSStella Laurenzo // tracked in this map are valid. When an operation is invalidated, it is 276436c6c9cSStella Laurenzo // removed from this map, and while it still exists as an instance, any 277436c6c9cSStella Laurenzo // attempt to access it will raise an error. 278436c6c9cSStella Laurenzo using LiveOperationMap = 279b56d1ec6SPeter Hawkins llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>; 280e2c49a45SPeter Hawkins nanobind::ft_mutex liveOperationsMutex; 281e2c49a45SPeter Hawkins 282e2c49a45SPeter Hawkins // Guarded by liveOperationsMutex in free-threading mode. 283436c6c9cSStella Laurenzo LiveOperationMap liveOperations; 284436c6c9cSStella Laurenzo 2853ea4c501SRahul Kayaith bool emitErrorDiagnostics = false; 2863ea4c501SRahul Kayaith 287436c6c9cSStella Laurenzo MlirContext context; 288436c6c9cSStella Laurenzo friend class PyModule; 289436c6c9cSStella Laurenzo friend class PyOperation; 290436c6c9cSStella Laurenzo }; 291436c6c9cSStella Laurenzo 292436c6c9cSStella Laurenzo /// Used in function arguments when None should resolve to the current context 293436c6c9cSStella Laurenzo /// manager set instance. 294436c6c9cSStella Laurenzo class DefaultingPyMlirContext 295436c6c9cSStella Laurenzo : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 296436c6c9cSStella Laurenzo public: 297436c6c9cSStella Laurenzo using Defaulting::Defaulting; 298a6e7d024SStella Laurenzo static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 299436c6c9cSStella Laurenzo static PyMlirContext &resolve(); 300436c6c9cSStella Laurenzo }; 301436c6c9cSStella Laurenzo 302436c6c9cSStella Laurenzo /// Base class for all objects that directly or indirectly depend on an 303436c6c9cSStella Laurenzo /// MlirContext. The lifetime of the context will extend at least to the 304436c6c9cSStella Laurenzo /// lifetime of these instances. 305436c6c9cSStella Laurenzo /// Immutable objects that depend on a context extend this directly. 306436c6c9cSStella Laurenzo class BaseContextObject { 307436c6c9cSStella Laurenzo public: 308436c6c9cSStella Laurenzo BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 309436c6c9cSStella Laurenzo assert(this->contextRef && 310436c6c9cSStella Laurenzo "context object constructed with null context ref"); 311436c6c9cSStella Laurenzo } 312436c6c9cSStella Laurenzo 313436c6c9cSStella Laurenzo /// Accesses the context reference. 314436c6c9cSStella Laurenzo PyMlirContextRef &getContext() { return contextRef; } 315436c6c9cSStella Laurenzo 316436c6c9cSStella Laurenzo private: 317436c6c9cSStella Laurenzo PyMlirContextRef contextRef; 318436c6c9cSStella Laurenzo }; 319436c6c9cSStella Laurenzo 3203ea4c501SRahul Kayaith /// Wrapper around an MlirLocation. 3213ea4c501SRahul Kayaith class PyLocation : public BaseContextObject { 3223ea4c501SRahul Kayaith public: 3233ea4c501SRahul Kayaith PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 3243ea4c501SRahul Kayaith : BaseContextObject(std::move(contextRef)), loc(loc) {} 3253ea4c501SRahul Kayaith 3263ea4c501SRahul Kayaith operator MlirLocation() const { return loc; } 3273ea4c501SRahul Kayaith MlirLocation get() const { return loc; } 3283ea4c501SRahul Kayaith 3293ea4c501SRahul Kayaith /// Enter and exit the context manager. 330b56d1ec6SPeter Hawkins static nanobind::object contextEnter(nanobind::object location); 331b56d1ec6SPeter Hawkins void contextExit(const nanobind::object &excType, 332b56d1ec6SPeter Hawkins const nanobind::object &excVal, 333b56d1ec6SPeter Hawkins const nanobind::object &excTb); 3343ea4c501SRahul Kayaith 3353ea4c501SRahul Kayaith /// Gets a capsule wrapping the void* within the MlirLocation. 336b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 3373ea4c501SRahul Kayaith 3383ea4c501SRahul Kayaith /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 3393ea4c501SRahul Kayaith /// Note that PyLocation instances are uniqued, so the returned object 3403ea4c501SRahul Kayaith /// may be a pre-existing object. Ownership of the underlying MlirLocation 3413ea4c501SRahul Kayaith /// is taken by calling this function. 342b56d1ec6SPeter Hawkins static PyLocation createFromCapsule(nanobind::object capsule); 3433ea4c501SRahul Kayaith 3443ea4c501SRahul Kayaith private: 3453ea4c501SRahul Kayaith MlirLocation loc; 3463ea4c501SRahul Kayaith }; 3473ea4c501SRahul Kayaith 3487ee25bc5SStella Laurenzo /// Python class mirroring the C MlirDiagnostic struct. Note that these structs 3497ee25bc5SStella Laurenzo /// are only valid for the duration of a diagnostic callback and attempting 3507ee25bc5SStella Laurenzo /// to access them outside of that will raise an exception. This applies to 3517ee25bc5SStella Laurenzo /// nested diagnostics (in the notes) as well. 3527ee25bc5SStella Laurenzo class PyDiagnostic { 3537ee25bc5SStella Laurenzo public: 3547ee25bc5SStella Laurenzo PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} 3557ee25bc5SStella Laurenzo void invalidate(); 3567ee25bc5SStella Laurenzo bool isValid() { return valid; } 3577ee25bc5SStella Laurenzo MlirDiagnosticSeverity getSeverity(); 3587ee25bc5SStella Laurenzo PyLocation getLocation(); 359b56d1ec6SPeter Hawkins nanobind::str getMessage(); 360b56d1ec6SPeter Hawkins nanobind::tuple getNotes(); 3617ee25bc5SStella Laurenzo 3623ea4c501SRahul Kayaith /// Materialized diagnostic information. This is safe to access outside the 3633ea4c501SRahul Kayaith /// diagnostic callback. 3643ea4c501SRahul Kayaith struct DiagnosticInfo { 3653ea4c501SRahul Kayaith MlirDiagnosticSeverity severity; 3663ea4c501SRahul Kayaith PyLocation location; 3673ea4c501SRahul Kayaith std::string message; 3683ea4c501SRahul Kayaith std::vector<DiagnosticInfo> notes; 3693ea4c501SRahul Kayaith }; 3703ea4c501SRahul Kayaith DiagnosticInfo getInfo(); 3713ea4c501SRahul Kayaith 3727ee25bc5SStella Laurenzo private: 3737ee25bc5SStella Laurenzo MlirDiagnostic diagnostic; 3747ee25bc5SStella Laurenzo 3757ee25bc5SStella Laurenzo void checkValid(); 3767ee25bc5SStella Laurenzo /// If notes have been materialized from the diagnostic, then this will 3777ee25bc5SStella Laurenzo /// be populated with the corresponding objects (all castable to 3787ee25bc5SStella Laurenzo /// PyDiagnostic). 379b56d1ec6SPeter Hawkins std::optional<nanobind::tuple> materializedNotes; 3807ee25bc5SStella Laurenzo bool valid = true; 3817ee25bc5SStella Laurenzo }; 3827ee25bc5SStella Laurenzo 3837ee25bc5SStella Laurenzo /// Represents a diagnostic handler attached to the context. The handler's 3847ee25bc5SStella Laurenzo /// callback will be invoked with PyDiagnostic instances until the detach() 3857ee25bc5SStella Laurenzo /// method is called or the context is destroyed. A diagnostic handler can be 3867ee25bc5SStella Laurenzo /// the subject of a `with` block, which will detach it when the block exits. 3877ee25bc5SStella Laurenzo /// 3887ee25bc5SStella Laurenzo /// Since diagnostic handlers can call back into Python code which can do 3897ee25bc5SStella Laurenzo /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, 3907ee25bc5SStella Laurenzo /// etc), this is generally not deemed to be a great user-level API. Users 3917ee25bc5SStella Laurenzo /// should generally use some form of DiagnosticCollector. If the handler raises 3927ee25bc5SStella Laurenzo /// any exceptions, they will just be emitted to stderr and dropped. 3937ee25bc5SStella Laurenzo /// 3947ee25bc5SStella Laurenzo /// The unique usage of this class means that its lifetime management is 3957ee25bc5SStella Laurenzo /// different from most other parts of the API. Instances are always created 3967ee25bc5SStella Laurenzo /// in an attached state and can transition to a detached state by either: 3977ee25bc5SStella Laurenzo /// a) The context being destroyed and unregistering all handlers. 3987ee25bc5SStella Laurenzo /// b) An explicit call to detach(). 3997ee25bc5SStella Laurenzo /// The object may remain live from a Python perspective for an arbitrary time 4007ee25bc5SStella Laurenzo /// after detachment, but there is nothing the user can do with it (since there 4017ee25bc5SStella Laurenzo /// is no way to attach an existing handler object). 4027ee25bc5SStella Laurenzo class PyDiagnosticHandler { 4037ee25bc5SStella Laurenzo public: 404b56d1ec6SPeter Hawkins PyDiagnosticHandler(MlirContext context, nanobind::object callback); 4057ee25bc5SStella Laurenzo ~PyDiagnosticHandler(); 4067ee25bc5SStella Laurenzo 40710de5512SJacques Pienaar bool isAttached() { return registeredID.has_value(); } 4087ee25bc5SStella Laurenzo bool getHadError() { return hadError; } 4097ee25bc5SStella Laurenzo 4107ee25bc5SStella Laurenzo /// Detaches the handler. Does nothing if not attached. 4117ee25bc5SStella Laurenzo void detach(); 4127ee25bc5SStella Laurenzo 413b56d1ec6SPeter Hawkins nanobind::object contextEnter() { return nanobind::cast(this); } 414b56d1ec6SPeter Hawkins void contextExit(const nanobind::object &excType, 415b56d1ec6SPeter Hawkins const nanobind::object &excVal, 416b56d1ec6SPeter Hawkins const nanobind::object &excTb) { 4177ee25bc5SStella Laurenzo detach(); 4187ee25bc5SStella Laurenzo } 4197ee25bc5SStella Laurenzo 4207ee25bc5SStella Laurenzo private: 4217ee25bc5SStella Laurenzo MlirContext context; 422b56d1ec6SPeter Hawkins nanobind::object callback; 4230a81ace0SKazu Hirata std::optional<MlirDiagnosticHandlerID> registeredID; 4247ee25bc5SStella Laurenzo bool hadError = false; 4257ee25bc5SStella Laurenzo friend class PyMlirContext; 4267ee25bc5SStella Laurenzo }; 4277ee25bc5SStella Laurenzo 4283ea4c501SRahul Kayaith /// RAII object that captures any error diagnostics emitted to the provided 4293ea4c501SRahul Kayaith /// context. 4303ea4c501SRahul Kayaith struct PyMlirContext::ErrorCapture { 4313ea4c501SRahul Kayaith ErrorCapture(PyMlirContextRef ctx) 4323ea4c501SRahul Kayaith : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( 4333ea4c501SRahul Kayaith ctx->get(), handler, /*userData=*/this, 4343ea4c501SRahul Kayaith /*deleteUserData=*/nullptr)) {} 4353ea4c501SRahul Kayaith ~ErrorCapture() { 4363ea4c501SRahul Kayaith mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); 4373ea4c501SRahul Kayaith assert(errors.empty() && "unhandled captured errors"); 4383ea4c501SRahul Kayaith } 4393ea4c501SRahul Kayaith 4403ea4c501SRahul Kayaith std::vector<PyDiagnostic::DiagnosticInfo> take() { 4413ea4c501SRahul Kayaith return std::move(errors); 4423ea4c501SRahul Kayaith }; 4433ea4c501SRahul Kayaith 4443ea4c501SRahul Kayaith private: 4453ea4c501SRahul Kayaith PyMlirContextRef ctx; 4463ea4c501SRahul Kayaith MlirDiagnosticHandlerID handlerID; 4473ea4c501SRahul Kayaith std::vector<PyDiagnostic::DiagnosticInfo> errors; 4483ea4c501SRahul Kayaith 4493ea4c501SRahul Kayaith static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); 4503ea4c501SRahul Kayaith }; 4513ea4c501SRahul Kayaith 452436c6c9cSStella Laurenzo /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 453436c6c9cSStella Laurenzo /// order to differentiate it from the `Dialect` base class which is extended by 454436c6c9cSStella Laurenzo /// plugins which extend dialect functionality through extension python code. 455436c6c9cSStella Laurenzo /// This should be seen as the "low-level" object and `Dialect` as the 456436c6c9cSStella Laurenzo /// high-level, user facing object. 457436c6c9cSStella Laurenzo class PyDialectDescriptor : public BaseContextObject { 458436c6c9cSStella Laurenzo public: 459436c6c9cSStella Laurenzo PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 460436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 461436c6c9cSStella Laurenzo 462436c6c9cSStella Laurenzo MlirDialect get() { return dialect; } 463436c6c9cSStella Laurenzo 464436c6c9cSStella Laurenzo private: 465436c6c9cSStella Laurenzo MlirDialect dialect; 466436c6c9cSStella Laurenzo }; 467436c6c9cSStella Laurenzo 468436c6c9cSStella Laurenzo /// User-level object for accessing dialects with dotted syntax such as: 469436c6c9cSStella Laurenzo /// ctx.dialect.std 470436c6c9cSStella Laurenzo class PyDialects : public BaseContextObject { 471436c6c9cSStella Laurenzo public: 472436c6c9cSStella Laurenzo PyDialects(PyMlirContextRef contextRef) 473436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)) {} 474436c6c9cSStella Laurenzo 475436c6c9cSStella Laurenzo MlirDialect getDialectForKey(const std::string &key, bool attrError); 476436c6c9cSStella Laurenzo }; 477436c6c9cSStella Laurenzo 478436c6c9cSStella Laurenzo /// User-level dialect object. For dialects that have a registered extension, 479436c6c9cSStella Laurenzo /// this will be the base class of the extension dialect type. For un-extended, 480436c6c9cSStella Laurenzo /// objects of this type will be returned directly. 481436c6c9cSStella Laurenzo class PyDialect { 482436c6c9cSStella Laurenzo public: 483b56d1ec6SPeter Hawkins PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} 484436c6c9cSStella Laurenzo 485b56d1ec6SPeter Hawkins nanobind::object getDescriptor() { return descriptor; } 486436c6c9cSStella Laurenzo 487436c6c9cSStella Laurenzo private: 488b56d1ec6SPeter Hawkins nanobind::object descriptor; 489436c6c9cSStella Laurenzo }; 490436c6c9cSStella Laurenzo 4915e83a5b4SStella Laurenzo /// Wrapper around an MlirDialectRegistry. 4925e83a5b4SStella Laurenzo /// Upon construction, the Python wrapper takes ownership of the 4935e83a5b4SStella Laurenzo /// underlying MlirDialectRegistry. 4945e83a5b4SStella Laurenzo class PyDialectRegistry { 4955e83a5b4SStella Laurenzo public: 4965e83a5b4SStella Laurenzo PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} 4975e83a5b4SStella Laurenzo PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} 4985e83a5b4SStella Laurenzo ~PyDialectRegistry() { 4995e83a5b4SStella Laurenzo if (!mlirDialectRegistryIsNull(registry)) 5005e83a5b4SStella Laurenzo mlirDialectRegistryDestroy(registry); 5015e83a5b4SStella Laurenzo } 5025e83a5b4SStella Laurenzo PyDialectRegistry(PyDialectRegistry &) = delete; 503ea2e83afSAdrian Kuegel PyDialectRegistry(PyDialectRegistry &&other) noexcept 504ea2e83afSAdrian Kuegel : registry(other.registry) { 5055e83a5b4SStella Laurenzo other.registry = {nullptr}; 5065e83a5b4SStella Laurenzo } 5075e83a5b4SStella Laurenzo 5085e83a5b4SStella Laurenzo operator MlirDialectRegistry() const { return registry; } 5095e83a5b4SStella Laurenzo MlirDialectRegistry get() const { return registry; } 5105e83a5b4SStella Laurenzo 511b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 512b56d1ec6SPeter Hawkins static PyDialectRegistry createFromCapsule(nanobind::object capsule); 5135e83a5b4SStella Laurenzo 5145e83a5b4SStella Laurenzo private: 5155e83a5b4SStella Laurenzo MlirDialectRegistry registry; 5165e83a5b4SStella Laurenzo }; 5175e83a5b4SStella Laurenzo 518436c6c9cSStella Laurenzo /// Used in function arguments when None should resolve to the current context 519436c6c9cSStella Laurenzo /// manager set instance. 520436c6c9cSStella Laurenzo class DefaultingPyLocation 521436c6c9cSStella Laurenzo : public Defaulting<DefaultingPyLocation, PyLocation> { 522436c6c9cSStella Laurenzo public: 523436c6c9cSStella Laurenzo using Defaulting::Defaulting; 524a6e7d024SStella Laurenzo static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 525436c6c9cSStella Laurenzo static PyLocation &resolve(); 526436c6c9cSStella Laurenzo 527436c6c9cSStella Laurenzo operator MlirLocation() const { return *get(); } 528436c6c9cSStella Laurenzo }; 529436c6c9cSStella Laurenzo 530436c6c9cSStella Laurenzo /// Wrapper around MlirModule. 531436c6c9cSStella Laurenzo /// This is the top-level, user-owned object that contains regions/ops/blocks. 532436c6c9cSStella Laurenzo class PyModule; 533436c6c9cSStella Laurenzo using PyModuleRef = PyObjectRef<PyModule>; 534436c6c9cSStella Laurenzo class PyModule : public BaseContextObject { 535436c6c9cSStella Laurenzo public: 536436c6c9cSStella Laurenzo /// Returns a PyModule reference for the given MlirModule. This may return 537436c6c9cSStella Laurenzo /// a pre-existing or new object. 538436c6c9cSStella Laurenzo static PyModuleRef forModule(MlirModule module); 539436c6c9cSStella Laurenzo PyModule(PyModule &) = delete; 540436c6c9cSStella Laurenzo PyModule(PyMlirContext &&) = delete; 541436c6c9cSStella Laurenzo ~PyModule(); 542436c6c9cSStella Laurenzo 543436c6c9cSStella Laurenzo /// Gets the backing MlirModule. 544436c6c9cSStella Laurenzo MlirModule get() { return module; } 545436c6c9cSStella Laurenzo 546436c6c9cSStella Laurenzo /// Gets a strong reference to this module. 547436c6c9cSStella Laurenzo PyModuleRef getRef() { 548b56d1ec6SPeter Hawkins return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle)); 549436c6c9cSStella Laurenzo } 550436c6c9cSStella Laurenzo 551436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirModule. 552436c6c9cSStella Laurenzo /// Note that the module does not (yet) provide a corresponding factory for 553436c6c9cSStella Laurenzo /// constructing from a capsule as that would require uniquing PyModule 554436c6c9cSStella Laurenzo /// instances, which is not currently done. 555b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 556436c6c9cSStella Laurenzo 557436c6c9cSStella Laurenzo /// Creates a PyModule from the MlirModule wrapped by a capsule. 558436c6c9cSStella Laurenzo /// Note that PyModule instances are uniqued, so the returned object 559436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirModule 560436c6c9cSStella Laurenzo /// is taken by calling this function. 561b56d1ec6SPeter Hawkins static nanobind::object createFromCapsule(nanobind::object capsule); 562436c6c9cSStella Laurenzo 563436c6c9cSStella Laurenzo private: 564436c6c9cSStella Laurenzo PyModule(PyMlirContextRef contextRef, MlirModule module); 565436c6c9cSStella Laurenzo MlirModule module; 566b56d1ec6SPeter Hawkins nanobind::handle handle; 567436c6c9cSStella Laurenzo }; 568436c6c9cSStella Laurenzo 569204acc5cSJacques Pienaar class PyAsmState; 570204acc5cSJacques Pienaar 571436c6c9cSStella Laurenzo /// Base class for PyOperation and PyOpView which exposes the primary, user 572436c6c9cSStella Laurenzo /// visible methods for manipulating it. 573436c6c9cSStella Laurenzo class PyOperationBase { 574436c6c9cSStella Laurenzo public: 575436c6c9cSStella Laurenzo virtual ~PyOperationBase() = default; 576436c6c9cSStella Laurenzo /// Implements the bound 'print' method and helps with others. 577204acc5cSJacques Pienaar void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo, 578ace1d0adSStella Laurenzo bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 579b56d1ec6SPeter Hawkins bool assumeVerified, nanobind::object fileObject, bool binary, 580abad8455SJonas Rickert bool skipRegions); 581b56d1ec6SPeter Hawkins void print(PyAsmState &state, nanobind::object fileObject, bool binary); 582204acc5cSJacques Pienaar 583b56d1ec6SPeter Hawkins nanobind::object getAsm(bool binary, 5840a81ace0SKazu Hirata std::optional<int64_t> largeElementsLimit, 585436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 586ace1d0adSStella Laurenzo bool printGenericOpForm, bool useLocalScope, 587abad8455SJonas Rickert bool assumeVerified, bool skipRegions); 588436c6c9cSStella Laurenzo 58989418ddcSMehdi Amini // Implement the bound 'writeBytecode' method. 590b56d1ec6SPeter Hawkins void writeBytecode(const nanobind::object &fileObject, 5910610e2f6SJacques Pienaar std::optional<int64_t> bytecodeVersion); 59289418ddcSMehdi Amini 59347148832SHideto Ueno // Implement the walk method. 59447148832SHideto Ueno void walk(std::function<MlirWalkResult(MlirOperation)> callback, 59547148832SHideto Ueno MlirWalkOrder walkOrder); 59647148832SHideto Ueno 59724685aaeSAlex Zinenko /// Moves the operation before or after the other operation. 59824685aaeSAlex Zinenko void moveAfter(PyOperationBase &other); 59924685aaeSAlex Zinenko void moveBefore(PyOperationBase &other); 60024685aaeSAlex Zinenko 6013ea4c501SRahul Kayaith /// Verify the operation. Throws `MLIRError` if verification fails, and 6023ea4c501SRahul Kayaith /// returns `true` otherwise. 6033ea4c501SRahul Kayaith bool verify(); 6043ea4c501SRahul Kayaith 605436c6c9cSStella Laurenzo /// Each must provide access to the raw Operation. 606436c6c9cSStella Laurenzo virtual PyOperation &getOperation() = 0; 607436c6c9cSStella Laurenzo }; 608436c6c9cSStella Laurenzo 609436c6c9cSStella Laurenzo /// Wrapper around PyOperation. 610436c6c9cSStella Laurenzo /// Operations exist in either an attached (dependent) or detached (top-level) 611436c6c9cSStella Laurenzo /// state. In the detached state (as on creation), an operation is owned by 612436c6c9cSStella Laurenzo /// the creator and its lifetime extends either until its reference count 613436c6c9cSStella Laurenzo /// drops to zero or it is attached to a parent, at which point its lifetime 614436c6c9cSStella Laurenzo /// is bounded by its top-level parent reference. 615436c6c9cSStella Laurenzo class PyOperation; 616436c6c9cSStella Laurenzo using PyOperationRef = PyObjectRef<PyOperation>; 617436c6c9cSStella Laurenzo class PyOperation : public PyOperationBase, public BaseContextObject { 618436c6c9cSStella Laurenzo public: 619bd87241cSMehdi Amini ~PyOperation() override; 620436c6c9cSStella Laurenzo PyOperation &getOperation() override { return *this; } 621436c6c9cSStella Laurenzo 622436c6c9cSStella Laurenzo /// Returns a PyOperation for the given MlirOperation, optionally associating 623436c6c9cSStella Laurenzo /// it with a parentKeepAlive. 624436c6c9cSStella Laurenzo static PyOperationRef 625436c6c9cSStella Laurenzo forOperation(PyMlirContextRef contextRef, MlirOperation operation, 626b56d1ec6SPeter Hawkins nanobind::object parentKeepAlive = nanobind::object()); 627436c6c9cSStella Laurenzo 628436c6c9cSStella Laurenzo /// Creates a detached operation. The operation must not be associated with 629436c6c9cSStella Laurenzo /// any existing live operation. 630436c6c9cSStella Laurenzo static PyOperationRef 631436c6c9cSStella Laurenzo createDetached(PyMlirContextRef contextRef, MlirOperation operation, 632b56d1ec6SPeter Hawkins nanobind::object parentKeepAlive = nanobind::object()); 633436c6c9cSStella Laurenzo 63437107e17Srkayaith /// Parses a source string (either text assembly or bytecode), creating a 63537107e17Srkayaith /// detached operation. 63637107e17Srkayaith static PyOperationRef parse(PyMlirContextRef contextRef, 63737107e17Srkayaith const std::string &sourceStr, 63837107e17Srkayaith const std::string &sourceName); 63937107e17Srkayaith 64024685aaeSAlex Zinenko /// Detaches the operation from its parent block and updates its state 64124685aaeSAlex Zinenko /// accordingly. 64224685aaeSAlex Zinenko void detachFromParent() { 64324685aaeSAlex Zinenko mlirOperationRemoveFromParent(getOperation()); 64424685aaeSAlex Zinenko setDetached(); 645b56d1ec6SPeter Hawkins parentKeepAlive = nanobind::object(); 64624685aaeSAlex Zinenko } 64724685aaeSAlex Zinenko 648436c6c9cSStella Laurenzo /// Gets the backing operation. 649436c6c9cSStella Laurenzo operator MlirOperation() const { return get(); } 650436c6c9cSStella Laurenzo MlirOperation get() const { 651436c6c9cSStella Laurenzo checkValid(); 652436c6c9cSStella Laurenzo return operation; 653436c6c9cSStella Laurenzo } 654436c6c9cSStella Laurenzo 655436c6c9cSStella Laurenzo PyOperationRef getRef() { 656b56d1ec6SPeter Hawkins return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle)); 657436c6c9cSStella Laurenzo } 658436c6c9cSStella Laurenzo 659436c6c9cSStella Laurenzo bool isAttached() { return attached; } 660b56d1ec6SPeter Hawkins void setAttached(const nanobind::object &parent = nanobind::object()) { 661436c6c9cSStella Laurenzo assert(!attached && "operation already attached"); 662436c6c9cSStella Laurenzo attached = true; 663436c6c9cSStella Laurenzo } 66424685aaeSAlex Zinenko void setDetached() { 66524685aaeSAlex Zinenko assert(attached && "operation already detached"); 66624685aaeSAlex Zinenko attached = false; 66724685aaeSAlex Zinenko } 668436c6c9cSStella Laurenzo void checkValid() const; 669436c6c9cSStella Laurenzo 670436c6c9cSStella Laurenzo /// Gets the owning block or raises an exception if the operation has no 671436c6c9cSStella Laurenzo /// owning block. 672436c6c9cSStella Laurenzo PyBlock getBlock(); 673436c6c9cSStella Laurenzo 674436c6c9cSStella Laurenzo /// Gets the parent operation or raises an exception if the operation has 675436c6c9cSStella Laurenzo /// no parent. 6760a81ace0SKazu Hirata std::optional<PyOperationRef> getParentOperation(); 677436c6c9cSStella Laurenzo 6780126e906SJohn Demme /// Gets a capsule wrapping the void* within the MlirOperation. 679b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 6800126e906SJohn Demme 6810126e906SJohn Demme /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 6820126e906SJohn Demme /// Ownership of the underlying MlirOperation is taken by calling this 6830126e906SJohn Demme /// function. 684b56d1ec6SPeter Hawkins static nanobind::object createFromCapsule(nanobind::object capsule); 6850126e906SJohn Demme 686436c6c9cSStella Laurenzo /// Creates an operation. See corresponding python docstring. 687b56d1ec6SPeter Hawkins static nanobind::object 688f4125e02SPeter Hawkins create(std::string_view name, std::optional<std::vector<PyType *>> results, 689*acde3f72SPeter Hawkins llvm::ArrayRef<MlirValue> operands, 690b56d1ec6SPeter Hawkins std::optional<nanobind::dict> attributes, 6910a81ace0SKazu Hirata std::optional<std::vector<PyBlock *>> successors, int regions, 692b56d1ec6SPeter Hawkins DefaultingPyLocation location, const nanobind::object &ip, 693f573bc24SJacques Pienaar bool inferType); 694436c6c9cSStella Laurenzo 695436c6c9cSStella Laurenzo /// Creates an OpView suitable for this operation. 696b56d1ec6SPeter Hawkins nanobind::object createOpView(); 697436c6c9cSStella Laurenzo 69849745f87SMike Urbach /// Erases the underlying MlirOperation, removes its pointer from the 69949745f87SMike Urbach /// parent context's live operations map, and sets the valid bit false. 70049745f87SMike Urbach void erase(); 70149745f87SMike Urbach 7026b0bed7eSJohn Demme /// Invalidate the operation. 7036b0bed7eSJohn Demme void setInvalid() { valid = false; } 7046b0bed7eSJohn Demme 705774818c0SDominik Grewe /// Clones this operation. 706b56d1ec6SPeter Hawkins nanobind::object clone(const nanobind::object &ip); 707774818c0SDominik Grewe 708436c6c9cSStella Laurenzo PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 709e30b7030SPeter Hawkins 710e30b7030SPeter Hawkins private: 711436c6c9cSStella Laurenzo static PyOperationRef createInstance(PyMlirContextRef contextRef, 712436c6c9cSStella Laurenzo MlirOperation operation, 713b56d1ec6SPeter Hawkins nanobind::object parentKeepAlive); 714436c6c9cSStella Laurenzo 715436c6c9cSStella Laurenzo MlirOperation operation; 716b56d1ec6SPeter Hawkins nanobind::handle handle; 717436c6c9cSStella Laurenzo // Keeps the parent alive, regardless of whether it is an Operation or 718436c6c9cSStella Laurenzo // Module. 719436c6c9cSStella Laurenzo // TODO: As implemented, this facility is only sufficient for modeling the 720436c6c9cSStella Laurenzo // trivial module parent back-reference. Generalize this to also account for 721436c6c9cSStella Laurenzo // transitions from detached to attached and address TODOs in the 722436c6c9cSStella Laurenzo // ir_operation.py regarding testing corresponding lifetime guarantees. 723b56d1ec6SPeter Hawkins nanobind::object parentKeepAlive; 724436c6c9cSStella Laurenzo bool attached = true; 725436c6c9cSStella Laurenzo bool valid = true; 72624685aaeSAlex Zinenko 72724685aaeSAlex Zinenko friend class PyOperationBase; 72830d61893SAlex Zinenko friend class PySymbolTable; 729436c6c9cSStella Laurenzo }; 730436c6c9cSStella Laurenzo 731436c6c9cSStella Laurenzo /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 732436c6c9cSStella Laurenzo /// providing more instance-specific accessors and serve as the base class for 733436c6c9cSStella Laurenzo /// custom ODS-style operation classes. Since this class is subclass on the 734436c6c9cSStella Laurenzo /// python side, it must present an __init__ method that operates in pure 735436c6c9cSStella Laurenzo /// python types. 736436c6c9cSStella Laurenzo class PyOpView : public PyOperationBase { 737436c6c9cSStella Laurenzo public: 738b56d1ec6SPeter Hawkins PyOpView(const nanobind::object &operationObject); 739436c6c9cSStella Laurenzo PyOperation &getOperation() override { return operation; } 740436c6c9cSStella Laurenzo 741b56d1ec6SPeter Hawkins nanobind::object getOperationObject() { return operationObject; } 742436c6c9cSStella Laurenzo 743f4125e02SPeter Hawkins static nanobind::object 744f4125e02SPeter Hawkins buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec, 745f4125e02SPeter Hawkins nanobind::object operandSegmentSpecObj, 746f4125e02SPeter Hawkins nanobind::object resultSegmentSpecObj, 747f4125e02SPeter Hawkins std::optional<nanobind::list> resultTypeList, 748f4125e02SPeter Hawkins nanobind::list operandList, 749f4125e02SPeter Hawkins std::optional<nanobind::dict> attributes, 7500a81ace0SKazu Hirata std::optional<std::vector<PyBlock *>> successors, 7510a81ace0SKazu Hirata std::optional<int> regions, DefaultingPyLocation location, 752b56d1ec6SPeter Hawkins const nanobind::object &maybeIp); 753436c6c9cSStella Laurenzo 754a7f8b7cdSRahul Kayaith /// Construct an instance of a class deriving from OpView, bypassing its 755a7f8b7cdSRahul Kayaith /// `__init__` method. The derived class will typically define a constructor 756a7f8b7cdSRahul Kayaith /// that provides a convenient builder, but we need to side-step this when 757a7f8b7cdSRahul Kayaith /// constructing an `OpView` for an already-built operation. 758a7f8b7cdSRahul Kayaith /// 759a7f8b7cdSRahul Kayaith /// The caller is responsible for verifying that `operation` is a valid 760a7f8b7cdSRahul Kayaith /// operation to construct `cls` with. 761b56d1ec6SPeter Hawkins static nanobind::object constructDerived(const nanobind::object &cls, 762b56d1ec6SPeter Hawkins const nanobind::object &operation); 763a7f8b7cdSRahul Kayaith 764436c6c9cSStella Laurenzo private: 765436c6c9cSStella Laurenzo PyOperation &operation; // For efficient, cast-free access from C++ 766b56d1ec6SPeter Hawkins nanobind::object operationObject; // Holds the reference. 767436c6c9cSStella Laurenzo }; 768436c6c9cSStella Laurenzo 769436c6c9cSStella Laurenzo /// Wrapper around an MlirRegion. 770436c6c9cSStella Laurenzo /// Regions are managed completely by their containing operation. Unlike the 771436c6c9cSStella Laurenzo /// C++ API, the python API does not support detached regions. 772436c6c9cSStella Laurenzo class PyRegion { 773436c6c9cSStella Laurenzo public: 774436c6c9cSStella Laurenzo PyRegion(PyOperationRef parentOperation, MlirRegion region) 775436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), region(region) { 776436c6c9cSStella Laurenzo assert(!mlirRegionIsNull(region) && "python region cannot be null"); 777436c6c9cSStella Laurenzo } 77878f2dae0SAlex Zinenko operator MlirRegion() const { return region; } 779436c6c9cSStella Laurenzo 780436c6c9cSStella Laurenzo MlirRegion get() { return region; } 781436c6c9cSStella Laurenzo PyOperationRef &getParentOperation() { return parentOperation; } 782436c6c9cSStella Laurenzo 783436c6c9cSStella Laurenzo void checkValid() { return parentOperation->checkValid(); } 784436c6c9cSStella Laurenzo 785436c6c9cSStella Laurenzo private: 786436c6c9cSStella Laurenzo PyOperationRef parentOperation; 787436c6c9cSStella Laurenzo MlirRegion region; 788436c6c9cSStella Laurenzo }; 789436c6c9cSStella Laurenzo 79075453714SJacques Pienaar /// Wrapper around an MlirAsmState. 79175453714SJacques Pienaar class PyAsmState { 79275453714SJacques Pienaar public: 79375453714SJacques Pienaar PyAsmState(MlirValue value, bool useLocalScope) { 79475453714SJacques Pienaar flags = mlirOpPrintingFlagsCreate(); 79575453714SJacques Pienaar // The OpPrintingFlags are not exposed Python side, create locally and 79675453714SJacques Pienaar // associate lifetime with the state. 79775453714SJacques Pienaar if (useLocalScope) 79875453714SJacques Pienaar mlirOpPrintingFlagsUseLocalScope(flags); 79975453714SJacques Pienaar state = mlirAsmStateCreateForValue(value, flags); 80075453714SJacques Pienaar } 801a677a173SJacques Pienaar 802a677a173SJacques Pienaar PyAsmState(PyOperationBase &operation, bool useLocalScope) { 803a677a173SJacques Pienaar flags = mlirOpPrintingFlagsCreate(); 804a677a173SJacques Pienaar // The OpPrintingFlags are not exposed Python side, create locally and 805a677a173SJacques Pienaar // associate lifetime with the state. 806a677a173SJacques Pienaar if (useLocalScope) 807a677a173SJacques Pienaar mlirOpPrintingFlagsUseLocalScope(flags); 808a677a173SJacques Pienaar state = 809a677a173SJacques Pienaar mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); 810a677a173SJacques Pienaar } 8117c850867SMaksim Levental ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } 81275453714SJacques Pienaar // Delete copy constructors. 81375453714SJacques Pienaar PyAsmState(PyAsmState &other) = delete; 81475453714SJacques Pienaar PyAsmState(const PyAsmState &other) = delete; 81575453714SJacques Pienaar 81675453714SJacques Pienaar MlirAsmState get() { return state; } 81775453714SJacques Pienaar 81875453714SJacques Pienaar private: 81975453714SJacques Pienaar MlirAsmState state; 82075453714SJacques Pienaar MlirOpPrintingFlags flags; 82175453714SJacques Pienaar }; 82275453714SJacques Pienaar 823436c6c9cSStella Laurenzo /// Wrapper around an MlirBlock. 824436c6c9cSStella Laurenzo /// Blocks are managed completely by their containing operation. Unlike the 825436c6c9cSStella Laurenzo /// C++ API, the python API does not support detached blocks. 826436c6c9cSStella Laurenzo class PyBlock { 827436c6c9cSStella Laurenzo public: 828436c6c9cSStella Laurenzo PyBlock(PyOperationRef parentOperation, MlirBlock block) 829436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), block(block) { 830436c6c9cSStella Laurenzo assert(!mlirBlockIsNull(block) && "python block cannot be null"); 831436c6c9cSStella Laurenzo } 832436c6c9cSStella Laurenzo 833436c6c9cSStella Laurenzo MlirBlock get() { return block; } 834436c6c9cSStella Laurenzo PyOperationRef &getParentOperation() { return parentOperation; } 835436c6c9cSStella Laurenzo 836436c6c9cSStella Laurenzo void checkValid() { return parentOperation->checkValid(); } 837436c6c9cSStella Laurenzo 838c83318e3SAdam Paszke /// Gets a capsule wrapping the void* within the MlirBlock. 839b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 840c83318e3SAdam Paszke 841436c6c9cSStella Laurenzo private: 842436c6c9cSStella Laurenzo PyOperationRef parentOperation; 843436c6c9cSStella Laurenzo MlirBlock block; 844436c6c9cSStella Laurenzo }; 845436c6c9cSStella Laurenzo 846436c6c9cSStella Laurenzo /// An insertion point maintains a pointer to a Block and a reference operation. 847436c6c9cSStella Laurenzo /// Calls to insert() will insert a new operation before the 848436c6c9cSStella Laurenzo /// reference operation. If the reference operation is null, then appends to 849436c6c9cSStella Laurenzo /// the end of the block. 850436c6c9cSStella Laurenzo class PyInsertionPoint { 851436c6c9cSStella Laurenzo public: 852436c6c9cSStella Laurenzo /// Creates an insertion point positioned after the last operation in the 853436c6c9cSStella Laurenzo /// block, but still inside the block. 854436c6c9cSStella Laurenzo PyInsertionPoint(PyBlock &block); 855436c6c9cSStella Laurenzo /// Creates an insertion point positioned before a reference operation. 856436c6c9cSStella Laurenzo PyInsertionPoint(PyOperationBase &beforeOperationBase); 857436c6c9cSStella Laurenzo 858436c6c9cSStella Laurenzo /// Shortcut to create an insertion point at the beginning of the block. 859436c6c9cSStella Laurenzo static PyInsertionPoint atBlockBegin(PyBlock &block); 860436c6c9cSStella Laurenzo /// Shortcut to create an insertion point before the block terminator. 861436c6c9cSStella Laurenzo static PyInsertionPoint atBlockTerminator(PyBlock &block); 862436c6c9cSStella Laurenzo 863436c6c9cSStella Laurenzo /// Inserts an operation. 864436c6c9cSStella Laurenzo void insert(PyOperationBase &operationBase); 865436c6c9cSStella Laurenzo 866436c6c9cSStella Laurenzo /// Enter and exit the context manager. 867b56d1ec6SPeter Hawkins static nanobind::object contextEnter(nanobind::object insertionPoint); 868b56d1ec6SPeter Hawkins void contextExit(const nanobind::object &excType, 869b56d1ec6SPeter Hawkins const nanobind::object &excVal, 870b56d1ec6SPeter Hawkins const nanobind::object &excTb); 871436c6c9cSStella Laurenzo 872436c6c9cSStella Laurenzo PyBlock &getBlock() { return block; } 8735a600c23STomás Longeri std::optional<PyOperationRef> &getRefOperation() { return refOperation; } 874436c6c9cSStella Laurenzo 875436c6c9cSStella Laurenzo private: 876436c6c9cSStella Laurenzo // Trampoline constructor that avoids null initializing members while 877436c6c9cSStella Laurenzo // looking up parents. 8780a81ace0SKazu Hirata PyInsertionPoint(PyBlock block, std::optional<PyOperationRef> refOperation) 879436c6c9cSStella Laurenzo : refOperation(std::move(refOperation)), block(std::move(block)) {} 880436c6c9cSStella Laurenzo 8810a81ace0SKazu Hirata std::optional<PyOperationRef> refOperation; 882436c6c9cSStella Laurenzo PyBlock block; 883436c6c9cSStella Laurenzo }; 8842995d29bSAlex Zinenko /// Wrapper around the generic MlirType. 8852995d29bSAlex Zinenko /// The lifetime of a type is bound by the PyContext that created it. 8862995d29bSAlex Zinenko class PyType : public BaseContextObject { 8872995d29bSAlex Zinenko public: 8882995d29bSAlex Zinenko PyType(PyMlirContextRef contextRef, MlirType type) 8892995d29bSAlex Zinenko : BaseContextObject(std::move(contextRef)), type(type) {} 890e6d738e0SRahul Kayaith bool operator==(const PyType &other) const; 8912995d29bSAlex Zinenko operator MlirType() const { return type; } 8922995d29bSAlex Zinenko MlirType get() const { return type; } 8932995d29bSAlex Zinenko 8942995d29bSAlex Zinenko /// Gets a capsule wrapping the void* within the MlirType. 895b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 8962995d29bSAlex Zinenko 8972995d29bSAlex Zinenko /// Creates a PyType from the MlirType wrapped by a capsule. 8982995d29bSAlex Zinenko /// Note that PyType instances are uniqued, so the returned object 8992995d29bSAlex Zinenko /// may be a pre-existing object. Ownership of the underlying MlirType 9002995d29bSAlex Zinenko /// is taken by calling this function. 901b56d1ec6SPeter Hawkins static PyType createFromCapsule(nanobind::object capsule); 9022995d29bSAlex Zinenko 9032995d29bSAlex Zinenko private: 9042995d29bSAlex Zinenko MlirType type; 9052995d29bSAlex Zinenko }; 9062995d29bSAlex Zinenko 907d39a7844Smax /// A TypeID provides an efficient and unique identifier for a specific C++ 908d39a7844Smax /// type. This allows for a C++ type to be compared, hashed, and stored in an 909d39a7844Smax /// opaque context. This class wraps around the generic MlirTypeID. 910d39a7844Smax class PyTypeID { 911d39a7844Smax public: 912d39a7844Smax PyTypeID(MlirTypeID typeID) : typeID(typeID) {} 913d39a7844Smax // Note, this tests whether the underlying TypeIDs are the same, 914d39a7844Smax // not whether the wrapper MlirTypeIDs are the same, nor whether 915d39a7844Smax // the PyTypeID objects are the same (i.e., PyTypeID is a value type). 916d39a7844Smax bool operator==(const PyTypeID &other) const; 917d39a7844Smax operator MlirTypeID() const { return typeID; } 918d39a7844Smax MlirTypeID get() { return typeID; } 919d39a7844Smax 920d39a7844Smax /// Gets a capsule wrapping the void* within the MlirTypeID. 921b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 922d39a7844Smax 923d39a7844Smax /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. 924b56d1ec6SPeter Hawkins static PyTypeID createFromCapsule(nanobind::object capsule); 925d39a7844Smax 926d39a7844Smax private: 927d39a7844Smax MlirTypeID typeID; 928d39a7844Smax }; 929d39a7844Smax 9302995d29bSAlex Zinenko /// CRTP base classes for Python types that subclass Type and should be 9312995d29bSAlex Zinenko /// castable from it (i.e. via something like IntegerType(t)). 9322995d29bSAlex Zinenko /// By default, type class hierarchies are one level deep (i.e. a 9332995d29bSAlex Zinenko /// concrete type class extends PyType); however, intermediate python-visible 9342995d29bSAlex Zinenko /// base classes can be modeled by specifying a BaseTy. 9352995d29bSAlex Zinenko template <typename DerivedTy, typename BaseTy = PyType> 9362995d29bSAlex Zinenko class PyConcreteType : public BaseTy { 9372995d29bSAlex Zinenko public: 9382995d29bSAlex Zinenko // Derived classes must define statics for: 9392995d29bSAlex Zinenko // IsAFunctionTy isaFunction 9402995d29bSAlex Zinenko // const char *pyClassName 941b56d1ec6SPeter Hawkins using ClassTy = nanobind::class_<DerivedTy, BaseTy>; 9422995d29bSAlex Zinenko using IsAFunctionTy = bool (*)(MlirType); 943d39a7844Smax using GetTypeIDFunctionTy = MlirTypeID (*)(); 944d39a7844Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; 9452995d29bSAlex Zinenko 9462995d29bSAlex Zinenko PyConcreteType() = default; 9472995d29bSAlex Zinenko PyConcreteType(PyMlirContextRef contextRef, MlirType t) 948bfb1ba75Smax : BaseTy(std::move(contextRef), t) {} 9492995d29bSAlex Zinenko PyConcreteType(PyType &orig) 9502995d29bSAlex Zinenko : PyConcreteType(orig.getContext(), castFrom(orig)) {} 9512995d29bSAlex Zinenko 9522995d29bSAlex Zinenko static MlirType castFrom(PyType &orig) { 9532995d29bSAlex Zinenko if (!DerivedTy::isaFunction(orig)) { 954b56d1ec6SPeter Hawkins auto origRepr = 955b56d1ec6SPeter Hawkins nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); 956b56d1ec6SPeter Hawkins throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + 957b56d1ec6SPeter Hawkins DerivedTy::pyClassName + " (from " + 958b56d1ec6SPeter Hawkins origRepr + ")") 959b56d1ec6SPeter Hawkins .str() 960b56d1ec6SPeter Hawkins .c_str()); 9612995d29bSAlex Zinenko } 9622995d29bSAlex Zinenko return orig; 9632995d29bSAlex Zinenko } 9642995d29bSAlex Zinenko 965b56d1ec6SPeter Hawkins static void bind(nanobind::module_ &m) { 966b56d1ec6SPeter Hawkins auto cls = ClassTy(m, DerivedTy::pyClassName); 967b56d1ec6SPeter Hawkins cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(), 968b56d1ec6SPeter Hawkins nanobind::arg("cast_from_type")); 969a6e7d024SStella Laurenzo cls.def_static( 970a6e7d024SStella Laurenzo "isinstance", 971a6e7d024SStella Laurenzo [](PyType &otherType) -> bool { 9722995d29bSAlex Zinenko return DerivedTy::isaFunction(otherType); 973a6e7d024SStella Laurenzo }, 974b56d1ec6SPeter Hawkins nanobind::arg("other")); 975b56d1ec6SPeter Hawkins cls.def_prop_ro_static( 976b56d1ec6SPeter Hawkins "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { 977d39a7844Smax if (DerivedTy::getTypeIdFunction) 978d39a7844Smax return DerivedTy::getTypeIdFunction(); 979b56d1ec6SPeter Hawkins throw nanobind::attribute_error( 980b56d1ec6SPeter Hawkins (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) 981b56d1ec6SPeter Hawkins .str() 982b56d1ec6SPeter Hawkins .c_str()); 983d39a7844Smax }); 984b56d1ec6SPeter Hawkins cls.def_prop_ro("typeid", [](PyType &self) { 985b56d1ec6SPeter Hawkins return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid")); 986d39a7844Smax }); 987d39a7844Smax cls.def("__repr__", [](DerivedTy &self) { 988d39a7844Smax PyPrintAccumulator printAccum; 989d39a7844Smax printAccum.parts.append(DerivedTy::pyClassName); 990d39a7844Smax printAccum.parts.append("("); 991d39a7844Smax mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 992d39a7844Smax printAccum.parts.append(")"); 993d39a7844Smax return printAccum.join(); 994d39a7844Smax }); 995d39a7844Smax 996bfb1ba75Smax if (DerivedTy::getTypeIdFunction) { 997bfb1ba75Smax PyGlobals::get().registerTypeCaster( 998bfb1ba75Smax DerivedTy::getTypeIdFunction(), 999b56d1ec6SPeter Hawkins nanobind::cast<nanobind::callable>(nanobind::cpp_function( 1000b56d1ec6SPeter Hawkins [](PyType pyType) -> DerivedTy { return pyType; }))); 1001bfb1ba75Smax } 1002bfb1ba75Smax 10032995d29bSAlex Zinenko DerivedTy::bindDerived(cls); 10042995d29bSAlex Zinenko } 10052995d29bSAlex Zinenko 10062995d29bSAlex Zinenko /// Implemented by derived classes to add methods to the Python subclass. 10072995d29bSAlex Zinenko static void bindDerived(ClassTy &m) {} 10082995d29bSAlex Zinenko }; 1009436c6c9cSStella Laurenzo 1010436c6c9cSStella Laurenzo /// Wrapper around the generic MlirAttribute. 1011436c6c9cSStella Laurenzo /// The lifetime of a type is bound by the PyContext that created it. 1012436c6c9cSStella Laurenzo class PyAttribute : public BaseContextObject { 1013436c6c9cSStella Laurenzo public: 1014436c6c9cSStella Laurenzo PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 1015436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), attr(attr) {} 1016e6d738e0SRahul Kayaith bool operator==(const PyAttribute &other) const; 1017436c6c9cSStella Laurenzo operator MlirAttribute() const { return attr; } 1018436c6c9cSStella Laurenzo MlirAttribute get() const { return attr; } 1019436c6c9cSStella Laurenzo 1020436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirAttribute. 1021b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 1022436c6c9cSStella Laurenzo 1023436c6c9cSStella Laurenzo /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 1024436c6c9cSStella Laurenzo /// Note that PyAttribute instances are uniqued, so the returned object 1025436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirAttribute 1026436c6c9cSStella Laurenzo /// is taken by calling this function. 1027b56d1ec6SPeter Hawkins static PyAttribute createFromCapsule(nanobind::object capsule); 1028436c6c9cSStella Laurenzo 1029436c6c9cSStella Laurenzo private: 1030436c6c9cSStella Laurenzo MlirAttribute attr; 1031436c6c9cSStella Laurenzo }; 1032436c6c9cSStella Laurenzo 1033436c6c9cSStella Laurenzo /// Represents a Python MlirNamedAttr, carrying an optional owned name. 1034436c6c9cSStella Laurenzo /// TODO: Refactor this and the C-API to be based on an Identifier owned 1035436c6c9cSStella Laurenzo /// by the context so as to avoid ownership issues here. 1036436c6c9cSStella Laurenzo class PyNamedAttribute { 1037436c6c9cSStella Laurenzo public: 1038436c6c9cSStella Laurenzo /// Constructs a PyNamedAttr that retains an owned name. This should be 1039436c6c9cSStella Laurenzo /// used in any code that originates an MlirNamedAttribute from a python 1040436c6c9cSStella Laurenzo /// string. 1041436c6c9cSStella Laurenzo /// The lifetime of the PyNamedAttr must extend to the lifetime of the 1042436c6c9cSStella Laurenzo /// passed attribute. 1043436c6c9cSStella Laurenzo PyNamedAttribute(MlirAttribute attr, std::string ownedName); 1044436c6c9cSStella Laurenzo 1045436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr; 1046436c6c9cSStella Laurenzo 1047436c6c9cSStella Laurenzo private: 1048436c6c9cSStella Laurenzo // Since the MlirNamedAttr contains an internal pointer to the actual 1049436c6c9cSStella Laurenzo // memory of the owned string, it must be heap allocated to remain valid. 1050436c6c9cSStella Laurenzo // Otherwise, strings that fit within the small object optimization threshold 1051436c6c9cSStella Laurenzo // will have their memory address change as the containing object is moved, 1052436c6c9cSStella Laurenzo // resulting in an invalid aliased pointer. 1053436c6c9cSStella Laurenzo std::unique_ptr<std::string> ownedName; 1054436c6c9cSStella Laurenzo }; 1055436c6c9cSStella Laurenzo 10560b10fdedSAlex Zinenko /// CRTP base classes for Python attributes that subclass Attribute and should 10570b10fdedSAlex Zinenko /// be castable from it (i.e. via something like StringAttr(attr)). 10580b10fdedSAlex Zinenko /// By default, attribute class hierarchies are one level deep (i.e. a 10590b10fdedSAlex Zinenko /// concrete attribute class extends PyAttribute); however, intermediate 10600b10fdedSAlex Zinenko /// python-visible base classes can be modeled by specifying a BaseTy. 10610b10fdedSAlex Zinenko template <typename DerivedTy, typename BaseTy = PyAttribute> 10620b10fdedSAlex Zinenko class PyConcreteAttribute : public BaseTy { 10630b10fdedSAlex Zinenko public: 10640b10fdedSAlex Zinenko // Derived classes must define statics for: 10650b10fdedSAlex Zinenko // IsAFunctionTy isaFunction 10660b10fdedSAlex Zinenko // const char *pyClassName 1067b56d1ec6SPeter Hawkins using ClassTy = nanobind::class_<DerivedTy, BaseTy>; 10680b10fdedSAlex Zinenko using IsAFunctionTy = bool (*)(MlirAttribute); 10699566ee28Smax using GetTypeIDFunctionTy = MlirTypeID (*)(); 10709566ee28Smax static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; 10710b10fdedSAlex Zinenko 10720b10fdedSAlex Zinenko PyConcreteAttribute() = default; 10730b10fdedSAlex Zinenko PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 10740b10fdedSAlex Zinenko : BaseTy(std::move(contextRef), attr) {} 10750b10fdedSAlex Zinenko PyConcreteAttribute(PyAttribute &orig) 10760b10fdedSAlex Zinenko : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 10770b10fdedSAlex Zinenko 10780b10fdedSAlex Zinenko static MlirAttribute castFrom(PyAttribute &orig) { 10790b10fdedSAlex Zinenko if (!DerivedTy::isaFunction(orig)) { 1080b56d1ec6SPeter Hawkins auto origRepr = 1081b56d1ec6SPeter Hawkins nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); 1082b56d1ec6SPeter Hawkins throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + 1083b56d1ec6SPeter Hawkins DerivedTy::pyClassName + " (from " + 1084b56d1ec6SPeter Hawkins origRepr + ")") 1085b56d1ec6SPeter Hawkins .str() 1086b56d1ec6SPeter Hawkins .c_str()); 10870b10fdedSAlex Zinenko } 10880b10fdedSAlex Zinenko return orig; 10890b10fdedSAlex Zinenko } 10900b10fdedSAlex Zinenko 1091b56d1ec6SPeter Hawkins static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { 1092b56d1ec6SPeter Hawkins ClassTy cls; 1093b56d1ec6SPeter Hawkins if (slots) { 1094b56d1ec6SPeter Hawkins cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); 1095b56d1ec6SPeter Hawkins } else { 1096b56d1ec6SPeter Hawkins cls = ClassTy(m, DerivedTy::pyClassName); 1097b56d1ec6SPeter Hawkins } 1098b56d1ec6SPeter Hawkins cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(), 1099b56d1ec6SPeter Hawkins nanobind::arg("cast_from_attr")); 1100a6e7d024SStella Laurenzo cls.def_static( 1101a6e7d024SStella Laurenzo "isinstance", 1102a6e7d024SStella Laurenzo [](PyAttribute &otherAttr) -> bool { 110378f2dae0SAlex Zinenko return DerivedTy::isaFunction(otherAttr); 1104a6e7d024SStella Laurenzo }, 1105b56d1ec6SPeter Hawkins nanobind::arg("other")); 1106b56d1ec6SPeter Hawkins cls.def_prop_ro( 1107bfb1ba75Smax "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); 1108b56d1ec6SPeter Hawkins cls.def_prop_ro_static( 1109b56d1ec6SPeter Hawkins "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { 11109566ee28Smax if (DerivedTy::getTypeIdFunction) 11119566ee28Smax return DerivedTy::getTypeIdFunction(); 1112b56d1ec6SPeter Hawkins throw nanobind::attribute_error( 1113b56d1ec6SPeter Hawkins (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) 1114b56d1ec6SPeter Hawkins .str() 1115b56d1ec6SPeter Hawkins .c_str()); 11169566ee28Smax }); 1117b56d1ec6SPeter Hawkins cls.def_prop_ro("typeid", [](PyAttribute &self) { 1118b56d1ec6SPeter Hawkins return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid")); 11199566ee28Smax }); 11209566ee28Smax cls.def("__repr__", [](DerivedTy &self) { 11219566ee28Smax PyPrintAccumulator printAccum; 11229566ee28Smax printAccum.parts.append(DerivedTy::pyClassName); 11239566ee28Smax printAccum.parts.append("("); 11249566ee28Smax mlirAttributePrint(self, printAccum.getCallback(), 11259566ee28Smax printAccum.getUserData()); 11269566ee28Smax printAccum.parts.append(")"); 11279566ee28Smax return printAccum.join(); 11289566ee28Smax }); 11299566ee28Smax 11309566ee28Smax if (DerivedTy::getTypeIdFunction) { 11319566ee28Smax PyGlobals::get().registerTypeCaster( 11329566ee28Smax DerivedTy::getTypeIdFunction(), 1133b56d1ec6SPeter Hawkins nanobind::cast<nanobind::callable>( 1134b56d1ec6SPeter Hawkins nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { 11359566ee28Smax return pyAttribute; 1136b56d1ec6SPeter Hawkins }))); 11379566ee28Smax } 11389566ee28Smax 113932e2fec7SJohn Demme DerivedTy::bindDerived(cls); 114032e2fec7SJohn Demme } 114132e2fec7SJohn Demme 114232e2fec7SJohn Demme /// Implemented by derived classes to add methods to the Python subclass. 114332e2fec7SJohn Demme static void bindDerived(ClassTy &m) {} 114432e2fec7SJohn Demme }; 114532e2fec7SJohn Demme 1146436c6c9cSStella Laurenzo /// Wrapper around the generic MlirValue. 1147436c6c9cSStella Laurenzo /// Values are managed completely by the operation that resulted in their 1148436c6c9cSStella Laurenzo /// definition. For op result value, this is the operation that defines the 1149436c6c9cSStella Laurenzo /// value. For block argument values, this is the operation that contains the 1150436c6c9cSStella Laurenzo /// block to which the value is an argument (blocks cannot be detached in Python 1151436c6c9cSStella Laurenzo /// bindings so such operation always exists). 1152436c6c9cSStella Laurenzo class PyValue { 1153436c6c9cSStella Laurenzo public: 11547c850867SMaksim Levental // The virtual here is "load bearing" in that it enables RTTI 11557c850867SMaksim Levental // for PyConcreteValue CRTP classes that support maybeDownCast. 11567c850867SMaksim Levental // See PyValue::maybeDownCast. 11577c850867SMaksim Levental virtual ~PyValue() = default; 1158436c6c9cSStella Laurenzo PyValue(PyOperationRef parentOperation, MlirValue value) 1159e8d07395SMehdi Amini : parentOperation(std::move(parentOperation)), value(value) {} 116078f2dae0SAlex Zinenko operator MlirValue() const { return value; } 1161436c6c9cSStella Laurenzo 1162436c6c9cSStella Laurenzo MlirValue get() { return value; } 1163436c6c9cSStella Laurenzo PyOperationRef &getParentOperation() { return parentOperation; } 1164436c6c9cSStella Laurenzo 1165436c6c9cSStella Laurenzo void checkValid() { return parentOperation->checkValid(); } 1166436c6c9cSStella Laurenzo 11673f3d1c90SMike Urbach /// Gets a capsule wrapping the void* within the MlirValue. 1168b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 11693f3d1c90SMike Urbach 1170b56d1ec6SPeter Hawkins nanobind::object maybeDownCast(); 11717c850867SMaksim Levental 11723f3d1c90SMike Urbach /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 11733f3d1c90SMike Urbach /// the underlying MlirValue is still tied to the owning operation. 1174b56d1ec6SPeter Hawkins static PyValue createFromCapsule(nanobind::object capsule); 11753f3d1c90SMike Urbach 1176436c6c9cSStella Laurenzo private: 1177436c6c9cSStella Laurenzo PyOperationRef parentOperation; 1178436c6c9cSStella Laurenzo MlirValue value; 1179436c6c9cSStella Laurenzo }; 1180436c6c9cSStella Laurenzo 1181436c6c9cSStella Laurenzo /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 1182436c6c9cSStella Laurenzo class PyAffineExpr : public BaseContextObject { 1183436c6c9cSStella Laurenzo public: 1184436c6c9cSStella Laurenzo PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 1185436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 1186e6d738e0SRahul Kayaith bool operator==(const PyAffineExpr &other) const; 1187436c6c9cSStella Laurenzo operator MlirAffineExpr() const { return affineExpr; } 1188436c6c9cSStella Laurenzo MlirAffineExpr get() const { return affineExpr; } 1189436c6c9cSStella Laurenzo 1190436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirAffineExpr. 1191b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 1192436c6c9cSStella Laurenzo 1193436c6c9cSStella Laurenzo /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 1194436c6c9cSStella Laurenzo /// Note that PyAffineExpr instances are uniqued, so the returned object 1195436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 1196436c6c9cSStella Laurenzo /// is taken by calling this function. 1197b56d1ec6SPeter Hawkins static PyAffineExpr createFromCapsule(nanobind::object capsule); 1198436c6c9cSStella Laurenzo 1199436c6c9cSStella Laurenzo PyAffineExpr add(const PyAffineExpr &other) const; 1200436c6c9cSStella Laurenzo PyAffineExpr mul(const PyAffineExpr &other) const; 1201436c6c9cSStella Laurenzo PyAffineExpr floorDiv(const PyAffineExpr &other) const; 1202436c6c9cSStella Laurenzo PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 1203436c6c9cSStella Laurenzo PyAffineExpr mod(const PyAffineExpr &other) const; 1204436c6c9cSStella Laurenzo 1205436c6c9cSStella Laurenzo private: 1206436c6c9cSStella Laurenzo MlirAffineExpr affineExpr; 1207436c6c9cSStella Laurenzo }; 1208436c6c9cSStella Laurenzo 1209436c6c9cSStella Laurenzo class PyAffineMap : public BaseContextObject { 1210436c6c9cSStella Laurenzo public: 1211436c6c9cSStella Laurenzo PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 1212436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 1213e6d738e0SRahul Kayaith bool operator==(const PyAffineMap &other) const; 1214436c6c9cSStella Laurenzo operator MlirAffineMap() const { return affineMap; } 1215436c6c9cSStella Laurenzo MlirAffineMap get() const { return affineMap; } 1216436c6c9cSStella Laurenzo 1217436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirAffineMap. 1218b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 1219436c6c9cSStella Laurenzo 1220436c6c9cSStella Laurenzo /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 1221436c6c9cSStella Laurenzo /// Note that PyAffineMap instances are uniqued, so the returned object 1222436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 1223436c6c9cSStella Laurenzo /// is taken by calling this function. 1224b56d1ec6SPeter Hawkins static PyAffineMap createFromCapsule(nanobind::object capsule); 1225436c6c9cSStella Laurenzo 1226436c6c9cSStella Laurenzo private: 1227436c6c9cSStella Laurenzo MlirAffineMap affineMap; 1228436c6c9cSStella Laurenzo }; 1229436c6c9cSStella Laurenzo 1230436c6c9cSStella Laurenzo class PyIntegerSet : public BaseContextObject { 1231436c6c9cSStella Laurenzo public: 1232436c6c9cSStella Laurenzo PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 1233436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 1234e6d738e0SRahul Kayaith bool operator==(const PyIntegerSet &other) const; 1235436c6c9cSStella Laurenzo operator MlirIntegerSet() const { return integerSet; } 1236436c6c9cSStella Laurenzo MlirIntegerSet get() const { return integerSet; } 1237436c6c9cSStella Laurenzo 1238436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirIntegerSet. 1239b56d1ec6SPeter Hawkins nanobind::object getCapsule(); 1240436c6c9cSStella Laurenzo 1241436c6c9cSStella Laurenzo /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 1242436c6c9cSStella Laurenzo /// Note that PyIntegerSet instances may be uniqued, so the returned object 1243436c6c9cSStella Laurenzo /// may be a pre-existing object. Integer sets are owned by the context. 1244b56d1ec6SPeter Hawkins static PyIntegerSet createFromCapsule(nanobind::object capsule); 1245436c6c9cSStella Laurenzo 1246436c6c9cSStella Laurenzo private: 1247436c6c9cSStella Laurenzo MlirIntegerSet integerSet; 1248436c6c9cSStella Laurenzo }; 1249436c6c9cSStella Laurenzo 125030d61893SAlex Zinenko /// Bindings for MLIR symbol tables. 125130d61893SAlex Zinenko class PySymbolTable { 125230d61893SAlex Zinenko public: 125330d61893SAlex Zinenko /// Constructs a symbol table for the given operation. 125430d61893SAlex Zinenko explicit PySymbolTable(PyOperationBase &operation); 125530d61893SAlex Zinenko 125630d61893SAlex Zinenko /// Destroys the symbol table. 125730d61893SAlex Zinenko ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 125830d61893SAlex Zinenko 125930d61893SAlex Zinenko /// Returns the symbol (opview) with the given name, throws if there is no 126030d61893SAlex Zinenko /// such symbol in the table. 1261b56d1ec6SPeter Hawkins nanobind::object dunderGetItem(const std::string &name); 126230d61893SAlex Zinenko 126330d61893SAlex Zinenko /// Removes the given operation from the symbol table and erases it. 126430d61893SAlex Zinenko void erase(PyOperationBase &symbol); 126530d61893SAlex Zinenko 126630d61893SAlex Zinenko /// Removes the operation with the given name from the symbol table and erases 126730d61893SAlex Zinenko /// it, throws if there is no such symbol in the table. 126830d61893SAlex Zinenko void dunderDel(const std::string &name); 126930d61893SAlex Zinenko 127030d61893SAlex Zinenko /// Inserts the given operation into the symbol table. The operation must have 127130d61893SAlex Zinenko /// the symbol trait. 1272974c1596SRahul Kayaith MlirAttribute insert(PyOperationBase &symbol); 127330d61893SAlex Zinenko 1274bdc31837SStella Laurenzo /// Gets and sets the name of a symbol op. 1275974c1596SRahul Kayaith static MlirAttribute getSymbolName(PyOperationBase &symbol); 1276bdc31837SStella Laurenzo static void setSymbolName(PyOperationBase &symbol, const std::string &name); 1277bdc31837SStella Laurenzo 1278bdc31837SStella Laurenzo /// Gets and sets the visibility of a symbol op. 1279974c1596SRahul Kayaith static MlirAttribute getVisibility(PyOperationBase &symbol); 1280bdc31837SStella Laurenzo static void setVisibility(PyOperationBase &symbol, 1281bdc31837SStella Laurenzo const std::string &visibility); 1282bdc31837SStella Laurenzo 1283bdc31837SStella Laurenzo /// Replaces all symbol uses within an operation. See the API 1284bdc31837SStella Laurenzo /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 1285bdc31837SStella Laurenzo static void replaceAllSymbolUses(const std::string &oldSymbol, 1286bdc31837SStella Laurenzo const std::string &newSymbol, 1287bdc31837SStella Laurenzo PyOperationBase &from); 1288bdc31837SStella Laurenzo 1289bdc31837SStella Laurenzo /// Walks all symbol tables under and including 'from'. 1290bdc31837SStella Laurenzo static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 1291b56d1ec6SPeter Hawkins nanobind::object callback); 1292bdc31837SStella Laurenzo 129330d61893SAlex Zinenko /// Casts the bindings class into the C API structure. 129430d61893SAlex Zinenko operator MlirSymbolTable() { return symbolTable; } 129530d61893SAlex Zinenko 129630d61893SAlex Zinenko private: 129730d61893SAlex Zinenko PyOperationRef operation; 129830d61893SAlex Zinenko MlirSymbolTable symbolTable; 129930d61893SAlex Zinenko }; 130030d61893SAlex Zinenko 13013ea4c501SRahul Kayaith /// Custom exception that allows access to error diagnostic information. This is 13023ea4c501SRahul Kayaith /// converted to the `ir.MLIRError` python exception when thrown. 13033ea4c501SRahul Kayaith struct MLIRError { 13043ea4c501SRahul Kayaith MLIRError(llvm::Twine message, 13053ea4c501SRahul Kayaith std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {}) 13063ea4c501SRahul Kayaith : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} 13073ea4c501SRahul Kayaith std::string message; 13083ea4c501SRahul Kayaith std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics; 13093ea4c501SRahul Kayaith }; 13103ea4c501SRahul Kayaith 1311b56d1ec6SPeter Hawkins void populateIRAffine(nanobind::module_ &m); 1312b56d1ec6SPeter Hawkins void populateIRAttributes(nanobind::module_ &m); 1313b56d1ec6SPeter Hawkins void populateIRCore(nanobind::module_ &m); 1314b56d1ec6SPeter Hawkins void populateIRInterfaces(nanobind::module_ &m); 1315b56d1ec6SPeter Hawkins void populateIRTypes(nanobind::module_ &m); 1316436c6c9cSStella Laurenzo 1317436c6c9cSStella Laurenzo } // namespace python 1318436c6c9cSStella Laurenzo } // namespace mlir 1319436c6c9cSStella Laurenzo 1320b56d1ec6SPeter Hawkins namespace nanobind { 1321436c6c9cSStella Laurenzo namespace detail { 1322436c6c9cSStella Laurenzo 1323436c6c9cSStella Laurenzo template <> 1324436c6c9cSStella Laurenzo struct type_caster<mlir::python::DefaultingPyMlirContext> 1325436c6c9cSStella Laurenzo : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 1326436c6c9cSStella Laurenzo template <> 1327436c6c9cSStella Laurenzo struct type_caster<mlir::python::DefaultingPyLocation> 1328436c6c9cSStella Laurenzo : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 1329436c6c9cSStella Laurenzo 1330436c6c9cSStella Laurenzo } // namespace detail 1331b56d1ec6SPeter Hawkins } // namespace nanobind 1332436c6c9cSStella Laurenzo 1333436c6c9cSStella Laurenzo #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 1334