1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 8 //===----------------------------------------------------------------------===// 9 10 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 11 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 12 13 #include <optional> 14 #include <utility> 15 #include <vector> 16 17 #include "Globals.h" 18 #include "NanobindUtils.h" 19 #include "mlir-c/AffineExpr.h" 20 #include "mlir-c/AffineMap.h" 21 #include "mlir-c/Diagnostics.h" 22 #include "mlir-c/IR.h" 23 #include "mlir-c/IntegerSet.h" 24 #include "mlir-c/Transforms.h" 25 #include "mlir/Bindings/Python/NanobindAdaptors.h" 26 #include "mlir/Bindings/Python/Nanobind.h" 27 #include "llvm/ADT/DenseMap.h" 28 29 namespace mlir { 30 namespace python { 31 32 class PyBlock; 33 class PyDiagnostic; 34 class PyDiagnosticHandler; 35 class PyInsertionPoint; 36 class PyLocation; 37 class DefaultingPyLocation; 38 class PyMlirContext; 39 class DefaultingPyMlirContext; 40 class PyModule; 41 class PyOperation; 42 class PyOperationBase; 43 class PyType; 44 class PySymbolTable; 45 class PyValue; 46 47 /// Template for a reference to a concrete type which captures a python 48 /// reference to its underlying python object. 49 template <typename T> 50 class PyObjectRef { 51 public: 52 PyObjectRef(T *referrent, nanobind::object object) 53 : referrent(referrent), object(std::move(object)) { 54 assert(this->referrent && 55 "cannot construct PyObjectRef with null referrent"); 56 assert(this->object && "cannot construct PyObjectRef with null object"); 57 } 58 PyObjectRef(PyObjectRef &&other) noexcept 59 : referrent(other.referrent), object(std::move(other.object)) { 60 other.referrent = nullptr; 61 assert(!other.object); 62 } 63 PyObjectRef(const PyObjectRef &other) 64 : referrent(other.referrent), object(other.object /* copies */) {} 65 ~PyObjectRef() = default; 66 67 int getRefCount() { 68 if (!object) 69 return 0; 70 return Py_REFCNT(object.ptr()); 71 } 72 73 /// Releases the object held by this instance, returning it. 74 /// This is the proper thing to return from a function that wants to return 75 /// the reference. Note that this does not work from initializers. 76 nanobind::object releaseObject() { 77 assert(referrent && object); 78 referrent = nullptr; 79 auto stolen = std::move(object); 80 return stolen; 81 } 82 83 T *get() { return referrent; } 84 T *operator->() { 85 assert(referrent && object); 86 return referrent; 87 } 88 nanobind::object getObject() { 89 assert(referrent && object); 90 return object; 91 } 92 operator bool() const { return referrent && object; } 93 94 private: 95 T *referrent; 96 nanobind::object object; 97 }; 98 99 /// Tracks an entry in the thread context stack. New entries are pushed onto 100 /// here for each with block that activates a new InsertionPoint, Context or 101 /// Location. 102 /// 103 /// Pushing either a Location or InsertionPoint also pushes its associated 104 /// Context. Pushing a Context will not modify the Location or InsertionPoint 105 /// unless if they are from a different context, in which case, they are 106 /// cleared. 107 class PyThreadContextEntry { 108 public: 109 enum class FrameKind { 110 Context, 111 InsertionPoint, 112 Location, 113 }; 114 115 PyThreadContextEntry(FrameKind frameKind, nanobind::object context, 116 nanobind::object insertionPoint, 117 nanobind::object location) 118 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 119 location(std::move(location)), frameKind(frameKind) {} 120 121 /// Gets the top of stack context and return nullptr if not defined. 122 static PyMlirContext *getDefaultContext(); 123 124 /// Gets the top of stack insertion point and return nullptr if not defined. 125 static PyInsertionPoint *getDefaultInsertionPoint(); 126 127 /// Gets the top of stack location and returns nullptr if not defined. 128 static PyLocation *getDefaultLocation(); 129 130 PyMlirContext *getContext(); 131 PyInsertionPoint *getInsertionPoint(); 132 PyLocation *getLocation(); 133 FrameKind getFrameKind() { return frameKind; } 134 135 /// Stack management. 136 static PyThreadContextEntry *getTopOfStack(); 137 static nanobind::object pushContext(nanobind::object context); 138 static void popContext(PyMlirContext &context); 139 static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); 140 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 141 static nanobind::object pushLocation(nanobind::object location); 142 static void popLocation(PyLocation &location); 143 144 /// Gets the thread local stack. 145 static std::vector<PyThreadContextEntry> &getStack(); 146 147 private: 148 static void push(FrameKind frameKind, nanobind::object context, 149 nanobind::object insertionPoint, nanobind::object location); 150 151 /// An object reference to the PyContext. 152 nanobind::object context; 153 /// An object reference to the current insertion point. 154 nanobind::object insertionPoint; 155 /// An object reference to the current location. 156 nanobind::object location; 157 // The kind of push that was performed. 158 FrameKind frameKind; 159 }; 160 161 /// Wrapper around MlirContext. 162 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 163 class PyMlirContext { 164 public: 165 PyMlirContext() = delete; 166 PyMlirContext(MlirContext context); 167 PyMlirContext(const PyMlirContext &) = delete; 168 PyMlirContext(PyMlirContext &&) = delete; 169 170 /// For the case of a python __init__ (nanobind::init) method, pybind11 is 171 /// quite strict about needing to return a pointer that is not yet associated 172 /// to an nanobind::object. Since the forContext() method acts like a pool, 173 /// possibly returning a recycled context, it does not satisfy this need. The 174 /// usual way in python to accomplish such a thing is to override __new__, but 175 /// that is also not supported by pybind11. Instead, we use this entry 176 /// point which always constructs a fresh context (which cannot alias an 177 /// existing one because it is fresh). 178 static PyMlirContext *createNewContextForInit(); 179 180 /// Returns a context reference for the singleton PyMlirContext wrapper for 181 /// the given context. 182 static PyMlirContextRef forContext(MlirContext context); 183 ~PyMlirContext(); 184 185 /// Accesses the underlying MlirContext. 186 MlirContext get() { return context; } 187 188 /// Gets a strong reference to this context, which will ensure it is kept 189 /// alive for the life of the reference. 190 PyMlirContextRef getRef() { 191 return PyMlirContextRef(this, nanobind::cast(this)); 192 } 193 194 /// Gets a capsule wrapping the void* within the MlirContext. 195 nanobind::object getCapsule(); 196 197 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 198 /// Note that PyMlirContext instances are uniqued, so the returned object 199 /// may be a pre-existing object. Ownership of the underlying MlirContext 200 /// is taken by calling this function. 201 static nanobind::object createFromCapsule(nanobind::object capsule); 202 203 /// Gets the count of live context objects. Used for testing. 204 static size_t getLiveCount(); 205 206 /// Get a list of Python objects which are still in the live context map. 207 std::vector<PyOperation *> getLiveOperationObjects(); 208 209 /// Gets the count of live operations associated with this context. 210 /// Used for testing. 211 size_t getLiveOperationCount(); 212 213 /// Clears the live operations map, returning the number of entries which were 214 /// invalidated. To be used as a safety mechanism so that API end-users can't 215 /// corrupt by holding references they shouldn't have accessed in the first 216 /// place. 217 size_t clearLiveOperations(); 218 219 /// Removes an operation from the live operations map and sets it invalid. 220 /// This is useful for when some non-bindings code destroys the operation and 221 /// the bindings need to made aware. For example, in the case when pass 222 /// manager is run. 223 /// 224 /// Note that this does *NOT* clear the nested operations. 225 void clearOperation(MlirOperation op); 226 227 /// Clears all operations nested inside the given op using 228 /// `clearOperation(MlirOperation)`. 229 void clearOperationsInside(PyOperationBase &op); 230 void clearOperationsInside(MlirOperation op); 231 232 /// Clears the operaiton _and_ all operations inside using 233 /// `clearOperation(MlirOperation)`. 234 void clearOperationAndInside(PyOperationBase &op); 235 236 /// Gets the count of live modules associated with this context. 237 /// Used for testing. 238 size_t getLiveModuleCount(); 239 240 /// Enter and exit the context manager. 241 static nanobind::object contextEnter(nanobind::object context); 242 void contextExit(const nanobind::object &excType, 243 const nanobind::object &excVal, 244 const nanobind::object &excTb); 245 246 /// Attaches a Python callback as a diagnostic handler, returning a 247 /// registration object (internally a PyDiagnosticHandler). 248 nanobind::object attachDiagnosticHandler(nanobind::object callback); 249 250 /// Controls whether error diagnostics should be propagated to diagnostic 251 /// handlers, instead of being captured by `ErrorCapture`. 252 void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } 253 struct ErrorCapture; 254 255 private: 256 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 257 // preserving the relationship that an MlirContext maps to a single 258 // PyMlirContext wrapper. This could be replaced in the future with an 259 // extension mechanism on the MlirContext for stashing user pointers. 260 // Note that this holds a handle, which does not imply ownership. 261 // Mappings will be removed when the context is destructed. 262 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 263 static nanobind::ft_mutex live_contexts_mutex; 264 static LiveContextMap &getLiveContexts(); 265 266 // Interns all live modules associated with this context. Modules tracked 267 // in this map are valid. When a module is invalidated, it is removed 268 // from this map, and while it still exists as an instance, any 269 // attempt to access it will raise an error. 270 using LiveModuleMap = 271 llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>; 272 LiveModuleMap liveModules; 273 274 // Interns all live operations associated with this context. Operations 275 // tracked in this map are valid. When an operation is invalidated, it is 276 // removed from this map, and while it still exists as an instance, any 277 // attempt to access it will raise an error. 278 using LiveOperationMap = 279 llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>; 280 nanobind::ft_mutex liveOperationsMutex; 281 282 // Guarded by liveOperationsMutex in free-threading mode. 283 LiveOperationMap liveOperations; 284 285 bool emitErrorDiagnostics = false; 286 287 MlirContext context; 288 friend class PyModule; 289 friend class PyOperation; 290 }; 291 292 /// Used in function arguments when None should resolve to the current context 293 /// manager set instance. 294 class DefaultingPyMlirContext 295 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 296 public: 297 using Defaulting::Defaulting; 298 static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 299 static PyMlirContext &resolve(); 300 }; 301 302 /// Base class for all objects that directly or indirectly depend on an 303 /// MlirContext. The lifetime of the context will extend at least to the 304 /// lifetime of these instances. 305 /// Immutable objects that depend on a context extend this directly. 306 class BaseContextObject { 307 public: 308 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 309 assert(this->contextRef && 310 "context object constructed with null context ref"); 311 } 312 313 /// Accesses the context reference. 314 PyMlirContextRef &getContext() { return contextRef; } 315 316 private: 317 PyMlirContextRef contextRef; 318 }; 319 320 /// Wrapper around an MlirLocation. 321 class PyLocation : public BaseContextObject { 322 public: 323 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 324 : BaseContextObject(std::move(contextRef)), loc(loc) {} 325 326 operator MlirLocation() const { return loc; } 327 MlirLocation get() const { return loc; } 328 329 /// Enter and exit the context manager. 330 static nanobind::object contextEnter(nanobind::object location); 331 void contextExit(const nanobind::object &excType, 332 const nanobind::object &excVal, 333 const nanobind::object &excTb); 334 335 /// Gets a capsule wrapping the void* within the MlirLocation. 336 nanobind::object getCapsule(); 337 338 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 339 /// Note that PyLocation instances are uniqued, so the returned object 340 /// may be a pre-existing object. Ownership of the underlying MlirLocation 341 /// is taken by calling this function. 342 static PyLocation createFromCapsule(nanobind::object capsule); 343 344 private: 345 MlirLocation loc; 346 }; 347 348 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs 349 /// are only valid for the duration of a diagnostic callback and attempting 350 /// to access them outside of that will raise an exception. This applies to 351 /// nested diagnostics (in the notes) as well. 352 class PyDiagnostic { 353 public: 354 PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} 355 void invalidate(); 356 bool isValid() { return valid; } 357 MlirDiagnosticSeverity getSeverity(); 358 PyLocation getLocation(); 359 nanobind::str getMessage(); 360 nanobind::tuple getNotes(); 361 362 /// Materialized diagnostic information. This is safe to access outside the 363 /// diagnostic callback. 364 struct DiagnosticInfo { 365 MlirDiagnosticSeverity severity; 366 PyLocation location; 367 std::string message; 368 std::vector<DiagnosticInfo> notes; 369 }; 370 DiagnosticInfo getInfo(); 371 372 private: 373 MlirDiagnostic diagnostic; 374 375 void checkValid(); 376 /// If notes have been materialized from the diagnostic, then this will 377 /// be populated with the corresponding objects (all castable to 378 /// PyDiagnostic). 379 std::optional<nanobind::tuple> materializedNotes; 380 bool valid = true; 381 }; 382 383 /// Represents a diagnostic handler attached to the context. The handler's 384 /// callback will be invoked with PyDiagnostic instances until the detach() 385 /// method is called or the context is destroyed. A diagnostic handler can be 386 /// the subject of a `with` block, which will detach it when the block exits. 387 /// 388 /// Since diagnostic handlers can call back into Python code which can do 389 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, 390 /// etc), this is generally not deemed to be a great user-level API. Users 391 /// should generally use some form of DiagnosticCollector. If the handler raises 392 /// any exceptions, they will just be emitted to stderr and dropped. 393 /// 394 /// The unique usage of this class means that its lifetime management is 395 /// different from most other parts of the API. Instances are always created 396 /// in an attached state and can transition to a detached state by either: 397 /// a) The context being destroyed and unregistering all handlers. 398 /// b) An explicit call to detach(). 399 /// The object may remain live from a Python perspective for an arbitrary time 400 /// after detachment, but there is nothing the user can do with it (since there 401 /// is no way to attach an existing handler object). 402 class PyDiagnosticHandler { 403 public: 404 PyDiagnosticHandler(MlirContext context, nanobind::object callback); 405 ~PyDiagnosticHandler(); 406 407 bool isAttached() { return registeredID.has_value(); } 408 bool getHadError() { return hadError; } 409 410 /// Detaches the handler. Does nothing if not attached. 411 void detach(); 412 413 nanobind::object contextEnter() { return nanobind::cast(this); } 414 void contextExit(const nanobind::object &excType, 415 const nanobind::object &excVal, 416 const nanobind::object &excTb) { 417 detach(); 418 } 419 420 private: 421 MlirContext context; 422 nanobind::object callback; 423 std::optional<MlirDiagnosticHandlerID> registeredID; 424 bool hadError = false; 425 friend class PyMlirContext; 426 }; 427 428 /// RAII object that captures any error diagnostics emitted to the provided 429 /// context. 430 struct PyMlirContext::ErrorCapture { 431 ErrorCapture(PyMlirContextRef ctx) 432 : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( 433 ctx->get(), handler, /*userData=*/this, 434 /*deleteUserData=*/nullptr)) {} 435 ~ErrorCapture() { 436 mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); 437 assert(errors.empty() && "unhandled captured errors"); 438 } 439 440 std::vector<PyDiagnostic::DiagnosticInfo> take() { 441 return std::move(errors); 442 }; 443 444 private: 445 PyMlirContextRef ctx; 446 MlirDiagnosticHandlerID handlerID; 447 std::vector<PyDiagnostic::DiagnosticInfo> errors; 448 449 static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); 450 }; 451 452 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 453 /// order to differentiate it from the `Dialect` base class which is extended by 454 /// plugins which extend dialect functionality through extension python code. 455 /// This should be seen as the "low-level" object and `Dialect` as the 456 /// high-level, user facing object. 457 class PyDialectDescriptor : public BaseContextObject { 458 public: 459 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 460 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 461 462 MlirDialect get() { return dialect; } 463 464 private: 465 MlirDialect dialect; 466 }; 467 468 /// User-level object for accessing dialects with dotted syntax such as: 469 /// ctx.dialect.std 470 class PyDialects : public BaseContextObject { 471 public: 472 PyDialects(PyMlirContextRef contextRef) 473 : BaseContextObject(std::move(contextRef)) {} 474 475 MlirDialect getDialectForKey(const std::string &key, bool attrError); 476 }; 477 478 /// User-level dialect object. For dialects that have a registered extension, 479 /// this will be the base class of the extension dialect type. For un-extended, 480 /// objects of this type will be returned directly. 481 class PyDialect { 482 public: 483 PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} 484 485 nanobind::object getDescriptor() { return descriptor; } 486 487 private: 488 nanobind::object descriptor; 489 }; 490 491 /// Wrapper around an MlirDialectRegistry. 492 /// Upon construction, the Python wrapper takes ownership of the 493 /// underlying MlirDialectRegistry. 494 class PyDialectRegistry { 495 public: 496 PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} 497 PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} 498 ~PyDialectRegistry() { 499 if (!mlirDialectRegistryIsNull(registry)) 500 mlirDialectRegistryDestroy(registry); 501 } 502 PyDialectRegistry(PyDialectRegistry &) = delete; 503 PyDialectRegistry(PyDialectRegistry &&other) noexcept 504 : registry(other.registry) { 505 other.registry = {nullptr}; 506 } 507 508 operator MlirDialectRegistry() const { return registry; } 509 MlirDialectRegistry get() const { return registry; } 510 511 nanobind::object getCapsule(); 512 static PyDialectRegistry createFromCapsule(nanobind::object capsule); 513 514 private: 515 MlirDialectRegistry registry; 516 }; 517 518 /// Used in function arguments when None should resolve to the current context 519 /// manager set instance. 520 class DefaultingPyLocation 521 : public Defaulting<DefaultingPyLocation, PyLocation> { 522 public: 523 using Defaulting::Defaulting; 524 static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 525 static PyLocation &resolve(); 526 527 operator MlirLocation() const { return *get(); } 528 }; 529 530 /// Wrapper around MlirModule. 531 /// This is the top-level, user-owned object that contains regions/ops/blocks. 532 class PyModule; 533 using PyModuleRef = PyObjectRef<PyModule>; 534 class PyModule : public BaseContextObject { 535 public: 536 /// Returns a PyModule reference for the given MlirModule. This may return 537 /// a pre-existing or new object. 538 static PyModuleRef forModule(MlirModule module); 539 PyModule(PyModule &) = delete; 540 PyModule(PyMlirContext &&) = delete; 541 ~PyModule(); 542 543 /// Gets the backing MlirModule. 544 MlirModule get() { return module; } 545 546 /// Gets a strong reference to this module. 547 PyModuleRef getRef() { 548 return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle)); 549 } 550 551 /// Gets a capsule wrapping the void* within the MlirModule. 552 /// Note that the module does not (yet) provide a corresponding factory for 553 /// constructing from a capsule as that would require uniquing PyModule 554 /// instances, which is not currently done. 555 nanobind::object getCapsule(); 556 557 /// Creates a PyModule from the MlirModule wrapped by a capsule. 558 /// Note that PyModule instances are uniqued, so the returned object 559 /// may be a pre-existing object. Ownership of the underlying MlirModule 560 /// is taken by calling this function. 561 static nanobind::object createFromCapsule(nanobind::object capsule); 562 563 private: 564 PyModule(PyMlirContextRef contextRef, MlirModule module); 565 MlirModule module; 566 nanobind::handle handle; 567 }; 568 569 class PyAsmState; 570 571 /// Base class for PyOperation and PyOpView which exposes the primary, user 572 /// visible methods for manipulating it. 573 class PyOperationBase { 574 public: 575 virtual ~PyOperationBase() = default; 576 /// Implements the bound 'print' method and helps with others. 577 void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo, 578 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 579 bool assumeVerified, nanobind::object fileObject, bool binary, 580 bool skipRegions); 581 void print(PyAsmState &state, nanobind::object fileObject, bool binary); 582 583 nanobind::object getAsm(bool binary, 584 std::optional<int64_t> largeElementsLimit, 585 bool enableDebugInfo, bool prettyDebugInfo, 586 bool printGenericOpForm, bool useLocalScope, 587 bool assumeVerified, bool skipRegions); 588 589 // Implement the bound 'writeBytecode' method. 590 void writeBytecode(const nanobind::object &fileObject, 591 std::optional<int64_t> bytecodeVersion); 592 593 // Implement the walk method. 594 void walk(std::function<MlirWalkResult(MlirOperation)> callback, 595 MlirWalkOrder walkOrder); 596 597 /// Moves the operation before or after the other operation. 598 void moveAfter(PyOperationBase &other); 599 void moveBefore(PyOperationBase &other); 600 601 /// Verify the operation. Throws `MLIRError` if verification fails, and 602 /// returns `true` otherwise. 603 bool verify(); 604 605 /// Each must provide access to the raw Operation. 606 virtual PyOperation &getOperation() = 0; 607 }; 608 609 /// Wrapper around PyOperation. 610 /// Operations exist in either an attached (dependent) or detached (top-level) 611 /// state. In the detached state (as on creation), an operation is owned by 612 /// the creator and its lifetime extends either until its reference count 613 /// drops to zero or it is attached to a parent, at which point its lifetime 614 /// is bounded by its top-level parent reference. 615 class PyOperation; 616 using PyOperationRef = PyObjectRef<PyOperation>; 617 class PyOperation : public PyOperationBase, public BaseContextObject { 618 public: 619 ~PyOperation() override; 620 PyOperation &getOperation() override { return *this; } 621 622 /// Returns a PyOperation for the given MlirOperation, optionally associating 623 /// it with a parentKeepAlive. 624 static PyOperationRef 625 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 626 nanobind::object parentKeepAlive = nanobind::object()); 627 628 /// Creates a detached operation. The operation must not be associated with 629 /// any existing live operation. 630 static PyOperationRef 631 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 632 nanobind::object parentKeepAlive = nanobind::object()); 633 634 /// Parses a source string (either text assembly or bytecode), creating a 635 /// detached operation. 636 static PyOperationRef parse(PyMlirContextRef contextRef, 637 const std::string &sourceStr, 638 const std::string &sourceName); 639 640 /// Detaches the operation from its parent block and updates its state 641 /// accordingly. 642 void detachFromParent() { 643 mlirOperationRemoveFromParent(getOperation()); 644 setDetached(); 645 parentKeepAlive = nanobind::object(); 646 } 647 648 /// Gets the backing operation. 649 operator MlirOperation() const { return get(); } 650 MlirOperation get() const { 651 checkValid(); 652 return operation; 653 } 654 655 PyOperationRef getRef() { 656 return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle)); 657 } 658 659 bool isAttached() { return attached; } 660 void setAttached(const nanobind::object &parent = nanobind::object()) { 661 assert(!attached && "operation already attached"); 662 attached = true; 663 } 664 void setDetached() { 665 assert(attached && "operation already detached"); 666 attached = false; 667 } 668 void checkValid() const; 669 670 /// Gets the owning block or raises an exception if the operation has no 671 /// owning block. 672 PyBlock getBlock(); 673 674 /// Gets the parent operation or raises an exception if the operation has 675 /// no parent. 676 std::optional<PyOperationRef> getParentOperation(); 677 678 /// Gets a capsule wrapping the void* within the MlirOperation. 679 nanobind::object getCapsule(); 680 681 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 682 /// Ownership of the underlying MlirOperation is taken by calling this 683 /// function. 684 static nanobind::object createFromCapsule(nanobind::object capsule); 685 686 /// Creates an operation. See corresponding python docstring. 687 static nanobind::object 688 create(std::string_view name, std::optional<std::vector<PyType *>> results, 689 llvm::ArrayRef<MlirValue> operands, 690 std::optional<nanobind::dict> attributes, 691 std::optional<std::vector<PyBlock *>> successors, int regions, 692 DefaultingPyLocation location, const nanobind::object &ip, 693 bool inferType); 694 695 /// Creates an OpView suitable for this operation. 696 nanobind::object createOpView(); 697 698 /// Erases the underlying MlirOperation, removes its pointer from the 699 /// parent context's live operations map, and sets the valid bit false. 700 void erase(); 701 702 /// Invalidate the operation. 703 void setInvalid() { valid = false; } 704 705 /// Clones this operation. 706 nanobind::object clone(const nanobind::object &ip); 707 708 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 709 710 private: 711 static PyOperationRef createInstance(PyMlirContextRef contextRef, 712 MlirOperation operation, 713 nanobind::object parentKeepAlive); 714 715 MlirOperation operation; 716 nanobind::handle handle; 717 // Keeps the parent alive, regardless of whether it is an Operation or 718 // Module. 719 // TODO: As implemented, this facility is only sufficient for modeling the 720 // trivial module parent back-reference. Generalize this to also account for 721 // transitions from detached to attached and address TODOs in the 722 // ir_operation.py regarding testing corresponding lifetime guarantees. 723 nanobind::object parentKeepAlive; 724 bool attached = true; 725 bool valid = true; 726 727 friend class PyOperationBase; 728 friend class PySymbolTable; 729 }; 730 731 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 732 /// providing more instance-specific accessors and serve as the base class for 733 /// custom ODS-style operation classes. Since this class is subclass on the 734 /// python side, it must present an __init__ method that operates in pure 735 /// python types. 736 class PyOpView : public PyOperationBase { 737 public: 738 PyOpView(const nanobind::object &operationObject); 739 PyOperation &getOperation() override { return operation; } 740 741 nanobind::object getOperationObject() { return operationObject; } 742 743 static nanobind::object 744 buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec, 745 nanobind::object operandSegmentSpecObj, 746 nanobind::object resultSegmentSpecObj, 747 std::optional<nanobind::list> resultTypeList, 748 nanobind::list operandList, 749 std::optional<nanobind::dict> attributes, 750 std::optional<std::vector<PyBlock *>> successors, 751 std::optional<int> regions, DefaultingPyLocation location, 752 const nanobind::object &maybeIp); 753 754 /// Construct an instance of a class deriving from OpView, bypassing its 755 /// `__init__` method. The derived class will typically define a constructor 756 /// that provides a convenient builder, but we need to side-step this when 757 /// constructing an `OpView` for an already-built operation. 758 /// 759 /// The caller is responsible for verifying that `operation` is a valid 760 /// operation to construct `cls` with. 761 static nanobind::object constructDerived(const nanobind::object &cls, 762 const nanobind::object &operation); 763 764 private: 765 PyOperation &operation; // For efficient, cast-free access from C++ 766 nanobind::object operationObject; // Holds the reference. 767 }; 768 769 /// Wrapper around an MlirRegion. 770 /// Regions are managed completely by their containing operation. Unlike the 771 /// C++ API, the python API does not support detached regions. 772 class PyRegion { 773 public: 774 PyRegion(PyOperationRef parentOperation, MlirRegion region) 775 : parentOperation(std::move(parentOperation)), region(region) { 776 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 777 } 778 operator MlirRegion() const { return region; } 779 780 MlirRegion get() { return region; } 781 PyOperationRef &getParentOperation() { return parentOperation; } 782 783 void checkValid() { return parentOperation->checkValid(); } 784 785 private: 786 PyOperationRef parentOperation; 787 MlirRegion region; 788 }; 789 790 /// Wrapper around an MlirAsmState. 791 class PyAsmState { 792 public: 793 PyAsmState(MlirValue value, bool useLocalScope) { 794 flags = mlirOpPrintingFlagsCreate(); 795 // The OpPrintingFlags are not exposed Python side, create locally and 796 // associate lifetime with the state. 797 if (useLocalScope) 798 mlirOpPrintingFlagsUseLocalScope(flags); 799 state = mlirAsmStateCreateForValue(value, flags); 800 } 801 802 PyAsmState(PyOperationBase &operation, bool useLocalScope) { 803 flags = mlirOpPrintingFlagsCreate(); 804 // The OpPrintingFlags are not exposed Python side, create locally and 805 // associate lifetime with the state. 806 if (useLocalScope) 807 mlirOpPrintingFlagsUseLocalScope(flags); 808 state = 809 mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); 810 } 811 ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } 812 // Delete copy constructors. 813 PyAsmState(PyAsmState &other) = delete; 814 PyAsmState(const PyAsmState &other) = delete; 815 816 MlirAsmState get() { return state; } 817 818 private: 819 MlirAsmState state; 820 MlirOpPrintingFlags flags; 821 }; 822 823 /// Wrapper around an MlirBlock. 824 /// Blocks are managed completely by their containing operation. Unlike the 825 /// C++ API, the python API does not support detached blocks. 826 class PyBlock { 827 public: 828 PyBlock(PyOperationRef parentOperation, MlirBlock block) 829 : parentOperation(std::move(parentOperation)), block(block) { 830 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 831 } 832 833 MlirBlock get() { return block; } 834 PyOperationRef &getParentOperation() { return parentOperation; } 835 836 void checkValid() { return parentOperation->checkValid(); } 837 838 /// Gets a capsule wrapping the void* within the MlirBlock. 839 nanobind::object getCapsule(); 840 841 private: 842 PyOperationRef parentOperation; 843 MlirBlock block; 844 }; 845 846 /// An insertion point maintains a pointer to a Block and a reference operation. 847 /// Calls to insert() will insert a new operation before the 848 /// reference operation. If the reference operation is null, then appends to 849 /// the end of the block. 850 class PyInsertionPoint { 851 public: 852 /// Creates an insertion point positioned after the last operation in the 853 /// block, but still inside the block. 854 PyInsertionPoint(PyBlock &block); 855 /// Creates an insertion point positioned before a reference operation. 856 PyInsertionPoint(PyOperationBase &beforeOperationBase); 857 858 /// Shortcut to create an insertion point at the beginning of the block. 859 static PyInsertionPoint atBlockBegin(PyBlock &block); 860 /// Shortcut to create an insertion point before the block terminator. 861 static PyInsertionPoint atBlockTerminator(PyBlock &block); 862 863 /// Inserts an operation. 864 void insert(PyOperationBase &operationBase); 865 866 /// Enter and exit the context manager. 867 static nanobind::object contextEnter(nanobind::object insertionPoint); 868 void contextExit(const nanobind::object &excType, 869 const nanobind::object &excVal, 870 const nanobind::object &excTb); 871 872 PyBlock &getBlock() { return block; } 873 std::optional<PyOperationRef> &getRefOperation() { return refOperation; } 874 875 private: 876 // Trampoline constructor that avoids null initializing members while 877 // looking up parents. 878 PyInsertionPoint(PyBlock block, std::optional<PyOperationRef> refOperation) 879 : refOperation(std::move(refOperation)), block(std::move(block)) {} 880 881 std::optional<PyOperationRef> refOperation; 882 PyBlock block; 883 }; 884 /// Wrapper around the generic MlirType. 885 /// The lifetime of a type is bound by the PyContext that created it. 886 class PyType : public BaseContextObject { 887 public: 888 PyType(PyMlirContextRef contextRef, MlirType type) 889 : BaseContextObject(std::move(contextRef)), type(type) {} 890 bool operator==(const PyType &other) const; 891 operator MlirType() const { return type; } 892 MlirType get() const { return type; } 893 894 /// Gets a capsule wrapping the void* within the MlirType. 895 nanobind::object getCapsule(); 896 897 /// Creates a PyType from the MlirType wrapped by a capsule. 898 /// Note that PyType instances are uniqued, so the returned object 899 /// may be a pre-existing object. Ownership of the underlying MlirType 900 /// is taken by calling this function. 901 static PyType createFromCapsule(nanobind::object capsule); 902 903 private: 904 MlirType type; 905 }; 906 907 /// A TypeID provides an efficient and unique identifier for a specific C++ 908 /// type. This allows for a C++ type to be compared, hashed, and stored in an 909 /// opaque context. This class wraps around the generic MlirTypeID. 910 class PyTypeID { 911 public: 912 PyTypeID(MlirTypeID typeID) : typeID(typeID) {} 913 // Note, this tests whether the underlying TypeIDs are the same, 914 // not whether the wrapper MlirTypeIDs are the same, nor whether 915 // the PyTypeID objects are the same (i.e., PyTypeID is a value type). 916 bool operator==(const PyTypeID &other) const; 917 operator MlirTypeID() const { return typeID; } 918 MlirTypeID get() { return typeID; } 919 920 /// Gets a capsule wrapping the void* within the MlirTypeID. 921 nanobind::object getCapsule(); 922 923 /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. 924 static PyTypeID createFromCapsule(nanobind::object capsule); 925 926 private: 927 MlirTypeID typeID; 928 }; 929 930 /// CRTP base classes for Python types that subclass Type and should be 931 /// castable from it (i.e. via something like IntegerType(t)). 932 /// By default, type class hierarchies are one level deep (i.e. a 933 /// concrete type class extends PyType); however, intermediate python-visible 934 /// base classes can be modeled by specifying a BaseTy. 935 template <typename DerivedTy, typename BaseTy = PyType> 936 class PyConcreteType : public BaseTy { 937 public: 938 // Derived classes must define statics for: 939 // IsAFunctionTy isaFunction 940 // const char *pyClassName 941 using ClassTy = nanobind::class_<DerivedTy, BaseTy>; 942 using IsAFunctionTy = bool (*)(MlirType); 943 using GetTypeIDFunctionTy = MlirTypeID (*)(); 944 static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; 945 946 PyConcreteType() = default; 947 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 948 : BaseTy(std::move(contextRef), t) {} 949 PyConcreteType(PyType &orig) 950 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 951 952 static MlirType castFrom(PyType &orig) { 953 if (!DerivedTy::isaFunction(orig)) { 954 auto origRepr = 955 nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); 956 throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + 957 DerivedTy::pyClassName + " (from " + 958 origRepr + ")") 959 .str() 960 .c_str()); 961 } 962 return orig; 963 } 964 965 static void bind(nanobind::module_ &m) { 966 auto cls = ClassTy(m, DerivedTy::pyClassName); 967 cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(), 968 nanobind::arg("cast_from_type")); 969 cls.def_static( 970 "isinstance", 971 [](PyType &otherType) -> bool { 972 return DerivedTy::isaFunction(otherType); 973 }, 974 nanobind::arg("other")); 975 cls.def_prop_ro_static( 976 "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { 977 if (DerivedTy::getTypeIdFunction) 978 return DerivedTy::getTypeIdFunction(); 979 throw nanobind::attribute_error( 980 (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) 981 .str() 982 .c_str()); 983 }); 984 cls.def_prop_ro("typeid", [](PyType &self) { 985 return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid")); 986 }); 987 cls.def("__repr__", [](DerivedTy &self) { 988 PyPrintAccumulator printAccum; 989 printAccum.parts.append(DerivedTy::pyClassName); 990 printAccum.parts.append("("); 991 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); 992 printAccum.parts.append(")"); 993 return printAccum.join(); 994 }); 995 996 if (DerivedTy::getTypeIdFunction) { 997 PyGlobals::get().registerTypeCaster( 998 DerivedTy::getTypeIdFunction(), 999 nanobind::cast<nanobind::callable>(nanobind::cpp_function( 1000 [](PyType pyType) -> DerivedTy { return pyType; }))); 1001 } 1002 1003 DerivedTy::bindDerived(cls); 1004 } 1005 1006 /// Implemented by derived classes to add methods to the Python subclass. 1007 static void bindDerived(ClassTy &m) {} 1008 }; 1009 1010 /// Wrapper around the generic MlirAttribute. 1011 /// The lifetime of a type is bound by the PyContext that created it. 1012 class PyAttribute : public BaseContextObject { 1013 public: 1014 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 1015 : BaseContextObject(std::move(contextRef)), attr(attr) {} 1016 bool operator==(const PyAttribute &other) const; 1017 operator MlirAttribute() const { return attr; } 1018 MlirAttribute get() const { return attr; } 1019 1020 /// Gets a capsule wrapping the void* within the MlirAttribute. 1021 nanobind::object getCapsule(); 1022 1023 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 1024 /// Note that PyAttribute instances are uniqued, so the returned object 1025 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 1026 /// is taken by calling this function. 1027 static PyAttribute createFromCapsule(nanobind::object capsule); 1028 1029 private: 1030 MlirAttribute attr; 1031 }; 1032 1033 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 1034 /// TODO: Refactor this and the C-API to be based on an Identifier owned 1035 /// by the context so as to avoid ownership issues here. 1036 class PyNamedAttribute { 1037 public: 1038 /// Constructs a PyNamedAttr that retains an owned name. This should be 1039 /// used in any code that originates an MlirNamedAttribute from a python 1040 /// string. 1041 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 1042 /// passed attribute. 1043 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 1044 1045 MlirNamedAttribute namedAttr; 1046 1047 private: 1048 // Since the MlirNamedAttr contains an internal pointer to the actual 1049 // memory of the owned string, it must be heap allocated to remain valid. 1050 // Otherwise, strings that fit within the small object optimization threshold 1051 // will have their memory address change as the containing object is moved, 1052 // resulting in an invalid aliased pointer. 1053 std::unique_ptr<std::string> ownedName; 1054 }; 1055 1056 /// CRTP base classes for Python attributes that subclass Attribute and should 1057 /// be castable from it (i.e. via something like StringAttr(attr)). 1058 /// By default, attribute class hierarchies are one level deep (i.e. a 1059 /// concrete attribute class extends PyAttribute); however, intermediate 1060 /// python-visible base classes can be modeled by specifying a BaseTy. 1061 template <typename DerivedTy, typename BaseTy = PyAttribute> 1062 class PyConcreteAttribute : public BaseTy { 1063 public: 1064 // Derived classes must define statics for: 1065 // IsAFunctionTy isaFunction 1066 // const char *pyClassName 1067 using ClassTy = nanobind::class_<DerivedTy, BaseTy>; 1068 using IsAFunctionTy = bool (*)(MlirAttribute); 1069 using GetTypeIDFunctionTy = MlirTypeID (*)(); 1070 static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; 1071 1072 PyConcreteAttribute() = default; 1073 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 1074 : BaseTy(std::move(contextRef), attr) {} 1075 PyConcreteAttribute(PyAttribute &orig) 1076 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 1077 1078 static MlirAttribute castFrom(PyAttribute &orig) { 1079 if (!DerivedTy::isaFunction(orig)) { 1080 auto origRepr = 1081 nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig))); 1082 throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + 1083 DerivedTy::pyClassName + " (from " + 1084 origRepr + ")") 1085 .str() 1086 .c_str()); 1087 } 1088 return orig; 1089 } 1090 1091 static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { 1092 ClassTy cls; 1093 if (slots) { 1094 cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); 1095 } else { 1096 cls = ClassTy(m, DerivedTy::pyClassName); 1097 } 1098 cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(), 1099 nanobind::arg("cast_from_attr")); 1100 cls.def_static( 1101 "isinstance", 1102 [](PyAttribute &otherAttr) -> bool { 1103 return DerivedTy::isaFunction(otherAttr); 1104 }, 1105 nanobind::arg("other")); 1106 cls.def_prop_ro( 1107 "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); 1108 cls.def_prop_ro_static( 1109 "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { 1110 if (DerivedTy::getTypeIdFunction) 1111 return DerivedTy::getTypeIdFunction(); 1112 throw nanobind::attribute_error( 1113 (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) 1114 .str() 1115 .c_str()); 1116 }); 1117 cls.def_prop_ro("typeid", [](PyAttribute &self) { 1118 return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid")); 1119 }); 1120 cls.def("__repr__", [](DerivedTy &self) { 1121 PyPrintAccumulator printAccum; 1122 printAccum.parts.append(DerivedTy::pyClassName); 1123 printAccum.parts.append("("); 1124 mlirAttributePrint(self, printAccum.getCallback(), 1125 printAccum.getUserData()); 1126 printAccum.parts.append(")"); 1127 return printAccum.join(); 1128 }); 1129 1130 if (DerivedTy::getTypeIdFunction) { 1131 PyGlobals::get().registerTypeCaster( 1132 DerivedTy::getTypeIdFunction(), 1133 nanobind::cast<nanobind::callable>( 1134 nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { 1135 return pyAttribute; 1136 }))); 1137 } 1138 1139 DerivedTy::bindDerived(cls); 1140 } 1141 1142 /// Implemented by derived classes to add methods to the Python subclass. 1143 static void bindDerived(ClassTy &m) {} 1144 }; 1145 1146 /// Wrapper around the generic MlirValue. 1147 /// Values are managed completely by the operation that resulted in their 1148 /// definition. For op result value, this is the operation that defines the 1149 /// value. For block argument values, this is the operation that contains the 1150 /// block to which the value is an argument (blocks cannot be detached in Python 1151 /// bindings so such operation always exists). 1152 class PyValue { 1153 public: 1154 // The virtual here is "load bearing" in that it enables RTTI 1155 // for PyConcreteValue CRTP classes that support maybeDownCast. 1156 // See PyValue::maybeDownCast. 1157 virtual ~PyValue() = default; 1158 PyValue(PyOperationRef parentOperation, MlirValue value) 1159 : parentOperation(std::move(parentOperation)), value(value) {} 1160 operator MlirValue() const { return value; } 1161 1162 MlirValue get() { return value; } 1163 PyOperationRef &getParentOperation() { return parentOperation; } 1164 1165 void checkValid() { return parentOperation->checkValid(); } 1166 1167 /// Gets a capsule wrapping the void* within the MlirValue. 1168 nanobind::object getCapsule(); 1169 1170 nanobind::object maybeDownCast(); 1171 1172 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 1173 /// the underlying MlirValue is still tied to the owning operation. 1174 static PyValue createFromCapsule(nanobind::object capsule); 1175 1176 private: 1177 PyOperationRef parentOperation; 1178 MlirValue value; 1179 }; 1180 1181 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 1182 class PyAffineExpr : public BaseContextObject { 1183 public: 1184 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 1185 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 1186 bool operator==(const PyAffineExpr &other) const; 1187 operator MlirAffineExpr() const { return affineExpr; } 1188 MlirAffineExpr get() const { return affineExpr; } 1189 1190 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 1191 nanobind::object getCapsule(); 1192 1193 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 1194 /// Note that PyAffineExpr instances are uniqued, so the returned object 1195 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 1196 /// is taken by calling this function. 1197 static PyAffineExpr createFromCapsule(nanobind::object capsule); 1198 1199 PyAffineExpr add(const PyAffineExpr &other) const; 1200 PyAffineExpr mul(const PyAffineExpr &other) const; 1201 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 1202 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 1203 PyAffineExpr mod(const PyAffineExpr &other) const; 1204 1205 private: 1206 MlirAffineExpr affineExpr; 1207 }; 1208 1209 class PyAffineMap : public BaseContextObject { 1210 public: 1211 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 1212 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 1213 bool operator==(const PyAffineMap &other) const; 1214 operator MlirAffineMap() const { return affineMap; } 1215 MlirAffineMap get() const { return affineMap; } 1216 1217 /// Gets a capsule wrapping the void* within the MlirAffineMap. 1218 nanobind::object getCapsule(); 1219 1220 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 1221 /// Note that PyAffineMap instances are uniqued, so the returned object 1222 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 1223 /// is taken by calling this function. 1224 static PyAffineMap createFromCapsule(nanobind::object capsule); 1225 1226 private: 1227 MlirAffineMap affineMap; 1228 }; 1229 1230 class PyIntegerSet : public BaseContextObject { 1231 public: 1232 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 1233 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 1234 bool operator==(const PyIntegerSet &other) const; 1235 operator MlirIntegerSet() const { return integerSet; } 1236 MlirIntegerSet get() const { return integerSet; } 1237 1238 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 1239 nanobind::object getCapsule(); 1240 1241 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 1242 /// Note that PyIntegerSet instances may be uniqued, so the returned object 1243 /// may be a pre-existing object. Integer sets are owned by the context. 1244 static PyIntegerSet createFromCapsule(nanobind::object capsule); 1245 1246 private: 1247 MlirIntegerSet integerSet; 1248 }; 1249 1250 /// Bindings for MLIR symbol tables. 1251 class PySymbolTable { 1252 public: 1253 /// Constructs a symbol table for the given operation. 1254 explicit PySymbolTable(PyOperationBase &operation); 1255 1256 /// Destroys the symbol table. 1257 ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 1258 1259 /// Returns the symbol (opview) with the given name, throws if there is no 1260 /// such symbol in the table. 1261 nanobind::object dunderGetItem(const std::string &name); 1262 1263 /// Removes the given operation from the symbol table and erases it. 1264 void erase(PyOperationBase &symbol); 1265 1266 /// Removes the operation with the given name from the symbol table and erases 1267 /// it, throws if there is no such symbol in the table. 1268 void dunderDel(const std::string &name); 1269 1270 /// Inserts the given operation into the symbol table. The operation must have 1271 /// the symbol trait. 1272 MlirAttribute insert(PyOperationBase &symbol); 1273 1274 /// Gets and sets the name of a symbol op. 1275 static MlirAttribute getSymbolName(PyOperationBase &symbol); 1276 static void setSymbolName(PyOperationBase &symbol, const std::string &name); 1277 1278 /// Gets and sets the visibility of a symbol op. 1279 static MlirAttribute getVisibility(PyOperationBase &symbol); 1280 static void setVisibility(PyOperationBase &symbol, 1281 const std::string &visibility); 1282 1283 /// Replaces all symbol uses within an operation. See the API 1284 /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 1285 static void replaceAllSymbolUses(const std::string &oldSymbol, 1286 const std::string &newSymbol, 1287 PyOperationBase &from); 1288 1289 /// Walks all symbol tables under and including 'from'. 1290 static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 1291 nanobind::object callback); 1292 1293 /// Casts the bindings class into the C API structure. 1294 operator MlirSymbolTable() { return symbolTable; } 1295 1296 private: 1297 PyOperationRef operation; 1298 MlirSymbolTable symbolTable; 1299 }; 1300 1301 /// Custom exception that allows access to error diagnostic information. This is 1302 /// converted to the `ir.MLIRError` python exception when thrown. 1303 struct MLIRError { 1304 MLIRError(llvm::Twine message, 1305 std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {}) 1306 : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} 1307 std::string message; 1308 std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics; 1309 }; 1310 1311 void populateIRAffine(nanobind::module_ &m); 1312 void populateIRAttributes(nanobind::module_ &m); 1313 void populateIRCore(nanobind::module_ &m); 1314 void populateIRInterfaces(nanobind::module_ &m); 1315 void populateIRTypes(nanobind::module_ &m); 1316 1317 } // namespace python 1318 } // namespace mlir 1319 1320 namespace nanobind { 1321 namespace detail { 1322 1323 template <> 1324 struct type_caster<mlir::python::DefaultingPyMlirContext> 1325 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 1326 template <> 1327 struct type_caster<mlir::python::DefaultingPyLocation> 1328 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 1329 1330 } // namespace detail 1331 } // namespace nanobind 1332 1333 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 1334