1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <optional> 10 #include <utility> 11 12 #include "Globals.h" 13 #include "IRModule.h" 14 #include "NanobindUtils.h" 15 #include "mlir-c/BuiltinAttributes.h" 16 #include "mlir-c/Debug.h" 17 #include "mlir-c/Diagnostics.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/Support.h" 20 #include "mlir/Bindings/Python/Nanobind.h" 21 #include "mlir/Bindings/Python/NanobindAdaptors.h" 22 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. 23 #include "llvm/ADT/ArrayRef.h" 24 #include "llvm/ADT/SmallVector.h" 25 26 namespace nb = nanobind; 27 using namespace nb::literals; 28 using namespace mlir; 29 using namespace mlir::python; 30 31 using llvm::SmallVector; 32 using llvm::StringRef; 33 using llvm::Twine; 34 35 //------------------------------------------------------------------------------ 36 // Docstrings (trivial, non-duplicated docstrings are included inline). 37 //------------------------------------------------------------------------------ 38 39 static const char kContextParseTypeDocstring[] = 40 R"(Parses the assembly form of a type. 41 42 Returns a Type object or raises an MLIRError if the type cannot be parsed. 43 44 See also: https://mlir.llvm.org/docs/LangRef/#type-system 45 )"; 46 47 static const char kContextGetCallSiteLocationDocstring[] = 48 R"(Gets a Location representing a caller and callsite)"; 49 50 static const char kContextGetFileLocationDocstring[] = 51 R"(Gets a Location representing a file, line and column)"; 52 53 static const char kContextGetFileRangeDocstring[] = 54 R"(Gets a Location representing a file, line and column range)"; 55 56 static const char kContextGetFusedLocationDocstring[] = 57 R"(Gets a Location representing a fused location with optional metadata)"; 58 59 static const char kContextGetNameLocationDocString[] = 60 R"(Gets a Location representing a named location with optional child location)"; 61 62 static const char kModuleParseDocstring[] = 63 R"(Parses a module's assembly format from a string. 64 65 Returns a new MlirModule or raises an MLIRError if the parsing fails. 66 67 See also: https://mlir.llvm.org/docs/LangRef/ 68 )"; 69 70 static const char kOperationCreateDocstring[] = 71 R"(Creates a new operation. 72 73 Args: 74 name: Operation name (e.g. "dialect.operation"). 75 results: Sequence of Type representing op result types. 76 attributes: Dict of str:Attribute. 77 successors: List of Block for the operation's successors. 78 regions: Number of regions to create. 79 location: A Location object (defaults to resolve from context manager). 80 ip: An InsertionPoint (defaults to resolve from context manager or set to 81 False to disable insertion, even with an insertion point set in the 82 context manager). 83 infer_type: Whether to infer result types. 84 Returns: 85 A new "detached" Operation object. Detached operations can be added 86 to blocks, which causes them to become "attached." 87 )"; 88 89 static const char kOperationPrintDocstring[] = 90 R"(Prints the assembly form of the operation to a file like object. 91 92 Args: 93 file: The file like object to write to. Defaults to sys.stdout. 94 binary: Whether to write bytes (True) or str (False). Defaults to False. 95 large_elements_limit: Whether to elide elements attributes above this 96 number of elements. Defaults to None (no limit). 97 enable_debug_info: Whether to print debug/location information. Defaults 98 to False. 99 pretty_debug_info: Whether to format debug information for easier reading 100 by a human (warning: the result is unparseable). 101 print_generic_op_form: Whether to print the generic assembly forms of all 102 ops. Defaults to False. 103 use_local_Scope: Whether to print in a way that is more optimized for 104 multi-threaded access but may not be consistent with how the overall 105 module prints. 106 assume_verified: By default, if not printing generic form, the verifier 107 will be run and if it fails, generic form will be printed with a comment 108 about failed verification. While a reasonable default for interactive use, 109 for systematic use, it is often better for the caller to verify explicitly 110 and report failures in a more robust fashion. Set this to True if doing this 111 in order to avoid running a redundant verification. If the IR is actually 112 invalid, behavior is undefined. 113 skip_regions: Whether to skip printing regions. Defaults to False. 114 )"; 115 116 static const char kOperationPrintStateDocstring[] = 117 R"(Prints the assembly form of the operation to a file like object. 118 119 Args: 120 file: The file like object to write to. Defaults to sys.stdout. 121 binary: Whether to write bytes (True) or str (False). Defaults to False. 122 state: AsmState capturing the operation numbering and flags. 123 )"; 124 125 static const char kOperationGetAsmDocstring[] = 126 R"(Gets the assembly form of the operation with all options available. 127 128 Args: 129 binary: Whether to return a bytes (True) or str (False) object. Defaults to 130 False. 131 ... others ...: See the print() method for common keyword arguments for 132 configuring the printout. 133 Returns: 134 Either a bytes or str object, depending on the setting of the 'binary' 135 argument. 136 )"; 137 138 static const char kOperationPrintBytecodeDocstring[] = 139 R"(Write the bytecode form of the operation to a file like object. 140 141 Args: 142 file: The file like object to write to. 143 desired_version: The version of bytecode to emit. 144 Returns: 145 The bytecode writer status. 146 )"; 147 148 static const char kOperationStrDunderDocstring[] = 149 R"(Gets the assembly form of the operation with default options. 150 151 If more advanced control over the assembly formatting or I/O options is needed, 152 use the dedicated print or get_asm method, which supports keyword arguments to 153 customize behavior. 154 )"; 155 156 static const char kDumpDocstring[] = 157 R"(Dumps a debug representation of the object to stderr.)"; 158 159 static const char kAppendBlockDocstring[] = 160 R"(Appends a new block, with argument types as positional args. 161 162 Returns: 163 The created block. 164 )"; 165 166 static const char kValueDunderStrDocstring[] = 167 R"(Returns the string form of the value. 168 169 If the value is a block argument, this is the assembly form of its type and the 170 position in the argument list. If the value is an operation result, this is 171 equivalent to printing the operation that produced it. 172 )"; 173 174 static const char kGetNameAsOperand[] = 175 R"(Returns the string form of value as an operand (i.e., the ValueID). 176 )"; 177 178 static const char kValueReplaceAllUsesWithDocstring[] = 179 R"(Replace all uses of value with the new value, updating anything in 180 the IR that uses 'self' to use the other value instead. 181 )"; 182 183 static const char kValueReplaceAllUsesExceptDocstring[] = 184 R"("Replace all uses of this value with the 'with' value, except for those 185 in 'exceptions'. 'exceptions' can be either a single operation or a list of 186 operations. 187 )"; 188 189 //------------------------------------------------------------------------------ 190 // Utilities. 191 //------------------------------------------------------------------------------ 192 193 /// Helper for creating an @classmethod. 194 template <class Func, typename... Args> 195 nb::object classmethod(Func f, Args... args) { 196 nb::object cf = nb::cpp_function(f, args...); 197 return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr()))); 198 } 199 200 static nb::object 201 createCustomDialectWrapper(const std::string &dialectNamespace, 202 nb::object dialectDescriptor) { 203 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); 204 if (!dialectClass) { 205 // Use the base class. 206 return nb::cast(PyDialect(std::move(dialectDescriptor))); 207 } 208 209 // Create the custom implementation. 210 return (*dialectClass)(std::move(dialectDescriptor)); 211 } 212 213 static MlirStringRef toMlirStringRef(const std::string &s) { 214 return mlirStringRefCreate(s.data(), s.size()); 215 } 216 217 static MlirStringRef toMlirStringRef(std::string_view s) { 218 return mlirStringRefCreate(s.data(), s.size()); 219 } 220 221 static MlirStringRef toMlirStringRef(const nb::bytes &s) { 222 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size()); 223 } 224 225 /// Create a block, using the current location context if no locations are 226 /// specified. 227 static MlirBlock createBlock(const nb::sequence &pyArgTypes, 228 const std::optional<nb::sequence> &pyArgLocs) { 229 SmallVector<MlirType> argTypes; 230 argTypes.reserve(nb::len(pyArgTypes)); 231 for (const auto &pyType : pyArgTypes) 232 argTypes.push_back(nb::cast<PyType &>(pyType)); 233 234 SmallVector<MlirLocation> argLocs; 235 if (pyArgLocs) { 236 argLocs.reserve(nb::len(*pyArgLocs)); 237 for (const auto &pyLoc : *pyArgLocs) 238 argLocs.push_back(nb::cast<PyLocation &>(pyLoc)); 239 } else if (!argTypes.empty()) { 240 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); 241 } 242 243 if (argTypes.size() != argLocs.size()) 244 throw nb::value_error(("Expected " + Twine(argTypes.size()) + 245 " locations, got: " + Twine(argLocs.size())) 246 .str() 247 .c_str()); 248 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); 249 } 250 251 /// Wrapper for the global LLVM debugging flag. 252 struct PyGlobalDebugFlag { 253 static void set(nb::object &o, bool enable) { 254 nb::ft_lock_guard lock(mutex); 255 mlirEnableGlobalDebug(enable); 256 } 257 258 static bool get(const nb::object &) { 259 nb::ft_lock_guard lock(mutex); 260 return mlirIsGlobalDebugEnabled(); 261 } 262 263 static void bind(nb::module_ &m) { 264 // Debug flags. 265 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug") 266 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, 267 &PyGlobalDebugFlag::set, "LLVM-wide debug flag") 268 .def_static( 269 "set_types", 270 [](const std::string &type) { 271 nb::ft_lock_guard lock(mutex); 272 mlirSetGlobalDebugType(type.c_str()); 273 }, 274 "types"_a, "Sets specific debug types to be produced by LLVM") 275 .def_static("set_types", [](const std::vector<std::string> &types) { 276 std::vector<const char *> pointers; 277 pointers.reserve(types.size()); 278 for (const std::string &str : types) 279 pointers.push_back(str.c_str()); 280 nb::ft_lock_guard lock(mutex); 281 mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); 282 }); 283 } 284 285 private: 286 static nb::ft_mutex mutex; 287 }; 288 289 nb::ft_mutex PyGlobalDebugFlag::mutex; 290 291 struct PyAttrBuilderMap { 292 static bool dunderContains(const std::string &attributeKind) { 293 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); 294 } 295 static nb::callable dunderGetItemNamed(const std::string &attributeKind) { 296 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); 297 if (!builder) 298 throw nb::key_error(attributeKind.c_str()); 299 return *builder; 300 } 301 static void dunderSetItemNamed(const std::string &attributeKind, 302 nb::callable func, bool replace) { 303 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), 304 replace); 305 } 306 307 static void bind(nb::module_ &m) { 308 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder") 309 .def_static("contains", &PyAttrBuilderMap::dunderContains) 310 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) 311 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, 312 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, 313 "Register an attribute builder for building MLIR " 314 "attributes from python values."); 315 } 316 }; 317 318 //------------------------------------------------------------------------------ 319 // PyBlock 320 //------------------------------------------------------------------------------ 321 322 nb::object PyBlock::getCapsule() { 323 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get())); 324 } 325 326 //------------------------------------------------------------------------------ 327 // Collections. 328 //------------------------------------------------------------------------------ 329 330 namespace { 331 332 class PyRegionIterator { 333 public: 334 PyRegionIterator(PyOperationRef operation) 335 : operation(std::move(operation)) {} 336 337 PyRegionIterator &dunderIter() { return *this; } 338 339 PyRegion dunderNext() { 340 operation->checkValid(); 341 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { 342 throw nb::stop_iteration(); 343 } 344 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); 345 return PyRegion(operation, region); 346 } 347 348 static void bind(nb::module_ &m) { 349 nb::class_<PyRegionIterator>(m, "RegionIterator") 350 .def("__iter__", &PyRegionIterator::dunderIter) 351 .def("__next__", &PyRegionIterator::dunderNext); 352 } 353 354 private: 355 PyOperationRef operation; 356 int nextIndex = 0; 357 }; 358 359 /// Regions of an op are fixed length and indexed numerically so are represented 360 /// with a sequence-like container. 361 class PyRegionList { 362 public: 363 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} 364 365 PyRegionIterator dunderIter() { 366 operation->checkValid(); 367 return PyRegionIterator(operation); 368 } 369 370 intptr_t dunderLen() { 371 operation->checkValid(); 372 return mlirOperationGetNumRegions(operation->get()); 373 } 374 375 PyRegion dunderGetItem(intptr_t index) { 376 // dunderLen checks validity. 377 if (index < 0 || index >= dunderLen()) { 378 throw nb::index_error("attempt to access out of bounds region"); 379 } 380 MlirRegion region = mlirOperationGetRegion(operation->get(), index); 381 return PyRegion(operation, region); 382 } 383 384 static void bind(nb::module_ &m) { 385 nb::class_<PyRegionList>(m, "RegionSequence") 386 .def("__len__", &PyRegionList::dunderLen) 387 .def("__iter__", &PyRegionList::dunderIter) 388 .def("__getitem__", &PyRegionList::dunderGetItem); 389 } 390 391 private: 392 PyOperationRef operation; 393 }; 394 395 class PyBlockIterator { 396 public: 397 PyBlockIterator(PyOperationRef operation, MlirBlock next) 398 : operation(std::move(operation)), next(next) {} 399 400 PyBlockIterator &dunderIter() { return *this; } 401 402 PyBlock dunderNext() { 403 operation->checkValid(); 404 if (mlirBlockIsNull(next)) { 405 throw nb::stop_iteration(); 406 } 407 408 PyBlock returnBlock(operation, next); 409 next = mlirBlockGetNextInRegion(next); 410 return returnBlock; 411 } 412 413 static void bind(nb::module_ &m) { 414 nb::class_<PyBlockIterator>(m, "BlockIterator") 415 .def("__iter__", &PyBlockIterator::dunderIter) 416 .def("__next__", &PyBlockIterator::dunderNext); 417 } 418 419 private: 420 PyOperationRef operation; 421 MlirBlock next; 422 }; 423 424 /// Blocks are exposed by the C-API as a forward-only linked list. In Python, 425 /// we present them as a more full-featured list-like container but optimize 426 /// it for forward iteration. Blocks are always owned by a region. 427 class PyBlockList { 428 public: 429 PyBlockList(PyOperationRef operation, MlirRegion region) 430 : operation(std::move(operation)), region(region) {} 431 432 PyBlockIterator dunderIter() { 433 operation->checkValid(); 434 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); 435 } 436 437 intptr_t dunderLen() { 438 operation->checkValid(); 439 intptr_t count = 0; 440 MlirBlock block = mlirRegionGetFirstBlock(region); 441 while (!mlirBlockIsNull(block)) { 442 count += 1; 443 block = mlirBlockGetNextInRegion(block); 444 } 445 return count; 446 } 447 448 PyBlock dunderGetItem(intptr_t index) { 449 operation->checkValid(); 450 if (index < 0) { 451 throw nb::index_error("attempt to access out of bounds block"); 452 } 453 MlirBlock block = mlirRegionGetFirstBlock(region); 454 while (!mlirBlockIsNull(block)) { 455 if (index == 0) { 456 return PyBlock(operation, block); 457 } 458 block = mlirBlockGetNextInRegion(block); 459 index -= 1; 460 } 461 throw nb::index_error("attempt to access out of bounds block"); 462 } 463 464 PyBlock appendBlock(const nb::args &pyArgTypes, 465 const std::optional<nb::sequence> &pyArgLocs) { 466 operation->checkValid(); 467 MlirBlock block = 468 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); 469 mlirRegionAppendOwnedBlock(region, block); 470 return PyBlock(operation, block); 471 } 472 473 static void bind(nb::module_ &m) { 474 nb::class_<PyBlockList>(m, "BlockList") 475 .def("__getitem__", &PyBlockList::dunderGetItem) 476 .def("__iter__", &PyBlockList::dunderIter) 477 .def("__len__", &PyBlockList::dunderLen) 478 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, 479 nb::arg("args"), nb::kw_only(), 480 nb::arg("arg_locs") = std::nullopt); 481 } 482 483 private: 484 PyOperationRef operation; 485 MlirRegion region; 486 }; 487 488 class PyOperationIterator { 489 public: 490 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) 491 : parentOperation(std::move(parentOperation)), next(next) {} 492 493 PyOperationIterator &dunderIter() { return *this; } 494 495 nb::object dunderNext() { 496 parentOperation->checkValid(); 497 if (mlirOperationIsNull(next)) { 498 throw nb::stop_iteration(); 499 } 500 501 PyOperationRef returnOperation = 502 PyOperation::forOperation(parentOperation->getContext(), next); 503 next = mlirOperationGetNextInBlock(next); 504 return returnOperation->createOpView(); 505 } 506 507 static void bind(nb::module_ &m) { 508 nb::class_<PyOperationIterator>(m, "OperationIterator") 509 .def("__iter__", &PyOperationIterator::dunderIter) 510 .def("__next__", &PyOperationIterator::dunderNext); 511 } 512 513 private: 514 PyOperationRef parentOperation; 515 MlirOperation next; 516 }; 517 518 /// Operations are exposed by the C-API as a forward-only linked list. In 519 /// Python, we present them as a more full-featured list-like container but 520 /// optimize it for forward iteration. Iterable operations are always owned 521 /// by a block. 522 class PyOperationList { 523 public: 524 PyOperationList(PyOperationRef parentOperation, MlirBlock block) 525 : parentOperation(std::move(parentOperation)), block(block) {} 526 527 PyOperationIterator dunderIter() { 528 parentOperation->checkValid(); 529 return PyOperationIterator(parentOperation, 530 mlirBlockGetFirstOperation(block)); 531 } 532 533 intptr_t dunderLen() { 534 parentOperation->checkValid(); 535 intptr_t count = 0; 536 MlirOperation childOp = mlirBlockGetFirstOperation(block); 537 while (!mlirOperationIsNull(childOp)) { 538 count += 1; 539 childOp = mlirOperationGetNextInBlock(childOp); 540 } 541 return count; 542 } 543 544 nb::object dunderGetItem(intptr_t index) { 545 parentOperation->checkValid(); 546 if (index < 0) { 547 throw nb::index_error("attempt to access out of bounds operation"); 548 } 549 MlirOperation childOp = mlirBlockGetFirstOperation(block); 550 while (!mlirOperationIsNull(childOp)) { 551 if (index == 0) { 552 return PyOperation::forOperation(parentOperation->getContext(), childOp) 553 ->createOpView(); 554 } 555 childOp = mlirOperationGetNextInBlock(childOp); 556 index -= 1; 557 } 558 throw nb::index_error("attempt to access out of bounds operation"); 559 } 560 561 static void bind(nb::module_ &m) { 562 nb::class_<PyOperationList>(m, "OperationList") 563 .def("__getitem__", &PyOperationList::dunderGetItem) 564 .def("__iter__", &PyOperationList::dunderIter) 565 .def("__len__", &PyOperationList::dunderLen); 566 } 567 568 private: 569 PyOperationRef parentOperation; 570 MlirBlock block; 571 }; 572 573 class PyOpOperand { 574 public: 575 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} 576 577 nb::object getOwner() { 578 MlirOperation owner = mlirOpOperandGetOwner(opOperand); 579 PyMlirContextRef context = 580 PyMlirContext::forContext(mlirOperationGetContext(owner)); 581 return PyOperation::forOperation(context, owner)->createOpView(); 582 } 583 584 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } 585 586 static void bind(nb::module_ &m) { 587 nb::class_<PyOpOperand>(m, "OpOperand") 588 .def_prop_ro("owner", &PyOpOperand::getOwner) 589 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); 590 } 591 592 private: 593 MlirOpOperand opOperand; 594 }; 595 596 class PyOpOperandIterator { 597 public: 598 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} 599 600 PyOpOperandIterator &dunderIter() { return *this; } 601 602 PyOpOperand dunderNext() { 603 if (mlirOpOperandIsNull(opOperand)) 604 throw nb::stop_iteration(); 605 606 PyOpOperand returnOpOperand(opOperand); 607 opOperand = mlirOpOperandGetNextUse(opOperand); 608 return returnOpOperand; 609 } 610 611 static void bind(nb::module_ &m) { 612 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator") 613 .def("__iter__", &PyOpOperandIterator::dunderIter) 614 .def("__next__", &PyOpOperandIterator::dunderNext); 615 } 616 617 private: 618 MlirOpOperand opOperand; 619 }; 620 621 } // namespace 622 623 //------------------------------------------------------------------------------ 624 // PyMlirContext 625 //------------------------------------------------------------------------------ 626 627 PyMlirContext::PyMlirContext(MlirContext context) : context(context) { 628 nb::gil_scoped_acquire acquire; 629 nb::ft_lock_guard lock(live_contexts_mutex); 630 auto &liveContexts = getLiveContexts(); 631 liveContexts[context.ptr] = this; 632 } 633 634 PyMlirContext::~PyMlirContext() { 635 // Note that the only public way to construct an instance is via the 636 // forContext method, which always puts the associated handle into 637 // liveContexts. 638 nb::gil_scoped_acquire acquire; 639 { 640 nb::ft_lock_guard lock(live_contexts_mutex); 641 getLiveContexts().erase(context.ptr); 642 } 643 mlirContextDestroy(context); 644 } 645 646 nb::object PyMlirContext::getCapsule() { 647 return nb::steal<nb::object>(mlirPythonContextToCapsule(get())); 648 } 649 650 nb::object PyMlirContext::createFromCapsule(nb::object capsule) { 651 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); 652 if (mlirContextIsNull(rawContext)) 653 throw nb::python_error(); 654 return forContext(rawContext).releaseObject(); 655 } 656 657 PyMlirContextRef PyMlirContext::forContext(MlirContext context) { 658 nb::gil_scoped_acquire acquire; 659 nb::ft_lock_guard lock(live_contexts_mutex); 660 auto &liveContexts = getLiveContexts(); 661 auto it = liveContexts.find(context.ptr); 662 if (it == liveContexts.end()) { 663 // Create. 664 PyMlirContext *unownedContextWrapper = new PyMlirContext(context); 665 nb::object pyRef = nb::cast(unownedContextWrapper); 666 assert(pyRef && "cast to nb::object failed"); 667 liveContexts[context.ptr] = unownedContextWrapper; 668 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); 669 } 670 // Use existing. 671 nb::object pyRef = nb::cast(it->second); 672 return PyMlirContextRef(it->second, std::move(pyRef)); 673 } 674 675 nb::ft_mutex PyMlirContext::live_contexts_mutex; 676 677 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { 678 static LiveContextMap liveContexts; 679 return liveContexts; 680 } 681 682 size_t PyMlirContext::getLiveCount() { 683 nb::ft_lock_guard lock(live_contexts_mutex); 684 return getLiveContexts().size(); 685 } 686 687 size_t PyMlirContext::getLiveOperationCount() { 688 nb::ft_lock_guard lock(liveOperationsMutex); 689 return liveOperations.size(); 690 } 691 692 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() { 693 std::vector<PyOperation *> liveObjects; 694 nb::ft_lock_guard lock(liveOperationsMutex); 695 for (auto &entry : liveOperations) 696 liveObjects.push_back(entry.second.second); 697 return liveObjects; 698 } 699 700 size_t PyMlirContext::clearLiveOperations() { 701 702 LiveOperationMap operations; 703 { 704 nb::ft_lock_guard lock(liveOperationsMutex); 705 std::swap(operations, liveOperations); 706 } 707 for (auto &op : operations) 708 op.second.second->setInvalid(); 709 size_t numInvalidated = operations.size(); 710 return numInvalidated; 711 } 712 713 void PyMlirContext::clearOperation(MlirOperation op) { 714 PyOperation *py_op; 715 { 716 nb::ft_lock_guard lock(liveOperationsMutex); 717 auto it = liveOperations.find(op.ptr); 718 if (it == liveOperations.end()) { 719 return; 720 } 721 py_op = it->second.second; 722 liveOperations.erase(it); 723 } 724 py_op->setInvalid(); 725 } 726 727 void PyMlirContext::clearOperationsInside(PyOperationBase &op) { 728 typedef struct { 729 PyOperation &rootOp; 730 bool rootSeen; 731 } callBackData; 732 callBackData data{op.getOperation(), false}; 733 // Mark all ops below the op that the passmanager will be rooted 734 // at (but not op itself - note the preorder) as invalid. 735 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, 736 void *userData) { 737 callBackData *data = static_cast<callBackData *>(userData); 738 if (LLVM_LIKELY(data->rootSeen)) 739 data->rootOp.getOperation().getContext()->clearOperation(op); 740 else 741 data->rootSeen = true; 742 return MlirWalkResult::MlirWalkResultAdvance; 743 }; 744 mlirOperationWalk(op.getOperation(), invalidatingCallback, 745 static_cast<void *>(&data), MlirWalkPreOrder); 746 } 747 void PyMlirContext::clearOperationsInside(MlirOperation op) { 748 PyOperationRef opRef = PyOperation::forOperation(getRef(), op); 749 clearOperationsInside(opRef->getOperation()); 750 } 751 752 void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { 753 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, 754 void *userData) { 755 PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData); 756 contextRef->clearOperation(op); 757 return MlirWalkResult::MlirWalkResultAdvance; 758 }; 759 mlirOperationWalk(op.getOperation(), invalidatingCallback, 760 &op.getOperation().getContext(), MlirWalkPreOrder); 761 } 762 763 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } 764 765 nb::object PyMlirContext::contextEnter(nb::object context) { 766 return PyThreadContextEntry::pushContext(context); 767 } 768 769 void PyMlirContext::contextExit(const nb::object &excType, 770 const nb::object &excVal, 771 const nb::object &excTb) { 772 PyThreadContextEntry::popContext(*this); 773 } 774 775 nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { 776 // Note that ownership is transferred to the delete callback below by way of 777 // an explicit inc_ref (borrow). 778 PyDiagnosticHandler *pyHandler = 779 new PyDiagnosticHandler(get(), std::move(callback)); 780 nb::object pyHandlerObject = 781 nb::cast(pyHandler, nb::rv_policy::take_ownership); 782 pyHandlerObject.inc_ref(); 783 784 // In these C callbacks, the userData is a PyDiagnosticHandler* that is 785 // guaranteed to be known to pybind. 786 auto handlerCallback = 787 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { 788 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); 789 nb::object pyDiagnosticObject = 790 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); 791 792 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 793 bool result = false; 794 { 795 // Since this can be called from arbitrary C++ contexts, always get the 796 // gil. 797 nb::gil_scoped_acquire gil; 798 try { 799 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic)); 800 } catch (std::exception &e) { 801 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", 802 e.what()); 803 pyHandler->hadError = true; 804 } 805 } 806 807 pyDiagnostic->invalidate(); 808 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); 809 }; 810 auto deleteCallback = +[](void *userData) { 811 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData); 812 assert(pyHandler->registeredID && "handler is not registered"); 813 pyHandler->registeredID.reset(); 814 815 // Decrement reference, balancing the inc_ref() above. 816 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); 817 pyHandlerObject.dec_ref(); 818 }; 819 820 pyHandler->registeredID = mlirContextAttachDiagnosticHandler( 821 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback); 822 return pyHandlerObject; 823 } 824 825 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, 826 void *userData) { 827 auto *self = static_cast<ErrorCapture *>(userData); 828 // Check if the context requested we emit errors instead of capturing them. 829 if (self->ctx->emitErrorDiagnostics) 830 return mlirLogicalResultFailure(); 831 832 if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) 833 return mlirLogicalResultFailure(); 834 835 self->errors.emplace_back(PyDiagnostic(diag).getInfo()); 836 return mlirLogicalResultSuccess(); 837 } 838 839 PyMlirContext &DefaultingPyMlirContext::resolve() { 840 PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); 841 if (!context) { 842 throw std::runtime_error( 843 "An MLIR function requires a Context but none was provided in the call " 844 "or from the surrounding environment. Either pass to the function with " 845 "a 'context=' argument or establish a default using 'with Context():'"); 846 } 847 return *context; 848 } 849 850 //------------------------------------------------------------------------------ 851 // PyThreadContextEntry management 852 //------------------------------------------------------------------------------ 853 854 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { 855 static thread_local std::vector<PyThreadContextEntry> stack; 856 return stack; 857 } 858 859 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { 860 auto &stack = getStack(); 861 if (stack.empty()) 862 return nullptr; 863 return &stack.back(); 864 } 865 866 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, 867 nb::object insertionPoint, 868 nb::object location) { 869 auto &stack = getStack(); 870 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), 871 std::move(location)); 872 // If the new stack has more than one entry and the context of the new top 873 // entry matches the previous, copy the insertionPoint and location from the 874 // previous entry if missing from the new top entry. 875 if (stack.size() > 1) { 876 auto &prev = *(stack.rbegin() + 1); 877 auto ¤t = stack.back(); 878 if (current.context.is(prev.context)) { 879 // Default non-context objects from the previous entry. 880 if (!current.insertionPoint) 881 current.insertionPoint = prev.insertionPoint; 882 if (!current.location) 883 current.location = prev.location; 884 } 885 } 886 } 887 888 PyMlirContext *PyThreadContextEntry::getContext() { 889 if (!context) 890 return nullptr; 891 return nb::cast<PyMlirContext *>(context); 892 } 893 894 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { 895 if (!insertionPoint) 896 return nullptr; 897 return nb::cast<PyInsertionPoint *>(insertionPoint); 898 } 899 900 PyLocation *PyThreadContextEntry::getLocation() { 901 if (!location) 902 return nullptr; 903 return nb::cast<PyLocation *>(location); 904 } 905 906 PyMlirContext *PyThreadContextEntry::getDefaultContext() { 907 auto *tos = getTopOfStack(); 908 return tos ? tos->getContext() : nullptr; 909 } 910 911 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { 912 auto *tos = getTopOfStack(); 913 return tos ? tos->getInsertionPoint() : nullptr; 914 } 915 916 PyLocation *PyThreadContextEntry::getDefaultLocation() { 917 auto *tos = getTopOfStack(); 918 return tos ? tos->getLocation() : nullptr; 919 } 920 921 nb::object PyThreadContextEntry::pushContext(nb::object context) { 922 push(FrameKind::Context, /*context=*/context, 923 /*insertionPoint=*/nb::object(), 924 /*location=*/nb::object()); 925 return context; 926 } 927 928 void PyThreadContextEntry::popContext(PyMlirContext &context) { 929 auto &stack = getStack(); 930 if (stack.empty()) 931 throw std::runtime_error("Unbalanced Context enter/exit"); 932 auto &tos = stack.back(); 933 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) 934 throw std::runtime_error("Unbalanced Context enter/exit"); 935 stack.pop_back(); 936 } 937 938 nb::object 939 PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { 940 PyInsertionPoint &insertionPoint = 941 nb::cast<PyInsertionPoint &>(insertionPointObj); 942 nb::object contextObj = 943 insertionPoint.getBlock().getParentOperation()->getContext().getObject(); 944 push(FrameKind::InsertionPoint, 945 /*context=*/contextObj, 946 /*insertionPoint=*/insertionPointObj, 947 /*location=*/nb::object()); 948 return insertionPointObj; 949 } 950 951 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { 952 auto &stack = getStack(); 953 if (stack.empty()) 954 throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); 955 auto &tos = stack.back(); 956 if (tos.frameKind != FrameKind::InsertionPoint && 957 tos.getInsertionPoint() != &insertionPoint) 958 throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); 959 stack.pop_back(); 960 } 961 962 nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { 963 PyLocation &location = nb::cast<PyLocation &>(locationObj); 964 nb::object contextObj = location.getContext().getObject(); 965 push(FrameKind::Location, /*context=*/contextObj, 966 /*insertionPoint=*/nb::object(), 967 /*location=*/locationObj); 968 return locationObj; 969 } 970 971 void PyThreadContextEntry::popLocation(PyLocation &location) { 972 auto &stack = getStack(); 973 if (stack.empty()) 974 throw std::runtime_error("Unbalanced Location enter/exit"); 975 auto &tos = stack.back(); 976 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) 977 throw std::runtime_error("Unbalanced Location enter/exit"); 978 stack.pop_back(); 979 } 980 981 //------------------------------------------------------------------------------ 982 // PyDiagnostic* 983 //------------------------------------------------------------------------------ 984 985 void PyDiagnostic::invalidate() { 986 valid = false; 987 if (materializedNotes) { 988 for (nb::handle noteObject : *materializedNotes) { 989 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject); 990 note->invalidate(); 991 } 992 } 993 } 994 995 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, 996 nb::object callback) 997 : context(context), callback(std::move(callback)) {} 998 999 PyDiagnosticHandler::~PyDiagnosticHandler() = default; 1000 1001 void PyDiagnosticHandler::detach() { 1002 if (!registeredID) 1003 return; 1004 MlirDiagnosticHandlerID localID = *registeredID; 1005 mlirContextDetachDiagnosticHandler(context, localID); 1006 assert(!registeredID && "should have unregistered"); 1007 // Not strictly necessary but keeps stale pointers from being around to cause 1008 // issues. 1009 context = {nullptr}; 1010 } 1011 1012 void PyDiagnostic::checkValid() { 1013 if (!valid) { 1014 throw std::invalid_argument( 1015 "Diagnostic is invalid (used outside of callback)"); 1016 } 1017 } 1018 1019 MlirDiagnosticSeverity PyDiagnostic::getSeverity() { 1020 checkValid(); 1021 return mlirDiagnosticGetSeverity(diagnostic); 1022 } 1023 1024 PyLocation PyDiagnostic::getLocation() { 1025 checkValid(); 1026 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); 1027 MlirContext context = mlirLocationGetContext(loc); 1028 return PyLocation(PyMlirContext::forContext(context), loc); 1029 } 1030 1031 nb::str PyDiagnostic::getMessage() { 1032 checkValid(); 1033 nb::object fileObject = nb::module_::import_("io").attr("StringIO")(); 1034 PyFileAccumulator accum(fileObject, /*binary=*/false); 1035 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); 1036 return nb::cast<nb::str>(fileObject.attr("getvalue")()); 1037 } 1038 1039 nb::tuple PyDiagnostic::getNotes() { 1040 checkValid(); 1041 if (materializedNotes) 1042 return *materializedNotes; 1043 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); 1044 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes)); 1045 for (intptr_t i = 0; i < numNotes; ++i) { 1046 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); 1047 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); 1048 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); 1049 } 1050 materializedNotes = std::move(notes); 1051 1052 return *materializedNotes; 1053 } 1054 1055 PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { 1056 std::vector<DiagnosticInfo> notes; 1057 for (nb::handle n : getNotes()) 1058 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo()); 1059 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()), 1060 std::move(notes)}; 1061 } 1062 1063 //------------------------------------------------------------------------------ 1064 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry 1065 //------------------------------------------------------------------------------ 1066 1067 MlirDialect PyDialects::getDialectForKey(const std::string &key, 1068 bool attrError) { 1069 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), 1070 {key.data(), key.size()}); 1071 if (mlirDialectIsNull(dialect)) { 1072 std::string msg = (Twine("Dialect '") + key + "' not found").str(); 1073 if (attrError) 1074 throw nb::attribute_error(msg.c_str()); 1075 throw nb::index_error(msg.c_str()); 1076 } 1077 return dialect; 1078 } 1079 1080 nb::object PyDialectRegistry::getCapsule() { 1081 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this)); 1082 } 1083 1084 PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { 1085 MlirDialectRegistry rawRegistry = 1086 mlirPythonCapsuleToDialectRegistry(capsule.ptr()); 1087 if (mlirDialectRegistryIsNull(rawRegistry)) 1088 throw nb::python_error(); 1089 return PyDialectRegistry(rawRegistry); 1090 } 1091 1092 //------------------------------------------------------------------------------ 1093 // PyLocation 1094 //------------------------------------------------------------------------------ 1095 1096 nb::object PyLocation::getCapsule() { 1097 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this)); 1098 } 1099 1100 PyLocation PyLocation::createFromCapsule(nb::object capsule) { 1101 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); 1102 if (mlirLocationIsNull(rawLoc)) 1103 throw nb::python_error(); 1104 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), 1105 rawLoc); 1106 } 1107 1108 nb::object PyLocation::contextEnter(nb::object locationObj) { 1109 return PyThreadContextEntry::pushLocation(locationObj); 1110 } 1111 1112 void PyLocation::contextExit(const nb::object &excType, 1113 const nb::object &excVal, 1114 const nb::object &excTb) { 1115 PyThreadContextEntry::popLocation(*this); 1116 } 1117 1118 PyLocation &DefaultingPyLocation::resolve() { 1119 auto *location = PyThreadContextEntry::getDefaultLocation(); 1120 if (!location) { 1121 throw std::runtime_error( 1122 "An MLIR function requires a Location but none was provided in the " 1123 "call or from the surrounding environment. Either pass to the function " 1124 "with a 'loc=' argument or establish a default using 'with loc:'"); 1125 } 1126 return *location; 1127 } 1128 1129 //------------------------------------------------------------------------------ 1130 // PyModule 1131 //------------------------------------------------------------------------------ 1132 1133 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) 1134 : BaseContextObject(std::move(contextRef)), module(module) {} 1135 1136 PyModule::~PyModule() { 1137 nb::gil_scoped_acquire acquire; 1138 auto &liveModules = getContext()->liveModules; 1139 assert(liveModules.count(module.ptr) == 1 && 1140 "destroying module not in live map"); 1141 liveModules.erase(module.ptr); 1142 mlirModuleDestroy(module); 1143 } 1144 1145 PyModuleRef PyModule::forModule(MlirModule module) { 1146 MlirContext context = mlirModuleGetContext(module); 1147 PyMlirContextRef contextRef = PyMlirContext::forContext(context); 1148 1149 nb::gil_scoped_acquire acquire; 1150 auto &liveModules = contextRef->liveModules; 1151 auto it = liveModules.find(module.ptr); 1152 if (it == liveModules.end()) { 1153 // Create. 1154 PyModule *unownedModule = new PyModule(std::move(contextRef), module); 1155 // Note that the default return value policy on cast is automatic_reference, 1156 // which does not take ownership (delete will not be called). 1157 // Just be explicit. 1158 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); 1159 unownedModule->handle = pyRef; 1160 liveModules[module.ptr] = 1161 std::make_pair(unownedModule->handle, unownedModule); 1162 return PyModuleRef(unownedModule, std::move(pyRef)); 1163 } 1164 // Use existing. 1165 PyModule *existing = it->second.second; 1166 nb::object pyRef = nb::borrow<nb::object>(it->second.first); 1167 return PyModuleRef(existing, std::move(pyRef)); 1168 } 1169 1170 nb::object PyModule::createFromCapsule(nb::object capsule) { 1171 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); 1172 if (mlirModuleIsNull(rawModule)) 1173 throw nb::python_error(); 1174 return forModule(rawModule).releaseObject(); 1175 } 1176 1177 nb::object PyModule::getCapsule() { 1178 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get())); 1179 } 1180 1181 //------------------------------------------------------------------------------ 1182 // PyOperation 1183 //------------------------------------------------------------------------------ 1184 1185 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) 1186 : BaseContextObject(std::move(contextRef)), operation(operation) {} 1187 1188 PyOperation::~PyOperation() { 1189 // If the operation has already been invalidated there is nothing to do. 1190 if (!valid) 1191 return; 1192 1193 // Otherwise, invalidate the operation and remove it from live map when it is 1194 // attached. 1195 if (isAttached()) { 1196 getContext()->clearOperation(*this); 1197 } else { 1198 // And destroy it when it is detached, i.e. owned by Python, in which case 1199 // all nested operations must be invalidated at removed from the live map as 1200 // well. 1201 erase(); 1202 } 1203 } 1204 1205 namespace { 1206 1207 // Constructs a new object of type T in-place on the Python heap, returning a 1208 // PyObjectRef to it, loosely analogous to std::make_shared<T>(). 1209 template <typename T, class... Args> 1210 PyObjectRef<T> makeObjectRef(Args &&...args) { 1211 nb::handle type = nb::type<T>(); 1212 nb::object instance = nb::inst_alloc(type); 1213 T *ptr = nb::inst_ptr<T>(instance); 1214 new (ptr) T(std::forward<Args>(args)...); 1215 nb::inst_mark_ready(instance); 1216 return PyObjectRef<T>(ptr, std::move(instance)); 1217 } 1218 1219 } // namespace 1220 1221 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, 1222 MlirOperation operation, 1223 nb::object parentKeepAlive) { 1224 // Create. 1225 PyOperationRef unownedOperation = 1226 makeObjectRef<PyOperation>(std::move(contextRef), operation); 1227 unownedOperation->handle = unownedOperation.getObject(); 1228 if (parentKeepAlive) { 1229 unownedOperation->parentKeepAlive = std::move(parentKeepAlive); 1230 } 1231 return unownedOperation; 1232 } 1233 1234 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, 1235 MlirOperation operation, 1236 nb::object parentKeepAlive) { 1237 nb::ft_lock_guard lock(contextRef->liveOperationsMutex); 1238 auto &liveOperations = contextRef->liveOperations; 1239 auto it = liveOperations.find(operation.ptr); 1240 if (it == liveOperations.end()) { 1241 // Create. 1242 PyOperationRef result = createInstance(std::move(contextRef), operation, 1243 std::move(parentKeepAlive)); 1244 liveOperations[operation.ptr] = 1245 std::make_pair(result.getObject(), result.get()); 1246 return result; 1247 } 1248 // Use existing. 1249 PyOperation *existing = it->second.second; 1250 nb::object pyRef = nb::borrow<nb::object>(it->second.first); 1251 return PyOperationRef(existing, std::move(pyRef)); 1252 } 1253 1254 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, 1255 MlirOperation operation, 1256 nb::object parentKeepAlive) { 1257 nb::ft_lock_guard lock(contextRef->liveOperationsMutex); 1258 auto &liveOperations = contextRef->liveOperations; 1259 assert(liveOperations.count(operation.ptr) == 0 && 1260 "cannot create detached operation that already exists"); 1261 (void)liveOperations; 1262 PyOperationRef created = createInstance(std::move(contextRef), operation, 1263 std::move(parentKeepAlive)); 1264 liveOperations[operation.ptr] = 1265 std::make_pair(created.getObject(), created.get()); 1266 created->attached = false; 1267 return created; 1268 } 1269 1270 PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, 1271 const std::string &sourceStr, 1272 const std::string &sourceName) { 1273 PyMlirContext::ErrorCapture errors(contextRef); 1274 MlirOperation op = 1275 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), 1276 toMlirStringRef(sourceName)); 1277 if (mlirOperationIsNull(op)) 1278 throw MLIRError("Unable to parse operation assembly", errors.take()); 1279 return PyOperation::createDetached(std::move(contextRef), op); 1280 } 1281 1282 void PyOperation::checkValid() const { 1283 if (!valid) { 1284 throw std::runtime_error("the operation has been invalidated"); 1285 } 1286 } 1287 1288 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit, 1289 bool enableDebugInfo, bool prettyDebugInfo, 1290 bool printGenericOpForm, bool useLocalScope, 1291 bool assumeVerified, nb::object fileObject, 1292 bool binary, bool skipRegions) { 1293 PyOperation &operation = getOperation(); 1294 operation.checkValid(); 1295 if (fileObject.is_none()) 1296 fileObject = nb::module_::import_("sys").attr("stdout"); 1297 1298 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 1299 if (largeElementsLimit) 1300 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); 1301 if (enableDebugInfo) 1302 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, 1303 /*prettyForm=*/prettyDebugInfo); 1304 if (printGenericOpForm) 1305 mlirOpPrintingFlagsPrintGenericOpForm(flags); 1306 if (useLocalScope) 1307 mlirOpPrintingFlagsUseLocalScope(flags); 1308 if (assumeVerified) 1309 mlirOpPrintingFlagsAssumeVerified(flags); 1310 if (skipRegions) 1311 mlirOpPrintingFlagsSkipRegions(flags); 1312 1313 PyFileAccumulator accum(fileObject, binary); 1314 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), 1315 accum.getUserData()); 1316 mlirOpPrintingFlagsDestroy(flags); 1317 } 1318 1319 void PyOperationBase::print(PyAsmState &state, nb::object fileObject, 1320 bool binary) { 1321 PyOperation &operation = getOperation(); 1322 operation.checkValid(); 1323 if (fileObject.is_none()) 1324 fileObject = nb::module_::import_("sys").attr("stdout"); 1325 PyFileAccumulator accum(fileObject, binary); 1326 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), 1327 accum.getUserData()); 1328 } 1329 1330 void PyOperationBase::writeBytecode(const nb::object &fileObject, 1331 std::optional<int64_t> bytecodeVersion) { 1332 PyOperation &operation = getOperation(); 1333 operation.checkValid(); 1334 PyFileAccumulator accum(fileObject, /*binary=*/true); 1335 1336 if (!bytecodeVersion.has_value()) 1337 return mlirOperationWriteBytecode(operation, accum.getCallback(), 1338 accum.getUserData()); 1339 1340 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); 1341 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); 1342 MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig( 1343 operation, config, accum.getCallback(), accum.getUserData()); 1344 mlirBytecodeWriterConfigDestroy(config); 1345 if (mlirLogicalResultIsFailure(res)) 1346 throw nb::value_error((Twine("Unable to honor desired bytecode version ") + 1347 Twine(*bytecodeVersion)) 1348 .str() 1349 .c_str()); 1350 } 1351 1352 void PyOperationBase::walk( 1353 std::function<MlirWalkResult(MlirOperation)> callback, 1354 MlirWalkOrder walkOrder) { 1355 PyOperation &operation = getOperation(); 1356 operation.checkValid(); 1357 struct UserData { 1358 std::function<MlirWalkResult(MlirOperation)> callback; 1359 bool gotException; 1360 std::string exceptionWhat; 1361 nb::object exceptionType; 1362 }; 1363 UserData userData{callback, false, {}, {}}; 1364 MlirOperationWalkCallback walkCallback = [](MlirOperation op, 1365 void *userData) { 1366 UserData *calleeUserData = static_cast<UserData *>(userData); 1367 try { 1368 return (calleeUserData->callback)(op); 1369 } catch (nb::python_error &e) { 1370 calleeUserData->gotException = true; 1371 calleeUserData->exceptionWhat = std::string(e.what()); 1372 calleeUserData->exceptionType = nb::borrow(e.type()); 1373 return MlirWalkResult::MlirWalkResultInterrupt; 1374 } 1375 }; 1376 mlirOperationWalk(operation, walkCallback, &userData, walkOrder); 1377 if (userData.gotException) { 1378 std::string message("Exception raised in callback: "); 1379 message.append(userData.exceptionWhat); 1380 throw std::runtime_error(message); 1381 } 1382 } 1383 1384 nb::object PyOperationBase::getAsm(bool binary, 1385 std::optional<int64_t> largeElementsLimit, 1386 bool enableDebugInfo, bool prettyDebugInfo, 1387 bool printGenericOpForm, bool useLocalScope, 1388 bool assumeVerified, bool skipRegions) { 1389 nb::object fileObject; 1390 if (binary) { 1391 fileObject = nb::module_::import_("io").attr("BytesIO")(); 1392 } else { 1393 fileObject = nb::module_::import_("io").attr("StringIO")(); 1394 } 1395 print(/*largeElementsLimit=*/largeElementsLimit, 1396 /*enableDebugInfo=*/enableDebugInfo, 1397 /*prettyDebugInfo=*/prettyDebugInfo, 1398 /*printGenericOpForm=*/printGenericOpForm, 1399 /*useLocalScope=*/useLocalScope, 1400 /*assumeVerified=*/assumeVerified, 1401 /*fileObject=*/fileObject, 1402 /*binary=*/binary, 1403 /*skipRegions=*/skipRegions); 1404 1405 return fileObject.attr("getvalue")(); 1406 } 1407 1408 void PyOperationBase::moveAfter(PyOperationBase &other) { 1409 PyOperation &operation = getOperation(); 1410 PyOperation &otherOp = other.getOperation(); 1411 operation.checkValid(); 1412 otherOp.checkValid(); 1413 mlirOperationMoveAfter(operation, otherOp); 1414 operation.parentKeepAlive = otherOp.parentKeepAlive; 1415 } 1416 1417 void PyOperationBase::moveBefore(PyOperationBase &other) { 1418 PyOperation &operation = getOperation(); 1419 PyOperation &otherOp = other.getOperation(); 1420 operation.checkValid(); 1421 otherOp.checkValid(); 1422 mlirOperationMoveBefore(operation, otherOp); 1423 operation.parentKeepAlive = otherOp.parentKeepAlive; 1424 } 1425 1426 bool PyOperationBase::verify() { 1427 PyOperation &op = getOperation(); 1428 PyMlirContext::ErrorCapture errors(op.getContext()); 1429 if (!mlirOperationVerify(op.get())) 1430 throw MLIRError("Verification failed", errors.take()); 1431 return true; 1432 } 1433 1434 std::optional<PyOperationRef> PyOperation::getParentOperation() { 1435 checkValid(); 1436 if (!isAttached()) 1437 throw nb::value_error("Detached operations have no parent"); 1438 MlirOperation operation = mlirOperationGetParentOperation(get()); 1439 if (mlirOperationIsNull(operation)) 1440 return {}; 1441 return PyOperation::forOperation(getContext(), operation); 1442 } 1443 1444 PyBlock PyOperation::getBlock() { 1445 checkValid(); 1446 std::optional<PyOperationRef> parentOperation = getParentOperation(); 1447 MlirBlock block = mlirOperationGetBlock(get()); 1448 assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); 1449 assert(parentOperation && "Operation has no parent"); 1450 return PyBlock{std::move(*parentOperation), block}; 1451 } 1452 1453 nb::object PyOperation::getCapsule() { 1454 checkValid(); 1455 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get())); 1456 } 1457 1458 nb::object PyOperation::createFromCapsule(nb::object capsule) { 1459 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); 1460 if (mlirOperationIsNull(rawOperation)) 1461 throw nb::python_error(); 1462 MlirContext rawCtxt = mlirOperationGetContext(rawOperation); 1463 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) 1464 .releaseObject(); 1465 } 1466 1467 static void maybeInsertOperation(PyOperationRef &op, 1468 const nb::object &maybeIp) { 1469 // InsertPoint active? 1470 if (!maybeIp.is(nb::cast(false))) { 1471 PyInsertionPoint *ip; 1472 if (maybeIp.is_none()) { 1473 ip = PyThreadContextEntry::getDefaultInsertionPoint(); 1474 } else { 1475 ip = nb::cast<PyInsertionPoint *>(maybeIp); 1476 } 1477 if (ip) 1478 ip->insert(*op.get()); 1479 } 1480 } 1481 1482 nb::object PyOperation::create(std::string_view name, 1483 std::optional<std::vector<PyType *>> results, 1484 llvm::ArrayRef<MlirValue> operands, 1485 std::optional<nb::dict> attributes, 1486 std::optional<std::vector<PyBlock *>> successors, 1487 int regions, DefaultingPyLocation location, 1488 const nb::object &maybeIp, bool inferType) { 1489 llvm::SmallVector<MlirType, 4> mlirResults; 1490 llvm::SmallVector<MlirBlock, 4> mlirSuccessors; 1491 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; 1492 1493 // General parameter validation. 1494 if (regions < 0) 1495 throw nb::value_error("number of regions must be >= 0"); 1496 1497 // Unpack/validate results. 1498 if (results) { 1499 mlirResults.reserve(results->size()); 1500 for (PyType *result : *results) { 1501 // TODO: Verify result type originate from the same context. 1502 if (!result) 1503 throw nb::value_error("result type cannot be None"); 1504 mlirResults.push_back(*result); 1505 } 1506 } 1507 // Unpack/validate attributes. 1508 if (attributes) { 1509 mlirAttributes.reserve(attributes->size()); 1510 for (std::pair<nb::handle, nb::handle> it : *attributes) { 1511 std::string key; 1512 try { 1513 key = nb::cast<std::string>(it.first); 1514 } catch (nb::cast_error &err) { 1515 std::string msg = "Invalid attribute key (not a string) when " 1516 "attempting to create the operation \"" + 1517 std::string(name) + "\" (" + err.what() + ")"; 1518 throw nb::type_error(msg.c_str()); 1519 } 1520 try { 1521 auto &attribute = nb::cast<PyAttribute &>(it.second); 1522 // TODO: Verify attribute originates from the same context. 1523 mlirAttributes.emplace_back(std::move(key), attribute); 1524 } catch (nb::cast_error &err) { 1525 std::string msg = "Invalid attribute value for the key \"" + key + 1526 "\" when attempting to create the operation \"" + 1527 std::string(name) + "\" (" + err.what() + ")"; 1528 throw nb::type_error(msg.c_str()); 1529 } catch (std::runtime_error &) { 1530 // This exception seems thrown when the value is "None". 1531 std::string msg = 1532 "Found an invalid (`None`?) attribute value for the key \"" + key + 1533 "\" when attempting to create the operation \"" + 1534 std::string(name) + "\""; 1535 throw std::runtime_error(msg); 1536 } 1537 } 1538 } 1539 // Unpack/validate successors. 1540 if (successors) { 1541 mlirSuccessors.reserve(successors->size()); 1542 for (auto *successor : *successors) { 1543 // TODO: Verify successor originate from the same context. 1544 if (!successor) 1545 throw nb::value_error("successor block cannot be None"); 1546 mlirSuccessors.push_back(successor->get()); 1547 } 1548 } 1549 1550 // Apply unpacked/validated to the operation state. Beyond this 1551 // point, exceptions cannot be thrown or else the state will leak. 1552 MlirOperationState state = 1553 mlirOperationStateGet(toMlirStringRef(name), location); 1554 if (!operands.empty()) 1555 mlirOperationStateAddOperands(&state, operands.size(), operands.data()); 1556 state.enableResultTypeInference = inferType; 1557 if (!mlirResults.empty()) 1558 mlirOperationStateAddResults(&state, mlirResults.size(), 1559 mlirResults.data()); 1560 if (!mlirAttributes.empty()) { 1561 // Note that the attribute names directly reference bytes in 1562 // mlirAttributes, so that vector must not be changed from here 1563 // on. 1564 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; 1565 mlirNamedAttributes.reserve(mlirAttributes.size()); 1566 for (auto &it : mlirAttributes) 1567 mlirNamedAttributes.push_back(mlirNamedAttributeGet( 1568 mlirIdentifierGet(mlirAttributeGetContext(it.second), 1569 toMlirStringRef(it.first)), 1570 it.second)); 1571 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), 1572 mlirNamedAttributes.data()); 1573 } 1574 if (!mlirSuccessors.empty()) 1575 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), 1576 mlirSuccessors.data()); 1577 if (regions) { 1578 llvm::SmallVector<MlirRegion, 4> mlirRegions; 1579 mlirRegions.resize(regions); 1580 for (int i = 0; i < regions; ++i) 1581 mlirRegions[i] = mlirRegionCreate(); 1582 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), 1583 mlirRegions.data()); 1584 } 1585 1586 // Construct the operation. 1587 MlirOperation operation = mlirOperationCreate(&state); 1588 if (!operation.ptr) 1589 throw nb::value_error("Operation creation failed"); 1590 PyOperationRef created = 1591 PyOperation::createDetached(location->getContext(), operation); 1592 maybeInsertOperation(created, maybeIp); 1593 1594 return created.getObject(); 1595 } 1596 1597 nb::object PyOperation::clone(const nb::object &maybeIp) { 1598 MlirOperation clonedOperation = mlirOperationClone(operation); 1599 PyOperationRef cloned = 1600 PyOperation::createDetached(getContext(), clonedOperation); 1601 maybeInsertOperation(cloned, maybeIp); 1602 1603 return cloned->createOpView(); 1604 } 1605 1606 nb::object PyOperation::createOpView() { 1607 checkValid(); 1608 MlirIdentifier ident = mlirOperationGetName(get()); 1609 MlirStringRef identStr = mlirIdentifierStr(ident); 1610 auto operationCls = PyGlobals::get().lookupOperationClass( 1611 StringRef(identStr.data, identStr.length)); 1612 if (operationCls) 1613 return PyOpView::constructDerived(*operationCls, getRef().getObject()); 1614 return nb::cast(PyOpView(getRef().getObject())); 1615 } 1616 1617 void PyOperation::erase() { 1618 checkValid(); 1619 getContext()->clearOperationAndInside(*this); 1620 mlirOperationDestroy(operation); 1621 } 1622 1623 namespace { 1624 /// CRTP base class for Python MLIR values that subclass Value and should be 1625 /// castable from it. The value hierarchy is one level deep and is not supposed 1626 /// to accommodate other levels unless core MLIR changes. 1627 template <typename DerivedTy> 1628 class PyConcreteValue : public PyValue { 1629 public: 1630 // Derived classes must define statics for: 1631 // IsAFunctionTy isaFunction 1632 // const char *pyClassName 1633 // and redefine bindDerived. 1634 using ClassTy = nb::class_<DerivedTy, PyValue>; 1635 using IsAFunctionTy = bool (*)(MlirValue); 1636 1637 PyConcreteValue() = default; 1638 PyConcreteValue(PyOperationRef operationRef, MlirValue value) 1639 : PyValue(operationRef, value) {} 1640 PyConcreteValue(PyValue &orig) 1641 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} 1642 1643 /// Attempts to cast the original value to the derived type and throws on 1644 /// type mismatches. 1645 static MlirValue castFrom(PyValue &orig) { 1646 if (!DerivedTy::isaFunction(orig.get())) { 1647 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig))); 1648 throw nb::value_error((Twine("Cannot cast value to ") + 1649 DerivedTy::pyClassName + " (from " + origRepr + 1650 ")") 1651 .str() 1652 .c_str()); 1653 } 1654 return orig.get(); 1655 } 1656 1657 /// Binds the Python module objects to functions of this class. 1658 static void bind(nb::module_ &m) { 1659 auto cls = ClassTy(m, DerivedTy::pyClassName); 1660 cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")); 1661 cls.def_static( 1662 "isinstance", 1663 [](PyValue &otherValue) -> bool { 1664 return DerivedTy::isaFunction(otherValue); 1665 }, 1666 nb::arg("other_value")); 1667 cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, 1668 [](DerivedTy &self) { return self.maybeDownCast(); }); 1669 DerivedTy::bindDerived(cls); 1670 } 1671 1672 /// Implemented by derived classes to add methods to the Python subclass. 1673 static void bindDerived(ClassTy &m) {} 1674 }; 1675 1676 } // namespace 1677 1678 /// Python wrapper for MlirOpResult. 1679 class PyOpResult : public PyConcreteValue<PyOpResult> { 1680 public: 1681 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; 1682 static constexpr const char *pyClassName = "OpResult"; 1683 using PyConcreteValue::PyConcreteValue; 1684 1685 static void bindDerived(ClassTy &c) { 1686 c.def_prop_ro("owner", [](PyOpResult &self) { 1687 assert( 1688 mlirOperationEqual(self.getParentOperation()->get(), 1689 mlirOpResultGetOwner(self.get())) && 1690 "expected the owner of the value in Python to match that in the IR"); 1691 return self.getParentOperation().getObject(); 1692 }); 1693 c.def_prop_ro("result_number", [](PyOpResult &self) { 1694 return mlirOpResultGetResultNumber(self.get()); 1695 }); 1696 } 1697 }; 1698 1699 /// Returns the list of types of the values held by container. 1700 template <typename Container> 1701 static std::vector<MlirType> getValueTypes(Container &container, 1702 PyMlirContextRef &context) { 1703 std::vector<MlirType> result; 1704 result.reserve(container.size()); 1705 for (int i = 0, e = container.size(); i < e; ++i) { 1706 result.push_back(mlirValueGetType(container.getElement(i).get())); 1707 } 1708 return result; 1709 } 1710 1711 /// A list of operation results. Internally, these are stored as consecutive 1712 /// elements, random access is cheap. The (returned) result list is associated 1713 /// with the operation whose results these are, and thus extends the lifetime of 1714 /// this operation. 1715 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { 1716 public: 1717 static constexpr const char *pyClassName = "OpResultList"; 1718 using SliceableT = Sliceable<PyOpResultList, PyOpResult>; 1719 1720 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, 1721 intptr_t length = -1, intptr_t step = 1) 1722 : Sliceable(startIndex, 1723 length == -1 ? mlirOperationGetNumResults(operation->get()) 1724 : length, 1725 step), 1726 operation(std::move(operation)) {} 1727 1728 static void bindDerived(ClassTy &c) { 1729 c.def_prop_ro("types", [](PyOpResultList &self) { 1730 return getValueTypes(self, self.operation->getContext()); 1731 }); 1732 c.def_prop_ro("owner", [](PyOpResultList &self) { 1733 return self.operation->createOpView(); 1734 }); 1735 } 1736 1737 PyOperationRef &getOperation() { return operation; } 1738 1739 private: 1740 /// Give the parent CRTP class access to hook implementations below. 1741 friend class Sliceable<PyOpResultList, PyOpResult>; 1742 1743 intptr_t getRawNumElements() { 1744 operation->checkValid(); 1745 return mlirOperationGetNumResults(operation->get()); 1746 } 1747 1748 PyOpResult getRawElement(intptr_t index) { 1749 PyValue value(operation, mlirOperationGetResult(operation->get(), index)); 1750 return PyOpResult(value); 1751 } 1752 1753 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 1754 return PyOpResultList(operation, startIndex, length, step); 1755 } 1756 1757 PyOperationRef operation; 1758 }; 1759 1760 //------------------------------------------------------------------------------ 1761 // PyOpView 1762 //------------------------------------------------------------------------------ 1763 1764 static void populateResultTypes(StringRef name, nb::list resultTypeList, 1765 const nb::object &resultSegmentSpecObj, 1766 std::vector<int32_t> &resultSegmentLengths, 1767 std::vector<PyType *> &resultTypes) { 1768 resultTypes.reserve(resultTypeList.size()); 1769 if (resultSegmentSpecObj.is_none()) { 1770 // Non-variadic result unpacking. 1771 for (const auto &it : llvm::enumerate(resultTypeList)) { 1772 try { 1773 resultTypes.push_back(nb::cast<PyType *>(it.value())); 1774 if (!resultTypes.back()) 1775 throw nb::cast_error(); 1776 } catch (nb::cast_error &err) { 1777 throw nb::value_error((llvm::Twine("Result ") + 1778 llvm::Twine(it.index()) + " of operation \"" + 1779 name + "\" must be a Type (" + err.what() + ")") 1780 .str() 1781 .c_str()); 1782 } 1783 } 1784 } else { 1785 // Sized result unpacking. 1786 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj); 1787 if (resultSegmentSpec.size() != resultTypeList.size()) { 1788 throw nb::value_error((llvm::Twine("Operation \"") + name + 1789 "\" requires " + 1790 llvm::Twine(resultSegmentSpec.size()) + 1791 " result segments but was provided " + 1792 llvm::Twine(resultTypeList.size())) 1793 .str() 1794 .c_str()); 1795 } 1796 resultSegmentLengths.reserve(resultTypeList.size()); 1797 for (const auto &it : 1798 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { 1799 int segmentSpec = std::get<1>(it.value()); 1800 if (segmentSpec == 1 || segmentSpec == 0) { 1801 // Unpack unary element. 1802 try { 1803 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value())); 1804 if (resultType) { 1805 resultTypes.push_back(resultType); 1806 resultSegmentLengths.push_back(1); 1807 } else if (segmentSpec == 0) { 1808 // Allowed to be optional. 1809 resultSegmentLengths.push_back(0); 1810 } else { 1811 throw nb::value_error( 1812 (llvm::Twine("Result ") + llvm::Twine(it.index()) + 1813 " of operation \"" + name + 1814 "\" must be a Type (was None and result is not optional)") 1815 .str() 1816 .c_str()); 1817 } 1818 } catch (nb::cast_error &err) { 1819 throw nb::value_error((llvm::Twine("Result ") + 1820 llvm::Twine(it.index()) + " of operation \"" + 1821 name + "\" must be a Type (" + err.what() + 1822 ")") 1823 .str() 1824 .c_str()); 1825 } 1826 } else if (segmentSpec == -1) { 1827 // Unpack sequence by appending. 1828 try { 1829 if (std::get<0>(it.value()).is_none()) { 1830 // Treat it as an empty list. 1831 resultSegmentLengths.push_back(0); 1832 } else { 1833 // Unpack the list. 1834 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value())); 1835 for (nb::handle segmentItem : segment) { 1836 resultTypes.push_back(nb::cast<PyType *>(segmentItem)); 1837 if (!resultTypes.back()) { 1838 throw nb::type_error("contained a None item"); 1839 } 1840 } 1841 resultSegmentLengths.push_back(nb::len(segment)); 1842 } 1843 } catch (std::exception &err) { 1844 // NOTE: Sloppy to be using a catch-all here, but there are at least 1845 // three different unrelated exceptions that can be thrown in the 1846 // above "casts". Just keep the scope above small and catch them all. 1847 throw nb::value_error((llvm::Twine("Result ") + 1848 llvm::Twine(it.index()) + " of operation \"" + 1849 name + "\" must be a Sequence of Types (" + 1850 err.what() + ")") 1851 .str() 1852 .c_str()); 1853 } 1854 } else { 1855 throw nb::value_error("Unexpected segment spec"); 1856 } 1857 } 1858 } 1859 } 1860 1861 static MlirValue getUniqueResult(MlirOperation operation) { 1862 auto numResults = mlirOperationGetNumResults(operation); 1863 if (numResults != 1) { 1864 auto name = mlirIdentifierStr(mlirOperationGetName(operation)); 1865 throw nb::value_error((Twine("Cannot call .result on operation ") + 1866 StringRef(name.data, name.length) + " which has " + 1867 Twine(numResults) + 1868 " results (it is only valid for operations with a " 1869 "single result)") 1870 .str() 1871 .c_str()); 1872 } 1873 return mlirOperationGetResult(operation, 0); 1874 } 1875 1876 static MlirValue getOpResultOrValue(nb::handle operand) { 1877 if (operand.is_none()) { 1878 throw nb::value_error("contained a None item"); 1879 } 1880 PyOperationBase *op; 1881 if (nb::try_cast<PyOperationBase *>(operand, op)) { 1882 return getUniqueResult(op->getOperation()); 1883 } 1884 PyOpResultList *opResultList; 1885 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) { 1886 return getUniqueResult(opResultList->getOperation()->get()); 1887 } 1888 PyValue *value; 1889 if (nb::try_cast<PyValue *>(operand, value)) { 1890 return value->get(); 1891 } 1892 throw nb::value_error("is not a Value"); 1893 } 1894 1895 nb::object PyOpView::buildGeneric( 1896 std::string_view name, std::tuple<int, bool> opRegionSpec, 1897 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj, 1898 std::optional<nb::list> resultTypeList, nb::list operandList, 1899 std::optional<nb::dict> attributes, 1900 std::optional<std::vector<PyBlock *>> successors, 1901 std::optional<int> regions, DefaultingPyLocation location, 1902 const nb::object &maybeIp) { 1903 PyMlirContextRef context = location->getContext(); 1904 1905 // Class level operation construction metadata. 1906 // Operand and result segment specs are either none, which does no 1907 // variadic unpacking, or a list of ints with segment sizes, where each 1908 // element is either a positive number (typically 1 for a scalar) or -1 to 1909 // indicate that it is derived from the length of the same-indexed operand 1910 // or result (implying that it is a list at that position). 1911 std::vector<int32_t> operandSegmentLengths; 1912 std::vector<int32_t> resultSegmentLengths; 1913 1914 // Validate/determine region count. 1915 int opMinRegionCount = std::get<0>(opRegionSpec); 1916 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); 1917 if (!regions) { 1918 regions = opMinRegionCount; 1919 } 1920 if (*regions < opMinRegionCount) { 1921 throw nb::value_error( 1922 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + 1923 llvm::Twine(opMinRegionCount) + 1924 " regions but was built with regions=" + llvm::Twine(*regions)) 1925 .str() 1926 .c_str()); 1927 } 1928 if (opHasNoVariadicRegions && *regions > opMinRegionCount) { 1929 throw nb::value_error( 1930 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + 1931 llvm::Twine(opMinRegionCount) + 1932 " regions but was built with regions=" + llvm::Twine(*regions)) 1933 .str() 1934 .c_str()); 1935 } 1936 1937 // Unpack results. 1938 std::vector<PyType *> resultTypes; 1939 if (resultTypeList.has_value()) { 1940 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj, 1941 resultSegmentLengths, resultTypes); 1942 } 1943 1944 // Unpack operands. 1945 llvm::SmallVector<MlirValue, 4> operands; 1946 operands.reserve(operands.size()); 1947 if (operandSegmentSpecObj.is_none()) { 1948 // Non-sized operand unpacking. 1949 for (const auto &it : llvm::enumerate(operandList)) { 1950 try { 1951 operands.push_back(getOpResultOrValue(it.value())); 1952 } catch (nb::builtin_exception &err) { 1953 throw nb::value_error((llvm::Twine("Operand ") + 1954 llvm::Twine(it.index()) + " of operation \"" + 1955 name + "\" must be a Value (" + err.what() + ")") 1956 .str() 1957 .c_str()); 1958 } 1959 } 1960 } else { 1961 // Sized operand unpacking. 1962 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj); 1963 if (operandSegmentSpec.size() != operandList.size()) { 1964 throw nb::value_error((llvm::Twine("Operation \"") + name + 1965 "\" requires " + 1966 llvm::Twine(operandSegmentSpec.size()) + 1967 "operand segments but was provided " + 1968 llvm::Twine(operandList.size())) 1969 .str() 1970 .c_str()); 1971 } 1972 operandSegmentLengths.reserve(operandList.size()); 1973 for (const auto &it : 1974 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { 1975 int segmentSpec = std::get<1>(it.value()); 1976 if (segmentSpec == 1 || segmentSpec == 0) { 1977 // Unpack unary element. 1978 auto &operand = std::get<0>(it.value()); 1979 if (!operand.is_none()) { 1980 try { 1981 1982 operands.push_back(getOpResultOrValue(operand)); 1983 } catch (nb::builtin_exception &err) { 1984 throw nb::value_error((llvm::Twine("Operand ") + 1985 llvm::Twine(it.index()) + 1986 " of operation \"" + name + 1987 "\" must be a Value (" + err.what() + ")") 1988 .str() 1989 .c_str()); 1990 } 1991 1992 operandSegmentLengths.push_back(1); 1993 } else if (segmentSpec == 0) { 1994 // Allowed to be optional. 1995 operandSegmentLengths.push_back(0); 1996 } else { 1997 throw nb::value_error( 1998 (llvm::Twine("Operand ") + llvm::Twine(it.index()) + 1999 " of operation \"" + name + 2000 "\" must be a Value (was None and operand is not optional)") 2001 .str() 2002 .c_str()); 2003 } 2004 } else if (segmentSpec == -1) { 2005 // Unpack sequence by appending. 2006 try { 2007 if (std::get<0>(it.value()).is_none()) { 2008 // Treat it as an empty list. 2009 operandSegmentLengths.push_back(0); 2010 } else { 2011 // Unpack the list. 2012 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value())); 2013 for (nb::handle segmentItem : segment) { 2014 operands.push_back(getOpResultOrValue(segmentItem)); 2015 } 2016 operandSegmentLengths.push_back(nb::len(segment)); 2017 } 2018 } catch (std::exception &err) { 2019 // NOTE: Sloppy to be using a catch-all here, but there are at least 2020 // three different unrelated exceptions that can be thrown in the 2021 // above "casts". Just keep the scope above small and catch them all. 2022 throw nb::value_error((llvm::Twine("Operand ") + 2023 llvm::Twine(it.index()) + " of operation \"" + 2024 name + "\" must be a Sequence of Values (" + 2025 err.what() + ")") 2026 .str() 2027 .c_str()); 2028 } 2029 } else { 2030 throw nb::value_error("Unexpected segment spec"); 2031 } 2032 } 2033 } 2034 2035 // Merge operand/result segment lengths into attributes if needed. 2036 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { 2037 // Dup. 2038 if (attributes) { 2039 attributes = nb::dict(*attributes); 2040 } else { 2041 attributes = nb::dict(); 2042 } 2043 if (attributes->contains("resultSegmentSizes") || 2044 attributes->contains("operandSegmentSizes")) { 2045 throw nb::value_error("Manually setting a 'resultSegmentSizes' or " 2046 "'operandSegmentSizes' attribute is unsupported. " 2047 "Use Operation.create for such low-level access."); 2048 } 2049 2050 // Add resultSegmentSizes attribute. 2051 if (!resultSegmentLengths.empty()) { 2052 MlirAttribute segmentLengthAttr = 2053 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(), 2054 resultSegmentLengths.data()); 2055 (*attributes)["resultSegmentSizes"] = 2056 PyAttribute(context, segmentLengthAttr); 2057 } 2058 2059 // Add operandSegmentSizes attribute. 2060 if (!operandSegmentLengths.empty()) { 2061 MlirAttribute segmentLengthAttr = 2062 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(), 2063 operandSegmentLengths.data()); 2064 (*attributes)["operandSegmentSizes"] = 2065 PyAttribute(context, segmentLengthAttr); 2066 } 2067 } 2068 2069 // Delegate to create. 2070 return PyOperation::create(name, 2071 /*results=*/std::move(resultTypes), 2072 /*operands=*/std::move(operands), 2073 /*attributes=*/std::move(attributes), 2074 /*successors=*/std::move(successors), 2075 /*regions=*/*regions, location, maybeIp, 2076 !resultTypeList); 2077 } 2078 2079 nb::object PyOpView::constructDerived(const nb::object &cls, 2080 const nb::object &operation) { 2081 nb::handle opViewType = nb::type<PyOpView>(); 2082 nb::object instance = cls.attr("__new__")(cls); 2083 opViewType.attr("__init__")(instance, operation); 2084 return instance; 2085 } 2086 2087 PyOpView::PyOpView(const nb::object &operationObject) 2088 // Casting through the PyOperationBase base-class and then back to the 2089 // Operation lets us accept any PyOperationBase subclass. 2090 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()), 2091 operationObject(operation.getRef().getObject()) {} 2092 2093 //------------------------------------------------------------------------------ 2094 // PyInsertionPoint. 2095 //------------------------------------------------------------------------------ 2096 2097 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} 2098 2099 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) 2100 : refOperation(beforeOperationBase.getOperation().getRef()), 2101 block((*refOperation)->getBlock()) {} 2102 2103 void PyInsertionPoint::insert(PyOperationBase &operationBase) { 2104 PyOperation &operation = operationBase.getOperation(); 2105 if (operation.isAttached()) 2106 throw nb::value_error( 2107 "Attempt to insert operation that is already attached"); 2108 block.getParentOperation()->checkValid(); 2109 MlirOperation beforeOp = {nullptr}; 2110 if (refOperation) { 2111 // Insert before operation. 2112 (*refOperation)->checkValid(); 2113 beforeOp = (*refOperation)->get(); 2114 } else { 2115 // Insert at end (before null) is only valid if the block does not 2116 // already end in a known terminator (violating this will cause assertion 2117 // failures later). 2118 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { 2119 throw nb::index_error("Cannot insert operation at the end of a block " 2120 "that already has a terminator. Did you mean to " 2121 "use 'InsertionPoint.at_block_terminator(block)' " 2122 "versus 'InsertionPoint(block)'?"); 2123 } 2124 } 2125 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); 2126 operation.setAttached(); 2127 } 2128 2129 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { 2130 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); 2131 if (mlirOperationIsNull(firstOp)) { 2132 // Just insert at end. 2133 return PyInsertionPoint(block); 2134 } 2135 2136 // Insert before first op. 2137 PyOperationRef firstOpRef = PyOperation::forOperation( 2138 block.getParentOperation()->getContext(), firstOp); 2139 return PyInsertionPoint{block, std::move(firstOpRef)}; 2140 } 2141 2142 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { 2143 MlirOperation terminator = mlirBlockGetTerminator(block.get()); 2144 if (mlirOperationIsNull(terminator)) 2145 throw nb::value_error("Block has no terminator"); 2146 PyOperationRef terminatorOpRef = PyOperation::forOperation( 2147 block.getParentOperation()->getContext(), terminator); 2148 return PyInsertionPoint{block, std::move(terminatorOpRef)}; 2149 } 2150 2151 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { 2152 return PyThreadContextEntry::pushInsertionPoint(insertPoint); 2153 } 2154 2155 void PyInsertionPoint::contextExit(const nb::object &excType, 2156 const nb::object &excVal, 2157 const nb::object &excTb) { 2158 PyThreadContextEntry::popInsertionPoint(*this); 2159 } 2160 2161 //------------------------------------------------------------------------------ 2162 // PyAttribute. 2163 //------------------------------------------------------------------------------ 2164 2165 bool PyAttribute::operator==(const PyAttribute &other) const { 2166 return mlirAttributeEqual(attr, other.attr); 2167 } 2168 2169 nb::object PyAttribute::getCapsule() { 2170 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this)); 2171 } 2172 2173 PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { 2174 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); 2175 if (mlirAttributeIsNull(rawAttr)) 2176 throw nb::python_error(); 2177 return PyAttribute( 2178 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); 2179 } 2180 2181 //------------------------------------------------------------------------------ 2182 // PyNamedAttribute. 2183 //------------------------------------------------------------------------------ 2184 2185 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) 2186 : ownedName(new std::string(std::move(ownedName))) { 2187 namedAttr = mlirNamedAttributeGet( 2188 mlirIdentifierGet(mlirAttributeGetContext(attr), 2189 toMlirStringRef(*this->ownedName)), 2190 attr); 2191 } 2192 2193 //------------------------------------------------------------------------------ 2194 // PyType. 2195 //------------------------------------------------------------------------------ 2196 2197 bool PyType::operator==(const PyType &other) const { 2198 return mlirTypeEqual(type, other.type); 2199 } 2200 2201 nb::object PyType::getCapsule() { 2202 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this)); 2203 } 2204 2205 PyType PyType::createFromCapsule(nb::object capsule) { 2206 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); 2207 if (mlirTypeIsNull(rawType)) 2208 throw nb::python_error(); 2209 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), 2210 rawType); 2211 } 2212 2213 //------------------------------------------------------------------------------ 2214 // PyTypeID. 2215 //------------------------------------------------------------------------------ 2216 2217 nb::object PyTypeID::getCapsule() { 2218 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this)); 2219 } 2220 2221 PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { 2222 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); 2223 if (mlirTypeIDIsNull(mlirTypeID)) 2224 throw nb::python_error(); 2225 return PyTypeID(mlirTypeID); 2226 } 2227 bool PyTypeID::operator==(const PyTypeID &other) const { 2228 return mlirTypeIDEqual(typeID, other.typeID); 2229 } 2230 2231 //------------------------------------------------------------------------------ 2232 // PyValue and subclasses. 2233 //------------------------------------------------------------------------------ 2234 2235 nb::object PyValue::getCapsule() { 2236 return nb::steal<nb::object>(mlirPythonValueToCapsule(get())); 2237 } 2238 2239 nb::object PyValue::maybeDownCast() { 2240 MlirType type = mlirValueGetType(get()); 2241 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); 2242 assert(!mlirTypeIDIsNull(mlirTypeID) && 2243 "mlirTypeID was expected to be non-null."); 2244 std::optional<nb::callable> valueCaster = 2245 PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); 2246 // nb::rv_policy::move means use std::move to move the return value 2247 // contents into a new instance that will be owned by Python. 2248 nb::object thisObj = nb::cast(this, nb::rv_policy::move); 2249 if (!valueCaster) 2250 return thisObj; 2251 return valueCaster.value()(thisObj); 2252 } 2253 2254 PyValue PyValue::createFromCapsule(nb::object capsule) { 2255 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); 2256 if (mlirValueIsNull(value)) 2257 throw nb::python_error(); 2258 MlirOperation owner; 2259 if (mlirValueIsAOpResult(value)) 2260 owner = mlirOpResultGetOwner(value); 2261 if (mlirValueIsABlockArgument(value)) 2262 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); 2263 if (mlirOperationIsNull(owner)) 2264 throw nb::python_error(); 2265 MlirContext ctx = mlirOperationGetContext(owner); 2266 PyOperationRef ownerRef = 2267 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); 2268 return PyValue(ownerRef, value); 2269 } 2270 2271 //------------------------------------------------------------------------------ 2272 // PySymbolTable. 2273 //------------------------------------------------------------------------------ 2274 2275 PySymbolTable::PySymbolTable(PyOperationBase &operation) 2276 : operation(operation.getOperation().getRef()) { 2277 symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); 2278 if (mlirSymbolTableIsNull(symbolTable)) { 2279 throw nb::type_error("Operation is not a Symbol Table."); 2280 } 2281 } 2282 2283 nb::object PySymbolTable::dunderGetItem(const std::string &name) { 2284 operation->checkValid(); 2285 MlirOperation symbol = mlirSymbolTableLookup( 2286 symbolTable, mlirStringRefCreate(name.data(), name.length())); 2287 if (mlirOperationIsNull(symbol)) 2288 throw nb::key_error( 2289 ("Symbol '" + name + "' not in the symbol table.").c_str()); 2290 2291 return PyOperation::forOperation(operation->getContext(), symbol, 2292 operation.getObject()) 2293 ->createOpView(); 2294 } 2295 2296 void PySymbolTable::erase(PyOperationBase &symbol) { 2297 operation->checkValid(); 2298 symbol.getOperation().checkValid(); 2299 mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); 2300 // The operation is also erased, so we must invalidate it. There may be Python 2301 // references to this operation so we don't want to delete it from the list of 2302 // live operations here. 2303 symbol.getOperation().valid = false; 2304 } 2305 2306 void PySymbolTable::dunderDel(const std::string &name) { 2307 nb::object operation = dunderGetItem(name); 2308 erase(nb::cast<PyOperationBase &>(operation)); 2309 } 2310 2311 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { 2312 operation->checkValid(); 2313 symbol.getOperation().checkValid(); 2314 MlirAttribute symbolAttr = mlirOperationGetAttributeByName( 2315 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); 2316 if (mlirAttributeIsNull(symbolAttr)) 2317 throw nb::value_error("Expected operation to have a symbol name."); 2318 return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); 2319 } 2320 2321 MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { 2322 // Op must already be a symbol. 2323 PyOperation &operation = symbol.getOperation(); 2324 operation.checkValid(); 2325 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 2326 MlirAttribute existingNameAttr = 2327 mlirOperationGetAttributeByName(operation.get(), attrName); 2328 if (mlirAttributeIsNull(existingNameAttr)) 2329 throw nb::value_error("Expected operation to have a symbol name."); 2330 return existingNameAttr; 2331 } 2332 2333 void PySymbolTable::setSymbolName(PyOperationBase &symbol, 2334 const std::string &name) { 2335 // Op must already be a symbol. 2336 PyOperation &operation = symbol.getOperation(); 2337 operation.checkValid(); 2338 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); 2339 MlirAttribute existingNameAttr = 2340 mlirOperationGetAttributeByName(operation.get(), attrName); 2341 if (mlirAttributeIsNull(existingNameAttr)) 2342 throw nb::value_error("Expected operation to have a symbol name."); 2343 MlirAttribute newNameAttr = 2344 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); 2345 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); 2346 } 2347 2348 MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { 2349 PyOperation &operation = symbol.getOperation(); 2350 operation.checkValid(); 2351 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 2352 MlirAttribute existingVisAttr = 2353 mlirOperationGetAttributeByName(operation.get(), attrName); 2354 if (mlirAttributeIsNull(existingVisAttr)) 2355 throw nb::value_error("Expected operation to have a symbol visibility."); 2356 return existingVisAttr; 2357 } 2358 2359 void PySymbolTable::setVisibility(PyOperationBase &symbol, 2360 const std::string &visibility) { 2361 if (visibility != "public" && visibility != "private" && 2362 visibility != "nested") 2363 throw nb::value_error( 2364 "Expected visibility to be 'public', 'private' or 'nested'"); 2365 PyOperation &operation = symbol.getOperation(); 2366 operation.checkValid(); 2367 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); 2368 MlirAttribute existingVisAttr = 2369 mlirOperationGetAttributeByName(operation.get(), attrName); 2370 if (mlirAttributeIsNull(existingVisAttr)) 2371 throw nb::value_error("Expected operation to have a symbol visibility."); 2372 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), 2373 toMlirStringRef(visibility)); 2374 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); 2375 } 2376 2377 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, 2378 const std::string &newSymbol, 2379 PyOperationBase &from) { 2380 PyOperation &fromOperation = from.getOperation(); 2381 fromOperation.checkValid(); 2382 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( 2383 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), 2384 from.getOperation()))) 2385 2386 throw nb::value_error("Symbol rename failed"); 2387 } 2388 2389 void PySymbolTable::walkSymbolTables(PyOperationBase &from, 2390 bool allSymUsesVisible, 2391 nb::object callback) { 2392 PyOperation &fromOperation = from.getOperation(); 2393 fromOperation.checkValid(); 2394 struct UserData { 2395 PyMlirContextRef context; 2396 nb::object callback; 2397 bool gotException; 2398 std::string exceptionWhat; 2399 nb::object exceptionType; 2400 }; 2401 UserData userData{ 2402 fromOperation.getContext(), std::move(callback), false, {}, {}}; 2403 mlirSymbolTableWalkSymbolTables( 2404 fromOperation.get(), allSymUsesVisible, 2405 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { 2406 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid); 2407 auto pyFoundOp = 2408 PyOperation::forOperation(calleeUserData->context, foundOp); 2409 if (calleeUserData->gotException) 2410 return; 2411 try { 2412 calleeUserData->callback(pyFoundOp.getObject(), isVisible); 2413 } catch (nb::python_error &e) { 2414 calleeUserData->gotException = true; 2415 calleeUserData->exceptionWhat = e.what(); 2416 calleeUserData->exceptionType = nb::borrow(e.type()); 2417 } 2418 }, 2419 static_cast<void *>(&userData)); 2420 if (userData.gotException) { 2421 std::string message("Exception raised in callback: "); 2422 message.append(userData.exceptionWhat); 2423 throw std::runtime_error(message); 2424 } 2425 } 2426 2427 namespace { 2428 2429 /// Python wrapper for MlirBlockArgument. 2430 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { 2431 public: 2432 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; 2433 static constexpr const char *pyClassName = "BlockArgument"; 2434 using PyConcreteValue::PyConcreteValue; 2435 2436 static void bindDerived(ClassTy &c) { 2437 c.def_prop_ro("owner", [](PyBlockArgument &self) { 2438 return PyBlock(self.getParentOperation(), 2439 mlirBlockArgumentGetOwner(self.get())); 2440 }); 2441 c.def_prop_ro("arg_number", [](PyBlockArgument &self) { 2442 return mlirBlockArgumentGetArgNumber(self.get()); 2443 }); 2444 c.def( 2445 "set_type", 2446 [](PyBlockArgument &self, PyType type) { 2447 return mlirBlockArgumentSetType(self.get(), type); 2448 }, 2449 nb::arg("type")); 2450 } 2451 }; 2452 2453 /// A list of block arguments. Internally, these are stored as consecutive 2454 /// elements, random access is cheap. The argument list is associated with the 2455 /// operation that contains the block (detached blocks are not allowed in 2456 /// Python bindings) and extends its lifetime. 2457 class PyBlockArgumentList 2458 : public Sliceable<PyBlockArgumentList, PyBlockArgument> { 2459 public: 2460 static constexpr const char *pyClassName = "BlockArgumentList"; 2461 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>; 2462 2463 PyBlockArgumentList(PyOperationRef operation, MlirBlock block, 2464 intptr_t startIndex = 0, intptr_t length = -1, 2465 intptr_t step = 1) 2466 : Sliceable(startIndex, 2467 length == -1 ? mlirBlockGetNumArguments(block) : length, 2468 step), 2469 operation(std::move(operation)), block(block) {} 2470 2471 static void bindDerived(ClassTy &c) { 2472 c.def_prop_ro("types", [](PyBlockArgumentList &self) { 2473 return getValueTypes(self, self.operation->getContext()); 2474 }); 2475 } 2476 2477 private: 2478 /// Give the parent CRTP class access to hook implementations below. 2479 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>; 2480 2481 /// Returns the number of arguments in the list. 2482 intptr_t getRawNumElements() { 2483 operation->checkValid(); 2484 return mlirBlockGetNumArguments(block); 2485 } 2486 2487 /// Returns `pos`-the element in the list. 2488 PyBlockArgument getRawElement(intptr_t pos) { 2489 MlirValue argument = mlirBlockGetArgument(block, pos); 2490 return PyBlockArgument(operation, argument); 2491 } 2492 2493 /// Returns a sublist of this list. 2494 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, 2495 intptr_t step) { 2496 return PyBlockArgumentList(operation, block, startIndex, length, step); 2497 } 2498 2499 PyOperationRef operation; 2500 MlirBlock block; 2501 }; 2502 2503 /// A list of operation operands. Internally, these are stored as consecutive 2504 /// elements, random access is cheap. The (returned) operand list is associated 2505 /// with the operation whose operands these are, and thus extends the lifetime 2506 /// of this operation. 2507 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { 2508 public: 2509 static constexpr const char *pyClassName = "OpOperandList"; 2510 using SliceableT = Sliceable<PyOpOperandList, PyValue>; 2511 2512 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, 2513 intptr_t length = -1, intptr_t step = 1) 2514 : Sliceable(startIndex, 2515 length == -1 ? mlirOperationGetNumOperands(operation->get()) 2516 : length, 2517 step), 2518 operation(operation) {} 2519 2520 void dunderSetItem(intptr_t index, PyValue value) { 2521 index = wrapIndex(index); 2522 mlirOperationSetOperand(operation->get(), index, value.get()); 2523 } 2524 2525 static void bindDerived(ClassTy &c) { 2526 c.def("__setitem__", &PyOpOperandList::dunderSetItem); 2527 } 2528 2529 private: 2530 /// Give the parent CRTP class access to hook implementations below. 2531 friend class Sliceable<PyOpOperandList, PyValue>; 2532 2533 intptr_t getRawNumElements() { 2534 operation->checkValid(); 2535 return mlirOperationGetNumOperands(operation->get()); 2536 } 2537 2538 PyValue getRawElement(intptr_t pos) { 2539 MlirValue operand = mlirOperationGetOperand(operation->get(), pos); 2540 MlirOperation owner; 2541 if (mlirValueIsAOpResult(operand)) 2542 owner = mlirOpResultGetOwner(operand); 2543 else if (mlirValueIsABlockArgument(operand)) 2544 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); 2545 else 2546 assert(false && "Value must be an block arg or op result."); 2547 PyOperationRef pyOwner = 2548 PyOperation::forOperation(operation->getContext(), owner); 2549 return PyValue(pyOwner, operand); 2550 } 2551 2552 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2553 return PyOpOperandList(operation, startIndex, length, step); 2554 } 2555 2556 PyOperationRef operation; 2557 }; 2558 2559 /// A list of operation successors. Internally, these are stored as consecutive 2560 /// elements, random access is cheap. The (returned) successor list is 2561 /// associated with the operation whose successors these are, and thus extends 2562 /// the lifetime of this operation. 2563 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> { 2564 public: 2565 static constexpr const char *pyClassName = "OpSuccessors"; 2566 2567 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, 2568 intptr_t length = -1, intptr_t step = 1) 2569 : Sliceable(startIndex, 2570 length == -1 ? mlirOperationGetNumSuccessors(operation->get()) 2571 : length, 2572 step), 2573 operation(operation) {} 2574 2575 void dunderSetItem(intptr_t index, PyBlock block) { 2576 index = wrapIndex(index); 2577 mlirOperationSetSuccessor(operation->get(), index, block.get()); 2578 } 2579 2580 static void bindDerived(ClassTy &c) { 2581 c.def("__setitem__", &PyOpSuccessors::dunderSetItem); 2582 } 2583 2584 private: 2585 /// Give the parent CRTP class access to hook implementations below. 2586 friend class Sliceable<PyOpSuccessors, PyBlock>; 2587 2588 intptr_t getRawNumElements() { 2589 operation->checkValid(); 2590 return mlirOperationGetNumSuccessors(operation->get()); 2591 } 2592 2593 PyBlock getRawElement(intptr_t pos) { 2594 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); 2595 return PyBlock(operation, block); 2596 } 2597 2598 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { 2599 return PyOpSuccessors(operation, startIndex, length, step); 2600 } 2601 2602 PyOperationRef operation; 2603 }; 2604 2605 /// A list of operation attributes. Can be indexed by name, producing 2606 /// attributes, or by index, producing named attributes. 2607 class PyOpAttributeMap { 2608 public: 2609 PyOpAttributeMap(PyOperationRef operation) 2610 : operation(std::move(operation)) {} 2611 2612 MlirAttribute dunderGetItemNamed(const std::string &name) { 2613 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), 2614 toMlirStringRef(name)); 2615 if (mlirAttributeIsNull(attr)) { 2616 throw nb::key_error("attempt to access a non-existent attribute"); 2617 } 2618 return attr; 2619 } 2620 2621 PyNamedAttribute dunderGetItemIndexed(intptr_t index) { 2622 if (index < 0 || index >= dunderLen()) { 2623 throw nb::index_error("attempt to access out of bounds attribute"); 2624 } 2625 MlirNamedAttribute namedAttr = 2626 mlirOperationGetAttribute(operation->get(), index); 2627 return PyNamedAttribute( 2628 namedAttr.attribute, 2629 std::string(mlirIdentifierStr(namedAttr.name).data, 2630 mlirIdentifierStr(namedAttr.name).length)); 2631 } 2632 2633 void dunderSetItem(const std::string &name, const PyAttribute &attr) { 2634 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), 2635 attr); 2636 } 2637 2638 void dunderDelItem(const std::string &name) { 2639 int removed = mlirOperationRemoveAttributeByName(operation->get(), 2640 toMlirStringRef(name)); 2641 if (!removed) 2642 throw nb::key_error("attempt to delete a non-existent attribute"); 2643 } 2644 2645 intptr_t dunderLen() { 2646 return mlirOperationGetNumAttributes(operation->get()); 2647 } 2648 2649 bool dunderContains(const std::string &name) { 2650 return !mlirAttributeIsNull(mlirOperationGetAttributeByName( 2651 operation->get(), toMlirStringRef(name))); 2652 } 2653 2654 static void bind(nb::module_ &m) { 2655 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap") 2656 .def("__contains__", &PyOpAttributeMap::dunderContains) 2657 .def("__len__", &PyOpAttributeMap::dunderLen) 2658 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) 2659 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) 2660 .def("__setitem__", &PyOpAttributeMap::dunderSetItem) 2661 .def("__delitem__", &PyOpAttributeMap::dunderDelItem); 2662 } 2663 2664 private: 2665 PyOperationRef operation; 2666 }; 2667 2668 } // namespace 2669 2670 //------------------------------------------------------------------------------ 2671 // Populates the core exports of the 'ir' submodule. 2672 //------------------------------------------------------------------------------ 2673 2674 void mlir::python::populateIRCore(nb::module_ &m) { 2675 // disable leak warnings which tend to be false positives. 2676 nb::set_leak_warnings(false); 2677 //---------------------------------------------------------------------------- 2678 // Enums. 2679 //---------------------------------------------------------------------------- 2680 nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity") 2681 .value("ERROR", MlirDiagnosticError) 2682 .value("WARNING", MlirDiagnosticWarning) 2683 .value("NOTE", MlirDiagnosticNote) 2684 .value("REMARK", MlirDiagnosticRemark); 2685 2686 nb::enum_<MlirWalkOrder>(m, "WalkOrder") 2687 .value("PRE_ORDER", MlirWalkPreOrder) 2688 .value("POST_ORDER", MlirWalkPostOrder); 2689 2690 nb::enum_<MlirWalkResult>(m, "WalkResult") 2691 .value("ADVANCE", MlirWalkResultAdvance) 2692 .value("INTERRUPT", MlirWalkResultInterrupt) 2693 .value("SKIP", MlirWalkResultSkip); 2694 2695 //---------------------------------------------------------------------------- 2696 // Mapping of Diagnostics. 2697 //---------------------------------------------------------------------------- 2698 nb::class_<PyDiagnostic>(m, "Diagnostic") 2699 .def_prop_ro("severity", &PyDiagnostic::getSeverity) 2700 .def_prop_ro("location", &PyDiagnostic::getLocation) 2701 .def_prop_ro("message", &PyDiagnostic::getMessage) 2702 .def_prop_ro("notes", &PyDiagnostic::getNotes) 2703 .def("__str__", [](PyDiagnostic &self) -> nb::str { 2704 if (!self.isValid()) 2705 return nb::str("<Invalid Diagnostic>"); 2706 return self.getMessage(); 2707 }); 2708 2709 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo") 2710 .def("__init__", 2711 [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { 2712 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); 2713 }) 2714 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) 2715 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) 2716 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) 2717 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) 2718 .def("__str__", 2719 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); 2720 2721 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler") 2722 .def("detach", &PyDiagnosticHandler::detach) 2723 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) 2724 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) 2725 .def("__enter__", &PyDiagnosticHandler::contextEnter) 2726 .def("__exit__", &PyDiagnosticHandler::contextExit, 2727 nb::arg("exc_type").none(), nb::arg("exc_value").none(), 2728 nb::arg("traceback").none()); 2729 2730 //---------------------------------------------------------------------------- 2731 // Mapping of MlirContext. 2732 // Note that this is exported as _BaseContext. The containing, Python level 2733 // __init__.py will subclass it with site-specific functionality and set a 2734 // "Context" attribute on this module. 2735 //---------------------------------------------------------------------------- 2736 nb::class_<PyMlirContext>(m, "_BaseContext") 2737 .def("__init__", 2738 [](PyMlirContext &self) { 2739 MlirContext context = mlirContextCreateWithThreading(false); 2740 new (&self) PyMlirContext(context); 2741 }) 2742 .def_static("_get_live_count", &PyMlirContext::getLiveCount) 2743 .def("_get_context_again", 2744 [](PyMlirContext &self) { 2745 PyMlirContextRef ref = PyMlirContext::forContext(self.get()); 2746 return ref.releaseObject(); 2747 }) 2748 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) 2749 .def("_get_live_operation_objects", 2750 &PyMlirContext::getLiveOperationObjects) 2751 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) 2752 .def("_clear_live_operations_inside", 2753 nb::overload_cast<MlirOperation>( 2754 &PyMlirContext::clearOperationsInside)) 2755 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) 2756 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) 2757 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) 2758 .def("__enter__", &PyMlirContext::contextEnter) 2759 .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), 2760 nb::arg("exc_value").none(), nb::arg("traceback").none()) 2761 .def_prop_ro_static( 2762 "current", 2763 [](nb::object & /*class*/) { 2764 auto *context = PyThreadContextEntry::getDefaultContext(); 2765 if (!context) 2766 return nb::none(); 2767 return nb::cast(context); 2768 }, 2769 "Gets the Context bound to the current thread or raises ValueError") 2770 .def_prop_ro( 2771 "dialects", 2772 [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2773 "Gets a container for accessing dialects by name") 2774 .def_prop_ro( 2775 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, 2776 "Alias for 'dialect'") 2777 .def( 2778 "get_dialect_descriptor", 2779 [=](PyMlirContext &self, std::string &name) { 2780 MlirDialect dialect = mlirContextGetOrLoadDialect( 2781 self.get(), {name.data(), name.size()}); 2782 if (mlirDialectIsNull(dialect)) { 2783 throw nb::value_error( 2784 (Twine("Dialect '") + name + "' not found").str().c_str()); 2785 } 2786 return PyDialectDescriptor(self.getRef(), dialect); 2787 }, 2788 nb::arg("dialect_name"), 2789 "Gets or loads a dialect by name, returning its descriptor object") 2790 .def_prop_rw( 2791 "allow_unregistered_dialects", 2792 [](PyMlirContext &self) -> bool { 2793 return mlirContextGetAllowUnregisteredDialects(self.get()); 2794 }, 2795 [](PyMlirContext &self, bool value) { 2796 mlirContextSetAllowUnregisteredDialects(self.get(), value); 2797 }) 2798 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, 2799 nb::arg("callback"), 2800 "Attaches a diagnostic handler that will receive callbacks") 2801 .def( 2802 "enable_multithreading", 2803 [](PyMlirContext &self, bool enable) { 2804 mlirContextEnableMultithreading(self.get(), enable); 2805 }, 2806 nb::arg("enable")) 2807 .def( 2808 "is_registered_operation", 2809 [](PyMlirContext &self, std::string &name) { 2810 return mlirContextIsRegisteredOperation( 2811 self.get(), MlirStringRef{name.data(), name.size()}); 2812 }, 2813 nb::arg("operation_name")) 2814 .def( 2815 "append_dialect_registry", 2816 [](PyMlirContext &self, PyDialectRegistry ®istry) { 2817 mlirContextAppendDialectRegistry(self.get(), registry); 2818 }, 2819 nb::arg("registry")) 2820 .def_prop_rw("emit_error_diagnostics", nullptr, 2821 &PyMlirContext::setEmitErrorDiagnostics, 2822 "Emit error diagnostics to diagnostic handlers. By default " 2823 "error diagnostics are captured and reported through " 2824 "MLIRError exceptions.") 2825 .def("load_all_available_dialects", [](PyMlirContext &self) { 2826 mlirContextLoadAllAvailableDialects(self.get()); 2827 }); 2828 2829 //---------------------------------------------------------------------------- 2830 // Mapping of PyDialectDescriptor 2831 //---------------------------------------------------------------------------- 2832 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor") 2833 .def_prop_ro("namespace", 2834 [](PyDialectDescriptor &self) { 2835 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2836 return nb::str(ns.data, ns.length); 2837 }) 2838 .def("__repr__", [](PyDialectDescriptor &self) { 2839 MlirStringRef ns = mlirDialectGetNamespace(self.get()); 2840 std::string repr("<DialectDescriptor "); 2841 repr.append(ns.data, ns.length); 2842 repr.append(">"); 2843 return repr; 2844 }); 2845 2846 //---------------------------------------------------------------------------- 2847 // Mapping of PyDialects 2848 //---------------------------------------------------------------------------- 2849 nb::class_<PyDialects>(m, "Dialects") 2850 .def("__getitem__", 2851 [=](PyDialects &self, std::string keyName) { 2852 MlirDialect dialect = 2853 self.getDialectForKey(keyName, /*attrError=*/false); 2854 nb::object descriptor = 2855 nb::cast(PyDialectDescriptor{self.getContext(), dialect}); 2856 return createCustomDialectWrapper(keyName, std::move(descriptor)); 2857 }) 2858 .def("__getattr__", [=](PyDialects &self, std::string attrName) { 2859 MlirDialect dialect = 2860 self.getDialectForKey(attrName, /*attrError=*/true); 2861 nb::object descriptor = 2862 nb::cast(PyDialectDescriptor{self.getContext(), dialect}); 2863 return createCustomDialectWrapper(attrName, std::move(descriptor)); 2864 }); 2865 2866 //---------------------------------------------------------------------------- 2867 // Mapping of PyDialect 2868 //---------------------------------------------------------------------------- 2869 nb::class_<PyDialect>(m, "Dialect") 2870 .def(nb::init<nb::object>(), nb::arg("descriptor")) 2871 .def_prop_ro("descriptor", 2872 [](PyDialect &self) { return self.getDescriptor(); }) 2873 .def("__repr__", [](nb::object self) { 2874 auto clazz = self.attr("__class__"); 2875 return nb::str("<Dialect ") + 2876 self.attr("descriptor").attr("namespace") + nb::str(" (class ") + 2877 clazz.attr("__module__") + nb::str(".") + 2878 clazz.attr("__name__") + nb::str(")>"); 2879 }); 2880 2881 //---------------------------------------------------------------------------- 2882 // Mapping of PyDialectRegistry 2883 //---------------------------------------------------------------------------- 2884 nb::class_<PyDialectRegistry>(m, "DialectRegistry") 2885 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) 2886 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) 2887 .def(nb::init<>()); 2888 2889 //---------------------------------------------------------------------------- 2890 // Mapping of Location 2891 //---------------------------------------------------------------------------- 2892 nb::class_<PyLocation>(m, "Location") 2893 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) 2894 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) 2895 .def("__enter__", &PyLocation::contextEnter) 2896 .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), 2897 nb::arg("exc_value").none(), nb::arg("traceback").none()) 2898 .def("__eq__", 2899 [](PyLocation &self, PyLocation &other) -> bool { 2900 return mlirLocationEqual(self, other); 2901 }) 2902 .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) 2903 .def_prop_ro_static( 2904 "current", 2905 [](nb::object & /*class*/) { 2906 auto *loc = PyThreadContextEntry::getDefaultLocation(); 2907 if (!loc) 2908 throw nb::value_error("No current Location"); 2909 return loc; 2910 }, 2911 "Gets the Location bound to the current thread or raises ValueError") 2912 .def_static( 2913 "unknown", 2914 [](DefaultingPyMlirContext context) { 2915 return PyLocation(context->getRef(), 2916 mlirLocationUnknownGet(context->get())); 2917 }, 2918 nb::arg("context").none() = nb::none(), 2919 "Gets a Location representing an unknown location") 2920 .def_static( 2921 "callsite", 2922 [](PyLocation callee, const std::vector<PyLocation> &frames, 2923 DefaultingPyMlirContext context) { 2924 if (frames.empty()) 2925 throw nb::value_error("No caller frames provided"); 2926 MlirLocation caller = frames.back().get(); 2927 for (const PyLocation &frame : 2928 llvm::reverse(llvm::ArrayRef(frames).drop_back())) 2929 caller = mlirLocationCallSiteGet(frame.get(), caller); 2930 return PyLocation(context->getRef(), 2931 mlirLocationCallSiteGet(callee.get(), caller)); 2932 }, 2933 nb::arg("callee"), nb::arg("frames"), 2934 nb::arg("context").none() = nb::none(), 2935 kContextGetCallSiteLocationDocstring) 2936 .def_static( 2937 "file", 2938 [](std::string filename, int line, int col, 2939 DefaultingPyMlirContext context) { 2940 return PyLocation( 2941 context->getRef(), 2942 mlirLocationFileLineColGet( 2943 context->get(), toMlirStringRef(filename), line, col)); 2944 }, 2945 nb::arg("filename"), nb::arg("line"), nb::arg("col"), 2946 nb::arg("context").none() = nb::none(), 2947 kContextGetFileLocationDocstring) 2948 .def_static( 2949 "file", 2950 [](std::string filename, int startLine, int startCol, int endLine, 2951 int endCol, DefaultingPyMlirContext context) { 2952 return PyLocation(context->getRef(), 2953 mlirLocationFileLineColRangeGet( 2954 context->get(), toMlirStringRef(filename), 2955 startLine, startCol, endLine, endCol)); 2956 }, 2957 nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), 2958 nb::arg("end_line"), nb::arg("end_col"), 2959 nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring) 2960 .def_static( 2961 "fused", 2962 [](const std::vector<PyLocation> &pyLocations, 2963 std::optional<PyAttribute> metadata, 2964 DefaultingPyMlirContext context) { 2965 llvm::SmallVector<MlirLocation, 4> locations; 2966 locations.reserve(pyLocations.size()); 2967 for (auto &pyLocation : pyLocations) 2968 locations.push_back(pyLocation.get()); 2969 MlirLocation location = mlirLocationFusedGet( 2970 context->get(), locations.size(), locations.data(), 2971 metadata ? metadata->get() : MlirAttribute{0}); 2972 return PyLocation(context->getRef(), location); 2973 }, 2974 nb::arg("locations"), nb::arg("metadata").none() = nb::none(), 2975 nb::arg("context").none() = nb::none(), 2976 kContextGetFusedLocationDocstring) 2977 .def_static( 2978 "name", 2979 [](std::string name, std::optional<PyLocation> childLoc, 2980 DefaultingPyMlirContext context) { 2981 return PyLocation( 2982 context->getRef(), 2983 mlirLocationNameGet( 2984 context->get(), toMlirStringRef(name), 2985 childLoc ? childLoc->get() 2986 : mlirLocationUnknownGet(context->get()))); 2987 }, 2988 nb::arg("name"), nb::arg("childLoc").none() = nb::none(), 2989 nb::arg("context").none() = nb::none(), 2990 kContextGetNameLocationDocString) 2991 .def_static( 2992 "from_attr", 2993 [](PyAttribute &attribute, DefaultingPyMlirContext context) { 2994 return PyLocation(context->getRef(), 2995 mlirLocationFromAttribute(attribute)); 2996 }, 2997 nb::arg("attribute"), nb::arg("context").none() = nb::none(), 2998 "Gets a Location from a LocationAttr") 2999 .def_prop_ro( 3000 "context", 3001 [](PyLocation &self) { return self.getContext().getObject(); }, 3002 "Context that owns the Location") 3003 .def_prop_ro( 3004 "attr", 3005 [](PyLocation &self) { return mlirLocationGetAttribute(self); }, 3006 "Get the underlying LocationAttr") 3007 .def( 3008 "emit_error", 3009 [](PyLocation &self, std::string message) { 3010 mlirEmitError(self, message.c_str()); 3011 }, 3012 nb::arg("message"), "Emits an error at this location") 3013 .def("__repr__", [](PyLocation &self) { 3014 PyPrintAccumulator printAccum; 3015 mlirLocationPrint(self, printAccum.getCallback(), 3016 printAccum.getUserData()); 3017 return printAccum.join(); 3018 }); 3019 3020 //---------------------------------------------------------------------------- 3021 // Mapping of Module 3022 //---------------------------------------------------------------------------- 3023 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable()) 3024 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) 3025 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) 3026 .def_static( 3027 "parse", 3028 [](const std::string &moduleAsm, DefaultingPyMlirContext context) { 3029 PyMlirContext::ErrorCapture errors(context->getRef()); 3030 MlirModule module = mlirModuleCreateParse( 3031 context->get(), toMlirStringRef(moduleAsm)); 3032 if (mlirModuleIsNull(module)) 3033 throw MLIRError("Unable to parse module assembly", errors.take()); 3034 return PyModule::forModule(module).releaseObject(); 3035 }, 3036 nb::arg("asm"), nb::arg("context").none() = nb::none(), 3037 kModuleParseDocstring) 3038 .def_static( 3039 "parse", 3040 [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { 3041 PyMlirContext::ErrorCapture errors(context->getRef()); 3042 MlirModule module = mlirModuleCreateParse( 3043 context->get(), toMlirStringRef(moduleAsm)); 3044 if (mlirModuleIsNull(module)) 3045 throw MLIRError("Unable to parse module assembly", errors.take()); 3046 return PyModule::forModule(module).releaseObject(); 3047 }, 3048 nb::arg("asm"), nb::arg("context").none() = nb::none(), 3049 kModuleParseDocstring) 3050 .def_static( 3051 "create", 3052 [](DefaultingPyLocation loc) { 3053 MlirModule module = mlirModuleCreateEmpty(loc); 3054 return PyModule::forModule(module).releaseObject(); 3055 }, 3056 nb::arg("loc").none() = nb::none(), "Creates an empty module") 3057 .def_prop_ro( 3058 "context", 3059 [](PyModule &self) { return self.getContext().getObject(); }, 3060 "Context that created the Module") 3061 .def_prop_ro( 3062 "operation", 3063 [](PyModule &self) { 3064 return PyOperation::forOperation(self.getContext(), 3065 mlirModuleGetOperation(self.get()), 3066 self.getRef().releaseObject()) 3067 .releaseObject(); 3068 }, 3069 "Accesses the module as an operation") 3070 .def_prop_ro( 3071 "body", 3072 [](PyModule &self) { 3073 PyOperationRef moduleOp = PyOperation::forOperation( 3074 self.getContext(), mlirModuleGetOperation(self.get()), 3075 self.getRef().releaseObject()); 3076 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); 3077 return returnBlock; 3078 }, 3079 "Return the block for this module") 3080 .def( 3081 "dump", 3082 [](PyModule &self) { 3083 mlirOperationDump(mlirModuleGetOperation(self.get())); 3084 }, 3085 kDumpDocstring) 3086 .def( 3087 "__str__", 3088 [](nb::object self) { 3089 // Defer to the operation's __str__. 3090 return self.attr("operation").attr("__str__")(); 3091 }, 3092 kOperationStrDunderDocstring); 3093 3094 //---------------------------------------------------------------------------- 3095 // Mapping of Operation. 3096 //---------------------------------------------------------------------------- 3097 nb::class_<PyOperationBase>(m, "_OperationBase") 3098 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, 3099 [](PyOperationBase &self) { 3100 return self.getOperation().getCapsule(); 3101 }) 3102 .def("__eq__", 3103 [](PyOperationBase &self, PyOperationBase &other) { 3104 return &self.getOperation() == &other.getOperation(); 3105 }) 3106 .def("__eq__", 3107 [](PyOperationBase &self, nb::object other) { return false; }) 3108 .def("__hash__", 3109 [](PyOperationBase &self) { 3110 return static_cast<size_t>(llvm::hash_value(&self.getOperation())); 3111 }) 3112 .def_prop_ro("attributes", 3113 [](PyOperationBase &self) { 3114 return PyOpAttributeMap(self.getOperation().getRef()); 3115 }) 3116 .def_prop_ro( 3117 "context", 3118 [](PyOperationBase &self) { 3119 PyOperation &concreteOperation = self.getOperation(); 3120 concreteOperation.checkValid(); 3121 return concreteOperation.getContext().getObject(); 3122 }, 3123 "Context that owns the Operation") 3124 .def_prop_ro("name", 3125 [](PyOperationBase &self) { 3126 auto &concreteOperation = self.getOperation(); 3127 concreteOperation.checkValid(); 3128 MlirOperation operation = concreteOperation.get(); 3129 MlirStringRef name = 3130 mlirIdentifierStr(mlirOperationGetName(operation)); 3131 return nb::str(name.data, name.length); 3132 }) 3133 .def_prop_ro("operands", 3134 [](PyOperationBase &self) { 3135 return PyOpOperandList(self.getOperation().getRef()); 3136 }) 3137 .def_prop_ro("regions", 3138 [](PyOperationBase &self) { 3139 return PyRegionList(self.getOperation().getRef()); 3140 }) 3141 .def_prop_ro( 3142 "results", 3143 [](PyOperationBase &self) { 3144 return PyOpResultList(self.getOperation().getRef()); 3145 }, 3146 "Returns the list of Operation results.") 3147 .def_prop_ro( 3148 "result", 3149 [](PyOperationBase &self) { 3150 auto &operation = self.getOperation(); 3151 return PyOpResult(operation.getRef(), getUniqueResult(operation)) 3152 .maybeDownCast(); 3153 }, 3154 "Shortcut to get an op result if it has only one (throws an error " 3155 "otherwise).") 3156 .def_prop_ro( 3157 "location", 3158 [](PyOperationBase &self) { 3159 PyOperation &operation = self.getOperation(); 3160 return PyLocation(operation.getContext(), 3161 mlirOperationGetLocation(operation.get())); 3162 }, 3163 "Returns the source location the operation was defined or derived " 3164 "from.") 3165 .def_prop_ro("parent", 3166 [](PyOperationBase &self) -> nb::object { 3167 auto parent = self.getOperation().getParentOperation(); 3168 if (parent) 3169 return parent->getObject(); 3170 return nb::none(); 3171 }) 3172 .def( 3173 "__str__", 3174 [](PyOperationBase &self) { 3175 return self.getAsm(/*binary=*/false, 3176 /*largeElementsLimit=*/std::nullopt, 3177 /*enableDebugInfo=*/false, 3178 /*prettyDebugInfo=*/false, 3179 /*printGenericOpForm=*/false, 3180 /*useLocalScope=*/false, 3181 /*assumeVerified=*/false, 3182 /*skipRegions=*/false); 3183 }, 3184 "Returns the assembly form of the operation.") 3185 .def("print", 3186 nb::overload_cast<PyAsmState &, nb::object, bool>( 3187 &PyOperationBase::print), 3188 nb::arg("state"), nb::arg("file").none() = nb::none(), 3189 nb::arg("binary") = false, kOperationPrintStateDocstring) 3190 .def("print", 3191 nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool, 3192 bool, nb::object, bool, bool>( 3193 &PyOperationBase::print), 3194 // Careful: Lots of arguments must match up with print method. 3195 nb::arg("large_elements_limit").none() = nb::none(), 3196 nb::arg("enable_debug_info") = false, 3197 nb::arg("pretty_debug_info") = false, 3198 nb::arg("print_generic_op_form") = false, 3199 nb::arg("use_local_scope") = false, 3200 nb::arg("assume_verified") = false, 3201 nb::arg("file").none() = nb::none(), nb::arg("binary") = false, 3202 nb::arg("skip_regions") = false, kOperationPrintDocstring) 3203 .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), 3204 nb::arg("desired_version").none() = nb::none(), 3205 kOperationPrintBytecodeDocstring) 3206 .def("get_asm", &PyOperationBase::getAsm, 3207 // Careful: Lots of arguments must match up with get_asm method. 3208 nb::arg("binary") = false, 3209 nb::arg("large_elements_limit").none() = nb::none(), 3210 nb::arg("enable_debug_info") = false, 3211 nb::arg("pretty_debug_info") = false, 3212 nb::arg("print_generic_op_form") = false, 3213 nb::arg("use_local_scope") = false, 3214 nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, 3215 kOperationGetAsmDocstring) 3216 .def("verify", &PyOperationBase::verify, 3217 "Verify the operation. Raises MLIRError if verification fails, and " 3218 "returns true otherwise.") 3219 .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), 3220 "Puts self immediately after the other operation in its parent " 3221 "block.") 3222 .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), 3223 "Puts self immediately before the other operation in its parent " 3224 "block.") 3225 .def( 3226 "clone", 3227 [](PyOperationBase &self, nb::object ip) { 3228 return self.getOperation().clone(ip); 3229 }, 3230 nb::arg("ip").none() = nb::none()) 3231 .def( 3232 "detach_from_parent", 3233 [](PyOperationBase &self) { 3234 PyOperation &operation = self.getOperation(); 3235 operation.checkValid(); 3236 if (!operation.isAttached()) 3237 throw nb::value_error("Detached operation has no parent."); 3238 3239 operation.detachFromParent(); 3240 return operation.createOpView(); 3241 }, 3242 "Detaches the operation from its parent block.") 3243 .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) 3244 .def("walk", &PyOperationBase::walk, nb::arg("callback"), 3245 nb::arg("walk_order") = MlirWalkPostOrder); 3246 3247 nb::class_<PyOperation, PyOperationBase>(m, "Operation") 3248 .def_static( 3249 "create", 3250 [](std::string_view name, 3251 std::optional<std::vector<PyType *>> results, 3252 std::optional<std::vector<PyValue *>> operands, 3253 std::optional<nb::dict> attributes, 3254 std::optional<std::vector<PyBlock *>> successors, int regions, 3255 DefaultingPyLocation location, const nb::object &maybeIp, 3256 bool inferType) { 3257 // Unpack/validate operands. 3258 llvm::SmallVector<MlirValue, 4> mlirOperands; 3259 if (operands) { 3260 mlirOperands.reserve(operands->size()); 3261 for (PyValue *operand : *operands) { 3262 if (!operand) 3263 throw nb::value_error("operand value cannot be None"); 3264 mlirOperands.push_back(operand->get()); 3265 } 3266 } 3267 3268 return PyOperation::create(name, results, mlirOperands, attributes, 3269 successors, regions, location, maybeIp, 3270 inferType); 3271 }, 3272 nb::arg("name"), nb::arg("results").none() = nb::none(), 3273 nb::arg("operands").none() = nb::none(), 3274 nb::arg("attributes").none() = nb::none(), 3275 nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0, 3276 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), 3277 nb::arg("infer_type") = false, kOperationCreateDocstring) 3278 .def_static( 3279 "parse", 3280 [](const std::string &sourceStr, const std::string &sourceName, 3281 DefaultingPyMlirContext context) { 3282 return PyOperation::parse(context->getRef(), sourceStr, sourceName) 3283 ->createOpView(); 3284 }, 3285 nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", 3286 nb::arg("context").none() = nb::none(), 3287 "Parses an operation. Supports both text assembly format and binary " 3288 "bytecode format.") 3289 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) 3290 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) 3291 .def_prop_ro("operation", [](nb::object self) { return self; }) 3292 .def_prop_ro("opview", &PyOperation::createOpView) 3293 .def_prop_ro( 3294 "successors", 3295 [](PyOperationBase &self) { 3296 return PyOpSuccessors(self.getOperation().getRef()); 3297 }, 3298 "Returns the list of Operation successors."); 3299 3300 auto opViewClass = 3301 nb::class_<PyOpView, PyOperationBase>(m, "OpView") 3302 .def(nb::init<nb::object>(), nb::arg("operation")) 3303 .def( 3304 "__init__", 3305 [](PyOpView *self, std::string_view name, 3306 std::tuple<int, bool> opRegionSpec, 3307 nb::object operandSegmentSpecObj, 3308 nb::object resultSegmentSpecObj, 3309 std::optional<nb::list> resultTypeList, nb::list operandList, 3310 std::optional<nb::dict> attributes, 3311 std::optional<std::vector<PyBlock *>> successors, 3312 std::optional<int> regions, DefaultingPyLocation location, 3313 const nb::object &maybeIp) { 3314 new (self) PyOpView(PyOpView::buildGeneric( 3315 name, opRegionSpec, operandSegmentSpecObj, 3316 resultSegmentSpecObj, resultTypeList, operandList, 3317 attributes, successors, regions, location, maybeIp)); 3318 }, 3319 nb::arg("name"), nb::arg("opRegionSpec"), 3320 nb::arg("operandSegmentSpecObj").none() = nb::none(), 3321 nb::arg("resultSegmentSpecObj").none() = nb::none(), 3322 nb::arg("results").none() = nb::none(), 3323 nb::arg("operands").none() = nb::none(), 3324 nb::arg("attributes").none() = nb::none(), 3325 nb::arg("successors").none() = nb::none(), 3326 nb::arg("regions").none() = nb::none(), 3327 nb::arg("loc").none() = nb::none(), 3328 nb::arg("ip").none() = nb::none()) 3329 3330 .def_prop_ro("operation", &PyOpView::getOperationObject) 3331 .def_prop_ro("opview", [](nb::object self) { return self; }) 3332 .def( 3333 "__str__", 3334 [](PyOpView &self) { return nb::str(self.getOperationObject()); }) 3335 .def_prop_ro( 3336 "successors", 3337 [](PyOperationBase &self) { 3338 return PyOpSuccessors(self.getOperation().getRef()); 3339 }, 3340 "Returns the list of Operation successors."); 3341 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); 3342 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); 3343 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); 3344 // It is faster to pass the operation_name, ods_regions, and 3345 // ods_operand_segments/ods_result_segments as arguments to the constructor, 3346 // rather than to access them as attributes. 3347 opViewClass.attr("build_generic") = classmethod( 3348 [](nb::handle cls, std::optional<nb::list> resultTypeList, 3349 nb::list operandList, std::optional<nb::dict> attributes, 3350 std::optional<std::vector<PyBlock *>> successors, 3351 std::optional<int> regions, DefaultingPyLocation location, 3352 const nb::object &maybeIp) { 3353 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME")); 3354 std::tuple<int, bool> opRegionSpec = 3355 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); 3356 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); 3357 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); 3358 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, 3359 resultSegmentSpec, resultTypeList, 3360 operandList, attributes, successors, 3361 regions, location, maybeIp); 3362 }, 3363 nb::arg("cls"), nb::arg("results").none() = nb::none(), 3364 nb::arg("operands").none() = nb::none(), 3365 nb::arg("attributes").none() = nb::none(), 3366 nb::arg("successors").none() = nb::none(), 3367 nb::arg("regions").none() = nb::none(), 3368 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), 3369 "Builds a specific, generated OpView based on class level attributes."); 3370 opViewClass.attr("parse") = classmethod( 3371 [](const nb::object &cls, const std::string &sourceStr, 3372 const std::string &sourceName, DefaultingPyMlirContext context) { 3373 PyOperationRef parsed = 3374 PyOperation::parse(context->getRef(), sourceStr, sourceName); 3375 3376 // Check if the expected operation was parsed, and cast to to the 3377 // appropriate `OpView` subclass if successful. 3378 // NOTE: This accesses attributes that have been automatically added to 3379 // `OpView` subclasses, and is not intended to be used on `OpView` 3380 // directly. 3381 std::string clsOpName = 3382 nb::cast<std::string>(cls.attr("OPERATION_NAME")); 3383 MlirStringRef identifier = 3384 mlirIdentifierStr(mlirOperationGetName(*parsed.get())); 3385 std::string_view parsedOpName(identifier.data, identifier.length); 3386 if (clsOpName != parsedOpName) 3387 throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + 3388 parsedOpName + "'"); 3389 return PyOpView::constructDerived(cls, parsed.getObject()); 3390 }, 3391 nb::arg("cls"), nb::arg("source"), nb::kw_only(), 3392 nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), 3393 "Parses a specific, generated OpView based on class level attributes"); 3394 3395 //---------------------------------------------------------------------------- 3396 // Mapping of PyRegion. 3397 //---------------------------------------------------------------------------- 3398 nb::class_<PyRegion>(m, "Region") 3399 .def_prop_ro( 3400 "blocks", 3401 [](PyRegion &self) { 3402 return PyBlockList(self.getParentOperation(), self.get()); 3403 }, 3404 "Returns a forward-optimized sequence of blocks.") 3405 .def_prop_ro( 3406 "owner", 3407 [](PyRegion &self) { 3408 return self.getParentOperation()->createOpView(); 3409 }, 3410 "Returns the operation owning this region.") 3411 .def( 3412 "__iter__", 3413 [](PyRegion &self) { 3414 self.checkValid(); 3415 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); 3416 return PyBlockIterator(self.getParentOperation(), firstBlock); 3417 }, 3418 "Iterates over blocks in the region.") 3419 .def("__eq__", 3420 [](PyRegion &self, PyRegion &other) { 3421 return self.get().ptr == other.get().ptr; 3422 }) 3423 .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); 3424 3425 //---------------------------------------------------------------------------- 3426 // Mapping of PyBlock. 3427 //---------------------------------------------------------------------------- 3428 nb::class_<PyBlock>(m, "Block") 3429 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) 3430 .def_prop_ro( 3431 "owner", 3432 [](PyBlock &self) { 3433 return self.getParentOperation()->createOpView(); 3434 }, 3435 "Returns the owning operation of this block.") 3436 .def_prop_ro( 3437 "region", 3438 [](PyBlock &self) { 3439 MlirRegion region = mlirBlockGetParentRegion(self.get()); 3440 return PyRegion(self.getParentOperation(), region); 3441 }, 3442 "Returns the owning region of this block.") 3443 .def_prop_ro( 3444 "arguments", 3445 [](PyBlock &self) { 3446 return PyBlockArgumentList(self.getParentOperation(), self.get()); 3447 }, 3448 "Returns a list of block arguments.") 3449 .def( 3450 "add_argument", 3451 [](PyBlock &self, const PyType &type, const PyLocation &loc) { 3452 return mlirBlockAddArgument(self.get(), type, loc); 3453 }, 3454 "Append an argument of the specified type to the block and returns " 3455 "the newly added argument.") 3456 .def( 3457 "erase_argument", 3458 [](PyBlock &self, unsigned index) { 3459 return mlirBlockEraseArgument(self.get(), index); 3460 }, 3461 "Erase the argument at 'index' and remove it from the argument list.") 3462 .def_prop_ro( 3463 "operations", 3464 [](PyBlock &self) { 3465 return PyOperationList(self.getParentOperation(), self.get()); 3466 }, 3467 "Returns a forward-optimized sequence of operations.") 3468 .def_static( 3469 "create_at_start", 3470 [](PyRegion &parent, const nb::sequence &pyArgTypes, 3471 const std::optional<nb::sequence> &pyArgLocs) { 3472 parent.checkValid(); 3473 MlirBlock block = createBlock(pyArgTypes, pyArgLocs); 3474 mlirRegionInsertOwnedBlock(parent, 0, block); 3475 return PyBlock(parent.getParentOperation(), block); 3476 }, 3477 nb::arg("parent"), nb::arg("arg_types") = nb::list(), 3478 nb::arg("arg_locs") = std::nullopt, 3479 "Creates and returns a new Block at the beginning of the given " 3480 "region (with given argument types and locations).") 3481 .def( 3482 "append_to", 3483 [](PyBlock &self, PyRegion ®ion) { 3484 MlirBlock b = self.get(); 3485 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) 3486 mlirBlockDetach(b); 3487 mlirRegionAppendOwnedBlock(region.get(), b); 3488 }, 3489 "Append this block to a region, transferring ownership if necessary") 3490 .def( 3491 "create_before", 3492 [](PyBlock &self, const nb::args &pyArgTypes, 3493 const std::optional<nb::sequence> &pyArgLocs) { 3494 self.checkValid(); 3495 MlirBlock block = 3496 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); 3497 MlirRegion region = mlirBlockGetParentRegion(self.get()); 3498 mlirRegionInsertOwnedBlockBefore(region, self.get(), block); 3499 return PyBlock(self.getParentOperation(), block); 3500 }, 3501 nb::arg("arg_types"), nb::kw_only(), 3502 nb::arg("arg_locs") = std::nullopt, 3503 "Creates and returns a new Block before this block " 3504 "(with given argument types and locations).") 3505 .def( 3506 "create_after", 3507 [](PyBlock &self, const nb::args &pyArgTypes, 3508 const std::optional<nb::sequence> &pyArgLocs) { 3509 self.checkValid(); 3510 MlirBlock block = 3511 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs); 3512 MlirRegion region = mlirBlockGetParentRegion(self.get()); 3513 mlirRegionInsertOwnedBlockAfter(region, self.get(), block); 3514 return PyBlock(self.getParentOperation(), block); 3515 }, 3516 nb::arg("arg_types"), nb::kw_only(), 3517 nb::arg("arg_locs") = std::nullopt, 3518 "Creates and returns a new Block after this block " 3519 "(with given argument types and locations).") 3520 .def( 3521 "__iter__", 3522 [](PyBlock &self) { 3523 self.checkValid(); 3524 MlirOperation firstOperation = 3525 mlirBlockGetFirstOperation(self.get()); 3526 return PyOperationIterator(self.getParentOperation(), 3527 firstOperation); 3528 }, 3529 "Iterates over operations in the block.") 3530 .def("__eq__", 3531 [](PyBlock &self, PyBlock &other) { 3532 return self.get().ptr == other.get().ptr; 3533 }) 3534 .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) 3535 .def("__hash__", 3536 [](PyBlock &self) { 3537 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3538 }) 3539 .def( 3540 "__str__", 3541 [](PyBlock &self) { 3542 self.checkValid(); 3543 PyPrintAccumulator printAccum; 3544 mlirBlockPrint(self.get(), printAccum.getCallback(), 3545 printAccum.getUserData()); 3546 return printAccum.join(); 3547 }, 3548 "Returns the assembly form of the block.") 3549 .def( 3550 "append", 3551 [](PyBlock &self, PyOperationBase &operation) { 3552 if (operation.getOperation().isAttached()) 3553 operation.getOperation().detachFromParent(); 3554 3555 MlirOperation mlirOperation = operation.getOperation().get(); 3556 mlirBlockAppendOwnedOperation(self.get(), mlirOperation); 3557 operation.getOperation().setAttached( 3558 self.getParentOperation().getObject()); 3559 }, 3560 nb::arg("operation"), 3561 "Appends an operation to this block. If the operation is currently " 3562 "in another block, it will be moved."); 3563 3564 //---------------------------------------------------------------------------- 3565 // Mapping of PyInsertionPoint. 3566 //---------------------------------------------------------------------------- 3567 3568 nb::class_<PyInsertionPoint>(m, "InsertionPoint") 3569 .def(nb::init<PyBlock &>(), nb::arg("block"), 3570 "Inserts after the last operation but still inside the block.") 3571 .def("__enter__", &PyInsertionPoint::contextEnter) 3572 .def("__exit__", &PyInsertionPoint::contextExit, 3573 nb::arg("exc_type").none(), nb::arg("exc_value").none(), 3574 nb::arg("traceback").none()) 3575 .def_prop_ro_static( 3576 "current", 3577 [](nb::object & /*class*/) { 3578 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); 3579 if (!ip) 3580 throw nb::value_error("No current InsertionPoint"); 3581 return ip; 3582 }, 3583 "Gets the InsertionPoint bound to the current thread or raises " 3584 "ValueError if none has been set") 3585 .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"), 3586 "Inserts before a referenced operation.") 3587 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, 3588 nb::arg("block"), "Inserts at the beginning of the block.") 3589 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, 3590 nb::arg("block"), "Inserts before the block terminator.") 3591 .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), 3592 "Inserts an operation.") 3593 .def_prop_ro( 3594 "block", [](PyInsertionPoint &self) { return self.getBlock(); }, 3595 "Returns the block that this InsertionPoint points to.") 3596 .def_prop_ro( 3597 "ref_operation", 3598 [](PyInsertionPoint &self) -> nb::object { 3599 auto refOperation = self.getRefOperation(); 3600 if (refOperation) 3601 return refOperation->getObject(); 3602 return nb::none(); 3603 }, 3604 "The reference operation before which new operations are " 3605 "inserted, or None if the insertion point is at the end of " 3606 "the block"); 3607 3608 //---------------------------------------------------------------------------- 3609 // Mapping of PyAttribute. 3610 //---------------------------------------------------------------------------- 3611 nb::class_<PyAttribute>(m, "Attribute") 3612 // Delegate to the PyAttribute copy constructor, which will also lifetime 3613 // extend the backing context which owns the MlirAttribute. 3614 .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"), 3615 "Casts the passed attribute to the generic Attribute") 3616 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) 3617 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) 3618 .def_static( 3619 "parse", 3620 [](const std::string &attrSpec, DefaultingPyMlirContext context) { 3621 PyMlirContext::ErrorCapture errors(context->getRef()); 3622 MlirAttribute attr = mlirAttributeParseGet( 3623 context->get(), toMlirStringRef(attrSpec)); 3624 if (mlirAttributeIsNull(attr)) 3625 throw MLIRError("Unable to parse attribute", errors.take()); 3626 return attr; 3627 }, 3628 nb::arg("asm"), nb::arg("context").none() = nb::none(), 3629 "Parses an attribute from an assembly form. Raises an MLIRError on " 3630 "failure.") 3631 .def_prop_ro( 3632 "context", 3633 [](PyAttribute &self) { return self.getContext().getObject(); }, 3634 "Context that owns the Attribute") 3635 .def_prop_ro("type", 3636 [](PyAttribute &self) { return mlirAttributeGetType(self); }) 3637 .def( 3638 "get_named", 3639 [](PyAttribute &self, std::string name) { 3640 return PyNamedAttribute(self, std::move(name)); 3641 }, 3642 nb::keep_alive<0, 1>(), "Binds a name to the attribute") 3643 .def("__eq__", 3644 [](PyAttribute &self, PyAttribute &other) { return self == other; }) 3645 .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) 3646 .def("__hash__", 3647 [](PyAttribute &self) { 3648 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3649 }) 3650 .def( 3651 "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, 3652 kDumpDocstring) 3653 .def( 3654 "__str__", 3655 [](PyAttribute &self) { 3656 PyPrintAccumulator printAccum; 3657 mlirAttributePrint(self, printAccum.getCallback(), 3658 printAccum.getUserData()); 3659 return printAccum.join(); 3660 }, 3661 "Returns the assembly form of the Attribute.") 3662 .def("__repr__", 3663 [](PyAttribute &self) { 3664 // Generally, assembly formats are not printed for __repr__ because 3665 // this can cause exceptionally long debug output and exceptions. 3666 // However, attribute values are generally considered useful and 3667 // are printed. This may need to be re-evaluated if debug dumps end 3668 // up being excessive. 3669 PyPrintAccumulator printAccum; 3670 printAccum.parts.append("Attribute("); 3671 mlirAttributePrint(self, printAccum.getCallback(), 3672 printAccum.getUserData()); 3673 printAccum.parts.append(")"); 3674 return printAccum.join(); 3675 }) 3676 .def_prop_ro("typeid", 3677 [](PyAttribute &self) -> MlirTypeID { 3678 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); 3679 assert(!mlirTypeIDIsNull(mlirTypeID) && 3680 "mlirTypeID was expected to be non-null."); 3681 return mlirTypeID; 3682 }) 3683 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { 3684 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); 3685 assert(!mlirTypeIDIsNull(mlirTypeID) && 3686 "mlirTypeID was expected to be non-null."); 3687 std::optional<nb::callable> typeCaster = 3688 PyGlobals::get().lookupTypeCaster(mlirTypeID, 3689 mlirAttributeGetDialect(self)); 3690 if (!typeCaster) 3691 return nb::cast(self); 3692 return typeCaster.value()(self); 3693 }); 3694 3695 //---------------------------------------------------------------------------- 3696 // Mapping of PyNamedAttribute 3697 //---------------------------------------------------------------------------- 3698 nb::class_<PyNamedAttribute>(m, "NamedAttribute") 3699 .def("__repr__", 3700 [](PyNamedAttribute &self) { 3701 PyPrintAccumulator printAccum; 3702 printAccum.parts.append("NamedAttribute("); 3703 printAccum.parts.append( 3704 nb::str(mlirIdentifierStr(self.namedAttr.name).data, 3705 mlirIdentifierStr(self.namedAttr.name).length)); 3706 printAccum.parts.append("="); 3707 mlirAttributePrint(self.namedAttr.attribute, 3708 printAccum.getCallback(), 3709 printAccum.getUserData()); 3710 printAccum.parts.append(")"); 3711 return printAccum.join(); 3712 }) 3713 .def_prop_ro( 3714 "name", 3715 [](PyNamedAttribute &self) { 3716 return nb::str(mlirIdentifierStr(self.namedAttr.name).data, 3717 mlirIdentifierStr(self.namedAttr.name).length); 3718 }, 3719 "The name of the NamedAttribute binding") 3720 .def_prop_ro( 3721 "attr", 3722 [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, 3723 nb::keep_alive<0, 1>(), 3724 "The underlying generic attribute of the NamedAttribute binding"); 3725 3726 //---------------------------------------------------------------------------- 3727 // Mapping of PyType. 3728 //---------------------------------------------------------------------------- 3729 nb::class_<PyType>(m, "Type") 3730 // Delegate to the PyType copy constructor, which will also lifetime 3731 // extend the backing context which owns the MlirType. 3732 .def(nb::init<PyType &>(), nb::arg("cast_from_type"), 3733 "Casts the passed type to the generic Type") 3734 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) 3735 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) 3736 .def_static( 3737 "parse", 3738 [](std::string typeSpec, DefaultingPyMlirContext context) { 3739 PyMlirContext::ErrorCapture errors(context->getRef()); 3740 MlirType type = 3741 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); 3742 if (mlirTypeIsNull(type)) 3743 throw MLIRError("Unable to parse type", errors.take()); 3744 return type; 3745 }, 3746 nb::arg("asm"), nb::arg("context").none() = nb::none(), 3747 kContextParseTypeDocstring) 3748 .def_prop_ro( 3749 "context", [](PyType &self) { return self.getContext().getObject(); }, 3750 "Context that owns the Type") 3751 .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) 3752 .def( 3753 "__eq__", [](PyType &self, nb::object &other) { return false; }, 3754 nb::arg("other").none()) 3755 .def("__hash__", 3756 [](PyType &self) { 3757 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3758 }) 3759 .def( 3760 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) 3761 .def( 3762 "__str__", 3763 [](PyType &self) { 3764 PyPrintAccumulator printAccum; 3765 mlirTypePrint(self, printAccum.getCallback(), 3766 printAccum.getUserData()); 3767 return printAccum.join(); 3768 }, 3769 "Returns the assembly form of the type.") 3770 .def("__repr__", 3771 [](PyType &self) { 3772 // Generally, assembly formats are not printed for __repr__ because 3773 // this can cause exceptionally long debug output and exceptions. 3774 // However, types are an exception as they typically have compact 3775 // assembly forms and printing them is useful. 3776 PyPrintAccumulator printAccum; 3777 printAccum.parts.append("Type("); 3778 mlirTypePrint(self, printAccum.getCallback(), 3779 printAccum.getUserData()); 3780 printAccum.parts.append(")"); 3781 return printAccum.join(); 3782 }) 3783 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, 3784 [](PyType &self) { 3785 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); 3786 assert(!mlirTypeIDIsNull(mlirTypeID) && 3787 "mlirTypeID was expected to be non-null."); 3788 std::optional<nb::callable> typeCaster = 3789 PyGlobals::get().lookupTypeCaster(mlirTypeID, 3790 mlirTypeGetDialect(self)); 3791 if (!typeCaster) 3792 return nb::cast(self); 3793 return typeCaster.value()(self); 3794 }) 3795 .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { 3796 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); 3797 if (!mlirTypeIDIsNull(mlirTypeID)) 3798 return mlirTypeID; 3799 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self))); 3800 throw nb::value_error( 3801 (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); 3802 }); 3803 3804 //---------------------------------------------------------------------------- 3805 // Mapping of PyTypeID. 3806 //---------------------------------------------------------------------------- 3807 nb::class_<PyTypeID>(m, "TypeID") 3808 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) 3809 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) 3810 // Note, this tests whether the underlying TypeIDs are the same, 3811 // not whether the wrapper MlirTypeIDs are the same, nor whether 3812 // the Python objects are the same (i.e., PyTypeID is a value type). 3813 .def("__eq__", 3814 [](PyTypeID &self, PyTypeID &other) { return self == other; }) 3815 .def("__eq__", 3816 [](PyTypeID &self, const nb::object &other) { return false; }) 3817 // Note, this gives the hash value of the underlying TypeID, not the 3818 // hash value of the Python object, nor the hash value of the 3819 // MlirTypeID wrapper. 3820 .def("__hash__", [](PyTypeID &self) { 3821 return static_cast<size_t>(mlirTypeIDHashValue(self)); 3822 }); 3823 3824 //---------------------------------------------------------------------------- 3825 // Mapping of Value. 3826 //---------------------------------------------------------------------------- 3827 nb::class_<PyValue>(m, "Value") 3828 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value")) 3829 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) 3830 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) 3831 .def_prop_ro( 3832 "context", 3833 [](PyValue &self) { return self.getParentOperation()->getContext(); }, 3834 "Context in which the value lives.") 3835 .def( 3836 "dump", [](PyValue &self) { mlirValueDump(self.get()); }, 3837 kDumpDocstring) 3838 .def_prop_ro( 3839 "owner", 3840 [](PyValue &self) -> nb::object { 3841 MlirValue v = self.get(); 3842 if (mlirValueIsAOpResult(v)) { 3843 assert( 3844 mlirOperationEqual(self.getParentOperation()->get(), 3845 mlirOpResultGetOwner(self.get())) && 3846 "expected the owner of the value in Python to match that in " 3847 "the IR"); 3848 return self.getParentOperation().getObject(); 3849 } 3850 3851 if (mlirValueIsABlockArgument(v)) { 3852 MlirBlock block = mlirBlockArgumentGetOwner(self.get()); 3853 return nb::cast(PyBlock(self.getParentOperation(), block)); 3854 } 3855 3856 assert(false && "Value must be a block argument or an op result"); 3857 return nb::none(); 3858 }) 3859 .def_prop_ro("uses", 3860 [](PyValue &self) { 3861 return PyOpOperandIterator( 3862 mlirValueGetFirstUse(self.get())); 3863 }) 3864 .def("__eq__", 3865 [](PyValue &self, PyValue &other) { 3866 return self.get().ptr == other.get().ptr; 3867 }) 3868 .def("__eq__", [](PyValue &self, nb::object other) { return false; }) 3869 .def("__hash__", 3870 [](PyValue &self) { 3871 return static_cast<size_t>(llvm::hash_value(self.get().ptr)); 3872 }) 3873 .def( 3874 "__str__", 3875 [](PyValue &self) { 3876 PyPrintAccumulator printAccum; 3877 printAccum.parts.append("Value("); 3878 mlirValuePrint(self.get(), printAccum.getCallback(), 3879 printAccum.getUserData()); 3880 printAccum.parts.append(")"); 3881 return printAccum.join(); 3882 }, 3883 kValueDunderStrDocstring) 3884 .def( 3885 "get_name", 3886 [](PyValue &self, bool useLocalScope) { 3887 PyPrintAccumulator printAccum; 3888 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); 3889 if (useLocalScope) 3890 mlirOpPrintingFlagsUseLocalScope(flags); 3891 MlirAsmState valueState = 3892 mlirAsmStateCreateForValue(self.get(), flags); 3893 mlirValuePrintAsOperand(self.get(), valueState, 3894 printAccum.getCallback(), 3895 printAccum.getUserData()); 3896 mlirOpPrintingFlagsDestroy(flags); 3897 mlirAsmStateDestroy(valueState); 3898 return printAccum.join(); 3899 }, 3900 nb::arg("use_local_scope") = false) 3901 .def( 3902 "get_name", 3903 [](PyValue &self, PyAsmState &state) { 3904 PyPrintAccumulator printAccum; 3905 MlirAsmState valueState = state.get(); 3906 mlirValuePrintAsOperand(self.get(), valueState, 3907 printAccum.getCallback(), 3908 printAccum.getUserData()); 3909 return printAccum.join(); 3910 }, 3911 nb::arg("state"), kGetNameAsOperand) 3912 .def_prop_ro("type", 3913 [](PyValue &self) { return mlirValueGetType(self.get()); }) 3914 .def( 3915 "set_type", 3916 [](PyValue &self, const PyType &type) { 3917 return mlirValueSetType(self.get(), type); 3918 }, 3919 nb::arg("type")) 3920 .def( 3921 "replace_all_uses_with", 3922 [](PyValue &self, PyValue &with) { 3923 mlirValueReplaceAllUsesOfWith(self.get(), with.get()); 3924 }, 3925 kValueReplaceAllUsesWithDocstring) 3926 .def( 3927 "replace_all_uses_except", 3928 [](MlirValue self, MlirValue with, PyOperation &exception) { 3929 MlirOperation exceptedUser = exception.get(); 3930 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); 3931 }, 3932 nb::arg("with"), nb::arg("exceptions"), 3933 kValueReplaceAllUsesExceptDocstring) 3934 .def( 3935 "replace_all_uses_except", 3936 [](MlirValue self, MlirValue with, nb::list exceptions) { 3937 // Convert Python list to a SmallVector of MlirOperations 3938 llvm::SmallVector<MlirOperation> exceptionOps; 3939 for (nb::handle exception : exceptions) { 3940 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get()); 3941 } 3942 3943 mlirValueReplaceAllUsesExcept( 3944 self, with, static_cast<intptr_t>(exceptionOps.size()), 3945 exceptionOps.data()); 3946 }, 3947 nb::arg("with"), nb::arg("exceptions"), 3948 kValueReplaceAllUsesExceptDocstring) 3949 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, 3950 [](PyValue &self) { return self.maybeDownCast(); }); 3951 PyBlockArgument::bind(m); 3952 PyOpResult::bind(m); 3953 PyOpOperand::bind(m); 3954 3955 nb::class_<PyAsmState>(m, "AsmState") 3956 .def(nb::init<PyValue &, bool>(), nb::arg("value"), 3957 nb::arg("use_local_scope") = false) 3958 .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"), 3959 nb::arg("use_local_scope") = false); 3960 3961 //---------------------------------------------------------------------------- 3962 // Mapping of SymbolTable. 3963 //---------------------------------------------------------------------------- 3964 nb::class_<PySymbolTable>(m, "SymbolTable") 3965 .def(nb::init<PyOperationBase &>()) 3966 .def("__getitem__", &PySymbolTable::dunderGetItem) 3967 .def("insert", &PySymbolTable::insert, nb::arg("operation")) 3968 .def("erase", &PySymbolTable::erase, nb::arg("operation")) 3969 .def("__delitem__", &PySymbolTable::dunderDel) 3970 .def("__contains__", 3971 [](PySymbolTable &table, const std::string &name) { 3972 return !mlirOperationIsNull(mlirSymbolTableLookup( 3973 table, mlirStringRefCreate(name.data(), name.length()))); 3974 }) 3975 // Static helpers. 3976 .def_static("set_symbol_name", &PySymbolTable::setSymbolName, 3977 nb::arg("symbol"), nb::arg("name")) 3978 .def_static("get_symbol_name", &PySymbolTable::getSymbolName, 3979 nb::arg("symbol")) 3980 .def_static("get_visibility", &PySymbolTable::getVisibility, 3981 nb::arg("symbol")) 3982 .def_static("set_visibility", &PySymbolTable::setVisibility, 3983 nb::arg("symbol"), nb::arg("visibility")) 3984 .def_static("replace_all_symbol_uses", 3985 &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), 3986 nb::arg("new_symbol"), nb::arg("from_op")) 3987 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, 3988 nb::arg("from_op"), nb::arg("all_sym_uses_visible"), 3989 nb::arg("callback")); 3990 3991 // Container bindings. 3992 PyBlockArgumentList::bind(m); 3993 PyBlockIterator::bind(m); 3994 PyBlockList::bind(m); 3995 PyOperationIterator::bind(m); 3996 PyOperationList::bind(m); 3997 PyOpAttributeMap::bind(m); 3998 PyOpOperandIterator::bind(m); 3999 PyOpOperandList::bind(m); 4000 PyOpResultList::bind(m); 4001 PyOpSuccessors::bind(m); 4002 PyRegionIterator::bind(m); 4003 PyRegionList::bind(m); 4004 4005 // Debug bindings. 4006 PyGlobalDebugFlag::bind(m); 4007 4008 // Attribute builder getter. 4009 PyAttrBuilderMap::bind(m); 4010 4011 nb::register_exception_translator([](const std::exception_ptr &p, 4012 void *payload) { 4013 // We can't define exceptions with custom fields through pybind, so instead 4014 // the exception class is defined in python and imported here. 4015 try { 4016 if (p) 4017 std::rethrow_exception(p); 4018 } catch (const MLIRError &e) { 4019 nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) 4020 .attr("MLIRError")(e.message, e.errorDiagnostics); 4021 PyErr_SetObject(PyExc_Exception, obj.ptr()); 4022 } 4023 }); 4024 } 4025